Lord-Raven commited on
Commit
0974f51
·
1 Parent(s): 54d574f

Adding another allowed origin. Commenting out few-shot model that is unused.

Browse files
Files changed (1) hide show
  1. app.py +68 -71
app.py CHANGED
@@ -17,42 +17,42 @@ app = FastAPI()
17
 
18
  app.add_middleware(
19
  CORSMiddleware,
20
- allow_origins=["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win","https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win","https://lord-raven.github.io"],
21
  allow_credentials=True,
22
  allow_methods=["*"],
23
  allow_headers=["*"],
24
  )
25
 
26
- class OnnxSetFitModel:
27
- def __init__(self, ort_model, tokenizer, model_head):
28
- self.ort_model = ort_model
29
- self.tokenizer = tokenizer
30
- self.model_head = model_head
31
-
32
- def predict(self, inputs):
33
- encoded_inputs = self.tokenizer(
34
- inputs, padding=True, truncation=True, return_tensors="pt"
35
- ).to(self.ort_model.device)
36
-
37
- outputs = self.ort_model(**encoded_inputs)
38
- embeddings = mean_pooling(
39
- outputs["last_hidden_state"], encoded_inputs["attention_mask"]
40
- )
41
- return self.model_head.predict(embeddings.cpu())
42
-
43
- def predict_proba(self, inputs):
44
- encoded_inputs = self.tokenizer(
45
- inputs, padding=True, truncation=True, return_tensors="pt"
46
- ).to(self.ort_model.device)
47
-
48
- outputs = self.ort_model(**encoded_inputs)
49
- embeddings = mean_pooling(
50
- outputs["last_hidden_state"], encoded_inputs["attention_mask"]
51
- )
52
- return self.model_head.predict_proba(embeddings.cpu())
53
-
54
- def __call__(self, inputs):
55
- return self.predict(inputs)
56
 
57
  # "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
58
  # "xenova/deberta-v3-base-tasksource-nli" Not impressed
@@ -67,31 +67,28 @@ model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=
67
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
68
  classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
69
 
70
- few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512) # 'BAAI/bge-small-en-v1.5'
71
- ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5', file_name="onnx/model.onnx") # 'BAAI/bge-small-en-v1.5'
72
- few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english") # "moshew/bge-small-en-v1.5_setfit-sst2-english"
73
 
74
  # Train few_shot_model
75
- candidate_labels = ["supported", "refuted"]
76
- reference_dataset = load_dataset("SetFit/sst2")
77
- dummy_dataset = Dataset.from_dict({})
78
- train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8, template="The CONCLUSION is {} by the PASSAGE.")
79
- args = TrainingArguments(
80
- batch_size=32,
81
- num_epochs=1
82
- )
83
- trainer = Trainer(
84
- model=few_shot_model,
85
- args=args,
86
- train_dataset=train_dataset,
87
- eval_dataset=reference_dataset["test"]
88
- )
89
- trainer.train()
90
-
91
- # metrics = trainer.evaluate()
92
- # print(metrics)
93
-
94
- onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head)
95
 
96
 
97
 
@@ -100,10 +97,10 @@ def classify(data_string, request: gradio.Request):
100
  if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space", "https://lord-raven.github.io"]:
101
  return "{}"
102
  data = json.loads(data_string)
103
- if 'task' in data and data['task'] == 'few_shot_classification':
104
- return few_shot_classification(data)
105
- else:
106
- return zero_shot_classification(data)
107
 
108
  def zero_shot_classification(data):
109
  results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
@@ -114,22 +111,22 @@ def create_sequences(data):
114
  # return ['###Given:\n' + data['sequence'] + '\n###End Given\n###Hypothesis:\n' + data['hypothesis_template'].format(label) + "\n###End Hypothesis" for label in data['candidate_labels']]
115
  return [data['sequence'] + '\n' + data['hypothesis_template'].format(label) for label in data['candidate_labels']]
116
 
117
- def few_shot_classification(data):
118
- sequences = create_sequences(data)
119
- print(sequences)
120
- # results = onnx_few_shot_model(sequences)
121
- probs = onnx_few_shot_model.predict_proba(sequences)
122
- scores = [true[0] for true in probs]
123
 
124
- composite = list(zip(scores, data['candidate_labels']))
125
- composite = sorted(composite, key=lambda x: x[0], reverse=True)
126
 
127
- labels, scores = zip(*composite)
128
 
129
- response_dict = {'scores': scores, 'labels': labels}
130
- print(response_dict)
131
- response_string = json.dumps(response_dict)
132
- return response_string
133
 
134
  gradio_interface = gradio.Interface(
135
  fn = classify,
 
17
 
18
  app.add_middleware(
19
  CORSMiddleware,
20
+ allow_origins=["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win","https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win","https://tamabotchi-2dba63df3bf1.c5v4v4jx6pq5.win","https://lord-raven.github.io"],
21
  allow_credentials=True,
22
  allow_methods=["*"],
23
  allow_headers=["*"],
24
  )
25
 
26
+ # class OnnxSetFitModel:
27
+ # def __init__(self, ort_model, tokenizer, model_head):
28
+ # self.ort_model = ort_model
29
+ # self.tokenizer = tokenizer
30
+ # self.model_head = model_head
31
+
32
+ # def predict(self, inputs):
33
+ # encoded_inputs = self.tokenizer(
34
+ # inputs, padding=True, truncation=True, return_tensors="pt"
35
+ # ).to(self.ort_model.device)
36
+
37
+ # outputs = self.ort_model(**encoded_inputs)
38
+ # embeddings = mean_pooling(
39
+ # outputs["last_hidden_state"], encoded_inputs["attention_mask"]
40
+ # )
41
+ # return self.model_head.predict(embeddings.cpu())
42
+
43
+ # def predict_proba(self, inputs):
44
+ # encoded_inputs = self.tokenizer(
45
+ # inputs, padding=True, truncation=True, return_tensors="pt"
46
+ # ).to(self.ort_model.device)
47
+
48
+ # outputs = self.ort_model(**encoded_inputs)
49
+ # embeddings = mean_pooling(
50
+ # outputs["last_hidden_state"], encoded_inputs["attention_mask"]
51
+ # )
52
+ # return self.model_head.predict_proba(embeddings.cpu())
53
+
54
+ # def __call__(self, inputs):
55
+ # return self.predict(inputs)
56
 
57
  # "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
58
  # "xenova/deberta-v3-base-tasksource-nli" Not impressed
 
67
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
68
  classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
69
 
70
+ # few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512) # 'BAAI/bge-small-en-v1.5'
71
+ # ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5', file_name="onnx/model.onnx") # 'BAAI/bge-small-en-v1.5'
72
+ # few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english") # "moshew/bge-small-en-v1.5_setfit-sst2-english"
73
 
74
  # Train few_shot_model
75
+ # candidate_labels = ["supported", "refuted"]
76
+ # reference_dataset = load_dataset("SetFit/sst2")
77
+ # dummy_dataset = Dataset.from_dict({})
78
+ # train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8, template="The CONCLUSION is {} by the PASSAGE.")
79
+ # args = TrainingArguments(
80
+ # batch_size=32,
81
+ # num_epochs=1
82
+ # )
83
+ # trainer = Trainer(
84
+ # model=few_shot_model,
85
+ # args=args,
86
+ # train_dataset=train_dataset,
87
+ # eval_dataset=reference_dataset["test"]
88
+ # )
89
+ # trainer.train()
90
+
91
+ # onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head)
 
 
 
92
 
93
 
94
 
 
97
  if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space", "https://lord-raven.github.io"]:
98
  return "{}"
99
  data = json.loads(data_string)
100
+ # if 'task' in data and data['task'] == 'few_shot_classification':
101
+ # return few_shot_classification(data)
102
+ # else:
103
+ return zero_shot_classification(data)
104
 
105
  def zero_shot_classification(data):
106
  results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
 
111
  # return ['###Given:\n' + data['sequence'] + '\n###End Given\n###Hypothesis:\n' + data['hypothesis_template'].format(label) + "\n###End Hypothesis" for label in data['candidate_labels']]
112
  return [data['sequence'] + '\n' + data['hypothesis_template'].format(label) for label in data['candidate_labels']]
113
 
114
+ # def few_shot_classification(data):
115
+ # sequences = create_sequences(data)
116
+ # print(sequences)
117
+ # # results = onnx_few_shot_model(sequences)
118
+ # probs = onnx_few_shot_model.predict_proba(sequences)
119
+ # scores = [true[0] for true in probs]
120
 
121
+ # composite = list(zip(scores, data['candidate_labels']))
122
+ # composite = sorted(composite, key=lambda x: x[0], reverse=True)
123
 
124
+ # labels, scores = zip(*composite)
125
 
126
+ # response_dict = {'scores': scores, 'labels': labels}
127
+ # print(response_dict)
128
+ # response_string = json.dumps(response_dict)
129
+ # return response_string
130
 
131
  gradio_interface = gradio.Interface(
132
  fn = classify,