poltextlab commited on
Commit
c9436a5
·
verified ·
1 Parent(s): ba9275b

create cap_media_demo

Browse files
Files changed (1) hide show
  1. interfaces/cap_media_demo.py +73 -0
interfaces/cap_media_demo.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+ import pandas as pd
7
+ from transformers import AutoModelForSequenceClassification
8
+ from transformers import AutoTokenizer
9
+ from huggingface_hub import HfApi
10
+
11
+ from label_dicts import CAP_MEDIA_NUM_DICT, CAP_MEDIA_LABEL_NAMES
12
+
13
+ from .utils import is_disk_full
14
+
15
+ HF_TOKEN = os.environ["hf_read"]
16
+
17
+ languages = [
18
+ "Multilingual",
19
+ ]
20
+
21
+ domains = {
22
+ "media": "media"
23
+ }
24
+
25
+ def check_huggingface_path(checkpoint_path: str):
26
+ try:
27
+ hf_api = HfApi(token=HF_TOKEN)
28
+ hf_api.model_info(checkpoint_path, token=HF_TOKEN)
29
+ return True
30
+ except:
31
+ return False
32
+
33
+ def build_huggingface_path(language: str, domain: str):
34
+ return "poltextlab/xlm-roberta-large-pooled-cap-media"
35
+
36
+ def predict(text, model_id, tokenizer_id):
37
+ device = torch.device("cpu")
38
+ model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", offload_folder="offload", token=HF_TOKEN)
39
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
40
+
41
+ inputs = tokenizer(text,
42
+ max_length=256,
43
+ truncation=True,
44
+ padding="do_not_pad",
45
+ return_tensors="pt").to(device)
46
+ model.eval()
47
+
48
+ with torch.no_grad():
49
+ logits = model(**inputs).logits
50
+
51
+ probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
52
+ output_pred = {f"[{CAP_MEDIA_NUM_DICT[i]}] {CAP_MEDIA_LABEL_NAMES[CAP_MEDIA_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
53
+ output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
54
+ return output_pred, output_info
55
+
56
+ def predict_cap(text, language, domain):
57
+ domain = domains[domain]
58
+ model_id = build_huggingface_path(language, domain)
59
+ tokenizer_id = "xlm-roberta-large"
60
+
61
+ if is_disk_full():
62
+ os.system('rm -rf /data/models*')
63
+ os.system('rm -r ~/.cache/huggingface/hub')
64
+
65
+ return predict(text, model_id, tokenizer_id)
66
+
67
+ demo = gr.Interface(
68
+ title="CAP Minor Topics Babel Demo",
69
+ fn=predict_cap,
70
+ inputs=[gr.Textbox(lines=6, label="Input"),
71
+ gr.Dropdown(languages, label="Language"),
72
+ gr.Dropdown(domains.keys(), label="Domain")],
73
+ outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])