We develop deep Q-learning using DQN as follows:
- Import all the necessary packages:
>>> import gym>>> import torch>>> from torch.autograd import Variable>>> import random
The variable wraps a tensor and supports backpropagation.
- Let's start with the __init__ method of the DQN class:
>>> class DQN(): ... def __init__(self, n_state, n_action, n_hidden=50, lr=0.05): ... self.criterion = torch.nn.MSELoss() ... self.model = torch.nn.Sequential( ... torch.nn.Linear(n_state, n_hidden), ... torch.nn.ReLU(), ... torch.nn.Linear(n_hidden, n_action) ... ) ... self.optimizer = torch.optim.Adam( self.model.parameters(), lr)
- We now develop the training method, which updates the neural network with a data point:
>>> def update(self, ...