콜백 사용하기
훈련이 몇 시간동안 지속되는 경우에는
훈련 마지막에 모델을 저장하는 것 뿐만 아니라
훈련 도중 일정 간격으로 체크포인트를 저장할 필요가 있습니다.
fit() 메소드의 callbacks 매개변수를 사용해서 케라스가 훈련의 시작이나 끝에 호출할 객체 리스트를 지정할 수 있습니다.
또는 에포크의 시작이나 끝, 각 배치 처리 전후에 호출할 수도 있습니다.
예를 들어 ModelCheckPoint는 훈련하는 동안 일정한 간격으로 모델의 체크포인트를 저장합니다.
기본적으로는 매 에포크의 끝에서 호출됩니다.
[...] # 모델을 만들고 컴파일하기
checkpoint_cb = keras.callbacks.ModelCheckPoint('filename.h5')
history = model.fit(X_train, y_train, epochs=10, callbacks=[checkpoint_cb])
훈련하는 동안 검증 세트를 사용하면 ModelCheckPoint를 만들 때 save_best_only=True로 지정할 수 있습니다.
이렇게 하면 최상의 검증 세트 점수에서만 모델을 저장합니다.
오랜 훈련 시간으로 과대적합될 걱정을 하지 않아도 됩니다.
훈련이 끝난 후 마지막에 저장된 모델을 복원하기만 하면 됩니다.
이 방법은 조기 종료를 구현하는 방법 중 하나입니다.
checkpoint_cb = keras.callbacks.ModelCheckPoint('filename.h5', save_best_only=True)
history = model.fit(X_train, y_train, epochs=10,
validation_data=(X_valid, y_valid),
callbacks=[checkpoint_cb])
model = keras.models.load_model('filename.h5') # 최상의 모델로 복원
조기 종료를 구현하는 또 다른 방법은 EarlyStopping 콜백을 사용하는 것입니다.
일정 에포크(patience 매개변수로 지정)동안 검증 세트에 대한 점수가 향상되지 않으면 훈련을 멈춥니다.
선택적으로 최상의 모델을 복원할 수도 있습니다.
체크포인트 저장 콜백과 진전이 없는 경우 훈련을 일찍 멈추는 콜백을 함께 사용 가능합니다.
checkpoint_cb = keras.callbacks.ModelCheckPoint('filename.h5', save_best_only=True)
early_stopping_cb = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)
history = model.fit(X_train, y_train, epochs=100,
validation_data=(X_valid, y_valid),
callbacks=[checkpoint_cb, early_stopping_cb])
모델이 향상되지 않으면 훈련이 자동으로 중지되므로 에포크의 숫자를 크게 지정해도 됩니다.
EarlyStopping 콜백에 restore_best_weights=True를 설정해줬기 때문에 최상의 가중치를 복원해주기 때문에 저장된 모델을 따로 복원할 필요가 없습니다.
더 많은 제어를 원하면 사용자 정의 콜백을 만들 수 있습니다.
아래와 같이 만든 사용자 정의 콜백은 훈련하는 동안 검증 손실과 훈련 손실의 비율을 출력합니다.
class PrintValTrainRatioCallback(keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs):
print('\nval/train: {:.2f}'.format(logs['val_loss'] / logs['loss'])
이 외에도 아래와 같은 메소드들을 더 구현할 수 있습니다.
- on_train_begin
- on_train_end
- on_epoch_begin
- on_epoch_end
- on_batch_begin
- on_batch_end
- on_test_begin
- on_test_end
- on_test_batch_begin
- on_test_batch_end
- on_predict_begin
- on_predict_end
- on_predict_batch_begin
- on_predict_batch_end
'DATA > 머신 러닝' 카테고리의 다른 글
[머신 러닝] 신경망 하이퍼파라미터 튜닝하기 (0) | 2022.01.23 |
---|---|
[머신 러닝] 텐서보드를 이용해 시각화하기 (0) | 2022.01.23 |
[머신 러닝] 모델 저장과 복원 (0) | 2022.01.22 |
[머신 러닝] 다층 퍼셉트론 (0) | 2022.01.22 |
[머신 러닝] 퍼셉트론 (0) | 2022.01.21 |