File size: 5,218 Bytes
0f1938f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
import random
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Any, Optional
from prompts import CATEGORY_SUGGESTION_PROMPT, TEXT_CLASSIFICATION_PROMPT

from base import BaseClassifier



class LLMClassifier(BaseClassifier):
    """Classifier using a Large Language Model for more accurate but slower classification"""

    def __init__(self, client, model="gpt-3.5-turbo"):
        super().__init__()
        self.client = client
        self.model = model

    def classify(
        self, texts: List[str], categories: Optional[List[str]] = None
    ) -> List[Dict[str, Any]]:
        """Classify texts using an LLM with parallel processing"""
        if not categories:
            # First, use LLM to generate appropriate categories
            categories = self._suggest_categories(texts)

        # Process texts in parallel
        with ThreadPoolExecutor(max_workers=10) as executor:
            # Submit all tasks with their original indices
            future_to_index = {
                executor.submit(self._classify_text, text, categories): idx
                for idx, text in enumerate(texts)
            }

            # Initialize results list with None values
            results = [None] * len(texts)

            # Collect results as they complete
            for future in as_completed(future_to_index):
                original_idx = future_to_index[future]
                try:
                    result = future.result()
                    results[original_idx] = result
                except Exception as e:
                    print(f"Error processing text: {str(e)}")
                    results[original_idx] = {
                        "category": categories[0],
                        "confidence": 50,
                        "explanation": f"Error during classification: {str(e)}",
                    }

        return results

    def _suggest_categories(self, texts: List[str], sample_size: int = 20) -> List[str]:
        """Use LLM to suggest appropriate categories for the dataset"""
        # Take a sample of texts to avoid token limitations
        if len(texts) > sample_size:
            sample_texts = random.sample(texts, sample_size)
        else:
            sample_texts = texts

        prompt = CATEGORY_SUGGESTION_PROMPT.format("\n---\n".join(sample_texts))

        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[{"role": "user", "content": prompt}],
                temperature=0.2,
                max_tokens=100,
            )

            # Parse response to get categories
            categories_text = response.choices[0].message.content.strip()
            categories = [cat.strip() for cat in categories_text.split(",")]

            return categories
        except Exception as e:
            # Fallback to default categories on error
            print(f"Error suggesting categories: {str(e)}")
            return self._generate_default_categories(texts)

    def _classify_text(self, text: str, categories: List[str]) -> Dict[str, Any]:
        """Use LLM to classify a single text"""
        prompt = TEXT_CLASSIFICATION_PROMPT.format(
            categories=", ".join(categories), text=text
        )

        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[{"role": "user", "content": prompt}],
                temperature=0,
                max_tokens=200,
            )

            # Parse JSON response
            response_text = response.choices[0].message.content.strip()

            result = json.loads(response_text)
            # Ensure all required fields are present
            if not all(k in result for k in ["category", "confidence", "explanation"]):
                raise ValueError("Missing required fields in LLM response")

            # Validate category is in the list
            if result["category"] not in categories:
                result["category"] = categories[
                    0
                ]  # Default to first category if invalid

            # Validate confidence is a number between 0 and 100
            try:
                result["confidence"] = float(result["confidence"])
                if not 0 <= result["confidence"] <= 100:
                    result["confidence"] = 50
            except:
                result["confidence"] = 50

            return result
        except json.JSONDecodeError:
            # Fall back to simple parsing if JSON fails
            category = categories[0]  # Default
            for cat in categories:
                if cat.lower() in response_text.lower():
                    category = cat
                    break

            return {
                "category": category,
                "confidence": 50,
                "explanation": f"Classification based on language model analysis. (Note: Structured response parsing failed)",
            }