mjwong commited on
Commit
42d4264
·
verified ·
1 Parent(s): d5b6595

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -3
app.py CHANGED
@@ -1,8 +1,14 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import pipeline
4
  from typing import Dict
5
 
 
 
 
 
 
 
6
  # Available models for zero-shot classification
7
  AVAILABLE_MODELS = [
8
  "mjwong/multilingual-e5-large-instruct-xnli-anli",
@@ -10,7 +16,7 @@ AVAILABLE_MODELS = [
10
  "mjwong/multilingual-e5-large-xnli-anli",
11
  "mjwong/mcontriever-msmarco-xnli",
12
  "mjwong/mcontriever-xnli"
13
- ]
14
 
15
  def classify_text(
16
  model_name: str,
@@ -38,7 +44,17 @@ def classify_text(
38
  # Set device: 0 if GPU available, else -1 for CPU
39
  device = 0 if torch.cuda.is_available() else -1
40
 
41
- classifier = pipeline("zero-shot-classification", model=model_name, device=device)
 
 
 
 
 
 
 
 
 
 
42
  labels_list = [label.strip() for label in labels.split(",")]
43
  result = classifier(text, candidate_labels=labels_list, multi_label=multi_label)
44
  return {label: score for label, score in zip(result["labels"], result["scores"])}
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, pipeline
4
  from typing import Dict
5
 
6
+ # Custom models for zero-shot classification requiring trust_remote_code=True
7
+ CUSTOM_MODELS = [
8
+ "mjwong/gte-multilingual-base-xnli",
9
+ "mjwong/gte-multilingual-base-xnli-anli"
10
+ ]
11
+
12
  # Available models for zero-shot classification
13
  AVAILABLE_MODELS = [
14
  "mjwong/multilingual-e5-large-instruct-xnli-anli",
 
16
  "mjwong/multilingual-e5-large-xnli-anli",
17
  "mjwong/mcontriever-msmarco-xnli",
18
  "mjwong/mcontriever-xnli"
19
+ ] + CUSTOM_MODELS
20
 
21
  def classify_text(
22
  model_name: str,
 
44
  # Set device: 0 if GPU available, else -1 for CPU
45
  device = 0 if torch.cuda.is_available() else -1
46
 
47
+ if model_name in CUSTOM_MODELS:
48
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
49
+ classifier = pipeline("zero-shot-classification",
50
+ model=model_name,
51
+ tokenizer=tokenizer,
52
+ trust_remote_code=True
53
+ )
54
+
55
+ else:
56
+ classifier = pipeline("zero-shot-classification", model=model_name, device=device)
57
+
58
  labels_list = [label.strip() for label in labels.split(",")]
59
  result = classifier(text, candidate_labels=labels_list, multi_label=multi_label)
60
  return {label: score for label, score in zip(result["labels"], result["scores"])}