Lord-Raven commited on
Commit
3b3af39
·
1 Parent(s): 9afecae

More cleanup.

Browse files
Files changed (1) hide show
  1. app.py +5 -12
app.py CHANGED
@@ -9,7 +9,7 @@ from transformers import pipeline
9
  from fastapi import FastAPI
10
  from fastapi.middleware.cors import CORSMiddleware
11
 
12
- # CORS Config
13
  app = FastAPI()
14
 
15
  app.add_middleware(
@@ -24,11 +24,6 @@ print(f"Is CUDA available: {torch.cuda.is_available()}")
24
  print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
25
 
26
  # "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
27
- # "xenova/deberta-v3-base-tasksource-nli" Not impressed
28
- # "Xenova/bart-large-mnli" A bit slow
29
- # "Xenova/distilbert-base-uncased-mnli" "typeform/distilbert-base-uncased-mnli" Bad answers
30
- # "Xenova/deBERTa-v3-base-mnli" "MoritzLaurer/DeBERTa-v3-base-mnli" Still a bit slow and not great answers
31
- # "xenova/nli-deberta-v3-small" "cross-encoder/nli-deberta-v3-small" Was using this for a good while and it was...okay
32
 
33
  model_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
34
  tokenizer_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
@@ -42,13 +37,10 @@ def classify(data_string, request: gradio.Request):
42
  return "{}"
43
  data = json.loads(data_string)
44
 
45
- # Prevent batch suggestion warning in log.
46
  classifier_cpu.call_count = 0
47
  classifier_gpu.call_count = 0
48
 
49
- # if 'task' in data and data['task'] == 'few_shot_classification':
50
- # return few_shot_classification(data)
51
- # else:
52
  start_time = time.time()
53
  result = {}
54
  try:
@@ -75,10 +67,11 @@ def create_sequences(data):
75
  gradio_interface = gradio.Interface(
76
  fn = classify,
77
  inputs = gradio.Textbox(label="JSON Input"),
78
- outputs = gradio.Textbox()
 
 
79
  )
80
 
81
  app.mount("/gradio", gradio_interface)
82
 
83
- # app = gradio.mount_gradio_app(app, gradio_interface, path="/gradio")
84
  gradio_interface.launch()
 
9
  from fastapi import FastAPI
10
  from fastapi.middleware.cors import CORSMiddleware
11
 
12
+ # CORS Config - This isn't actually working; instead, I am taking a gross approach to origin whitelisting within the service.
13
  app = FastAPI()
14
 
15
  app.add_middleware(
 
24
  print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
25
 
26
  # "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
 
 
 
 
 
27
 
28
  model_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
29
  tokenizer_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
 
37
  return "{}"
38
  data = json.loads(data_string)
39
 
40
+ # Try to prevent batch suggestion warning in log.
41
  classifier_cpu.call_count = 0
42
  classifier_gpu.call_count = 0
43
 
 
 
 
44
  start_time = time.time()
45
  result = {}
46
  try:
 
67
  gradio_interface = gradio.Interface(
68
  fn = classify,
69
  inputs = gradio.Textbox(label="JSON Input"),
70
+ outputs = gradio.Textbox(),
71
+ title = "Statosphere Backend",
72
+ description = "This Space is a classification service for a set of chub.ai stages and not really intended for use through this UI."
73
  )
74
 
75
  app.mount("/gradio", gradio_interface)
76
 
 
77
  gradio_interface.launch()