Search

Early Stopping

대분류
인공지능/데이터
소분류
ML/DL 정리 노트
유형
딥 러닝
부유형
모듈
최종 편집 일시
2024/10/27 15:15
생성 일시
2024/10/25 11:10
14 more properties
class EarlyStopper(object): def __init__(self, num_trials, save_path): self.num_trials = num_trials self.trial_counter = 0 self.best_loss = np.inf self.save_path = save_path def is_continuable(self, model, loss): if loss < self.best_loss: # 현재 loss가 최고 loss보다 더 낮은 경우 self.best_loss = loss # 최고 loss를 현재 loss로 업데이트 self.trial_counter = 0 # 초기화 torch.save(model, self.save_path) # 최고 loss를 갖은 모델 저장 return True elif self.trial_counter + 1 < self.num_trials: # 현재 loss가 최고 loss보다 작은 경우 & max 시도횟수보다 현재 시도횟수가 작은 경우 self.trial_counter += 1 # 기존 시도횟수 + 1 return True else: # 현재 정확도가 최고 정확도보다 작은 경우 & 현재 시도횟수가 max 시도횟수보다 큰 경우 return False def get_best_model(self, device): return torch.load(self.save_path).to(device)
Python
복사