Xenova HF Staff commited on
Commit
0e7154d
·
verified ·
1 Parent(s): e255b4e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +245 -0
README.md CHANGED
@@ -8,6 +8,251 @@ tags: []
8
  <!-- Provide a quick summary of what the model is/does. -->
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  ## Model Details
13
 
 
8
  <!-- Provide a quick summary of what the model is/does. -->
9
 
10
 
11
+ ## ONNX export code
12
+
13
+ ```py
14
+ import os
15
+ import torch
16
+ from transformers import (
17
+ AutoProcessor,
18
+ Qwen2VLForConditionalGeneration,
19
+ DynamicCache,
20
+ )
21
+
22
+
23
+ class PatchedQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
24
+ def forward(self, *args):
25
+ inputs_embeds, attention_mask, position_ids, *past_key_values_args = args
26
+
27
+ # Convert past_key_values list to DynamicCache
28
+ if len(past_key_values_args) == 0:
29
+ past_key_values = None
30
+ else:
31
+ past_key_values = DynamicCache(self.config.num_hidden_layers)
32
+ for i in range(self.config.num_hidden_layers):
33
+ key = past_key_values_args.pop(0)
34
+ value = past_key_values_args.pop(0)
35
+ past_key_values.update(key_states=key, value_states=value, layer_idx=i)
36
+
37
+ o = super().forward(
38
+ inputs_embeds=inputs_embeds,
39
+ attention_mask=attention_mask,
40
+ position_ids=position_ids,
41
+ past_key_values=past_key_values,
42
+ )
43
+
44
+ flattened_past_key_values_outputs = {
45
+ "logits": o.logits,
46
+ }
47
+ output_past_key_values: DynamicCache = o.past_key_values
48
+ for i, (key, value) in enumerate(
49
+ zip(output_past_key_values.key_cache, output_past_key_values.value_cache)
50
+ ):
51
+ flattened_past_key_values_outputs[f"present.{i}.key"] = key
52
+ flattened_past_key_values_outputs[f"present.{i}.value"] = value
53
+
54
+ return flattened_past_key_values_outputs
55
+
56
+
57
+ # Constants
58
+ OUTPUT_FOLDER = "output"
59
+ EMBEDDING_MODEL_NAME = "embed_tokens.onnx"
60
+ TEXT_MODEL_NAME = "decoder_model_merged.onnx"
61
+ VISION_MODEL_NAME = "vision_encoder.onnx"
62
+ TEMP_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "temp")
63
+ FINAL_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "onnx")
64
+
65
+
66
+ # Load model and processor
67
+ model_id = "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
68
+ model = PatchedQwen2VLForConditionalGeneration.from_pretrained(model_id).eval()
69
+ processor = AutoProcessor.from_pretrained(model_id)
70
+
71
+
72
+ # Save model configs and processor
73
+ model.config.save_pretrained(OUTPUT_FOLDER)
74
+ model.generation_config.save_pretrained(OUTPUT_FOLDER)
75
+ processor.save_pretrained(OUTPUT_FOLDER)
76
+ os.makedirs(TEMP_MODEL_OUTPUT_FOLDER, exist_ok=True)
77
+
78
+
79
+ # Configuration values
80
+ ## Text model
81
+ text_config = model.config
82
+ num_heads = text_config.num_attention_heads
83
+ num_key_value_heads = text_config.num_key_value_heads
84
+ head_dim = text_config.hidden_size // num_heads
85
+ num_layers = text_config.num_hidden_layers
86
+ hidden_size = text_config.hidden_size
87
+
88
+ ## Vision model
89
+ vision_config = model.config.vision_config
90
+ channel = vision_config.in_chans
91
+ temporal_patch_size = vision_config.temporal_patch_size
92
+ patch_size = vision_config.spatial_patch_size
93
+
94
+
95
+ # Dummy input sizes
96
+ grid_t, grid_h, grid_w = [1, 16, 16]
97
+ batch_size = 1
98
+ sequence_length = 16
99
+ num_channels = 3
100
+ past_sequence_length = 0
101
+
102
+ image_batch_size = 1 # TODO: Add support for > 1 images
103
+ assert image_batch_size == 1
104
+
105
+
106
+ # Dummy inputs
107
+ ## Embedding inputs
108
+ input_ids = torch.randint(
109
+ 0, model.config.vocab_size, (batch_size, sequence_length), dtype=torch.int64
110
+ )
111
+
112
+ ## Text inputs
113
+ dummy_past_key_values_kwargs = {
114
+ f"past_key_values.{i}.{key}": torch.zeros(
115
+ batch_size,
116
+ num_key_value_heads,
117
+ past_sequence_length,
118
+ head_dim,
119
+ dtype=torch.float32,
120
+ )
121
+ for i in range(num_layers)
122
+ for key in ["key", "value"]
123
+ }
124
+ inputs_embeds = torch.ones(
125
+ batch_size, sequence_length, hidden_size, dtype=torch.float32
126
+ )
127
+ attention_mask = torch.ones(batch_size, sequence_length, dtype=torch.int64)
128
+ position_ids = torch.ones(3, batch_size, sequence_length, dtype=torch.int64)
129
+
130
+ ## Vision inputs
131
+ grid_thw = torch.tensor(
132
+ [[grid_t, grid_h, grid_w]] * image_batch_size, dtype=torch.int64
133
+ )
134
+ pixel_values = torch.randn(
135
+ image_batch_size * grid_t * grid_h * grid_w,
136
+ channel * temporal_patch_size * patch_size * patch_size,
137
+ dtype=torch.float32,
138
+ )
139
+
140
+
141
+ # ONNX Exports
142
+ ## Embedding model
143
+ embedding_inputs = dict(input_ids=input_ids)
144
+ embedding_inputs_positional = tuple(embedding_inputs.values())
145
+ model.model.embed_tokens(*embedding_inputs_positional) # Test forward pass
146
+ EMBED_TOKENS_OUTPUT_PATH = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, EMBEDDING_MODEL_NAME)
147
+ torch.onnx.export(
148
+ model.model.embed_tokens,
149
+ args=embedding_inputs_positional,
150
+ f=EMBED_TOKENS_OUTPUT_PATH,
151
+ export_params=True,
152
+ opset_version=14,
153
+ do_constant_folding=True,
154
+ input_names=list(embedding_inputs.keys()),
155
+ output_names=["inputs_embeds"],
156
+ dynamic_axes={
157
+ "input_ids": {0: "batch_size", 1: "sequence_length"},
158
+ "inputs_embeds": {0: "batch_size", 1: "sequence_length"},
159
+ },
160
+ )
161
+
162
+ ## Text model
163
+ text_inputs = dict(
164
+ inputs_embeds=inputs_embeds,
165
+ attention_mask=attention_mask,
166
+ position_ids=position_ids,
167
+ **dummy_past_key_values_kwargs,
168
+ )
169
+ text_inputs_positional = tuple(text_inputs.values())
170
+ text_outputs = model.forward(*text_inputs_positional) # Test forward pass
171
+ TEXT_MODEL_OUTPUT_PATH=os.path.join(TEMP_MODEL_OUTPUT_FOLDER, TEXT_MODEL_NAME)
172
+ torch.onnx.export(
173
+ model,
174
+ args=text_inputs_positional,
175
+ f=TEXT_MODEL_OUTPUT_PATH,
176
+ export_params=True,
177
+ opset_version=14,
178
+ do_constant_folding=True,
179
+ input_names=list(text_inputs.keys()),
180
+ output_names=["logits"]
181
+ + [f"present.{i}.{key}" for i in range(num_layers) for key in ["key", "value"]],
182
+ dynamic_axes={
183
+ "inputs_embeds": {0: "batch_size", 1: "sequence_length"},
184
+ "attention_mask": {0: "batch_size", 1: "sequence_length"},
185
+ "position_ids": {1: "batch_size", 2: "sequence_length"},
186
+ **{
187
+ f"past_key_values.{i}.{key}": {0: "batch_size", 2: "past_sequence_length"}
188
+ for i in range(num_layers)
189
+ for key in ["key", "value"]
190
+ },
191
+ "logits": {0: "batch_size", 1: "sequence_length"},
192
+ **{
193
+ f"present.{i}.{key}": {0: "batch_size", 2: "past_sequence_length + 1"}
194
+ for i in range(num_layers)
195
+ for key in ["key", "value"]
196
+ },
197
+ },
198
+ )
199
+
200
+ ## Vision model
201
+ vision_inputs = dict(
202
+ pixel_values=pixel_values,
203
+ grid_thw=grid_thw,
204
+ )
205
+ vision_inputs_positional = tuple(vision_inputs.values())
206
+ vision_outputs = model.visual.forward(*vision_inputs_positional) # Test forward pass
207
+ VISION_ENCODER_OUTPUT_PATH = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, VISION_MODEL_NAME)
208
+ torch.onnx.export(
209
+ model.visual,
210
+ args=vision_inputs_positional,
211
+ f=VISION_ENCODER_OUTPUT_PATH,
212
+ export_params=True,
213
+ opset_version=14,
214
+ do_constant_folding=True,
215
+ input_names=list(vision_inputs.keys()),
216
+ output_names=["image_features"],
217
+ dynamic_axes={
218
+ "pixel_values": {
219
+ 0: "batch_size * grid_t * grid_h * grid_w",
220
+ 1: "channel * temporal_patch_size * patch_size * patch_size",
221
+ },
222
+ "grid_thw": {0: "batch_size"},
223
+ "image_features": {0: "batch_size * grid_t * grid_h * grid_w"},
224
+ },
225
+ )
226
+
227
+
228
+ # Post-processing
229
+ import onnx
230
+ import onnxslim
231
+ from optimum.onnx.graph_transformations import check_and_save_model
232
+
233
+ os.makedirs(FINAL_MODEL_OUTPUT_FOLDER, exist_ok=True)
234
+ for name in (EMBEDDING_MODEL_NAME, TEXT_MODEL_NAME, VISION_MODEL_NAME):
235
+ temp_model_path = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, name)
236
+
237
+ ## Shape inference (especially needed by the vision encoder)
238
+ onnx.shape_inference.infer_shapes_path(temp_model_path, check_type=True, strict_mode=True)
239
+
240
+ ## Attempt to optimize the model with onnxslim
241
+ try:
242
+ model = onnxslim.slim(temp_model_path)
243
+ except Exception as e:
244
+ print(f"Failed to slim {model}: {e}")
245
+ model = onnx.load(temp_model_path)
246
+
247
+ ## Save model
248
+ final_model_path = os.path.join(FINAL_MODEL_OUTPUT_FOLDER, name)
249
+ check_and_save_model(model, final_model_path)
250
+
251
+ ## Cleanup
252
+ import shutil
253
+ shutil.rmtree(TEMP_MODEL_OUTPUT_FOLDER)
254
+ ```
255
+
256
 
257
  ## Model Details
258