simondh commited on
Commit
1bc76b5
·
0 Parent(s):

first commit

Browse files
Files changed (6) hide show
  1. .gitignore +2 -0
  2. app.py +586 -0
  3. classifiers.py +256 -0
  4. examples/sample_reviews.csv +11 -0
  5. requirements.txt +9 -0
  6. utils.py +188 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .env
2
+ *.pyc
app.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import numpy as np
5
+ from litellm import OpenAI
6
+ import json
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.cluster import KMeans
9
+ from sklearn.decomposition import PCA
10
+ import matplotlib.pyplot as plt
11
+ import time
12
+ import torch
13
+ import traceback
14
+ import logging
15
+
16
+ # Import local modules
17
+ from classifiers import TFIDFClassifier, LLMClassifier
18
+ from utils import load_data, export_data, visualize_results, validate_results
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO,
22
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
23
+
24
+ # Initialize API key from environment variable
25
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
26
+
27
+ # Only initialize client if API key is available
28
+ client = None
29
+ if OPENAI_API_KEY:
30
+ try:
31
+ client = OpenAI(api_key=OPENAI_API_KEY)
32
+ logging.info("OpenAI client initialized successfully")
33
+ except Exception as e:
34
+ logging.error(f"Failed to initialize OpenAI client: {str(e)}")
35
+
36
+ def update_api_key(api_key):
37
+ """Update the OpenAI API key"""
38
+ global OPENAI_API_KEY, client
39
+
40
+ if not api_key:
41
+ return "API Key cannot be empty"
42
+
43
+ OPENAI_API_KEY = api_key
44
+
45
+ try:
46
+ client = OpenAI(api_key=api_key)
47
+ # Test the connection with a simple request
48
+ response = client.chat.completions.create(
49
+ model="gpt-3.5-turbo",
50
+ messages=[{"role": "user", "content": "test"}],
51
+ max_tokens=5
52
+ )
53
+ return f"API Key updated and verified successfully"
54
+ except Exception as e:
55
+ error_msg = str(e)
56
+ logging.error(f"API key update failed: {error_msg}")
57
+ return f"Failed to update API Key: {error_msg}"
58
+
59
+ def process_file(file, text_columns, categories, classifier_type, show_explanations):
60
+ """Process the uploaded file and classify text data"""
61
+ try:
62
+ # Load data from file
63
+ if isinstance(file, str):
64
+ df = load_data(file)
65
+ else:
66
+ df = load_data(file.name)
67
+
68
+ if not text_columns:
69
+ return None, "Please select at least one text column"
70
+
71
+ # Check if all selected columns exist
72
+ missing_columns = [col for col in text_columns if col not in df.columns]
73
+ if missing_columns:
74
+ return None, f"Columns not found in the file: {', '.join(missing_columns)}. Available columns: {', '.join(df.columns)}"
75
+
76
+ # Combine text from selected columns
77
+ texts = []
78
+ for _, row in df.iterrows():
79
+ combined_text = " ".join(str(row[col]) for col in text_columns)
80
+ texts.append(combined_text)
81
+
82
+ # Parse categories if provided
83
+ category_list = []
84
+ if categories:
85
+ category_list = [cat.strip() for cat in categories.split(",")]
86
+
87
+ # Select classifier based on data size and user choice
88
+ num_texts = len(texts)
89
+
90
+ # If no specific model is chosen, select the most appropriate one
91
+ if classifier_type == "auto":
92
+ if num_texts <= 500:
93
+ classifier_type = "gpt4"
94
+ elif num_texts <= 1000:
95
+ classifier_type = "gpt35"
96
+ elif num_texts <= 5000:
97
+ classifier_type = "hybrid"
98
+ else:
99
+ classifier_type = "tfidf"
100
+
101
+ # Initialize appropriate classifier
102
+ if classifier_type == "tfidf":
103
+ classifier = TFIDFClassifier()
104
+ results = classifier.classify(texts, category_list)
105
+ elif classifier_type == "gpt35":
106
+ if client is None:
107
+ return None, "Erreur : Le client API n'est pas initialisé. Veuillez configurer une clé API valide dans l'onglet 'Setup'."
108
+ classifier = LLMClassifier(client=client, model="gpt-3.5-turbo")
109
+ results = classifier.classify(texts, category_list)
110
+ elif classifier_type == "gpt4":
111
+ if client is None:
112
+ return None, "Erreur : Le client API n'est pas initialisé. Veuillez configurer une clé API valide dans l'onglet 'Setup'."
113
+ classifier = LLMClassifier(client=client, model="gpt-4")
114
+ results = classifier.classify(texts, category_list)
115
+ else: # hybrid
116
+ if client is None:
117
+ return None, "Erreur : Le client API n'est pas initialisé. Veuillez configurer une clé API valide dans l'onglet 'Setup'."
118
+ # First pass with TF-IDF
119
+ tfidf_classifier = TFIDFClassifier()
120
+ tfidf_results = tfidf_classifier.classify(texts, category_list)
121
+
122
+ # Second pass with LLM for low confidence results
123
+ llm_classifier = LLMClassifier(client=client, model="gpt-3.5-turbo")
124
+ results = []
125
+ for i, (text, tfidf_result) in enumerate(zip(texts, tfidf_results)):
126
+ if tfidf_result["confidence"] < 70: # If confidence is below 70%
127
+ llm_result = llm_classifier.classify([text], category_list)[0]
128
+ results.append(llm_result)
129
+ else:
130
+ results.append(tfidf_result)
131
+
132
+ # Create results dataframe
133
+ result_df = df.copy()
134
+ result_df["Category"] = [r["category"] for r in results]
135
+ result_df["Confidence"] = [r["confidence"] for r in results]
136
+
137
+ if show_explanations:
138
+ result_df["Explanation"] = [r["explanation"] for r in results]
139
+
140
+ # Validate results using LLM
141
+ validation_report = validate_results(result_df, text_columns, client)
142
+
143
+ return result_df, validation_report
144
+
145
+ except Exception as e:
146
+ error_traceback = traceback.format_exc()
147
+ return None, f"Error: {str(e)}\n{error_traceback}"
148
+
149
+ def export_results(df, format_type):
150
+ """Export results to a file and return the file path for download"""
151
+ if df is None:
152
+ return None
153
+
154
+ # Create a temporary file
155
+ import tempfile
156
+ import os
157
+
158
+ # Create a temporary directory if it doesn't exist
159
+ temp_dir = "temp_exports"
160
+ os.makedirs(temp_dir, exist_ok=True)
161
+
162
+ # Generate a unique filename
163
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
164
+ filename = f"classification_results_{timestamp}"
165
+
166
+ if format_type == "excel":
167
+ file_path = os.path.join(temp_dir, f"{filename}.xlsx")
168
+ df.to_excel(file_path, index=False)
169
+ else:
170
+ file_path = os.path.join(temp_dir, f"{filename}.csv")
171
+ df.to_csv(file_path, index=False)
172
+
173
+ return file_path
174
+
175
+ # Create Gradio interface
176
+ with gr.Blocks(title="Text Classification System") as demo:
177
+ gr.Markdown("# Text Classification System")
178
+ gr.Markdown("Upload your data file (Excel/CSV) and classify text using AI")
179
+
180
+ with gr.Tab("Setup"):
181
+ api_key_input = gr.Textbox(
182
+ label="OpenAI API Key",
183
+ placeholder="Enter your API key here",
184
+ type="password",
185
+ value=OPENAI_API_KEY
186
+ )
187
+ api_key_button = gr.Button("Update API Key")
188
+ api_key_message = gr.Textbox(label="Status", interactive=False)
189
+
190
+ # Display current API status
191
+ api_status = "API Key is set" if OPENAI_API_KEY else "No API Key found. Please set one."
192
+ gr.Markdown(f"**Current API Status**: {api_status}")
193
+
194
+ api_key_button.click(update_api_key, inputs=[api_key_input], outputs=[api_key_message])
195
+
196
+ with gr.Tab("Classify Data"):
197
+ with gr.Column():
198
+ file_input = gr.File(label="Upload Excel/CSV File")
199
+
200
+ # Variable to store available columns
201
+ available_columns = gr.State([])
202
+
203
+ # Button to load file and suggest categories
204
+ load_categories_button = gr.Button("Load File")
205
+
206
+ # Display original dataframe
207
+ original_df = gr.Dataframe(
208
+ label="Original Data",
209
+ interactive=False,
210
+ visible=False
211
+ )
212
+
213
+ with gr.Row():
214
+ with gr.Column():
215
+ suggested_categories = gr.CheckboxGroup(
216
+ label="Suggested Categories",
217
+ choices=[],
218
+ value=[],
219
+ interactive=True,
220
+ visible=False
221
+ )
222
+
223
+ new_category = gr.Textbox(
224
+ label="Add New Category",
225
+ placeholder="Enter a new category name",
226
+ visible=False
227
+ )
228
+ with gr.Row():
229
+ add_category_button = gr.Button("Add Category", visible=False)
230
+ suggest_category_button = gr.Button("Suggest Category", visible=False)
231
+
232
+
233
+ # Original categories input (hidden)
234
+ categories = gr.Textbox(
235
+ visible=False
236
+ )
237
+
238
+
239
+ with gr.Column():
240
+ text_column = gr.CheckboxGroup(
241
+ label="Select Text Columns",
242
+ choices=[],
243
+ interactive=True,
244
+ visible=False
245
+ )
246
+
247
+ classifier_type = gr.Dropdown(
248
+ choices=[
249
+ ("TF-IDF (Rapide, <1000 lignes)", "tfidf"),
250
+ ("LLM GPT-3.5 (Fiable, <1000 lignes)", "gpt35"),
251
+ ("LLM GPT-4 (Très fiable, <500 lignes)", "gpt4"),
252
+ ("TF-IDF + LLM (Hybride, >1000 lignes)", "hybrid")
253
+ ],
254
+ label="Modèle de classification",
255
+ value="tfidf",
256
+ visible=False
257
+ )
258
+ show_explanations = gr.Checkbox(label="Show Explanations", value=True, visible=False)
259
+
260
+ process_button = gr.Button("Process and Classify", visible=False)
261
+
262
+
263
+
264
+ results_df = gr.Dataframe(interactive=True, visible=False)
265
+
266
+
267
+
268
+ # Create containers for visualization and validation report
269
+ with gr.Row(visible=False) as results_row:
270
+ with gr.Column():
271
+ visualization = gr.Plot(label="Classification Distribution")
272
+ with gr.Row():
273
+ csv_download = gr.File(label="Download CSV", visible=False)
274
+ excel_download = gr.File(label="Download Excel", visible=False)
275
+ with gr.Column():
276
+ validation_output = gr.Textbox(label="Validation Report", interactive=False)
277
+ improve_button = gr.Button("Improve Classification with Report", visible=False)
278
+
279
+
280
+ # Function to load file and suggest categories
281
+ def load_file_and_suggest_categories(file):
282
+ if not file:
283
+ return [], gr.CheckboxGroup(choices=[]), gr.CheckboxGroup(choices=[], visible=False), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False), gr.CheckboxGroup(choices=[], visible=False), gr.Dropdown(visible=False), gr.Checkbox(visible=False), gr.Button(visible=False), gr.Dataframe(visible=False)
284
+ try:
285
+ df = load_data(file.name)
286
+ columns = list(df.columns)
287
+
288
+ # Analyze columns to suggest text columns
289
+ suggested_text_columns = []
290
+ for col in columns:
291
+ # Check if column contains text data
292
+ if df[col].dtype == 'object': # String type
293
+ # Check if column contains mostly text (not just numbers or dates)
294
+ sample = df[col].head(100).dropna()
295
+ if len(sample) > 0:
296
+ # Check if most values contain spaces (indicating text)
297
+ text_ratio = sum(' ' in str(val) for val in sample) / len(sample)
298
+ if text_ratio > 0.3: # If more than 30% of values contain spaces
299
+ suggested_text_columns.append(col)
300
+
301
+ # If no columns were suggested, use all object columns
302
+ if not suggested_text_columns:
303
+ suggested_text_columns = [col for col in columns if df[col].dtype == 'object']
304
+
305
+ # Get a sample of text for category suggestion
306
+ sample_texts = []
307
+ for col in suggested_text_columns:
308
+ sample_texts.extend(df[col].head(5).tolist())
309
+
310
+ # Use LLM to suggest categories
311
+ if client:
312
+ prompt = f"""
313
+ Based on these example texts, suggest 5 appropriate categories for classification:
314
+
315
+ {sample_texts[:5]}
316
+
317
+ Return your answer as a comma-separated list of category names only.
318
+ """
319
+ try:
320
+ response = client.chat.completions.create(
321
+ model="gpt-3.5-turbo",
322
+ messages=[{"role": "user", "content": prompt}],
323
+ temperature=0.2,
324
+ max_tokens=100
325
+ )
326
+ suggested_cats = [cat.strip() for cat in response.choices[0].message.content.strip().split(",")]
327
+ except:
328
+ suggested_cats = ["Positive", "Negative", "Neutral", "Mixed", "Other"]
329
+ else:
330
+ suggested_cats = ["Positive", "Negative", "Neutral", "Mixed", "Other"]
331
+
332
+ return (
333
+ columns,
334
+ gr.CheckboxGroup(choices=columns, value=suggested_text_columns),
335
+ gr.CheckboxGroup(choices=suggested_cats, value=suggested_cats, visible=True),
336
+ gr.Textbox(visible=True),
337
+ gr.Button(visible=True),
338
+ gr.Button(visible=True),
339
+ gr.CheckboxGroup(choices=columns, value=suggested_text_columns, visible=True),
340
+ gr.Dropdown(visible=True),
341
+ gr.Checkbox(visible=True),
342
+ gr.Button(visible=True),
343
+ gr.Dataframe(value=df, visible=True)
344
+ )
345
+ except Exception as e:
346
+ return [], gr.CheckboxGroup(choices=[]), gr.CheckboxGroup(choices=[], visible=False), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False), gr.CheckboxGroup(choices=[], visible=False), gr.Dropdown(visible=False), gr.Checkbox(visible=False), gr.Button(visible=False), gr.Dataframe(visible=False)
347
+
348
+ # Function to add a new category
349
+ def add_new_category(current_categories, new_category):
350
+ if not new_category or new_category.strip() == "":
351
+ return current_categories
352
+ new_categories = current_categories + [new_category.strip()]
353
+ return gr.CheckboxGroup(choices=new_categories, value=new_categories)
354
+
355
+ # Function to update categories textbox
356
+ def update_categories_textbox(selected_categories):
357
+ return ", ".join(selected_categories)
358
+
359
+ # Function to show results after processing
360
+ def show_results(df, validation_report):
361
+ if df is None:
362
+ return gr.Row(visible=False), gr.File(visible=False), gr.File(visible=False), gr.Dataframe(visible=False)
363
+
364
+ # Export to both formats
365
+ csv_path = export_results(df, "csv")
366
+ excel_path = export_results(df, "excel")
367
+
368
+ return gr.Row(visible=True), gr.File(value=csv_path, visible=True), gr.File(value=excel_path, visible=True), gr.Dataframe(value=df, visible=True)
369
+
370
+ # Function to suggest a new category
371
+ def suggest_new_category(file, current_categories, text_columns):
372
+ if not file or not text_columns:
373
+ return gr.CheckboxGroup(choices=current_categories, value=current_categories)
374
+
375
+ try:
376
+ df = load_data(file.name)
377
+
378
+ # Get sample texts from selected columns
379
+ sample_texts = []
380
+ for col in text_columns:
381
+ sample_texts.extend(df[col].head(5).tolist())
382
+
383
+ if client:
384
+ prompt = f"""
385
+ Based on these example texts and the existing categories ({', '.join(current_categories)}),
386
+ suggest one additional appropriate category for classification.
387
+
388
+ Example texts:
389
+ {sample_texts[:5]}
390
+
391
+ Return only the suggested category name, nothing else.
392
+ """
393
+ try:
394
+ response = client.chat.completions.create(
395
+ model="gpt-3.5-turbo",
396
+ messages=[{"role": "user", "content": prompt}],
397
+ temperature=0.2,
398
+ max_tokens=50
399
+ )
400
+ new_cat = response.choices[0].message.content.strip()
401
+ if new_cat and new_cat not in current_categories:
402
+ current_categories.append(new_cat)
403
+ except:
404
+ pass
405
+
406
+ return gr.CheckboxGroup(choices=current_categories, value=current_categories)
407
+ except Exception as e:
408
+ return gr.CheckboxGroup(choices=current_categories, value=current_categories)
409
+
410
+ # Function to handle export and show download button
411
+ def handle_export(df, format_type):
412
+ if df is None:
413
+ return gr.File(visible=False)
414
+ file_path = export_results(df, format_type)
415
+ return gr.File(value=file_path, visible=True)
416
+
417
+ # Function to improve classification based on validation report
418
+ def improve_classification(df, validation_report, text_columns, categories, classifier_type, show_explanations, file):
419
+ """Improve classification based on validation report"""
420
+ if df is None or not validation_report:
421
+ return df, validation_report, gr.Button(visible=False), gr.CheckboxGroup(choices=[], value=[])
422
+
423
+ try:
424
+ # Extract insights from validation report
425
+ if client:
426
+ prompt = f"""
427
+ Based on this validation report, analyze the current classification and suggest improvements:
428
+
429
+ {validation_report}
430
+
431
+ Return your answer in JSON format with these fields:
432
+ - suggested_categories: list of improved category names (must be different from current categories: {categories})
433
+ - confidence_threshold: a number between 0 and 100 for minimum confidence
434
+ - focus_areas: list of specific aspects to focus on during classification
435
+ - analysis: a brief analysis of what needs improvement
436
+ - new_categories_needed: boolean indicating if new categories should be added
437
+
438
+ JSON response:
439
+ """
440
+ try:
441
+ response = client.chat.completions.create(
442
+ model="gpt-4",
443
+ messages=[{"role": "user", "content": prompt}],
444
+ temperature=0.2,
445
+ max_tokens=300
446
+ )
447
+ improvements = json.loads(response.choices[0].message.content.strip())
448
+
449
+ # Get current categories
450
+ current_categories = [cat.strip() for cat in categories.split(",")]
451
+
452
+ # If new categories are needed, suggest them based on the data
453
+ if improvements.get("new_categories_needed", False):
454
+ # Get sample texts for category suggestion
455
+ sample_texts = []
456
+ for col in text_columns:
457
+ if isinstance(file, str):
458
+ temp_df = load_data(file)
459
+ else:
460
+ temp_df = load_data(file.name)
461
+ sample_texts.extend(temp_df[col].head(5).tolist())
462
+
463
+ category_prompt = f"""
464
+ Based on these example texts and the current categories ({', '.join(current_categories)}),
465
+ suggest new categories that would improve the classification. The validation report indicates:
466
+ {improvements.get('analysis', '')}
467
+
468
+ Example texts:
469
+ {sample_texts[:5]}
470
+
471
+ Return your answer as a comma-separated list of new category names only.
472
+ """
473
+
474
+ category_response = client.chat.completions.create(
475
+ model="gpt-4",
476
+ messages=[{"role": "user", "content": category_prompt}],
477
+ temperature=0.2,
478
+ max_tokens=100
479
+ )
480
+
481
+ new_categories = [cat.strip() for cat in category_response.choices[0].message.content.strip().split(",")]
482
+ # Combine current and new categories
483
+ all_categories = current_categories + new_categories
484
+ categories = ",".join(all_categories)
485
+
486
+ # Process with improved parameters
487
+ improved_df, new_validation = process_file(
488
+ file,
489
+ text_columns,
490
+ categories,
491
+ classifier_type,
492
+ show_explanations
493
+ )
494
+
495
+ return improved_df, new_validation, gr.Button(visible=True), gr.CheckboxGroup(choices=all_categories, value=all_categories)
496
+ except Exception as e:
497
+ print(f"Error in improvement process: {str(e)}")
498
+ return df, validation_report, gr.Button(visible=True), gr.CheckboxGroup(choices=current_categories, value=current_categories)
499
+ else:
500
+ return df, validation_report, gr.Button(visible=True), gr.CheckboxGroup(choices=current_categories, value=current_categories)
501
+ except Exception as e:
502
+ print(f"Error in improvement process: {str(e)}")
503
+ return df, validation_report, gr.Button(visible=True), gr.CheckboxGroup(choices=current_categories, value=current_categories)
504
+
505
+ # Connect functions
506
+ load_categories_button.click(
507
+ load_file_and_suggest_categories,
508
+ inputs=[file_input],
509
+ outputs=[
510
+ available_columns,
511
+ text_column,
512
+ suggested_categories,
513
+ new_category,
514
+ add_category_button,
515
+ suggest_category_button,
516
+ text_column,
517
+ classifier_type,
518
+ show_explanations,
519
+ process_button,
520
+ original_df
521
+ ]
522
+ )
523
+
524
+ add_category_button.click(
525
+ add_new_category,
526
+ inputs=[suggested_categories, new_category],
527
+ outputs=[suggested_categories]
528
+ )
529
+
530
+ suggested_categories.change(
531
+ update_categories_textbox,
532
+ inputs=[suggested_categories],
533
+ outputs=[categories]
534
+ )
535
+
536
+ suggest_category_button.click(
537
+ suggest_new_category,
538
+ inputs=[file_input, suggested_categories, text_column],
539
+ outputs=[suggested_categories]
540
+ )
541
+
542
+ process_button.click(
543
+ process_file,
544
+ inputs=[file_input, text_column, categories, classifier_type, show_explanations],
545
+ outputs=[results_df, validation_output]
546
+ ).then(
547
+ show_results,
548
+ inputs=[results_df, validation_output],
549
+ outputs=[results_row, csv_download, excel_download, results_df]
550
+ ).then(
551
+ visualize_results,
552
+ inputs=[results_df, text_column],
553
+ outputs=[visualization]
554
+ ).then(
555
+ lambda x: gr.Button(visible=True),
556
+ inputs=[],
557
+ outputs=[improve_button]
558
+ )
559
+
560
+ improve_button.click(
561
+ improve_classification,
562
+ inputs=[results_df, validation_output, text_column, categories, classifier_type, show_explanations, file_input],
563
+ outputs=[results_df, validation_output, improve_button, suggested_categories]
564
+ ).then(
565
+ show_results,
566
+ inputs=[results_df, validation_output],
567
+ outputs=[results_row, csv_download, excel_download, results_df]
568
+ ).then(
569
+ visualize_results,
570
+ inputs=[results_df, text_column],
571
+ outputs=[visualization]
572
+ )
573
+
574
+ def create_example_data():
575
+ """Create example data for demonstration"""
576
+ from utils import create_example_file
577
+ example_path = create_example_file()
578
+ return f"Example file created at: {example_path}"
579
+
580
+ if __name__ == "__main__":
581
+ # Create examples directory and sample file if it doesn't exist
582
+ if not os.path.exists("examples"):
583
+ create_example_data()
584
+
585
+ # Launch the Gradio app
586
+ demo.launch()
classifiers.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from sklearn.feature_extraction.text import TfidfVectorizer
4
+ from sklearn.cluster import KMeans
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ import random
7
+ import json
8
+
9
+ class BaseClassifier:
10
+ """Base class for text classifiers"""
11
+ def __init__(self):
12
+ pass
13
+
14
+ def classify(self, texts, categories=None):
15
+ """
16
+ Classify a list of texts into categories
17
+
18
+ Args:
19
+ texts (list): List of text strings to classify
20
+ categories (list, optional): List of category names. If None, categories will be auto-detected
21
+
22
+ Returns:
23
+ list: List of classification results with categories, confidence scores, and explanations
24
+ """
25
+ raise NotImplementedError("Subclasses must implement this method")
26
+
27
+ def _generate_default_categories(self, texts, num_clusters=5):
28
+ """
29
+ Generate default categories based on text clustering
30
+
31
+ Args:
32
+ texts (list): List of text strings
33
+ num_clusters (int): Number of clusters to generate
34
+
35
+ Returns:
36
+ list: List of category names
37
+ """
38
+ # Simple implementation - in real system this would be more sophisticated
39
+ default_categories = [f"Category {i+1}" for i in range(num_clusters)]
40
+ return default_categories
41
+
42
+
43
+ class TFIDFClassifier(BaseClassifier):
44
+ """Classifier using TF-IDF and clustering for fast classification"""
45
+
46
+ def __init__(self):
47
+ super().__init__()
48
+ self.vectorizer = TfidfVectorizer(
49
+ max_features=1000,
50
+ stop_words='english',
51
+ ngram_range=(1, 2)
52
+ )
53
+ self.model = None
54
+ self.feature_names = None
55
+ self.categories = None
56
+ self.centroids = None
57
+
58
+ def classify(self, texts, categories=None):
59
+ """Classify texts using TF-IDF and clustering"""
60
+ # Vectorize the texts
61
+ X = self.vectorizer.fit_transform(texts)
62
+ self.feature_names = self.vectorizer.get_feature_names_out()
63
+
64
+ # Auto-detect categories if not provided
65
+ if not categories:
66
+ num_clusters = min(5, len(texts)) # Don't create more clusters than texts
67
+ self.categories = self._generate_default_categories(texts, num_clusters)
68
+ else:
69
+ self.categories = categories
70
+ num_clusters = len(categories)
71
+
72
+ # Cluster the texts
73
+ self.model = KMeans(n_clusters=num_clusters, random_state=42)
74
+ clusters = self.model.fit_predict(X)
75
+ self.centroids = self.model.cluster_centers_
76
+
77
+ # Calculate distances to centroids for confidence
78
+ distances = self._calculate_distances(X)
79
+
80
+ # Prepare results
81
+ results = []
82
+ for i, text in enumerate(texts):
83
+ cluster_idx = clusters[i]
84
+
85
+ # Calculate confidence (inverse of distance, normalized)
86
+ confidence = self._calculate_confidence(distances[i])
87
+
88
+ # Create explanation
89
+ explanation = self._generate_explanation(X[i], cluster_idx)
90
+
91
+ results.append({
92
+ "category": self.categories[cluster_idx],
93
+ "confidence": confidence,
94
+ "explanation": explanation
95
+ })
96
+
97
+ return results
98
+
99
+ def _calculate_distances(self, X):
100
+ """Calculate distances from each point to each centroid"""
101
+ return np.sqrt(((X.toarray()[:, np.newaxis, :] - self.centroids[np.newaxis, :, :]) ** 2).sum(axis=2))
102
+
103
+ def _calculate_confidence(self, distances):
104
+ """Convert distances to confidence scores (0-100)"""
105
+ min_dist = np.min(distances)
106
+ max_dist = np.max(distances)
107
+
108
+ # Normalize and invert (smaller distance = higher confidence)
109
+ if max_dist == min_dist:
110
+ return 70 # Default mid-range confidence when all distances are equal
111
+
112
+ normalized_dist = (distances - min_dist) / (max_dist - min_dist)
113
+ min_normalized = np.min(normalized_dist)
114
+
115
+ # Invert and scale to 50-100 range (TF-IDF is never 100% confident)
116
+ confidence = 100 - (min_normalized * 50)
117
+ return round(confidence, 1)
118
+
119
+ def _generate_explanation(self, text_vector, cluster_idx):
120
+ """Generate an explanation for the classification"""
121
+ # Get the most important features for this cluster
122
+ centroid = self.centroids[cluster_idx]
123
+
124
+ # Get indices of top features for this text
125
+ text_array = text_vector.toarray()[0]
126
+ top_indices = text_array.argsort()[-5:][::-1]
127
+
128
+ # Get the feature names for these indices
129
+ top_features = [self.feature_names[i] for i in top_indices if text_array[i] > 0]
130
+
131
+ if not top_features:
132
+ return "No significant features identified for this classification."
133
+
134
+ explanation = f"Classification based on key terms: {', '.join(top_features)}"
135
+ return explanation
136
+
137
+
138
+ class LLMClassifier(BaseClassifier):
139
+ """Classifier using a Large Language Model for more accurate but slower classification"""
140
+
141
+ def __init__(self, client, model="gpt-3.5-turbo"):
142
+ super().__init__()
143
+ self.client = client
144
+ self.model = model
145
+
146
+ def classify(self, texts, categories=None):
147
+ """Classify texts using an LLM"""
148
+ if not categories:
149
+ # First, use LLM to generate appropriate categories
150
+ categories = self._suggest_categories(texts)
151
+
152
+ results = []
153
+ for text in texts:
154
+ # Classify each text individually
155
+ result = self._classify_text(text, categories)
156
+ results.append(result)
157
+
158
+ return results
159
+
160
+ def _suggest_categories(self, texts, sample_size=20):
161
+ """Use LLM to suggest appropriate categories for the dataset"""
162
+ # Take a sample of texts to avoid token limitations
163
+ if len(texts) > sample_size:
164
+ sample_texts = random.sample(texts, sample_size)
165
+ else:
166
+ sample_texts = texts
167
+
168
+ prompt = """
169
+ I have a collection of texts that I need to classify into categories. Here are some examples:
170
+
171
+ {}
172
+
173
+ Based on these examples, suggest up 2 to 5 appropriate categories for classification.
174
+ Return your answer as a comma-separated list of category names only.
175
+ """.format("\n---\n".join(sample_texts))
176
+
177
+ try:
178
+ response = self.client.chat.completions.create(
179
+ model=self.model,
180
+ messages=[{"role": "user", "content": prompt}],
181
+ temperature=0.2,
182
+ max_tokens=100
183
+ )
184
+
185
+ # Parse response to get categories
186
+ categories_text = response.choices[0].message.content.strip()
187
+ categories = [cat.strip() for cat in categories_text.split(",")]
188
+
189
+ return categories
190
+ except Exception as e:
191
+ # Fallback to default categories on error
192
+ print(f"Error suggesting categories: {str(e)}")
193
+ return self._generate_default_categories(texts)
194
+
195
+ def _classify_text(self, text, categories):
196
+ """Use LLM to classify a single text"""
197
+ categories_str = ", ".join(categories)
198
+
199
+ prompt = f"""
200
+ Classify the following text into one of these categories: {categories_str}
201
+
202
+ Text: {text}
203
+
204
+ Return your answer in JSON format with these fields:
205
+ - category: the chosen category from the list
206
+ - confidence: a value between 0 and 100 indicating your confidence in this classification (as a percentage)
207
+ - explanation: a brief explanation of why this category was chosen (1-2 sentences)
208
+
209
+ JSON response:
210
+ """
211
+
212
+ try:
213
+ response = self.client.chat.completions.create(
214
+ model=self.model,
215
+ messages=[{"role": "user", "content": prompt}],
216
+ temperature=0,
217
+ max_tokens=200
218
+ )
219
+
220
+ # Parse JSON response
221
+ response_text = response.choices[0].message.content.strip()
222
+
223
+ result = json.loads(response_text)
224
+ # Ensure all required fields are present
225
+ if not all(k in result for k in ["category", "confidence", "explanation"]):
226
+ raise ValueError("Missing required fields in LLM response")
227
+
228
+ # Validate category is in the list
229
+ if result["category"] not in categories:
230
+ result["category"] = categories[0] # Default to first category if invalid
231
+
232
+ # Validate confidence is a number between 0 and 100
233
+ try:
234
+ result["confidence"] = float(result["confidence"])
235
+ if not 0 <= result["confidence"] <= 100:
236
+ result["confidence"] = 50
237
+ except:
238
+ result["confidence"] = 50
239
+
240
+ return result
241
+ except json.JSONDecodeError:
242
+ # Fall back to simple parsing if JSON fails
243
+ category = categories[0] # Default
244
+ for cat in categories:
245
+ if cat.lower() in response_text.lower():
246
+ category = cat
247
+ break
248
+
249
+ return {
250
+ "category": category,
251
+ "confidence": 50,
252
+ "explanation": f"Classification based on language model analysis. (Note: Structured response parsing failed)"
253
+ }
254
+
255
+
256
+
examples/sample_reviews.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ text
2
+ "I absolutely love this product! It exceeded all my expectations."
3
+ "The service was terrible and the staff was rude."
4
+ "The product arrived on time but was slightly damaged."
5
+ "I have mixed feelings about this. Some features are great, others not so much."
6
+ "This is a complete waste of money. Do not buy!"
7
+ "The customer service team was very helpful in resolving my issue."
8
+ "It's okay, nothing special but gets the job done."
9
+ "I'm extremely disappointed with the quality of this product."
10
+ "This is the best purchase I've made all year!"
11
+ "It's reasonably priced and works as expected."
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ litellm>=1.10.0
3
+ pandas>=2.0.0
4
+ numpy>=1.24.0
5
+ scikit-learn>=1.2.0
6
+ openpyxl>=3.1.0
7
+ torch>=2.0.0
8
+ transformers>=4.30.0
9
+ matplotlib>=3.7.0
utils.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from sklearn.decomposition import PCA
6
+ from sklearn.feature_extraction.text import TfidfVectorizer
7
+ import tempfile
8
+
9
+ def load_data(file_path):
10
+ """
11
+ Load data from an Excel or CSV file
12
+
13
+ Args:
14
+ file_path (str): Path to the file
15
+
16
+ Returns:
17
+ pd.DataFrame: Loaded data
18
+ """
19
+ file_ext = os.path.splitext(file_path)[1].lower()
20
+
21
+ if file_ext == '.xlsx' or file_ext == '.xls':
22
+ return pd.read_excel(file_path)
23
+ elif file_ext == '.csv':
24
+ return pd.read_csv(file_path)
25
+ else:
26
+ raise ValueError(f"Unsupported file format: {file_ext}. Please upload an Excel or CSV file.")
27
+
28
+ def export_data(df, file_name, format_type="excel"):
29
+ """
30
+ Export dataframe to file
31
+
32
+ Args:
33
+ df (pd.DataFrame): Dataframe to export
34
+ file_name (str): Name of the output file
35
+ format_type (str): "excel" or "csv"
36
+
37
+ Returns:
38
+ str: Path to the exported file
39
+ """
40
+ # Create export directory if it doesn't exist
41
+ export_dir = "exports"
42
+ os.makedirs(export_dir, exist_ok=True)
43
+
44
+ # Full path for the export file
45
+ export_path = os.path.join(export_dir, file_name)
46
+
47
+ # Export based on format type
48
+ if format_type == "excel":
49
+ df.to_excel(export_path, index=False)
50
+ else:
51
+ df.to_csv(export_path, index=False)
52
+
53
+ return export_path
54
+
55
+ def visualize_results(df, text_column, category_column="Category"):
56
+ """
57
+ Create visualization of classification results
58
+
59
+ Args:
60
+ df (pd.DataFrame): Dataframe with classification results
61
+ text_column (str): Name of the column containing text data
62
+ category_column (str): Name of the column containing categories
63
+
64
+ Returns:
65
+ matplotlib.figure.Figure: Visualization figure
66
+ """
67
+ # Get categories and their counts
68
+ category_counts = df[category_column].value_counts()
69
+
70
+ # Create a new figure
71
+ fig, ax = plt.subplots(figsize=(10, 6))
72
+
73
+ # Create the histogram
74
+ bars = ax.bar(category_counts.index, category_counts.values)
75
+
76
+ # Add value labels on top of each bar
77
+ for bar in bars:
78
+ height = bar.get_height()
79
+ ax.text(bar.get_x() + bar.get_width()/2., height,
80
+ f'{int(height)}',
81
+ ha='center', va='bottom')
82
+
83
+ # Customize the plot
84
+ ax.set_xlabel('Categories')
85
+ ax.set_ylabel('Number of Texts')
86
+ ax.set_title('Distribution of Classified Texts')
87
+
88
+ # Rotate x-axis labels if they're too long
89
+ plt.xticks(rotation=45, ha='right')
90
+
91
+ # Add grid
92
+ ax.grid(True, linestyle='--', alpha=0.7)
93
+
94
+ plt.tight_layout()
95
+
96
+ return fig
97
+
98
+ def validate_results(df, text_columns, client):
99
+ """
100
+ Use LLM to validate the classification results
101
+
102
+ Args:
103
+ df (pd.DataFrame): Dataframe with classification results
104
+ text_columns (list): List of column names containing text data
105
+ client: LiteLLM client
106
+
107
+ Returns:
108
+ str: Validation report
109
+ """
110
+ try:
111
+ # Sample a few rows for validation
112
+ sample_size = min(5, len(df))
113
+ sample_df = df.sample(n=sample_size, random_state=42)
114
+
115
+ # Build validation prompt
116
+ validation_prompts = []
117
+ for _, row in sample_df.iterrows():
118
+ # Combine text from all selected columns
119
+ text = " ".join(str(row[col]) for col in text_columns)
120
+ assigned_category = row["Category"]
121
+ confidence = row["Confidence"]
122
+
123
+ validation_prompts.append(
124
+ f"Text: {text}\nAssigned Category: {assigned_category}\nConfidence: {confidence}\n"
125
+ )
126
+
127
+ prompt = """
128
+ As a validation expert, review the following text classifications and provide feedback.
129
+ For each text, assess whether the assigned category seems appropriate:
130
+
131
+ {}
132
+
133
+ Provide a brief validation report with:
134
+ 1. Overall accuracy assessment (0-100%)
135
+ 2. Any potential misclassifications identified
136
+ 3. Suggestions for improvement
137
+
138
+ Keep your response under 300 words.
139
+ """.format("\n---\n".join(validation_prompts))
140
+
141
+ # Call LLM API
142
+ response = client.chat.completions.create(
143
+ model="gpt-3.5-turbo",
144
+ messages=[{"role": "user", "content": prompt}],
145
+ temperature=0.3,
146
+ max_tokens=400
147
+ )
148
+
149
+ validation_report = response.choices[0].message.content.strip()
150
+ return validation_report
151
+
152
+ except Exception as e:
153
+ return f"Validation failed: {str(e)}"
154
+
155
+
156
+ def create_example_file():
157
+ """
158
+ Create an example CSV file for testing
159
+
160
+ Returns:
161
+ str: Path to the created file
162
+ """
163
+ # Create some example data
164
+ data = {
165
+ "text": [
166
+ "I absolutely love this product! It exceeded all my expectations.",
167
+ "The service was terrible and the staff was rude.",
168
+ "The product arrived on time but was slightly damaged.",
169
+ "I have mixed feelings about this. Some features are great, others not so much.",
170
+ "This is a complete waste of money. Do not buy!",
171
+ "The customer service team was very helpful in resolving my issue.",
172
+ "It's okay, nothing special but gets the job done.",
173
+ "I'm extremely disappointed with the quality of this product.",
174
+ "This is the best purchase I've made all year!",
175
+ "It's reasonably priced and works as expected."
176
+ ]
177
+ }
178
+
179
+ # Create dataframe
180
+ df = pd.DataFrame(data)
181
+
182
+ # Save to a CSV file
183
+ example_dir = "examples"
184
+ os.makedirs(example_dir, exist_ok=True)
185
+ file_path = os.path.join(example_dir, "sample_reviews.csv")
186
+ df.to_csv(file_path, index=False)
187
+
188
+ return file_path