import os from typing import Any from openai import OpenAI from rag_demo.rag.base.query import Query from rag_demo.rag.base.template_factory import RAGStep from transformers import AutoTokenizer, AutoModelForSequenceClassification from loguru import logger import torch model_name = ( "AdrienB134/greetings-classifier" # Model trained on English greetings only ) tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) class QueryClassifier(RAGStep): def generate(self, query: Query) -> Any: if self._mock: return "Sources_needed" with torch.no_grad(): inputs = tokenizer(query.content, return_tensors="pt") logits = model(**inputs).logits predictions = logits.argmax() return model.config.id2label[predictions.item()]