What I Do/프로그래밍

keras에서 layer의 weight를 업데이트하지 않고 고정하기

1Millisecond 2023. 1. 26. 22:23

keras(텐서플로)로 모델을 생성하고,

fine tunning을 위해서 일부 layer의 weight는 update하지 말고

마지막 layer만 업데이트 한다든지 모두 fix(freeze)한다음 마지막에 layer를 추가한다든지 할때

유용한 기법이다.

(fix나 non-update로 표현하지 않고 freeze한다는 표현을 쓰는 것 같다.)

 

방법은 매우 간단하다.

layer.trainable = False 해주면된다.

 

 

model = Sequential([
    ResNet50Base(input_shape=(32, 32, 3), weights='pretrained'),
    Dense(10),
])
model.layers[0].trainable = False  # Freeze ResNet50Base.

assert model.layers[0].trainable_weights == []  # ResNet50Base has no trainable weights.
assert len(model.trainable_weights) == 2  # Just the bias & kernel of the Dense layer.

model.compile(...)
model.fit(...)  # Train Dense while excluding ResNet50Base.

 

관련 내용은 아래에서 찾아 볼수 있다.

https://keras.io/getting_started/faq/#how-can-i-freeze-layers-and-do-finetuning

 

Keras documentation: Keras FAQ

» Getting started / Keras FAQ Keras FAQ A list of frequently Asked Keras Questions. General questions General questions How can I train a Keras model on multiple GPUs (on a single machine)? There are two ways to run a single model on multiple GPUs: data p

keras.io