File size: 5,800 Bytes
0f1938f
 
 
 
 
 
 
36183d4
156898c
e5c1bae
 
535a3a5
e5c1bae
 
 
156898c
0f1938f
720c911
0f1938f
 
 
 
 
535a3a5
0f1938f
535a3a5
 
0f1938f
156898c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36183d4
 
535a3a5
36183d4
 
0f1938f
 
 
e5c1bae
535a3a5
 
e5c1bae
 
 
 
 
 
 
0f1938f
 
 
535a3a5
 
36183d4
0f1938f
 
 
 
 
 
36183d4
0f1938f
 
 
 
 
 
 
 
 
 
 
 
535a3a5
0f1938f
 
 
 
 
 
 
 
 
 
36183d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535a3a5
36183d4
 
535a3a5
36183d4
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
import random
import json
import asyncio
from typing import List, Dict, Any, Optional, Union, Tuple
import sys
import os
from litellm import OpenAI

# Add the project root to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from prompts import CATEGORY_SUGGESTION_PROMPT, TEXT_CLASSIFICATION_PROMPT, ADDITIONAL_CATEGORY_PROMPT

from .base import BaseClassifier


class LLMClassifier(BaseClassifier):
    """Classifier using a Large Language Model for more accurate but slower classification"""

    def __init__(self, client: OpenAI, model: str = "gpt-3.5-turbo") -> None:
        super().__init__()
        self.client: OpenAI = client
        self.model: str = model

    async def _suggest_categories_async(self, texts: List[str], sample_size: int = 20) -> List[str]:
        """Async version of category suggestion"""
        # Take a sample of texts to avoid token limitations
        if len(texts) > sample_size:
            sample_texts: List[str] = random.sample(texts, sample_size)
        else:
            sample_texts: List[str] = texts

        prompt: str = CATEGORY_SUGGESTION_PROMPT.format("\n---\n".join(sample_texts))

        try:
            # Use the synchronous client method but run it in a thread pool
            loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
            response: Any = await loop.run_in_executor(
                None,
                lambda: self.client.chat.completions.create(
                    model=self.model,
                    messages=[{"role": "user", "content": prompt}],
                    temperature=0.2,
                    max_tokens=100,
                )
            )

            # Parse response to get categories
            categories_text: str = response.choices[0].message.content.strip()
            categories: List[str] = [cat.strip() for cat in categories_text.split(",")]

            return categories
        except Exception as e:
            # Fallback to default categories on error
            print(f"Error suggesting categories: {str(e)}")
            return self._generate_default_categories(texts)

    def _generate_default_categories(self, texts: List[str]) -> List[str]:
        """Generate default categories if LLM suggestion fails"""
        return ["Positive", "Negative", "Neutral", "Mixed", "Other"]

    async def _classify_text_async(self, text: str, categories: List[str]) -> Dict[str, Any]:
        """Async version of text classification"""
        prompt: str = TEXT_CLASSIFICATION_PROMPT.format(
            categories=", ".join(categories),
            text=text
        )

        try:
            # Use the synchronous client method but run it in a thread pool
            loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
            response: Any = await loop.run_in_executor(
                None,
                lambda: self.client.chat.completions.create(
                    model=self.model,
                    messages=[{"role": "user", "content": prompt}],
                    temperature=0,
                    max_tokens=200,
                )
            )

            # Parse JSON response
            response_text: str = response.choices[0].message.content.strip()
            result: Dict[str, Any] = json.loads(response_text)

            # Ensure all required fields are present
            if not all(k in result for k in ["category", "confidence", "explanation"]):
                raise ValueError("Missing required fields in LLM response")

            # Validate category is in the list
            if result["category"] not in categories:
                result["category"] = categories[0]  # Default to first category if invalid

            # Validate confidence is a number between 0 and 100
            try:
                result["confidence"] = float(result["confidence"])
                if not 0 <= result["confidence"] <= 100:
                    result["confidence"] = 50
            except:
                result["confidence"] = 50

            return result
        except json.JSONDecodeError:
            # Fall back to simple parsing if JSON fails
            category: str = categories[0]  # Default
            for cat in categories:
                if cat.lower() in response_text.lower():
                    category = cat
                    break

            return {
                "category": category,
                "confidence": 50,
                "explanation": f"Classification based on language model analysis. (Note: Structured response parsing failed)",
            }
        except Exception as e:
            return {
                "category": categories[0],
                "confidence": 50,
                "explanation": f"Error during classification: {str(e)}",
            }

    async def classify_async(
        self, texts: List[str], categories: Optional[List[str]] = None
    ) -> List[Dict[str, Any]]:
        """Async method to classify texts"""
        if not categories:
            categories = await self._suggest_categories_async(texts)

        # Create tasks for all texts
        tasks: List[asyncio.Task] = [self._classify_text_async(text, categories) for text in texts]
        
        # Gather all results
        results: List[Dict[str, Any]] = await asyncio.gather(*tasks)
        return results

    def classify(
        self, texts: List[str], categories: Optional[List[str]] = None
    ) -> List[Dict[str, Any]]:
        """Synchronous wrapper for backwards compatibility"""
        return asyncio.run(self.classify_async(texts, categories))