Spaces:
Sleeping
Sleeping
import torch | |
import requests | |
from .config import TAG_NAMES | |
# Local model setup (only load if needed) | |
local_model = None | |
local_tokenizer = None | |
def predict_single(text, hf_repo, backend="local", hf_token=None): | |
if backend == "local": | |
return _predict_local(text, hf_repo) | |
elif backend == "hf": | |
return _predict_hf_api(text, hf_token) | |
else: | |
raise ValueError(f"Unknown backend: {backend}") | |
def _predict_local(text, hf_repo): | |
global local_model, local_tokenizer | |
# Lazy-loading to avoid slow startup | |
if local_model is None: | |
from .model import QwenClassifier | |
from transformers import AutoTokenizer | |
local_model = QwenClassifier.from_pretrained(hf_repo).eval() | |
local_tokenizer = AutoTokenizer.from_pretrained(hf_repo) | |
inputs = local_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
with torch.no_grad(): | |
logits = local_model(**inputs) | |
return _process_output(logits) | |
def _predict_hf_api(text, hf_token=None): | |
# Use your Space endpoint instead of direct model API | |
SPACE_URL = "https://KeivanR/qwen-classifier-demo" | |
try: | |
response = requests.post( | |
f"{SPACE_URL}/predict", | |
json={"text": text}, | |
headers={"Authorization": f"Bearer {hf_token}"} if hf_token else {} | |
) | |
return response.json() | |
except Exception as e: | |
raise ValueError(f"Space API Error: {str(e)}") | |
def _process_output(logits): | |
probs = torch.sigmoid(logits) | |
s = '' | |
for tag, prob in zip(TAG_NAMES, probs[0]): | |
if prob>0.5: | |
s += f"{tag}({prob:.2f}), " | |
return s[:-2] | |