File size: 9,005 Bytes
90c062f
 
c1d360b
 
90c062f
c1d360b
 
 
90c062f
 
c1d360b
 
 
8883d49
c1d360b
 
 
 
 
 
 
 
 
 
 
 
 
8883d49
 
 
c1d360b
8883d49
c1d360b
8883d49
c1d360b
8883d49
c1d360b
8883d49
c1d360b
 
8883d49
c1d360b
 
b485e94
90c062f
c1d360b
90c062f
c1d360b
 
 
62e5f0b
 
 
 
 
 
 
 
 
 
 
c1d360b
62e5f0b
 
 
 
 
c1d360b
 
90c062f
 
c1d360b
90c062f
c1d360b
90c062f
 
c1d360b
90c062f
c1d360b
90c062f
c1d360b
 
 
90c062f
 
 
c1d360b
90c062f
 
c1d360b
90c062f
 
 
 
c1d360b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90c062f
 
 
c1d360b
 
90c062f
 
c1d360b
90c062f
 
c1d360b
90c062f
c1d360b
 
90c062f
c1d360b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90c062f
087c4c8
c1d360b
 
90c062f
 
c1d360b
90c062f
c1d360b
 
90c062f
 
 
c1d360b
 
90c062f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import gradio as gr
import os
import json
from openai import OpenAI

# Load sensitive information from environment variables
RUNPOD_API_KEY = os.getenv('RUNPOD_API_KEY')
RUNPOD_ENDPOINT_ID = os.getenv('RUNPOD_ENDPOINT_ID')


BASE_URL = f"https://api.runpod.ai/v2/{RUNPOD_ENDPOINT_ID}/openai/v1"
MODEL_NAME = "karths/coder_commit_32B" # The specific model hosted on RunPod
MAX_TOKENS = 4096# Max tokens for the model response

# --- OpenAI Client Initialization ---
# Check if the API key is provided
if not RUNPOD_API_KEY:
    raise ValueError("RunPod API key not found. Please set the RUNPOD_API_KEY environment variable or add it directly in the script.")

# Initialize the OpenAI client to connect to the RunPod endpoint
client = OpenAI(
    api_key=RUNPOD_API_KEY,
    base_url=BASE_URL,
)

# --- Gradio App Configuration ---
title = "Python Maintainability Refactoring (RunPod)"
description = """
## Instructions for Using the Model
### Model Loading Time:
- Please allow time for the model on RunPod to initialize if it's starting fresh ("Cold Start").
### Code Submission:
- You can enter or paste your Python code you wish to have refactored, or use the provided example.
### Python Code Constraints:
- Keep the code reasonably sized. While the 120-line limit was for the previous setup, large code blocks might still face limitations depending on the RunPod instance and model constraints. Max response length is set to {} tokens.
### Understanding Changes:
- It's important to read the "Changes made" section (if provided by the model) in the refactored code response. This will help in understanding what modifications have been made.
### Usage Recommendation:
- Intended for research and evaluation purposes.
""".format(MAX_TOKENS)

system_prompt = """### Instruction:
Refactor the provided Python code to improve its maintainability and efficiency and reduce complexity. Include the refactored code along with comments on the changes made for improving the metrics.
### Input:
"""

css = """.toast-wrap { display: none !important } """

examples = [
    ["""def analyze_sales_data(sales_records):
    active_sales = filter(lambda record: record['status'] == 'active', sales_records)
    sales_by_category = {}
    for record in active_sales:
        category = record['category']
        total_sales = record['units_sold'] * record['price_per_unit']
        if category not in sales_by_category:
            sales_by_category[category] = {'total_sales': 0, 'total_units': 0}
        sales_by_category[category]['total_sales'] += total_sales
        sales_by_category[category]['total_units'] += record['units_sold']
    average_sales_data = []
    for category, data in sales_by_category.items():
        average_sales = data['total_sales'] / data['total_units'] if data['total_units'] > 0 else 0 # Avoid division by zero
        sales_by_category[category]['average_sales'] = average_sales
        average_sales_data.append((category, average_sales))
    average_sales_data.sort(key=lambda x: x[1], reverse=True)
    for rank, (category, _) in enumerate(average_sales_data, start=1):
        sales_by_category[category]['rank'] = rank
    return sales_by_category"""],
    ["""import pandas as pd
import re
import ast
from code_bert_score import score # Assuming this library is available in the environment
import numpy as np

def preprocess_code(source_text):
    def remove_comments_and_docstrings(source_code):
        # Remove single-line comments
        source_code = re.sub(r'#.*', '', source_code)
        # Remove multi-line strings (docstrings)
        source_code = re.sub(r'(\'\'\'(.*?)\'\'\'|\"\"\"(.*?)\"\"\")', '', source_code, flags=re.DOTALL)
        return source_code.strip() # Added strip

    # Pattern to extract code specifically from markdown blocks if present
    pattern = r"```python\s+(.+?)\s+```"
    matches = re.findall(pattern, source_text, re.DOTALL)
    code_to_process = '\n'.join(matches) if matches else source_text

    cleaned_code = remove_comments_and_docstrings(code_to_process)
    return cleaned_code

def evaluate_dataframe(df):
    results = {'P': [], 'R': [], 'F1': [], 'F3': []}
    for index, row in df.iterrows():
        try:
            # Ensure inputs are lists of strings
            cands = [preprocess_code(str(row['generated_text']))] # Added str() conversion
            refs = [preprocess_code(str(row['output']))] # Added str() conversion

            # Ensure code_bert_score.score returns four values
            score_results = score(cands, refs, lang='python')
            if len(score_results) == 4:
                 P, R, F1, F3 = score_results
                 results['P'].append(P.item() if hasattr(P, 'item') else P) # Handle potential tensor output
                 results['R'].append(R.item() if hasattr(R, 'item') else R)
                 results['F1'].append(F1.item() if hasattr(F1, 'item') else F1)
                 results['F3'].append(F3.item() if hasattr(F3, 'item') else F3) # Assuming F3 is returned
            else:
                 print(f"Warning: Unexpected number of return values from score function for row {index}. Got {len(score_results)} values.")
                 for key in results.keys():
                    results[key].append(np.nan) # Append NaN for unexpected format

        except Exception as e:
            print(f"Error processing row {index}: {e}")
            for key in results.keys():
                results[key].append(np.nan) # Use NaN for errors

    df_metrics = pd.DataFrame(results)
    return df_metrics

def evaluate_dataframe_multiple_runs(df, runs=3):
    all_results = []
    print(f"Starting evaluation for {runs} runs...")
    for run in range(runs):
        print(f"Run {run + 1}/{runs}")
        df_metrics = evaluate_dataframe(df.copy()) # Use a copy to avoid side effects if df is modified
        all_results.append(df_metrics)
        print(f"Run {run + 1} completed.")

    if not all_results:
        print("No results collected.")
        return pd.DataFrame(), pd.DataFrame()

    # Concatenate results and calculate statistics
    try:
        concatenated_results = pd.concat(all_results)
        df_metrics_mean = concatenated_results.groupby(level=0).mean()
        df_metrics_std = concatenated_results.groupby(level=0).std()
        print("Mean and standard deviation calculated.")
    except Exception as e:
        print(f"Error calculating statistics: {e}")
        # Return empty DataFrames or handle as appropriate
        return pd.DataFrame(), pd.DataFrame()

    return df_metrics_mean, df_metrics_std"""]
]

# --- Core Logic ---
def gen_solution(prompt):
    """
    Generates a solution for a given problem prompt by calling the LLM via RunPod.

    Parameters:
    - prompt (str): The problem prompt including the system message and user input.

    Returns:
    - str: The generated solution text, or an error message.
    """
    try:
        # Call the OpenAI compatible endpoint on RunPod
        completion = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.1, # Keep temperature low for more deterministic refactoring
            top_p=1.0,
            max_tokens=MAX_TOKENS,
            # stream=False # Explicitly setting stream to False (default)
        )
        # Extract the response content
        response_content = completion.choices[0].message.content
        return response_content

    except Exception as e:
        print(f"Error calling RunPod API: {e}")
        # Provide a user-friendly error message
        return f"Error: Could not get response from the model. Details: {str(e)}"

# --- Gradio Interface Function ---
def predict(message, history):
    """
    Handles the user input, calls the backend model, and returns the response.
    'history' parameter is required by gr.ChatInterface but might not be used here.
    """
    # Construct the full prompt
    input_prompt = system_prompt + str(message)  # Using the format from the original code

    # Get the refactored code from the backend
    refactored_code_response = gen_solution(input_prompt)

    # The response is returned directly to the ChatInterface
    return refactored_code_response

# --- Launch Gradio Interface ---
# Use gr.ChatInterface for a chat-like experience
gr.ChatInterface(
    predict,
    chatbot=gr.Chatbot(height=500, label="Refactored Code and Explanation"),
    textbox=gr.Textbox(lines=10, label="Python Code", placeholder="Enter or Paste your Python code here..."),
    title=title,
    description=description,
    theme="abidlabs/Lime", # Or choose another theme e.g., gr.themes.Default()
    examples=examples,
    cache_examples=False, # Consider enabling caching if examples don't change often
    submit_btn="Submit Code",
    retry_btn="Retry",
    undo_btn="Undo",
    clear_btn="Clear",
    css=css # Apply custom CSS if needed
).queue().launch(share=True) # share=True creates a public link (use with caution)