File size: 4,440 Bytes
70b5e18
93643d5
040c521
b0d2a02
afbe2ec
593941a
de1ced9
a3a5d99
fd79eb2
 
 
 
 
 
 
 
1602927
fd79eb2
 
 
 
b0d2a02
1602927
 
 
30d670a
6c40a85
dc02763
cfd4b0d
f4c9eb8
39080c2
83433fb
de1ced9
 
 
83433fb
de1ced9
 
442d668
257aff4
 
 
5021bd8
b875571
 
 
 
257aff4
b875571
5021bd8
b875571
442d668
de1ced9
2bf9da4
e355e92
d65667c
50b814c
a822923
1602927
f3bcef9
47a0109
0974f51
 
 
 
50b814c
e355e92
50b814c
de1ced9
 
0686401
 
5071704
8ec85f2
cdb9220
d2e06fa
8ec85f2
e355e92
0974f51
 
 
 
 
 
fb2ea03
0974f51
 
fb2ea03
0974f51
fb2ea03
0974f51
 
 
 
50b814c
e355e92
93643d5
50b814c
daac94f
0686401
93643d5
e355e92
93643d5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import spaces
import gradio
import json
import torch
import onnxruntime
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer
from optimum.pipelines import pipeline
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

# CORS Config
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win","https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win","https://crunchatize-2-2b4f5b1479a6.c5v4v4jx6pq5.win","https://tamabotchi-2dba63df3bf1.c5v4v4jx6pq5.win","https://lord-raven.github.io"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
# "xenova/deberta-v3-base-tasksource-nli" Not impressed
# "Xenova/bart-large-mnli" A bit slow
# "Xenova/distilbert-base-uncased-mnli" "typeform/distilbert-base-uncased-mnli" Bad answers
# "Xenova/deBERTa-v3-base-mnli" "MoritzLaurer/DeBERTa-v3-base-mnli" Still a bit slow and not great answers
# "xenova/nli-deberta-v3-small" "cross-encoder/nli-deberta-v3-small" Was using this for a good while and it was...okay

# model_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
# file_name = "onnx/model.onnx"
# tokenizer_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"

# model = ORTModelForSequenceClassification.from_pretrained(model_name, export=True, provider="CUDAExecutionProvider")
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)

session_options = onnxruntime.SessionOptions()
session_options.log_severity_level = 0

print(f"ORTModelForSequenceClassification.from_pretrained")
model = ORTModelForSequenceClassification.from_pretrained(
    "philschmid/tiny-bert-sst2-distilled",
    export=True,
    provider="CUDAExecutionProvider",
    session_options=session_options
)
print(f"AutoTokenizer.from_pretrained")
tokenizer = AutoTokenizer.from_pretrained("philschmid/tiny-bert-sst2-distilled")

# classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer, device="cuda:0")

print(f"Testing 1")
@spaces.GPU()
def classify(data_string, request: gradio.Request):
    if request:
        if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://crunchatize-2-2b4f5b1479a6.c5v4v4jx6pq5.win", "https://tamabotchi-2dba63df3bf1.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space", "https://lord-raven.github.io"]:
            return "{}"
    data = json.loads(data_string)
    # if 'task' in data and data['task'] == 'few_shot_classification':
    #     return few_shot_classification(data)
    # else:
    return zero_shot_classification(data)

print(f"Testing 2")
def zero_shot_classification(data):
    results = []
    # classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
    response_string = json.dumps(results)
    return response_string

def create_sequences(data):
    # return ['###Given:\n' + data['sequence'] + '\n###End Given\n###Hypothesis:\n' + data['hypothesis_template'].format(label) + "\n###End Hypothesis" for label in data['candidate_labels']]
    return [data['sequence'] + '\n' + data['hypothesis_template'].format(label) for label in data['candidate_labels']]

print(f"Testing 3")
# def few_shot_classification(data):
#     sequences = create_sequences(data)
#     print(sequences)
#     # results = onnx_few_shot_model(sequences)
#     probs = onnx_few_shot_model.predict_proba(sequences)
#     scores = [true[0] for true in probs]

#     composite = list(zip(scores, data['candidate_labels']))
#     composite = sorted(composite, key=lambda x: x[0], reverse=True)

#     labels, scores = zip(*composite)

#     response_dict = {'scores': scores, 'labels': labels}
#     print(response_dict)
#     response_string = json.dumps(response_dict)
#     return response_string

print(f"Testing 4")
gradio_interface = gradio.Interface(
    fn = classify,
    inputs = gradio.Textbox(label="JSON Input"),
    outputs = gradio.Textbox()
)
print(f"Testing 5")
gradio_interface.launch()