
Моделирование последовательных данных с помощью рекуррентных нейронных сетей
515
... encoded_input = torch.tensor(
... [char2int[s] for s in starting_str]
... )
... encoded_input = torch.reshape(
... encoded_input, (1, -1)
... )
... generated_str = starting_str
...
... model.eval()
... hidden, cell = model.init_hidden(1)
... for c in range(len(starting_str)-1):
... _, hidden, cell = model(
... encoded_input[:, c].view(1), hidden, cell
... )
...
... last_char = encoded_input[:, -1]
... for i in range(len_generated_text):
... logits, hidden, cell = model(
... last_char.view(1), hidden, cell
... )
... logits = torch.squeeze(logits, 0) ...