diff --git a/codegen_sources/model/src/model/__init__.py b/codegen_sources/model/src/model/__init__.py index 6c95b733..6f371d58 100644 --- a/codegen_sources/model/src/model/__init__.py +++ b/codegen_sources/model/src/model/__init__.py @@ -343,7 +343,8 @@ def reload_transformer( clean_model_state_dict(reloaded, model_type, model_number) reload_word_embeddings(reloaded, dico, model_type) reload_lang_embeddings(reloaded, params, model_type) - reload_position_embeddings(reloaded, model, model_type) + if not params.cape_embeddings: + reload_position_embeddings(reloaded, model, model_type) # if the model is a decoder if hasattr(model, "encoder_attn"):