Lord-Raven commited on
Commit
f63295c
·
1 Parent(s): 2b2a5e4

Messing with fastAPI.

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -4,6 +4,7 @@ import gradio
4
  import json
5
  import onnxruntime
6
  import time
 
7
  from transformers import pipeline
8
  from fastapi import FastAPI
9
  from fastapi.middleware.cors import CORSMiddleware
@@ -32,7 +33,8 @@ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
32
  model_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
33
  tokenizer_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
34
 
35
- classifier = pipeline(task="zero-shot-classification", model=model_name, tokenizer=tokenizer_name, device="cuda:0")
 
36
  # classifier = pipeline(task="zero-shot-classification", model=model_name, tokenizer=tokenizer_name)
37
 
38
  def classify(data_string, request: gradio.Request):
@@ -42,19 +44,27 @@ def classify(data_string, request: gradio.Request):
42
  data = json.loads(data_string)
43
 
44
  # Prevent batch suggestion warning in log.
45
- classifier.call_count = 0
 
46
 
47
  # if 'task' in data and data['task'] == 'few_shot_classification':
48
  # return few_shot_classification(data)
49
  # else:
50
  start_time = time.time()
51
- result = zero_shot_classification(data)
52
- print(f"classification took {time.time() - start_time}.")
 
 
 
 
53
  return json.dumps(result)
54
 
 
 
 
55
  @spaces.GPU(duration=3)
56
- def zero_shot_classification(data):
57
- return classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
58
 
59
  def create_sequences(data):
60
  return [data['sequence'] + '\n' + data['hypothesis_template'].format(label) for label in data['candidate_labels']]
 
4
  import json
5
  import onnxruntime
6
  import time
7
+ from datetime import datetime
8
  from transformers import pipeline
9
  from fastapi import FastAPI
10
  from fastapi.middleware.cors import CORSMiddleware
 
33
  model_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
34
  tokenizer_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
35
 
36
+ classifier_cpu = pipeline(task="zero-shot-classification", model=model_name, tokenizer=tokenizer_name)
37
+ classifier_gpu = pipeline(task="zero-shot-classification", model=model_name, tokenizer=tokenizer_name, device="cuda:0")
38
  # classifier = pipeline(task="zero-shot-classification", model=model_name, tokenizer=tokenizer_name)
39
 
40
  def classify(data_string, request: gradio.Request):
 
44
  data = json.loads(data_string)
45
 
46
  # Prevent batch suggestion warning in log.
47
+ classifier_cpu.call_count = 0
48
+ classifier_gpu.call_count = 0
49
 
50
  # if 'task' in data and data['task'] == 'few_shot_classification':
51
  # return few_shot_classification(data)
52
  # else:
53
  start_time = time.time()
54
+ result = {}
55
+ if (data['cpu'])
56
+ result = zero_shot_classification_cpu(data)
57
+ else
58
+ result = zero_shot_classification_gpu(data)
59
+ print(f"Classification @ [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] took {time.time() - start_time}.")
60
  return json.dumps(result)
61
 
62
+ def zero_shot_classification_cpu(data):
63
+ return classifier_cpu(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
64
+
65
  @spaces.GPU(duration=3)
66
+ def zero_shot_classification_gpu(data):
67
+ return classifier_gpu(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
68
 
69
  def create_sequences(data):
70
  return [data['sequence'] + '\n' + data['hypothesis_template'].format(label) for label in data['candidate_labels']]