class EarlyStopper(object):
def __init__(self, num_trials, save_path):
self.num_trials = num_trials
self.trial_counter = 0
self.best_loss = np.inf
self.save_path = save_path
def is_continuable(self, model, loss):
if loss < self.best_loss: # 현재 loss가 최고 loss보다 더 낮은 경우
self.best_loss = loss # 최고 loss를 현재 loss로 업데이트
self.trial_counter = 0 # 초기화
torch.save(model, self.save_path) # 최고 loss를 갖은 모델 저장
return True
elif self.trial_counter + 1 < self.num_trials: # 현재 loss가 최고 loss보다 작은 경우 & max 시도횟수보다 현재 시도횟수가 작은 경우
self.trial_counter += 1 # 기존 시도횟수 + 1
return True
else: # 현재 정확도가 최고 정확도보다 작은 경우 & 현재 시도횟수가 max 시도횟수보다 큰 경우
return False
def get_best_model(self, device):
return torch.load(self.save_path).to(device)
Python
복사