How to do it...

We develop deep Q-learning using DQN as follows:

  1. Import all the necessary packages:
>>> import gym>>> import torch>>> from torch.autograd import Variable>>> import random

The variable wraps a tensor and supports backpropagation.

  1. Let's start with the __init__ method of the DQN class:
>>> class DQN(): ...     def __init__(self, n_state, n_action, n_hidden=50,                      lr=0.05): ...         self.criterion = torch.nn.MSELoss() ...         self.model = torch.nn.Sequential( ...                         torch.nn.Linear(n_state, n_hidden), ...                         torch.nn.ReLU(), ...                         torch.nn.Linear(n_hidden, n_action) ...                 ) ...         self.optimizer = torch.optim.Adam(                             self.model.parameters(), lr)
  1. We now develop the training method, which updates the neural network with a data point:
>>> def update(self, ...

Get PyTorch 1.x Reinforcement Learning Cookbook now with the O’Reilly learning platform.

O’Reilly members experience books, live events, courses curated by job role, and more from O’Reilly and nearly 200 top publishers.