File size: 874 Bytes
cc3f1e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import os
from typing import Any
from openai import OpenAI
from rag_demo.rag.base.query import Query
from rag_demo.rag.base.template_factory import RAGStep
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from loguru import logger
import torch
model_name = (
"AdrienB134/greetings-classifier" # Model trained on English greetings only
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
class QueryClassifier(RAGStep):
def generate(self, query: Query) -> Any:
if self._mock:
return "Sources_needed"
with torch.no_grad():
inputs = tokenizer(query.content, return_tensors="pt")
logits = model(**inputs).logits
predictions = logits.argmax()
return model.config.id2label[predictions.item()]
|