Epic Fail #2: Assuming the function that throws the exception is where the error is in my code
For neural networks class, a major component of the second assignment was to build and train a recurrent architecture. At this point, I haven’t really had a tonne of experience with the TensorFlow framework. Although I’ve implemented some simple networks, I haven’t done something involving recurrent architectures before, so I consulted various tutorials.
With these tutorials, my overall approach of starting with a rough sketch in code I think makes sense when learning a new framework: You build a rough version that seems to represent what you aim to do, then you iron it out to get your first super-basic running version. Figure out what it’s doing, then iteratively build on top of it. Ideally, this would mean going back to the framework documentation and seeing if all the pieces work as you expect.
Unfortunately, somewhere along the way, I had the brilliant idea of blindly using a TensorFlow function without reading the documentation, and continuing on to assume it to be correct. Testing raised an exception.
And on debugging that exception, I made the assumption that if a function call doesn’t raise an exception, then I used the function correctly.
Working from these assumptions, I wasted a lot of time attempting various things and reading different tutorials. With nothing working, I then moved on to reviewing all the relevant theory from the lectures to ensure I completely understood everything. But even after that, I couldn’t figure out what was wrong with the code, so I spent more time reading various articles on the design of the TensorFlow framework. Nothing gave me an answer, which left me beyond frustrated.
But as it turned out, the function call raising the exception was fine. The problem was that I used another function incorrectly in an earlier part of the code. Despite using that function incorrectly, no exception was raised until later on.
The Problematic Code
In the version before fixing the error, I had the following two lines as part of the graph definition code:
lstm = tf.contrib.rnn.BasicLSTMCell([BATCH_SIZE, WORD_COUNT])
rnn_outputs, states = tf.nn.dynamic_rnn(lstm, input_data, dtype=tf.float32)
When run, the first line seems to run fine while the second line raises an exception that seems unrelated to my own code, and rather a bug in TensorFlow:
Traceback (most recent call last):
File "train.py", line 41, in <module>
imp.define_graph(glove_array)
File "/home/simshadows/git/cs9444_coursework/asst2stage2-sentiment-classifier/implementation.py", line 154, in define_graph
rnn_outputs, states = tf.nn.dynamic_rnn(lstm, input_data_expanded, dtype=tf.float32) # , time_major=False
File "/home/simshadows/.local/lib/python3.6/site-packages/tensorflow/python/ops/rnn.py", line 598, in dynamic_rnn
dtype=dtype)
File "/home/simshadows/.local/lib/python3.6/site-packages/tensorflow/python/ops/rnn.py", line 761, in _dynamic_rnn_loop
swap_memory=swap_memory)
File "/home/simshadows/.local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2775, in while_loop
result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
File "/home/simshadows/.local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2604, in BuildLoop
pred, body, original_loop_vars, loop_vars, shape_invariants)
File "/home/simshadows/.local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2554, in _BuildLoop
body_result = body(*packed_vars_for_body)
File "/home/simshadows/.local/lib/python3.6/site-packages/tensorflow/python/ops/rnn.py", line 746, in _time_step
(output, new_state) = call_cell()
File "/home/simshadows/.local/lib/python3.6/site-packages/tensorflow/python/ops/rnn.py", line 732, in <lambda>
call_cell = lambda: cell(input_t, state)
File "/home/simshadows/.local/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 180, in __call__
return super(RNNCell, self).__call__(inputs, state)
File "/home/simshadows/.local/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 450, in __call__
outputs = self.call(inputs, *args, **kwargs)
File "/home/simshadows/.local/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 401, in call
concat = _linear([inputs, h], 4 * self._num_units, True)
File "/home/simshadows/.local/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1021, in _linear
shapes = [a.get_shape() for a in args]
File "/home/simshadows/.local/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1021, in <listcomp>
shapes = [a.get_shape() for a in args]
AttributeError: 'list' object has no attribute 'get_shape'
The problem in the code above is that tf.contrib.rnn.BasicLSTMCell()
takes an int
as the first positional argument, not a list
. (See the __init__
arguments in this page of the documentation.)
The version after fixing the error (simplified) looks like the following:
lstm = tf.contrib.rnn.BasicLSTMCell(HIDDEN_SIZE)
rnn_outputs, states = tf.nn.dynamic_rnn(lstm, input_data, dtype=tf.float32)
An Explanation
With duck-typing in Python, that previous function call didn’t raise an exception. This was because the operations performed with the bad argument value were by chance supported by that bad argument value.
However, having the bad argument work doesn’t mean we can get away with it. The bad function call can break a class invariant, meaning that the state of the object is no longer guaranteed to be valid.
With the state of the object now potentially invalid, the behaviour of operations on or with that object will now be undefined, and can raise weird exceptions that can be difficult to trace back to the bad function call.
To illustrate what this means, let’s look at a simple example.
Example: Incrementor
Consider the following class definition:
class Incrementor:
def __init__(self, value):
# value is an int.
self.value = value
def change_value(self, value):
# value is an int.
self.value = value
def inc_and_print(self):
self.value += 1
print(str(self.value))
Added above in comments are the documented usages of the methods. Particularly, these methods require the argument to be an int
.
And now consider the following sequence of calls:
>>> x = Incrementor(4)
>>> x.inc_and_print()
5
>>> x.inc_and_print()
6
Looks good so far. Note that the constructor call Incrementor(4)
followed the documented usage by passing 4
as the argument, which is an int
.
Now consider the following sequence of calls, continuing on from above:
>>> x.change_value("twelve")
>>> x.inc_and_print()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/simshadows/work/tmp/myincrementor.py", line 10, in inc_and_print
self.value += 1
TypeError: must be str, not int
The first statement x.change_value("twelve")
was a method call that passed a str
instead of an int
. Despite deviating from the documentation, the method call succeeded because the operations it performed were supported by the str
(since there was nothing more than the value assignment self.value = value
).
However, x.change_value("twelve")
broke an undocumented invariant that the author of the class definition relied on. In particular, the author of the class relied on self.value
being an int
.
The next statement x.inc_and_print()
then raised an exception because it needed to do something which relied particularly on this invariant.
What about statically typed languages?
Indeed, I mentioned duck typing earlier, and it may seem that perhaps this could be solved by instead using a statically typed language such as Java
. And indeed, that will help. By having the compiler check on compile time that types are as expected, a great many such errors can be caught at compile time.
However, it’s not a perfect solution because preconditions to functions can go beyond simply type-checking. For example, a function might expect to be passed list with more than 3 elements.
Perhaps I might write another post some day going into more depth on this since this is a rabbithole into software design which goes beyond the intended scope of this blog post.