RNN classification

Here, we will look at an example of how to build an RNN to identify handwritten numbers from the MNIST database:

import torchfrom torch import nnimport torchvision.datasets as dsetsimport torchvision.transforms as transformsimport matplotlib.pyplot as plt# torch.manual_seed(1) # reproducible# Hyper ParametersEPOCH = 1 # train the training data n times, to save time, we just train 1 epochBATCH_SIZE = 64TIME_STEP = 28 # rnn time step / image heightINPUT_SIZE = 28 # rnn input size / image widthLR = 0.01 # learning rateDOWNLOAD_MNIST = True # set to True if haven't download the data# Mnist digital datasettrain_data = dsets.MNIST( root='./mnist/', train=True, # this is training data transform=transforms.ToTensor(), # Converts ...

Get Mobile Artificial Intelligence Projects now with the O’Reilly learning platform.

O’Reilly members experience books, live events, courses curated by job role, and more from O’Reilly and nearly 200 top publishers.