We develop the REINFORCE algorithm to solve the CartPole environment as follows:
- Import all the necessary packages and create a CartPole instance:
>>> import gym>>> import torch>>> import torch.nn as nn>>> env = gym.make('CartPole-v0')
- Let's start with the __init__method of the PolicyNetwork class, which approximates the policy using a neural network:
>>> class PolicyNetwork(): ... def __init__(self, n_state, n_action, n_hidden=50, lr=0.001): ... self.model = nn.Sequential( ... nn.Linear(n_state, n_hidden), ... nn.ReLU(), ... nn.Linear(n_hidden, n_action), ... nn.Softmax(), ... ) ... self.optimizer = torch.optim.Adam( self.model.parameters(), lr)
- Next, add the predict method, which computes the estimated policy:
>>> ...