How To Set Parameters In Keras To Be Non-trainable?
Solution 1:
You can simple assign a boolean value to the layer property trainable
.
model.layers[n].trainable = False
You can visualize which layer is trainable:
for l in model.layers:
print(l.name, l.trainable)
You can pass it by the model definition too:
frozen_layer = Dense(32, trainable=False)
From Keras documentation:
To "freeze" a layer means to exclude it from training, i.e. its weights will never be updated. This is useful in the context of fine-tuning a model, or using fixed embeddings for a text input. You can pass a trainable argument (boolean) to a layer constructor to set a layer to be non-trainable. Additionally, you can set the trainable property of a layer to True or False after instantiation. For this to take effect, you will need to call compile() on your model after modifying the trainable property.
Solution 2:
There is a typo in the Word "trainble"(missing an "a"). Saddly keras doesn't warn me that the model doesn't have the property "trainble". The question could be closed.
Solution 3:
Despite the fact that the original question's solution is a typo fix, let me add some information on keras trainables.
Modern Keras contains the following facilities to view and manipulate trainable state:
tf.keras.Layer._get_trainable_state()
function - prints the dictinary where keys are model components and values are booleans. Note thattf.keras.Model
is also atf.Keras.Layer
.tf.keras.Layer.trainable
property - to manipulate trainable state of individual layers.
So the typical actions look like following:
# Print current trainable map:print(model._get_trainable_state())
# Set every layer to be non-trainable:for k,v in model._get_trainable_state().items():
k.trainable = False# Don't forget to re-compile the model
model.compile(...)
Solution 4:
Change the last 3 lines in your code:
last_few_layers = 20#number of the last few layers to freeze
self.domain_regressor = Model(img_inputs, domain_label)
for layer in model.layers[:-last_few_layers]:
layer.trainable = False
self.domain_regressor.compile(optimizer = opt, loss='binary_crossentropy', metrics=['accuracy'])
Post a Comment for "How To Set Parameters In Keras To Be Non-trainable?"