Spaces:
Sleeping
Sleeping
File size: 5,218 Bytes
0f1938f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
import random
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Any, Optional
from prompts import CATEGORY_SUGGESTION_PROMPT, TEXT_CLASSIFICATION_PROMPT
from base import BaseClassifier
class LLMClassifier(BaseClassifier):
"""Classifier using a Large Language Model for more accurate but slower classification"""
def __init__(self, client, model="gpt-3.5-turbo"):
super().__init__()
self.client = client
self.model = model
def classify(
self, texts: List[str], categories: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""Classify texts using an LLM with parallel processing"""
if not categories:
# First, use LLM to generate appropriate categories
categories = self._suggest_categories(texts)
# Process texts in parallel
with ThreadPoolExecutor(max_workers=10) as executor:
# Submit all tasks with their original indices
future_to_index = {
executor.submit(self._classify_text, text, categories): idx
for idx, text in enumerate(texts)
}
# Initialize results list with None values
results = [None] * len(texts)
# Collect results as they complete
for future in as_completed(future_to_index):
original_idx = future_to_index[future]
try:
result = future.result()
results[original_idx] = result
except Exception as e:
print(f"Error processing text: {str(e)}")
results[original_idx] = {
"category": categories[0],
"confidence": 50,
"explanation": f"Error during classification: {str(e)}",
}
return results
def _suggest_categories(self, texts: List[str], sample_size: int = 20) -> List[str]:
"""Use LLM to suggest appropriate categories for the dataset"""
# Take a sample of texts to avoid token limitations
if len(texts) > sample_size:
sample_texts = random.sample(texts, sample_size)
else:
sample_texts = texts
prompt = CATEGORY_SUGGESTION_PROMPT.format("\n---\n".join(sample_texts))
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.2,
max_tokens=100,
)
# Parse response to get categories
categories_text = response.choices[0].message.content.strip()
categories = [cat.strip() for cat in categories_text.split(",")]
return categories
except Exception as e:
# Fallback to default categories on error
print(f"Error suggesting categories: {str(e)}")
return self._generate_default_categories(texts)
def _classify_text(self, text: str, categories: List[str]) -> Dict[str, Any]:
"""Use LLM to classify a single text"""
prompt = TEXT_CLASSIFICATION_PROMPT.format(
categories=", ".join(categories), text=text
)
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0,
max_tokens=200,
)
# Parse JSON response
response_text = response.choices[0].message.content.strip()
result = json.loads(response_text)
# Ensure all required fields are present
if not all(k in result for k in ["category", "confidence", "explanation"]):
raise ValueError("Missing required fields in LLM response")
# Validate category is in the list
if result["category"] not in categories:
result["category"] = categories[
0
] # Default to first category if invalid
# Validate confidence is a number between 0 and 100
try:
result["confidence"] = float(result["confidence"])
if not 0 <= result["confidence"] <= 100:
result["confidence"] = 50
except:
result["confidence"] = 50
return result
except json.JSONDecodeError:
# Fall back to simple parsing if JSON fails
category = categories[0] # Default
for cat in categories:
if cat.lower() in response_text.lower():
category = cat
break
return {
"category": category,
"confidence": 50,
"explanation": f"Classification based on language model analysis. (Note: Structured response parsing failed)",
}
|