이 글은 파이토치로 배우는 자연어처리(O'REILLY, 한빛미디어)를 공부한 내용을 바탕으로 작성하였습니다.
훈련 반복
첫번째 for문 : epoch 수 만큼 반복 (전체 데이터셋)
두번째 for문 : batch 수 만큼 반복
train 데이터에 대해 param 학습 후 val 데이터로 검증
accuracy, loss 정보 저장
epoch_bar = tqdm.notebook.tqdm(desc='training routine',
total=args.num_epochs,
position=0)
dataset.set_split('train')
train_bar = tqdm.notebook.tqdm(desc='split=train',
total=dataset.get_num_batches(args.batch_size),
position=1,
leave=True)
dataset.set_split('val')
val_bar = tqdm.notebook.tqdm(desc='split=val',
total=dataset.get_num_batches(args.batch_size),
position=1,
leave=True)
try:
for epoch_index in range(args.num_epochs):
print(epoch_index)
train_state['epoch_index'] = epoch_index
# 훈련 세트에 대한 순회
# 훈련 세트와 배치 제너레이터 준비, 손실과 정확도를 0으로 설정
dataset.set_split('train')
batch_generator = generate_batches(dataset,
batch_size=args.batch_size,
device=args.device)
running_loss = 0.0
running_acc = 0.0
classifier.train()
for batch_index, batch_dict in enumerate(batch_generator):
# 훈련 과정은 5단계로 이루어집니다
# --------------------------------------
# 단계 1. 그레이디언트를 0으로 초기화합니다
optimizer.zero_grad()
# 단계 2. 출력을 계산합니다
y_pred = classifier(x_in=batch_dict['x_data'].float()) # .forward() 안해도 자동 실행
# 단계 3. 손실을 계산합니다
loss = loss_func(y_pred, batch_dict['y_target'].float())
loss_t = loss.item() # tensor에서 스칼라값 뽑기
running_loss += (loss_t - running_loss) / (batch_index + 1)
# 단계 4. 손실을 사용해 그레이디언트를 계산합니다
loss.backward()
# 단계 5. 옵티마이저로 가중치를 업데이트합니다
optimizer.step()
# -----------------------------------------
# 정확도를 계산합니다
acc_t = compute_accuracy(y_pred, batch_dict['y_target'])
running_acc += (acc_t - running_acc) / (batch_index + 1)
# 진행 바 업데이트
train_bar.set_postfix(loss=running_loss,
acc=running_acc,
epoch=epoch_index)
train_bar.update()
train_state['train_loss'].append(running_loss)
train_state['train_acc'].append(running_acc)
# 검증 세트에 대한 순회
# 검증 세트와 배치 제너레이터 준비, 손실과 정확도를 0으로 설정
dataset.set_split('val')
batch_generator = generate_batches(dataset,
batch_size=args.batch_size,
device=args.device)
running_loss = 0.
running_acc = 0.
classifier.eval()
for batch_index, batch_dict in enumerate(batch_generator):
# 단계 1. 출력을 계산합니다
y_pred = classifier(x_in=batch_dict['x_data'].float())
# 단계 2. 손실을 계산합니다
loss = loss_func(y_pred, batch_dict['y_target'].float())
loss_t = loss.item()
running_loss += (loss_t - running_loss) / (batch_index + 1)
# 단계 3. 정확도를 계산합니다
acc_t = compute_accuracy(y_pred, batch_dict['y_target'])
running_acc += (acc_t - running_acc) / (batch_index + 1)
val_bar.set_postfix(loss=running_loss,
acc=running_acc,
epoch=epoch_index)
val_bar.update()
train_state['val_loss'].append(running_loss)
train_state['val_acc'].append(running_acc)
train_state = update_train_state(args=args, model=classifier,
train_state=train_state)
scheduler.step(train_state['val_loss'][-1])
train_bar.n = 0
val_bar.n = 0
epoch_bar.update()
if train_state['stop_early']:
break
train_bar.n = 0
val_bar.n = 0
epoch_bar.update()
print(epoch_index)
except KeyboardInterrupt:
print("Exiting loop")
테스트 데이터 평가
# 가장 좋은 모델을 사용해 테스트 세트의 손실과 정확도를 계산합니다
classifier.load_state_dict(torch.load(train_state['model_filename']))
classifier = classifier.to(args.device)
dataset.set_split('test')
batch_generator = generate_batches(dataset,
batch_size=args.batch_size,
device=args.device)
running_loss = 0.
running_acc = 0.
classifier.eval()
for batch_index, batch_dict in enumerate(batch_generator):
# 출력을 계산합니다
y_pred = classifier(x_in=batch_dict['x_data'].float())
# 손실을 계산합니다
loss = loss_func(y_pred, batch_dict['y_target'].float())
loss_t = loss.item()
running_loss += (loss_t - running_loss) / (batch_index + 1)
# 정확도를 계산합니다
acc_t = compute_accuracy(y_pred, batch_dict['y_target'])
running_acc += (acc_t - running_acc) / (batch_index + 1)
train_state['test_loss'] = running_loss
train_state['test_acc'] = running_acc
print("테스트 손실: {:.3f}".format(train_state['test_loss']))
print("테스트 정확도: {:.2f}".format(train_state['test_acc']))
'AI 인공지능' 카테고리의 다른 글
기계 학습(머신러닝)을 시작하는 방법: 단계별 가이드 (0) | 2023.03.31 |
---|---|
yelp 리뷰 감성 분류 (2) (0) | 2023.03.30 |
yelp 리뷰 감성 분류 (1) (0) | 2023.03.30 |
파이토치 신경망 구성하기 (0) | 2023.03.24 |
자연어처리(NLP) 기본 용어 정리 (0) | 2023.03.21 |