https://www.youtube.com/watch?v=28QbrkRkHlo&t=5526s
대부분 위의 영상을 참고했다.
모델의 저장과 복원
- save() 메소드를 이용하여 저장
- load_model() 을 이용해 복원
- Sequencial API, 함수형 API에서는 모델의 저장 및 로드가 가능하지만 서브클래싱 방식으로는 할 수 없음
- 서브클래싱 방식은 save_weights()와 load_weights()를 이용해 모델의 파라미터만 저장 및 로드
- JSON 형식
model.to_json() (저장)
tf.keras.models.model_from_json(file_path) (복원)
- YAML로 직렬화
model.to_yaml() (저장)
tf.keras.models.model_from_yaml(file_path) (복원)
model.save('mnist_model.h5')
loaded_model = models.load_model('mnist_model.h5')
loaded_model.summary()
콜백 (callbacks)
- 학습시에 callbacks 변수를 통해 훈련의 시작이나 끝에 호출할 객체 지정
- 여러개 같이 사용가능
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler, TensorBoard
1. ModelCheckpoint
-정기적으로 모델의 체크포인트를 저장하고 문제가 발생시 복구하는데에 사용
check_point_cb = ModelCheckpoint('keras_mnist_model.h5', save_best_only=True)
history = model.fit(x_train, y_train, epochs=10,
validation_data=(x_val, y_val),
callbacks=[check_point_cb])
save_best_only=True 시 최상의 모델만을 저장 (validation 의 성능을 보고 판단)
2. EarlyStopping
- val_set이 한동안 개선되지 않을 경우 학습을 중단할 때 사용
- 일정 patience 동안 val_set의 점수가 오르지 않으면 학습을 중단함
- 모델이 좋아지지않으면 학습이 자동으로 중단되므로 epochs를 크게 가져가도 무방하다 (과적합에 주의)
- 학습이 끝난 후 최상의 가중치를 복원한다.
check_point_cb = ModelCheckpoint('keras_mnist_model.h5', save_best_only=True)
earlystopping_cb = EarlyStopping(patience=3, monitor='val_loss',
restore_best_weights=True)
history = model.fit(x_train, y_train, epochs=10,
validation_data=(x_val, y_val),
callbacks=[check_point_cb, earlystopping_cb])
3. LearnigRateScheduler
- 학습중에 학습률을 변경시키기 위해 사용
def scheduler(epoch, leaning_rate):
if epoch < 10:
return leaning_rate
else:
return leaning_rate * tf.math.exp(-0.1)
print(round(model.optimizer.lr.numpy(),5))
lr_scheduler_cb = LearningRateScheduler(scheduler)
history = model.fit(x_train, y_train, epochs=15,
callbacks=[lr_scheduler_cb], verbose=0)
print(round(model.optimizer.lr.numpy(),5))
0.01
0.00607
함수를 정의하여 lr을 어떻게 변화시킬지 정의해야함
4. Tensorboard
- 학습과정 모니터링하기 위함
-
log_dir = './logs'
tensor_board_cb = [TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=True,)]
model.fit(x_train, y_train, batch_size=32, validation_data=(x_val, y_val),
epochs=30, callbacks=tensor_board_cb)
%load_ext tensorboard
#%tensorboard --logdir {log_dir} port 8000
%tensorboard --logdir {log_dir}
이런식으로 동적으로 확인할 수 있는 창이 뜸
'TF' 카테고리의 다른 글
[TF] CNN 컨볼루션 신경망 (0) | 2022.04.13 |
---|---|
[TF] 딥러닝 학습기술 (0) | 2022.04.12 |
[TF] 모델 컴파일 및 학습 mnist (0) | 2022.04.11 |
[TF] Layer, Model, 모델구성 (0) | 2022.04.11 |