Matt commited on
Commit
b5c7055
·
1 Parent(s): f285b2c

Tie weights correctly

Browse files
Files changed (1) hide show
  1. modeling_florence2.py +8 -2
modeling_florence2.py CHANGED
@@ -2066,6 +2066,12 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2066
  # Initialize weights and apply final processing
2067
  self.post_init()
2068
 
 
 
 
 
 
 
2069
  def get_encoder(self):
2070
  return self.model.get_encoder()
2071
 
@@ -2523,6 +2529,8 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
2523
  FLORENCE2_START_DOCSTRING,
2524
  )
2525
  class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
 
 
2526
  def __init__(self, config: Florence2Config):
2527
  super().__init__(config)
2528
  assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
@@ -2537,8 +2545,6 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2537
 
2538
  language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
2539
 
2540
- if language_model._tied_weights_keys is not None:
2541
- self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
2542
  self.language_model = language_model
2543
 
2544
  self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
 
2066
  # Initialize weights and apply final processing
2067
  self.post_init()
2068
 
2069
+ def _tie_weights(self):
2070
+ if self.config.tie_word_embeddings:
2071
+ self._tie_or_clone_weights(self.model.encoder.embed_tokens, self.model.shared)
2072
+ self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.model.shared)
2073
+ self._tie_or_clone_weights(self.lm_head, self.model.shared)
2074
+
2075
  def get_encoder(self):
2076
  return self.model.get_encoder()
2077
 
 
2529
  FLORENCE2_START_DOCSTRING,
2530
  )
2531
  class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2532
+ _tied_weights_keys = ["language_model.encoder.embed_tokens.weight", "language_model.decoder.embed_tokens.weight", "language_model.lm_head.weight"]
2533
+
2534
  def __init__(self, config: Florence2Config):
2535
  super().__init__(config)
2536
  assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
 
2545
 
2546
  language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
2547
 
 
 
2548
  self.language_model = language_model
2549
 
2550
  self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1