Wednesday, September 30, 2015

Symbolic differentiation to the rescue! ... Or not.

I'm still playing with my LSTM networks, inspired by a few blog posts about their "unreasonable efficiency". I spent a lot of time messing with genetic algorithms, but then I came back to more predictable methods, namely gradient descent. I was trying to optimize the performance of the cost function I use via AD (even asking help on stack overflow) when I stumbled across a couple of blog posts on symbolic differentiation (here and here). The last one, combining automatic and symbolic differentiation, struck a chord. If my differentiation calculation was taking so much time to calculate, could I just not calculate it once with symbolic expressions, then close the resulting expression over my variables (my LSTM network) repeatedly while applying the gradients. I should only pay the price for the derivation once!

So I extended the data type suggested in the blog post to include all operations I was using in my function, manage to sort out all the types and verify via a few tests that it would work. I had great hopes! But as soon as I started testing on a real LSTM, the code just crawled to a halt. I even thought I had some infinite loop, maybe some recursion on the expression, but testing more thoroughly showed that it was the sheer size of the generated expression that was the issue. A LSTM of 2 cells is represented in the cost function used by AD as an array of 44 doubles, so basically for a LSTM of 2 cells, I'll have 44 variables in my gradient expression. My simple test that tries to use a LSTM to generate the string "hello world!" uses 9 cells (9 different characters in the string) , which is 702 variables. Even printing the expression takes forever. Running it through a simplifying step takes also forever. So my idea was not as good as it first looked, but it was fun testing it!

The code for my expressions can be found here, and the code doing the gradient descent via the symbolic differentiation is here. All of that looks probably very naive for you calculus and machine learning experts but hey, I'm a human learning...  If everybody has any idea to speed up my code, I'l happily listen!