import torch import requests from .config import TAG_NAMES, SPACE_URL from .globals import global_model, global_tokenizer def predict_single(text, hf_repo, backend="local", hf_token=None): if backend == "local": return _predict_local(text, hf_repo) elif backend == "hf": return _predict_hf_api(text, hf_token) else: raise ValueError(f"Unknown backend: {backend}") def _predict_local(text, hf_repo): global global_model, global_tokenizer # Lazy-loading to avoid slow startup if global_model is None: from .model import QwenClassifier from transformers import AutoTokenizer global_model = QwenClassifier.from_pretrained(hf_repo).eval() global_tokenizer = AutoTokenizer.from_pretrained(hf_repo) inputs = global_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): logits = global_model(**inputs) return _process_output(logits) def _predict_hf_api(text, hf_token=None): try: response = requests.post( f"{SPACE_URL}/predict", json={"text": text}, # This matches the Pydantic model headers={ "Authorization": f"Bearer {hf_token}", "Content-Type": "application/json" } if hf_token else {"Content-Type": "application/json"}, timeout=60 ) response.raise_for_status() # Raise HTTP errors return response.json() except requests.exceptions.RequestException as e: raise ValueError(f"API Error: {str(e)}\nResponse: {e.response.text if hasattr(e, 'response') else ''}") def _process_output(logits): probs = torch.sigmoid(logits) s = '' for tag, prob in zip(TAG_NAMES, probs[0]): if prob>0.5: s += f"{tag}({prob:.2f}), " return s[:-2]