File size: 4,319 Bytes
70b5e18
5825e6d
93643d5
040c521
afbe2ec
5ff29bf
f63295c
3f34143
fd79eb2
 
 
 
 
 
 
 
59c51e8
fd79eb2
 
 
 
b0d2a02
5825e6d
 
1602927
30d670a
6c40a85
dc02763
cfd4b0d
f4c9eb8
39080c2
83433fb
3f34143
 
83433fb
f63295c
 
5825e6d
2bf9da4
50b814c
a822923
1602927
f3bcef9
47a0109
b3cb286
 
f63295c
 
b3cb286
0974f51
 
 
5ff29bf
f63295c
 
 
 
 
 
5ff29bf
50b814c
f63295c
 
 
b3cb286
f63295c
 
5071704
8ec85f2
d2e06fa
8ec85f2
0974f51
 
 
 
 
 
fb2ea03
0974f51
 
fb2ea03
0974f51
fb2ea03
0974f51
 
 
19a483c
93643d5
50b814c
daac94f
0686401
93643d5
19a483c
2b2a5e4
8ec27b2
cdc6ff7
2b2a5e4
a426b5f
277a03c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import spaces
import torch
import gradio
import json
import onnxruntime
import time
from datetime import datetime
from transformers import pipeline
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

# CORS Config
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win","https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win","https://crunchatize-2-2b4f5b1479a6.c5v4v4jx6pq5.win","https://tamabotchi-2dba63df3bf1.c5v4v4jx6pq5.win"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
# "xenova/deberta-v3-base-tasksource-nli" Not impressed
# "Xenova/bart-large-mnli" A bit slow
# "Xenova/distilbert-base-uncased-mnli" "typeform/distilbert-base-uncased-mnli" Bad answers
# "Xenova/deBERTa-v3-base-mnli" "MoritzLaurer/DeBERTa-v3-base-mnli" Still a bit slow and not great answers
# "xenova/nli-deberta-v3-small" "cross-encoder/nli-deberta-v3-small" Was using this for a good while and it was...okay

model_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
tokenizer_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"

classifier_cpu = pipeline(task="zero-shot-classification", model=model_name, tokenizer=tokenizer_name)
classifier_gpu = pipeline(task="zero-shot-classification", model=model_name, tokenizer=tokenizer_name, device="cuda:0")
# classifier = pipeline(task="zero-shot-classification", model=model_name, tokenizer=tokenizer_name)

def classify(data_string, request: gradio.Request):
    if request:
        if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://crunchatize-2-2b4f5b1479a6.c5v4v4jx6pq5.win", "https://tamabotchi-2dba63df3bf1.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space", "https://lord-raven.github.io"]:
            return "{}"
    data = json.loads(data_string)

    # Prevent batch suggestion warning in log.
    classifier_cpu.call_count = 0
    classifier_gpu.call_count = 0

    # if 'task' in data and data['task'] == 'few_shot_classification':
    #     return few_shot_classification(data)
    # else:
    start_time = time.time()
    result = {}
    if (data['cpu'])
        result = zero_shot_classification_cpu(data)
    else
        result = zero_shot_classification_gpu(data)
    print(f"Classification @ [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] took {time.time() - start_time}.")
    return json.dumps(result)

def zero_shot_classification_cpu(data):
    return classifier_cpu(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])

@spaces.GPU(duration=3)
def zero_shot_classification_gpu(data):
    return classifier_gpu(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])

def create_sequences(data):
    return [data['sequence'] + '\n' + data['hypothesis_template'].format(label) for label in data['candidate_labels']]

# def few_shot_classification(data):
#     sequences = create_sequences(data)
#     print(sequences)
#     # results = onnx_few_shot_model(sequences)
#     probs = onnx_few_shot_model.predict_proba(sequences)
#     scores = [true[0] for true in probs]

#     composite = list(zip(scores, data['candidate_labels']))
#     composite = sorted(composite, key=lambda x: x[0], reverse=True)

#     labels, scores = zip(*composite)

#     response_dict = {'scores': scores, 'labels': labels}
#     print(response_dict)
#     response_string = json.dumps(response_dict)
#     return response_strin
gradio_interface = gradio.Interface(
    fn = classify,
    inputs = gradio.Textbox(label="JSON Input"),
    outputs = gradio.Textbox()
)

app.mount("/gradio", gradio_interface)

# app = gradio.mount_gradio_app(app, gradio_interface, path="/gradio")
gradio_interface.launch()

# if __name__ == "__main__":
#     import uvicorn
#     uvicorn.run(app, host="0.0.0.0", port=8000)