Spaces:
Sleeping
Sleeping
File size: 5,800 Bytes
0f1938f 36183d4 156898c e5c1bae 535a3a5 e5c1bae 156898c 0f1938f 720c911 0f1938f 535a3a5 0f1938f 535a3a5 0f1938f 156898c 36183d4 535a3a5 36183d4 0f1938f e5c1bae 535a3a5 e5c1bae 0f1938f 535a3a5 36183d4 0f1938f 36183d4 0f1938f 535a3a5 0f1938f 36183d4 535a3a5 36183d4 535a3a5 36183d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
import random
import json
import asyncio
from typing import List, Dict, Any, Optional, Union, Tuple
import sys
import os
from litellm import OpenAI
# Add the project root to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from prompts import CATEGORY_SUGGESTION_PROMPT, TEXT_CLASSIFICATION_PROMPT, ADDITIONAL_CATEGORY_PROMPT
from .base import BaseClassifier
class LLMClassifier(BaseClassifier):
"""Classifier using a Large Language Model for more accurate but slower classification"""
def __init__(self, client: OpenAI, model: str = "gpt-3.5-turbo") -> None:
super().__init__()
self.client: OpenAI = client
self.model: str = model
async def _suggest_categories_async(self, texts: List[str], sample_size: int = 20) -> List[str]:
"""Async version of category suggestion"""
# Take a sample of texts to avoid token limitations
if len(texts) > sample_size:
sample_texts: List[str] = random.sample(texts, sample_size)
else:
sample_texts: List[str] = texts
prompt: str = CATEGORY_SUGGESTION_PROMPT.format("\n---\n".join(sample_texts))
try:
# Use the synchronous client method but run it in a thread pool
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
response: Any = await loop.run_in_executor(
None,
lambda: self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.2,
max_tokens=100,
)
)
# Parse response to get categories
categories_text: str = response.choices[0].message.content.strip()
categories: List[str] = [cat.strip() for cat in categories_text.split(",")]
return categories
except Exception as e:
# Fallback to default categories on error
print(f"Error suggesting categories: {str(e)}")
return self._generate_default_categories(texts)
def _generate_default_categories(self, texts: List[str]) -> List[str]:
"""Generate default categories if LLM suggestion fails"""
return ["Positive", "Negative", "Neutral", "Mixed", "Other"]
async def _classify_text_async(self, text: str, categories: List[str]) -> Dict[str, Any]:
"""Async version of text classification"""
prompt: str = TEXT_CLASSIFICATION_PROMPT.format(
categories=", ".join(categories),
text=text
)
try:
# Use the synchronous client method but run it in a thread pool
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
response: Any = await loop.run_in_executor(
None,
lambda: self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0,
max_tokens=200,
)
)
# Parse JSON response
response_text: str = response.choices[0].message.content.strip()
result: Dict[str, Any] = json.loads(response_text)
# Ensure all required fields are present
if not all(k in result for k in ["category", "confidence", "explanation"]):
raise ValueError("Missing required fields in LLM response")
# Validate category is in the list
if result["category"] not in categories:
result["category"] = categories[0] # Default to first category if invalid
# Validate confidence is a number between 0 and 100
try:
result["confidence"] = float(result["confidence"])
if not 0 <= result["confidence"] <= 100:
result["confidence"] = 50
except:
result["confidence"] = 50
return result
except json.JSONDecodeError:
# Fall back to simple parsing if JSON fails
category: str = categories[0] # Default
for cat in categories:
if cat.lower() in response_text.lower():
category = cat
break
return {
"category": category,
"confidence": 50,
"explanation": f"Classification based on language model analysis. (Note: Structured response parsing failed)",
}
except Exception as e:
return {
"category": categories[0],
"confidence": 50,
"explanation": f"Error during classification: {str(e)}",
}
async def classify_async(
self, texts: List[str], categories: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""Async method to classify texts"""
if not categories:
categories = await self._suggest_categories_async(texts)
# Create tasks for all texts
tasks: List[asyncio.Task] = [self._classify_text_async(text, categories) for text in texts]
# Gather all results
results: List[Dict[str, Any]] = await asyncio.gather(*tasks)
return results
def classify(
self, texts: List[str], categories: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""Synchronous wrapper for backwards compatibility"""
return asyncio.run(self.classify_async(texts, categories))
|