import os import gradio as gr from litellm import OpenAI import json from sklearn.cluster import KMeans from sklearn.decomposition import PCA import matplotlib.pyplot as plt import logging from dotenv import load_dotenv from process import update_api_key, process_file, export_results # Load environment variables from .env file load_dotenv() # Import local modules from utils import load_data, visualize_results from prompts import ( CATEGORY_SUGGESTION_PROMPT, ADDITIONAL_CATEGORY_PROMPT, VALIDATION_ANALYSIS_PROMPT, CATEGORY_IMPROVEMENT_PROMPT, ) # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) # Initialize API key from environment variable OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") # Only initialize client if API key is available client = None if OPENAI_API_KEY: try: client = OpenAI(api_key=OPENAI_API_KEY) logging.info("OpenAI client initialized successfully") except Exception as e: logging.error(f"Failed to initialize OpenAI client: {str(e)}") # Create Gradio interface with gr.Blocks(title="Text Classification System") as demo: gr.Markdown("# Text Classification System") gr.Markdown("Upload your data file (Excel/CSV) and classify text using AI") with gr.Tab("Setup"): api_key_input = gr.Textbox( label="OpenAI API Key", placeholder="Enter your API key here", type="password", value=OPENAI_API_KEY, ) api_key_button = gr.Button("Update API Key") api_key_message = gr.Textbox(label="Status", interactive=False) # Display current API status api_status = ( "API Key is set" if OPENAI_API_KEY else "No API Key found. Please set one." ) gr.Markdown(f"**Current API Status**: {api_status}") api_key_button.click( update_api_key, inputs=[api_key_input], outputs=[api_key_message] ) with gr.Tab("Classify Data"): with gr.Column(): file_input = gr.File(label="Upload Excel/CSV File") # Variable to store available columns available_columns = gr.State([]) # Button to load file and suggest categories load_categories_button = gr.Button("Load File") # Display original dataframe original_df = gr.Dataframe( label="Original Data", interactive=False, visible=False ) with gr.Row(): with gr.Column(): suggested_categories = gr.CheckboxGroup( label="Suggested Categories", choices=[], value=[], interactive=True, visible=False, ) new_category = gr.Textbox( label="Add New Category", placeholder="Enter a new category name", visible=False, ) with gr.Row(): add_category_button = gr.Button("Add Category", visible=False) suggest_category_button = gr.Button( "Suggest Category", visible=False ) # Original categories input (hidden) categories = gr.Textbox(visible=False) with gr.Column(): text_column = gr.CheckboxGroup( label="Select Text Columns", choices=[], interactive=True, visible=False, ) classifier_type = gr.Dropdown( choices=[ ("TF-IDF (Rapide, <1000 lignes)", "tfidf"), ("LLM GPT-3.5 (Fiable, <1000 lignes)", "gpt35"), ("LLM GPT-4 (Très fiable, <500 lignes)", "gpt4"), ("TF-IDF + LLM (Hybride, >1000 lignes)", "hybrid"), ], label="Modèle de classification", value="gpt35", visible=False, ) show_explanations = gr.Checkbox( label="Show Explanations", value=True, visible=False ) process_button = gr.Button("Process and Classify", visible=False) results_df = gr.Dataframe(interactive=True, visible=False) # Create containers for visualization and validation report with gr.Row(visible=False) as results_row: with gr.Column(): visualization = gr.Plot(label="Classification Distribution") with gr.Row(): csv_download = gr.File(label="Download CSV", visible=False) excel_download = gr.File(label="Download Excel", visible=False) with gr.Column(): validation_output = gr.Textbox( label="Validation Report", interactive=False ) improve_button = gr.Button( "Improve Classification with Report", visible=False ) # Function to load file and suggest categories def load_file_and_suggest_categories(file): if not file: return ( [], gr.CheckboxGroup(choices=[]), gr.CheckboxGroup(choices=[], visible=False), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False), gr.CheckboxGroup(choices=[], visible=False), gr.Dropdown(visible=False), gr.Checkbox(visible=False), gr.Button(visible=False), gr.Dataframe(visible=False), ) try: df = load_data(file.name) columns = list(df.columns) # Analyze columns to suggest text columns suggested_text_columns = [] for col in columns: # Check if column contains text data if df[col].dtype == "object": # String type # Check if column contains mostly text (not just numbers or dates) sample = df[col].head(100).dropna() if len(sample) > 0: # Check if most values contain spaces (indicating text) text_ratio = sum(" " in str(val) for val in sample) / len( sample ) if ( text_ratio > 0.3 ): # If more than 30% of values contain spaces suggested_text_columns.append(col) # If no columns were suggested, use all object columns if not suggested_text_columns: suggested_text_columns = [ col for col in columns if df[col].dtype == "object" ] # Get a sample of text for category suggestion sample_texts = [] for col in suggested_text_columns: sample_texts.extend(df[col].head(5).tolist()) # Use LLM to suggest categories if client: prompt = CATEGORY_SUGGESTION_PROMPT.format( "\n---\n".join(sample_texts[:5]) ) try: response = client.chat.completions.create( model="gpt-3.5-turbo", messages=[{"role": "user", "content": prompt}], temperature=0, max_tokens=100, ) suggested_cats = [ cat.strip() for cat in response.choices[0] .message.content.strip() .split(",") ] except: suggested_cats = [ "Positive", "Negative", "Neutral", "Mixed", "Other", ] else: suggested_cats = [ "Positive", "Negative", "Neutral", "Mixed", "Other", ] return ( columns, gr.CheckboxGroup(choices=columns, value=suggested_text_columns), gr.CheckboxGroup( choices=suggested_cats, value=suggested_cats, visible=True ), gr.Textbox(visible=True), gr.Button(visible=True), gr.Button(visible=True), gr.CheckboxGroup( choices=columns, value=suggested_text_columns, visible=True ), gr.Dropdown(visible=True), gr.Checkbox(visible=True), gr.Button(visible=True), gr.Dataframe(value=df, visible=True), ) except Exception as e: return ( [], gr.CheckboxGroup(choices=[]), gr.CheckboxGroup(choices=[], visible=False), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False), gr.CheckboxGroup(choices=[], visible=False), gr.Dropdown(visible=False), gr.Checkbox(visible=False), gr.Button(visible=False), gr.Dataframe(visible=False), ) # Function to add a new category def add_new_category(current_categories, new_category): if not new_category or new_category.strip() == "": return current_categories new_categories = current_categories + [new_category.strip()] return gr.CheckboxGroup(choices=new_categories, value=new_categories) # Function to update categories textbox def update_categories_textbox(selected_categories): return ", ".join(selected_categories) # Function to show results after processing def show_results(df, validation_report): """Show the results after processing""" if df is None: return ( gr.Row(visible=False), gr.File(visible=False), gr.File(visible=False), gr.Dataframe(visible=False), ) # Export to both formats csv_path = export_results(df, "csv") excel_path = export_results(df, "excel") return ( gr.Row(visible=True), gr.File(value=csv_path, visible=True), gr.File(value=excel_path, visible=True), gr.Dataframe(value=df, visible=True), ) # Function to suggest a new category def suggest_new_category(file, current_categories, text_columns): if not file or not text_columns: return gr.CheckboxGroup( choices=current_categories, value=current_categories ) try: df = load_data(file.name) # Get sample texts from selected columns sample_texts = [] for col in text_columns: sample_texts.extend(df[col].head(5).tolist()) if client: prompt = ADDITIONAL_CATEGORY_PROMPT.format( existing_categories=", ".join(current_categories), sample_texts="\n---\n".join(sample_texts[:10]), ) try: response = client.chat.completions.create( model="gpt-3.5-turbo", messages=[{"role": "user", "content": prompt}], temperature=0, max_tokens=50, ) new_cat = response.choices[0].message.content.strip() if new_cat and new_cat not in current_categories: current_categories.append(new_cat) except: pass return gr.CheckboxGroup( choices=current_categories, value=current_categories ) except Exception as e: return gr.CheckboxGroup( choices=current_categories, value=current_categories ) # Function to handle export and show download button def handle_export(df, format_type): if df is None: return gr.File(visible=False) file_path = export_results(df, format_type) return gr.File(value=file_path, visible=True) # Function to improve classification based on validation report def improve_classification( df, validation_report, text_columns, categories, classifier_type, show_explanations, file, ): """Improve classification based on validation report""" if df is None or not validation_report: return ( df, validation_report, gr.Button(visible=False), gr.CheckboxGroup(choices=[], value=[]), ) try: # Extract insights from validation report if client: prompt = VALIDATION_ANALYSIS_PROMPT.format( validation_report=validation_report, current_categories=categories, ) try: response = client.chat.completions.create( model="gpt-4", messages=[{"role": "user", "content": prompt}], temperature=0, max_tokens=300, ) improvements = json.loads( response.choices[0].message.content.strip() ) # Get current categories current_categories = [ cat.strip() for cat in categories.split(",") ] # If new categories are needed, suggest them based on the data if improvements.get("new_categories_needed", False): # Get sample texts for category suggestion sample_texts = [] for col in text_columns: if isinstance(file, str): temp_df = load_data(file) else: temp_df = load_data(file.name) sample_texts.extend(temp_df[col].head(10).tolist()) category_prompt = CATEGORY_IMPROVEMENT_PROMPT.format( current_categories=", ".join(current_categories), analysis=improvements.get("analysis", ""), sample_texts="\n---\n".join(sample_texts[:10]), ) category_response = client.chat.completions.create( model="gpt-4", messages=[{"role": "user", "content": category_prompt}], temperature=0, max_tokens=100, ) new_categories = [ cat.strip() for cat in category_response.choices[0] .message.content.strip() .split(",") ] # Combine current and new categories all_categories = current_categories + new_categories categories = ",".join(all_categories) # Process with improved parameters improved_df, new_validation = process_file( file, text_columns, categories, classifier_type, show_explanations, ) return ( improved_df, new_validation, gr.Button(visible=True), gr.CheckboxGroup( choices=all_categories, value=all_categories ), ) except Exception as e: print(f"Error in improvement process: {str(e)}") return ( df, validation_report, gr.Button(visible=True), gr.CheckboxGroup( choices=current_categories, value=current_categories ), ) else: return ( df, validation_report, gr.Button(visible=True), gr.CheckboxGroup( choices=current_categories, value=current_categories ), ) except Exception as e: print(f"Error in improvement process: {str(e)}") return ( df, validation_report, gr.Button(visible=True), gr.CheckboxGroup( choices=current_categories, value=current_categories ), ) # Connect functions load_categories_button.click( load_file_and_suggest_categories, inputs=[file_input], outputs=[ available_columns, text_column, suggested_categories, new_category, add_category_button, suggest_category_button, text_column, classifier_type, show_explanations, process_button, original_df, ], ) add_category_button.click( add_new_category, inputs=[suggested_categories, new_category], outputs=[suggested_categories], ) suggested_categories.change( update_categories_textbox, inputs=[suggested_categories], outputs=[categories], ) suggest_category_button.click( suggest_new_category, inputs=[file_input, suggested_categories, text_column], outputs=[suggested_categories], ) process_button.click( lambda: gr.Dataframe(visible=True), inputs=[], outputs=[results_df] ).then( process_file, inputs=[ file_input, text_column, categories, classifier_type, show_explanations, ], outputs=[results_df, validation_output], ).then( show_results, inputs=[results_df, validation_output], outputs=[results_row, csv_download, excel_download, results_df], ).then( visualize_results, inputs=[results_df, text_column], outputs=[visualization] ).then( lambda x: gr.Button(visible=True), inputs=[], outputs=[improve_button] ) improve_button.click( improve_classification, inputs=[ results_df, validation_output, text_column, categories, classifier_type, show_explanations, file_input, ], outputs=[ results_df, validation_output, improve_button, suggested_categories, ], ).then( show_results, inputs=[results_df, validation_output], outputs=[results_row, csv_download, excel_download, results_df], ).then( visualize_results, inputs=[results_df, text_column], outputs=[visualization] ) def create_example_data(): """Create example data for demonstration""" from utils import create_example_file example_path = create_example_file() return f"Example file created at: {example_path}" if __name__ == "__main__": # Create examples directory and sample file if it doesn't exist if not os.path.exists("examples"): create_example_data() # Launch the Gradio app demo.launch()