Network training

First, we should create PyTorch data loader objects for the train and test datasets. The data loader object is responsible for sampling objects from the dataset and making mini-batches from them. This object can be configured as follows:

  1. First, we initialize the MNISTDataset type objects representing our datasets.
  2. Then, we use the torch::data::make_data_loader function to create a data loader object. This function takes the torch::data::DataLoaderOptions type object with configuration settings for the data loader. We set the mini-batch size equal to 256 items and set 8 parallel data loading threads. We should also configure the sampler type, but in this case, we'll leave the default one the random sampler.

The following ...

Get Hands-On Machine Learning with C++ now with the O’Reilly learning platform.

O’Reilly members experience books, live events, courses curated by job role, and more from O’Reilly and nearly 200 top publishers.