Lord-Raven commited on
Commit
b0d2a02
·
1 Parent(s): d184de8

Trying to use ONNX model.

Browse files
Files changed (1) hide show
  1. app.py +73 -1
app.py CHANGED
@@ -1,12 +1,70 @@
1
  import gradio
2
  import json
 
3
  from transformers import pipeline
4
  from transformers import AutoTokenizer
5
  from fastapi import FastAPI
6
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- classifier = pipeline(task='zero-shot-classification', model='xenova/mobilebert-uncased-mnli')
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
 
 
 
 
 
10
  app = FastAPI()
11
 
12
  app.add_middleware(
@@ -17,6 +75,20 @@ app.add_middleware(
17
  allow_headers=["*"],
18
  )
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def zero_shot_classification(data_string):
21
  print(data_string)
22
  data = json.loads(data_string)
 
1
  import gradio
2
  import json
3
+ import torch
4
  from transformers import pipeline
5
  from transformers import AutoTokenizer
6
  from fastapi import FastAPI
7
  from fastapi.middleware.cors import CORSMiddleware
8
+ from onnxruntime import (
9
+ InferenceSession, SessionOptions, GraphOptimizationLevel
10
+ )
11
+ from transformers import (
12
+ TokenClassificationPipeline, AutoTokenizer, AutoModelForTokenClassification
13
+ )
14
+
15
+ class OnnxTokenClassificationPipeline(TokenClassificationPipeline):
16
+
17
+ def __init__(self, *args, **kwargs):
18
+ super().__init__(*args, **kwargs)
19
+
20
+
21
+ def _forward(self, model_inputs):
22
+ """
23
+ Forward pass through the model. This method is not to be called by the user directly and is only used
24
+ by the pipeline to perform the actual predictions.
25
+ This is where we will define the actual process to do inference with the ONNX model and the session created
26
+ before.
27
+ """
28
+
29
+ # This comes from the original implementation of the pipeline
30
+ special_tokens_mask = model_inputs.pop("special_tokens_mask")
31
+ offset_mapping = model_inputs.pop("offset_mapping", None)
32
+ sentence = model_inputs.pop("sentence")
33
+
34
+ inputs = {k: v.cpu().detach().numpy() for k, v in model_inputs.items()} # dict of numpy arrays
35
+ outputs_name = session.get_outputs()[0].name # get the name of the output tensor
36
+
37
+ logits = session.run(output_names=[outputs_name], input_feed=inputs)[0] # run the session
38
+ logits = torch.tensor(logits) # convert to torch tensor to be compatible with the original implementation
39
+
40
+ return {
41
+ "logits": logits,
42
+ "special_tokens_mask": special_tokens_mask,
43
+ "offset_mapping": offset_mapping,
44
+ "sentence": sentence,
45
+ **model_inputs,
46
+ }
47
 
48
+ # We need to override the preprocess method because the onnx model is waiting for the attention masks as inputs
49
+ # along with the embeddings.
50
+ def preprocess(self, sentence, offset_mapping=None):
51
+ truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
52
+ model_inputs = self.tokenizer(
53
+ sentence,
54
+ return_attention_mask=True, # This is the only difference from the original implementation
55
+ return_tensors=self.framework,
56
+ truncation=truncation,
57
+ return_special_tokens_mask=True,
58
+ return_offsets_mapping=self.tokenizer.is_fast,
59
+ )
60
+ if offset_mapping:
61
+ model_inputs["offset_mapping"] = offset_mapping
62
 
63
+ model_inputs["sentence"] = sentence
64
+
65
+ return model_inputs
66
+
67
+ # CORS Config
68
  app = FastAPI()
69
 
70
  app.add_middleware(
 
75
  allow_headers=["*"],
76
  )
77
 
78
+ options = SessionOptions()
79
+ options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
80
+
81
+ session = InferenceSession("onnx/model.onnx", sess_options=options, providers=["CPUExecutionProvider"])
82
+
83
+ session.disable_fallback()
84
+
85
+ model_name = "xenova/mobilebert-uncased-mnli"
86
+
87
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
88
+ model = AutoModelForTokenClassification.from_pretrained(model_name)
89
+
90
+ classifier = OnnxTokenClassificationPipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer, framework="pt", aggregation_strategy="simple")
91
+
92
  def zero_shot_classification(data_string):
93
  print(data_string)
94
  data = json.loads(data_string)