Spaces:
Running
on
Zero
Running
on
Zero
Lord-Raven
commited on
Commit
·
1f33968
1
Parent(s):
2bf9da4
Messing with configuration.
Browse files- app.py +2 -59
- requirements.txt +1 -1
app.py
CHANGED
@@ -5,7 +5,6 @@ import torch
|
|
5 |
from transformers import AutoTokenizer
|
6 |
from transformers import pipeline
|
7 |
from optimum.onnxruntime import ORTModelForSequenceClassification
|
8 |
-
from optimum.onnxruntime import ORTModelForFeatureExtraction
|
9 |
from fastapi import FastAPI
|
10 |
from fastapi.middleware.cors import CORSMiddleware
|
11 |
from setfit import SetFitModel, SetFitTrainer, Trainer, TrainingArguments
|
@@ -27,37 +26,6 @@ app.add_middleware(
|
|
27 |
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
28 |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
29 |
|
30 |
-
# class OnnxSetFitModel:
|
31 |
-
# def __init__(self, ort_model, tokenizer, model_head):
|
32 |
-
# self.ort_model = ort_model
|
33 |
-
# self.tokenizer = tokenizer
|
34 |
-
# self.model_head = model_head
|
35 |
-
|
36 |
-
# def predict(self, inputs):
|
37 |
-
# encoded_inputs = self.tokenizer(
|
38 |
-
# inputs, padding=True, truncation=True, return_tensors="pt"
|
39 |
-
# ).to(self.ort_model.device)
|
40 |
-
|
41 |
-
# outputs = self.ort_model(**encoded_inputs)
|
42 |
-
# embeddings = mean_pooling(
|
43 |
-
# outputs["last_hidden_state"], encoded_inputs["attention_mask"]
|
44 |
-
# )
|
45 |
-
# return self.model_head.predict(embeddings.cpu())
|
46 |
-
|
47 |
-
# def predict_proba(self, inputs):
|
48 |
-
# encoded_inputs = self.tokenizer(
|
49 |
-
# inputs, padding=True, truncation=True, return_tensors="pt"
|
50 |
-
# ).to(self.ort_model.device)
|
51 |
-
|
52 |
-
# outputs = self.ort_model(**encoded_inputs)
|
53 |
-
# embeddings = mean_pooling(
|
54 |
-
# outputs["last_hidden_state"], encoded_inputs["attention_mask"]
|
55 |
-
# )
|
56 |
-
# return self.model_head.predict_proba(embeddings.cpu())
|
57 |
-
|
58 |
-
# def __call__(self, inputs):
|
59 |
-
# return self.predict(inputs)
|
60 |
-
|
61 |
# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
|
62 |
# "xenova/deberta-v3-base-tasksource-nli" Not impressed
|
63 |
# "Xenova/bart-large-mnli" A bit slow
|
@@ -67,34 +35,9 @@ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
|
67 |
model_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
|
68 |
file_name = "onnx/model.onnx"
|
69 |
tokenizer_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
|
70 |
-
model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=file_name)
|
71 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
|
72 |
-
classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
|
73 |
-
|
74 |
-
# few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512) # 'BAAI/bge-small-en-v1.5'
|
75 |
-
# ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5', file_name="onnx/model.onnx") # 'BAAI/bge-small-en-v1.5'
|
76 |
-
# few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english") # "moshew/bge-small-en-v1.5_setfit-sst2-english"
|
77 |
-
|
78 |
-
# Train few_shot_model
|
79 |
-
# candidate_labels = ["supported", "refuted"]
|
80 |
-
# reference_dataset = load_dataset("SetFit/sst2")
|
81 |
-
# dummy_dataset = Dataset.from_dict({})
|
82 |
-
# train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8, template="The CONCLUSION is {} by the PASSAGE.")
|
83 |
-
# args = TrainingArguments(
|
84 |
-
# batch_size=32,
|
85 |
-
# num_epochs=1
|
86 |
-
# )
|
87 |
-
# trainer = Trainer(
|
88 |
-
# model=few_shot_model,
|
89 |
-
# args=args,
|
90 |
-
# train_dataset=train_dataset,
|
91 |
-
# eval_dataset=reference_dataset["test"]
|
92 |
-
# )
|
93 |
-
# trainer.train()
|
94 |
-
|
95 |
-
# onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head)
|
96 |
-
|
97 |
-
|
98 |
|
99 |
def classify(data_string, request: gradio.Request):
|
100 |
if request:
|
|
|
5 |
from transformers import AutoTokenizer
|
6 |
from transformers import pipeline
|
7 |
from optimum.onnxruntime import ORTModelForSequenceClassification
|
|
|
8 |
from fastapi import FastAPI
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
from setfit import SetFitModel, SetFitTrainer, Trainer, TrainingArguments
|
|
|
26 |
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
27 |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
|
30 |
# "xenova/deberta-v3-base-tasksource-nli" Not impressed
|
31 |
# "Xenova/bart-large-mnli" A bit slow
|
|
|
35 |
model_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
|
36 |
file_name = "onnx/model.onnx"
|
37 |
tokenizer_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
|
38 |
+
model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=file_name, provider="CUDAExecutionProvider")
|
39 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
|
40 |
+
classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer, device="cuda:0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
def classify(data_string, request: gradio.Request):
|
43 |
if request:
|
requirements.txt
CHANGED
@@ -2,7 +2,7 @@ fastapi==0.88.0
|
|
2 |
huggingface_hub==0.23.5
|
3 |
json5==0.9.25
|
4 |
numpy<2.0
|
5 |
-
optimum[exporters,onnxruntime]==1.21.4
|
6 |
setfit==1.0.3
|
7 |
transformers==4.40.2
|
8 |
sentence-transformers==3.0.1
|
|
|
2 |
huggingface_hub==0.23.5
|
3 |
json5==0.9.25
|
4 |
numpy<2.0
|
5 |
+
optimum[exporters,onnxruntime-gpu]==1.21.4
|
6 |
setfit==1.0.3
|
7 |
transformers==4.40.2
|
8 |
sentence-transformers==3.0.1
|