What is the Appropriate Loss Function for Modeling Neural Responses?

This question has bothered me occasionally over the years and I’ve never really come to what I felt was a satisfactory answer. The choice of a loss function is really a choice about the assumed noise distribution. A mean squared error loss function, for example, corresponds to an assumption that the noise in the system to be modeled is Gaussian. Noise in the firing rate of neurons is often assumed to follow a Poisson distribution.

P\left( x;\lambda \right) =\dfrac {\lambda ^{x}e^{-\lambda }} {x!}

In the Bayesian framework, the noise distribution is also known as the likelihood. Converting a likelihood to a loss function is just a matter of taking the negative log of the the likelihood distribution: loss=-log\left( P\left( x;\lambda \right) \right)

Writing this in terms of the measured firing rate, r, and the predicted firing rate, \hat {r}, and dropping constant terms, we get the loss function typically used to fit models to neural data:

r=x,\hat {r}=\lambda
loss\left( r,\hat {r}\right) =\hat {r}-r\log \left(\hat {r}\right)

The use of the Poisson distribution as a noise model carries with it the assumption that within a given time window, the variance of the firing rate is equal to the mean of the firing rate. The parameter that determines the mean and variance in the Poisson distribution is \lambda. So the Poisson distribution says that the higher the average firing rate in response to a given stimulus, the higher the variance of the firing rate in response to that stimulus. Yet because signal-to-noise (SNR) is determined by the ratio of the mean to the standard deviation, SNR increases as the mean firing rate \lambda increases:

SNR=\dfrac {\mu } {\sigma }=\dfrac {\lambda } {\sqrt {\lambda }}

This all makes sense, we would want and expect higher firing rates of neurons to have higher SNR, otherwise it would be hard to transmit information and would be a tremendous waste of energy. But the relevant questions are:
1) Is the mean and variance of the firing rates of real neurons coupled as predicted by the Poisson distribution?
2) And if not, what are the consequences of the difference and how can we fix it?

The answer to the first question is a definite NO. For neurons driven by natural stimuli, their responses are clearly sub-Poisson; the variance of the firing rate is lower than the mean of the firing rate. For example, here are plots of the mean vs the variance of firing rates in each of 2000 16.6ms bins for two neurons in V1. The mean and variance seem to be linearly related, but the slope is clearly less than 1. The slope in each plot is essentially the Fano Factor for each neuron. The parameter \alpha, which we will make use of shortly, is the inverse of the Fano Factor, estimated by taking the median of the \dfrac{mean}{variance} across all bins.

So, higher firing rates have even higher SNR than predicted by the Poisson distribution. This is good for energy use in the brain, but seems to be bad for the standard Poisson loss function. The Poisson loss function may not give enough credence to the SNR of high firing rates. Models fit with the Poisson loss function could thus be more influenced by lower firing rate time bins and less influenced by higher firing rate time bins than one would want, given that the noise is actually sub-Poisson.

So how can we fix this? Can we modify the Poisson distribution with an extra parameter,  \alpha, so that the variance is coupled to the mean like:
E\left( X\right) =\lambda
Var\left( X\right) =\dfrac {\lambda}{\alpha}

After playing with the equation for the Poisson distribution for a while, I worked out this approximate distribution function:

P\left( x;\lambda ,\alpha \right) \approx \dfrac {\alpha \left( \alpha \lambda \right) ^{\alpha x}e^{-\alpha \lambda }} {\Gamma \left( \alpha x+1\right) }

This function isn’t an exact distribution function, it doesn’t sum to exactly 1 for every setting of \alpha and \lambda, but it is a very good approximation and has the properties E\left( X\right)\approx\lambda andVar\left( X\right) \approx \dfrac {\lambda}{\alpha}. Furthermore, it is easy to see that when \alpha=1 this function reduces to the Poisson distribution function, since \Gamma \left( x+1\right) =x! for all positive integers x.

This is nice, but it seemed a bit contrived to me. While trying to find some connection to the literature, I came across this paper which generalizes the Poisson distribution using the Mittag-Leffler function:

E_{\alpha ,\beta }\left( \lambda \right) =\sum _{k=0}^{\infty }\dfrac {\lambda ^{k}} {\Gamma \left( \alpha k+\beta \right) } 

The Mittag-Leffler function generalizes the exponential function; the exponential corresponds to setting \alpha = \beta =1. And so to generalize the Poisson distribution, the authors proposed:

P\left( x;\lambda, \alpha, \beta \right) =\dfrac {\lambda ^{x}} {E_{\alpha, \beta}(\lambda)\Gamma(\alpha x + \beta)}

The formulation, however, breaks the connection between \lambda and E\left( X\right) as well as any simple connection to Var\left( X\right). I found, however, that by making the replacement \lambda \rightarrow \left( \alpha \lambda \right) ^{\alpha }, this restores the connections E\left( X\right)\approx\lambda and Var\left( X\right) \approx \dfrac {\lambda}{\alpha}, at least when \beta=1. I therefore propose this distribution as a more user friendly generalization of the Poisson distribution, using the Mittag-Leffler function:

P\left( x;\lambda, \alpha \right) =\dfrac {(\alpha \lambda)^{\alpha x}} {E_{\alpha, 1}((\alpha \lambda)^{\alpha})\Gamma(\alpha x + 1)}

This exact distribution relates to the approximate distribution function presented earlier, because as it turns out, the Mittag-Leffler function can in this case be well approximated as:

E_{\alpha ,1}\left( \left( \alpha \lambda \right) ^{\alpha }\right) \approx \dfrac {e^{\alpha \lambda }} {\alpha }

Making this substitution returns the approximate distribution function above. Now that this approximation has a little more mathematical grounding, we can return to the question of appropriate loss functions for neurons. For reference, let’s first plot the squared error loss function. We see that this loss function is minimized when a model’s prediction \lambda and the data x are equal.squared_error

We can turn both the exact distribution function and the approximate distribution function described above into loss functions by taking their negative log. Let’s now plot the exact distribution function’s negative log likelihood when  \alpha=2. The 1-to-1 line where a model’s prediction \lambda and the data x are equal is plotted in orange and the minimum of the function, given the data, is plotted in red. If the loss function is unbiased the red line should be hidden behind the orange line, but we see that is not the case for small values. The loss function is actually minimized with slightly higher values of \lambda, given small values for x. It should also be noted that this loss function contains an infinite sum so it is pretty inconvenient to work with.exactNow let’s plot the approximate distribution function’s negative log likelihood when  \alpha=2.approxThis loss function is nearly identical, but is unbiased. It’s also much easier to deal with since there’s no infinite summation. However, after dropping constants, scaling factors and terms that don’t depend on the model parameters, we get:

loss\left( r,\hat {r}\right) =\hat {r}-r\log \left(\hat {r}\right)

which is just the standard Poisson loss function!

This was surprising to me at first, so I investigated a related model that has a similar linear coupling between the mean and variance, the quasi-Poisson  model. (The quasi-Poisson model also has a corresponding approximate distribution function that I found is actually less accurate is most situations than the one proposed here, but that’s really not worth getting into as we’ll see). The mean and variance in the quasi-Poison model are defined as:

E\left( Y\right) =\mu
Var\left( Y\right) =\theta \mu

And a quasi-Poison GLM is defined as:

\mu =g^{-1}\left( X\beta \right)

with X as the independent variables and \beta as the weights and g is the exponential function. The weights of the quasi-Poisson model are found using iteratively weighted least squares (IWLS) with the update rule:

\hat {\beta }^{\left[ j+1\right] }=\left( X'W^{\left[ j\right] }X\right)^{-1}X'W^{\left[ j\right] }\overline {y}^{\left[ j\right] }

and the weighting function:

W=diag\left( \dfrac {\mu_{1}} {\theta }\ldots \dfrac {\mu_{n}} {\theta }\right)=\dfrac {1}{\theta} diag\left(\mu_{1}\ldots\mu_{n}\right)

However, notice that when you substitute W into the update rule, all the \theta terms will cancel and disappear. Thus the weights found in a quasi-Poisson model are independent of \theta and are identical to what would be found with a standard Poisson model! So again it seems that as long as the variance is linearly related to the mean, the standard Poisson loss function is appropriate. Given that the mean and variance of the firing rate seem linearly related in neurons, we can continue using the standard Poisson loss function when fitting models, but with a little more confidence that we’re using the appropriate loss function for the job.

 

Interpretating Deep Residual Learning Blocks as Locally Recurrent Connections

Microsoft Research Asia (MSRA) recently blew everyone away with their results on the ImageNet and COCO datasets. If you haven’t yet seen the work check out the paper here and especially some of the examples in the presentation here (Note: I’ve borrowed some of their figures below). Computer vision has come a long way, color me impressed!

Their basic claim is that a deeper network should in principle be able to learn anything a shallower network can learn: if the additional layers simply performed an identity mapping, the deeper network would be functionally identical to a shallower network. However, they show empirically that increases in network depth beyond a certain point make deeper networks perform worse than shallower networks. Despite modern optimization techniques like batch normalization and Adam, it seems ultra-deep networks with a standard architecture are fundamentally hard to train with gradient methods.cifarIt would seem that for standard architectures and training methods, we’ve passed the point of diminishing returns and started to regress. The theoretical benefits of increased depth, will never be realized unless we do something differently. We must go deeper!

The MSRA paper makes a simple proposal based on their insight that the additional layers in deeper networks need only perform identity transforms to perform as well as shallower networks. Because the deeper networks seem to have a hard time discovering the identity transform on their own, they simply build the identity transform in! The layers in the neural network now learn a “residual” function F(x) to add to the identity transform. To perform an identity transform only, the network only needs to force the weights in F(x) to zero. The basic building block of their Deep Residual Learning network is:

residualunit

A similar line of reasoning also led to the recently proposed Highway networks. The main difference is that Highway networks have an additional set of weights that control the switching between, or mixing, of x and F(x).

My first reaction to the residual learning framework, was “that’s an interesting hack, I’m amazed it seems to work as well as it does”. Now don’t get me wrong, I love a simple and useful hack (*cough* dropout *cough* batch normalization *cough*) as much as the next neural net aficionado. But on further consideration, it occurred to me there is an interesting way to look at what is going on in the residual learning blocks in terms of theoretical neuroscience.

Below is a cartoon model of some of the basic computations believed to take place within a cortical area of the brain. (A couple examples of research in this area can be found here and here.) The responses of pyramidal cells, the main output cells of cortex (shown as triangles), are determined by their input as well as modulations due locally recurrent interactions with inhibitory cells (shown as circles) and each other. I hope the analogy I’m making to the components of the residual learning block is made clear by the color coding. Basically the initial activation, shown in red and indicated by x, is due to the input. This initial activation triggers the recurrent connections which are a nonlinear function of the initial activation, shown in blue and represented by F(x). The final output, shown in purple, is simply the sum of the input driven activity, x, and the recurrently driven activity, F(x).recur2The residual learning blocks can thus be thought of as implementing locally recurrent processing, perhaps analogously to how it happens in the brain! The input to a brain area/residual learning block is processed recurrently before being passed to the next brain area/residual learning block. Obviously, the usual caveats apply: the processing in the brain is dynamic in time and is much more complicated and nonlinear, there is no account of feedback here, etc. However, I think this analogy might be a useful, and biologically plausible, way to understand the success of deep residual learning.