Spaces:
Running
on
Zero
Running
on
Zero
Lord-Raven
commited on
Commit
·
0974f51
1
Parent(s):
54d574f
Adding another allowed origin. Commenting out few-shot model that is unused.
Browse files
app.py
CHANGED
@@ -17,42 +17,42 @@ app = FastAPI()
|
|
17 |
|
18 |
app.add_middleware(
|
19 |
CORSMiddleware,
|
20 |
-
allow_origins=["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win","https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win","https://lord-raven.github.io"],
|
21 |
allow_credentials=True,
|
22 |
allow_methods=["*"],
|
23 |
allow_headers=["*"],
|
24 |
)
|
25 |
|
26 |
-
class OnnxSetFitModel:
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
|
57 |
# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
|
58 |
# "xenova/deberta-v3-base-tasksource-nli" Not impressed
|
@@ -67,31 +67,28 @@ model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=
|
|
67 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
|
68 |
classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
|
69 |
|
70 |
-
few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512) # 'BAAI/bge-small-en-v1.5'
|
71 |
-
ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5', file_name="onnx/model.onnx") # 'BAAI/bge-small-en-v1.5'
|
72 |
-
few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english") # "moshew/bge-small-en-v1.5_setfit-sst2-english"
|
73 |
|
74 |
# Train few_shot_model
|
75 |
-
candidate_labels = ["supported", "refuted"]
|
76 |
-
reference_dataset = load_dataset("SetFit/sst2")
|
77 |
-
dummy_dataset = Dataset.from_dict({})
|
78 |
-
train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8, template="The CONCLUSION is {} by the PASSAGE.")
|
79 |
-
args = TrainingArguments(
|
80 |
-
|
81 |
-
|
82 |
-
)
|
83 |
-
trainer = Trainer(
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
)
|
89 |
-
trainer.train()
|
90 |
-
|
91 |
-
#
|
92 |
-
# print(metrics)
|
93 |
-
|
94 |
-
onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head)
|
95 |
|
96 |
|
97 |
|
@@ -100,10 +97,10 @@ def classify(data_string, request: gradio.Request):
|
|
100 |
if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space", "https://lord-raven.github.io"]:
|
101 |
return "{}"
|
102 |
data = json.loads(data_string)
|
103 |
-
if 'task' in data and data['task'] == 'few_shot_classification':
|
104 |
-
|
105 |
-
else:
|
106 |
-
|
107 |
|
108 |
def zero_shot_classification(data):
|
109 |
results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
|
@@ -114,22 +111,22 @@ def create_sequences(data):
|
|
114 |
# return ['###Given:\n' + data['sequence'] + '\n###End Given\n###Hypothesis:\n' + data['hypothesis_template'].format(label) + "\n###End Hypothesis" for label in data['candidate_labels']]
|
115 |
return [data['sequence'] + '\n' + data['hypothesis_template'].format(label) for label in data['candidate_labels']]
|
116 |
|
117 |
-
def few_shot_classification(data):
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
|
124 |
-
|
125 |
-
|
126 |
|
127 |
-
|
128 |
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
|
134 |
gradio_interface = gradio.Interface(
|
135 |
fn = classify,
|
|
|
17 |
|
18 |
app.add_middleware(
|
19 |
CORSMiddleware,
|
20 |
+
allow_origins=["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win","https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win","https://tamabotchi-2dba63df3bf1.c5v4v4jx6pq5.win","https://lord-raven.github.io"],
|
21 |
allow_credentials=True,
|
22 |
allow_methods=["*"],
|
23 |
allow_headers=["*"],
|
24 |
)
|
25 |
|
26 |
+
# class OnnxSetFitModel:
|
27 |
+
# def __init__(self, ort_model, tokenizer, model_head):
|
28 |
+
# self.ort_model = ort_model
|
29 |
+
# self.tokenizer = tokenizer
|
30 |
+
# self.model_head = model_head
|
31 |
+
|
32 |
+
# def predict(self, inputs):
|
33 |
+
# encoded_inputs = self.tokenizer(
|
34 |
+
# inputs, padding=True, truncation=True, return_tensors="pt"
|
35 |
+
# ).to(self.ort_model.device)
|
36 |
+
|
37 |
+
# outputs = self.ort_model(**encoded_inputs)
|
38 |
+
# embeddings = mean_pooling(
|
39 |
+
# outputs["last_hidden_state"], encoded_inputs["attention_mask"]
|
40 |
+
# )
|
41 |
+
# return self.model_head.predict(embeddings.cpu())
|
42 |
+
|
43 |
+
# def predict_proba(self, inputs):
|
44 |
+
# encoded_inputs = self.tokenizer(
|
45 |
+
# inputs, padding=True, truncation=True, return_tensors="pt"
|
46 |
+
# ).to(self.ort_model.device)
|
47 |
+
|
48 |
+
# outputs = self.ort_model(**encoded_inputs)
|
49 |
+
# embeddings = mean_pooling(
|
50 |
+
# outputs["last_hidden_state"], encoded_inputs["attention_mask"]
|
51 |
+
# )
|
52 |
+
# return self.model_head.predict_proba(embeddings.cpu())
|
53 |
+
|
54 |
+
# def __call__(self, inputs):
|
55 |
+
# return self.predict(inputs)
|
56 |
|
57 |
# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
|
58 |
# "xenova/deberta-v3-base-tasksource-nli" Not impressed
|
|
|
67 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
|
68 |
classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
|
69 |
|
70 |
+
# few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512) # 'BAAI/bge-small-en-v1.5'
|
71 |
+
# ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5', file_name="onnx/model.onnx") # 'BAAI/bge-small-en-v1.5'
|
72 |
+
# few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english") # "moshew/bge-small-en-v1.5_setfit-sst2-english"
|
73 |
|
74 |
# Train few_shot_model
|
75 |
+
# candidate_labels = ["supported", "refuted"]
|
76 |
+
# reference_dataset = load_dataset("SetFit/sst2")
|
77 |
+
# dummy_dataset = Dataset.from_dict({})
|
78 |
+
# train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8, template="The CONCLUSION is {} by the PASSAGE.")
|
79 |
+
# args = TrainingArguments(
|
80 |
+
# batch_size=32,
|
81 |
+
# num_epochs=1
|
82 |
+
# )
|
83 |
+
# trainer = Trainer(
|
84 |
+
# model=few_shot_model,
|
85 |
+
# args=args,
|
86 |
+
# train_dataset=train_dataset,
|
87 |
+
# eval_dataset=reference_dataset["test"]
|
88 |
+
# )
|
89 |
+
# trainer.train()
|
90 |
+
|
91 |
+
# onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head)
|
|
|
|
|
|
|
92 |
|
93 |
|
94 |
|
|
|
97 |
if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space", "https://lord-raven.github.io"]:
|
98 |
return "{}"
|
99 |
data = json.loads(data_string)
|
100 |
+
# if 'task' in data and data['task'] == 'few_shot_classification':
|
101 |
+
# return few_shot_classification(data)
|
102 |
+
# else:
|
103 |
+
return zero_shot_classification(data)
|
104 |
|
105 |
def zero_shot_classification(data):
|
106 |
results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
|
|
|
111 |
# return ['###Given:\n' + data['sequence'] + '\n###End Given\n###Hypothesis:\n' + data['hypothesis_template'].format(label) + "\n###End Hypothesis" for label in data['candidate_labels']]
|
112 |
return [data['sequence'] + '\n' + data['hypothesis_template'].format(label) for label in data['candidate_labels']]
|
113 |
|
114 |
+
# def few_shot_classification(data):
|
115 |
+
# sequences = create_sequences(data)
|
116 |
+
# print(sequences)
|
117 |
+
# # results = onnx_few_shot_model(sequences)
|
118 |
+
# probs = onnx_few_shot_model.predict_proba(sequences)
|
119 |
+
# scores = [true[0] for true in probs]
|
120 |
|
121 |
+
# composite = list(zip(scores, data['candidate_labels']))
|
122 |
+
# composite = sorted(composite, key=lambda x: x[0], reverse=True)
|
123 |
|
124 |
+
# labels, scores = zip(*composite)
|
125 |
|
126 |
+
# response_dict = {'scores': scores, 'labels': labels}
|
127 |
+
# print(response_dict)
|
128 |
+
# response_string = json.dumps(response_dict)
|
129 |
+
# return response_string
|
130 |
|
131 |
gradio_interface = gradio.Interface(
|
132 |
fn = classify,
|