Matt
commited on
Commit
·
6813a8d
1
Parent(s):
24d011b
Tie weights correctly
Browse files- modeling_florence2.py +8 -2
modeling_florence2.py
CHANGED
@@ -2066,6 +2066,12 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
|
|
2066 |
# Initialize weights and apply final processing
|
2067 |
self.post_init()
|
2068 |
|
|
|
|
|
|
|
|
|
|
|
|
|
2069 |
def get_encoder(self):
|
2070 |
return self.model.get_encoder()
|
2071 |
|
@@ -2523,6 +2529,8 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
|
|
2523 |
FLORENCE2_START_DOCSTRING,
|
2524 |
)
|
2525 |
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
|
|
|
2526 |
def __init__(self, config: Florence2Config):
|
2527 |
super().__init__(config)
|
2528 |
assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
|
@@ -2537,8 +2545,6 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2537 |
|
2538 |
language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
|
2539 |
|
2540 |
-
if language_model._tied_weights_keys is not None:
|
2541 |
-
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
2542 |
self.language_model = language_model
|
2543 |
|
2544 |
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
|
|
2066 |
# Initialize weights and apply final processing
|
2067 |
self.post_init()
|
2068 |
|
2069 |
+
def _tie_weights(self):
|
2070 |
+
if self.config.tie_word_embeddings:
|
2071 |
+
self._tie_or_clone_weights(self.model.encoder.embed_tokens, self.model.shared)
|
2072 |
+
self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.model.shared)
|
2073 |
+
self._tie_or_clone_weights(self.lm_head, self.model.shared)
|
2074 |
+
|
2075 |
def get_encoder(self):
|
2076 |
return self.model.get_encoder()
|
2077 |
|
|
|
2529 |
FLORENCE2_START_DOCSTRING,
|
2530 |
)
|
2531 |
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
2532 |
+
_tied_weights_keys = ["language_model.encoder.embed_tokens.weight", "language_model.decoder.embed_tokens.weight", "language_model.lm_head.weight"]
|
2533 |
+
|
2534 |
def __init__(self, config: Florence2Config):
|
2535 |
super().__init__(config)
|
2536 |
assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
|
|
|
2545 |
|
2546 |
language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
|
2547 |
|
|
|
|
|
2548 |
self.language_model = language_model
|
2549 |
|
2550 |
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|