
558 Глава 16
... attention_mask = batch['attention_mask'].to(device)
... labels = batch['labels'].to(device)
... outputs = model(input_ids, attention_mask=attention_mask)
... logits = outputs['logits']
... predicted_labels = torch.argmax(logits, 1)
... num_examples += labels.size(0)
... correct_pred += (predicted_labels == labels).sum()
... return correct_pred.float()/num_examples * 100
В функции
compute_accuracy
мы загружаем текущий пакет, а затем получаем предсказан-
ные метки из выходных данных. При этом мы отслеживаем общее количество приме-
ров через
num_examples
. Точно так же мы отслеживаем количество правильных прогнозов
с помощь ...