Boosting Machine Learning Performance With Rust (Part 2)
Experimenting with Convolutional Neural Networks (CNNs) from scratch in Rust.In my previous article (Part 1) I started my experiment to develop a machine learning framework in Rust from scratch. The main aim of my experiment was to gauge model training speed improvements that can be attained by using Rust in conjunction with PyTorch over a Python equivalent. Results were very encouraging for Feedforward Networks. In this article I continue building on that, with the main objective being to be able to define and train Convolutional Neural Networks (CNNs). As in the previous article, I continue to make use of the Tch-rs Rust crate as a wrapper to the PyTorch C++ Library LibTorch, primarily to access the tensors linear algebra and autograd functions, and the rest is developed from scratch. The code for Part 1 and 2 are now available on Github (Link).
By trying to keep model definition as similar as possible with a Python equivalent, Listing 1 above should be quite intuitive for Python-PyTorch users. In the MyModel struct we are now able to add Conv2D layers and then initiate them in the associated function new. In the Compute trait implementation, the forward function is defined and takes the input through all the layers including the intermediate MaxPooling function. In the main function (Listing 2), similar to our previous article, we are training our model and applying it on the Mnist dataset.
0 Comments