In this section, we'll implement an LSTM cell with PyTorch 1.3.1. First, let's note that PyTorch already has an LSTM implementation, which is available at torch.nn.LSTM. However, our goal is to understand how the LSTM cell works, so we'll implement our own version from scratch instead. The cell will be a subclass of torch.nn.Module and we'll use it as a building block for larger models. The source code for this example is available at https://github.com/PacktPublishing/Advanced-Deep-Learning-with-Python/tree/master/Chapter07/lstm_cell.py. Let's get started:
- First, we'll do the imports:
import mathimport typingimport torch
- Next, we'll implement the class and the __init__ method:
class LSTMCell(torch.nn.Module): def