natural neural networks

Guillaume Desjardins, Karen Simonyan, Razvan Pascanu, Koray Kavukcuoglu

We introduce Natural Neural Networks, a novel family of algorithms that speed up convergence by adapting their internal representation during training to improve conditioning of the Fisher matrix. In particular, we show a specific example that employs a simple and efficient reparametrization of the neural network weights by implicitly whitening the representation obtained at each layer, while preserving the feed-forward computation of the network. Such networks can be trained efficiently via the proposed Projected Natural Gradient Descent algorithm (PRONG), which amortizes the cost of these reparametrizations over many parameter updates and is closely related to the Mirror Descent online learning algorithm. We highlight the benefits of our method on both unsupervised and supervised learning tasks, and showcase its scalability by training on the large-scale ImageNet Challenge dataset.

Suggested to Venues

  • Added 2 years ago by Marc'Aurelio Ranzato

  • 3 Comments

Discussion

SC

  • Soumith Chintala wrote 2 years ago

  • Public

A request to the authors is that they report the results of PRONG (and not PRONG+) on imagenet, even if the results are not great. I would like to de-correlate the performance of PRONG from BatchNorm.

OG

  • Olivier Grisel wrote 2 years ago

  • Public

I agree with Soumith's comment, in particular, footnote 6 says: "This instability may have been compounded by momentum, which was initially not reset after each model reparametrization when using standard PRONG." It would be very interesting if the could re-run the experiment with the "correct" handling of the momentum and report if that can fix the instability issue they observed with vanilla PRONG on imagenet.

TC

  • Tim Cooijmans wrote 2 years ago

  • Public

The algorithm description contains two errors: * on line 7, b_i = d_i - W_i c_i (that is minus, not plus) * on line 10, it's not c_i that is to be updated but d_i, and the update should read d_i <- b_i + V_i U_{i-1} c_i (i.e. plus not minus, and c_i should be premultiplied with U_{i-1} first)