How to do it...

We develop the REINFORCE algorithm to solve the CartPole environment as follows:

  1. Import all the necessary packages and create a CartPole instance:
>>> import gym>>> import torch>>> import torch.nn as nn>>> env = gym.make('CartPole-v0')
  1. Let's start with the __init__method of the PolicyNetwork class, which approximates the policy using a neural network:
>>> class PolicyNetwork(): ...     def __init__(self, n_state, n_action, n_hidden=50, lr=0.001): ...         self.model = nn.Sequential( ...                         nn.Linear(n_state, n_hidden), ...                         nn.ReLU(), ...                         nn.Linear(n_hidden, n_action), ...                         nn.Softmax(), ...                 ) ...         self.optimizer = torch.optim.Adam(                         self.model.parameters(), lr)
  1. Next, add the predict method, which computes the estimated policy:
>>> ...

Get PyTorch 1.x Reinforcement Learning Cookbook now with the O’Reilly learning platform.

O’Reilly members experience books, live events, courses curated by job role, and more from O’Reilly and nearly 200 top publishers.