File size: 9,877 Bytes
0e34d9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a94cab5
0e34d9a
 
 
a94cab5
0e34d9a
 
 
ddc3bda
0e34d9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import gradio as gr
import os
import json
from openai import OpenAI
import sys # Added for flushing output in case of direct printing

# Load sensitive information from environment variables
RUNPOD_API_KEY = os.getenv('RUNPOD_API_KEY')
RUNPOD_ENDPOINT_ID = os.getenv('RUNPOD_ENDPOINT_ID')

# --- Basic Input Validation ---
if not RUNPOD_API_KEY:
    raise ValueError("RunPod API key not found. Please set the RUNPOD_API_KEY environment variable.")
if not RUNPOD_ENDPOINT_ID:
    raise ValueError("RunPod Endpoint ID not found. Please set the RUNPOD_ENDPOINT_ID environment variable.")

BASE_URL = f"https://api.runpod.ai/v2/{RUNPOD_ENDPOINT_ID}/openai/v1"
MODEL_NAME = "karths/coder_commit_32B" # The specific model hosted on RunPod
MAX_TOKENS = 4096 # Max tokens for the model response

# --- OpenAI Client Initialization ---
client = OpenAI(
    api_key=RUNPOD_API_KEY,
    base_url=BASE_URL,
)

# --- Gradio App Configuration ---
title = "Python Maintainability Refactoring demo"
description = """
## Instructions for Using the Model
### Model Loading Time:
- Please allow time for the model on GPU server to initialize if it's starting fresh ("Cold Start"). The response will appear token by token.
### Code Submission:
- You can enter or paste your Python code you wish to have refactored, or use the provided example.
### Python Code Constraints:
- Keep the code reasonably sized. Large code blocks might face limitations depending on the GPU instance and model constraints. Max response length is set to {} tokens.
### Understanding Changes:
- It's important to read the "Changes made" section (if provided by the model) in the refactored code response. This will help in understanding what modifications have been made.
### Usage Recommendation:
- Intended for research and evaluation purposes.
""".format(MAX_TOKENS)

system_prompt = """### Instruction:
Refactor the provided Python code to improve its maintainability and efficiency and reduce complexity. Include the refactored code along with comments on the changes made for improving the metrics.
### Input:
"""

css = """.toast-wrap { display: none !important } """

examples = [
    ["""def analyze_sales_data(sales_records):
    active_sales = filter(lambda record: record['status'] == 'active', sales_records)
    sales_by_category = {}
    for record in active_sales:
        category = record['category']
        total_sales = record['units_sold'] * record['price_per_unit']
        if category not in sales_by_category:
            sales_by_category[category] = {'total_sales': 0, 'total_units': 0}
        sales_by_category[category]['total_sales'] += total_sales
        sales_by_category[category]['total_units'] += record['units_sold']
    average_sales_data = []
    for category, data in sales_by_category.items():
        average_sales = data['total_sales'] / data['total_units'] if data['total_units'] > 0 else 0 # Avoid division by zero
        sales_by_category[category]['average_sales'] = average_sales
        average_sales_data.append((category, average_sales))
    average_sales_data.sort(key=lambda x: x[1], reverse=True)
    for rank, (category, _) in enumerate(average_sales_data, start=1):
        sales_by_category[category]['rank'] = rank
    return sales_by_category"""],
    ["""import pandas as pd
import re
import ast
from code_bert_score import score # Assuming this library is available in the environment
import numpy as np

def preprocess_code(source_text):
    def remove_comments_and_docstrings(source_code):
        # Remove single-line comments
        source_code = re.sub(r'#.*', '', source_code)
        # Remove multi-line strings (docstrings)
        source_code = re.sub(r'(\'\'\'(.*?)\'\'\'|\"\"\"(.*?)\"\"\")', '', source_code, flags=re.DOTALL)
        return source_code.strip() # Added strip

    # Pattern to extract code specifically from markdown blocks if present
    pattern = r"```python\s+(.+?)\s+```"
    matches = re.findall(pattern, source_text, re.DOTALL)
    code_to_process = '\n'.join(matches) if matches else source_text

    cleaned_code = remove_comments_and_docstrings(code_to_process)
    return cleaned_code

def evaluate_dataframe(df):
    results = {'P': [], 'R': [], 'F1': [], 'F3': []}
    for index, row in df.iterrows():
        try:
            # Ensure inputs are lists of strings
            cands = [preprocess_code(str(row['generated_text']))] # Added str() conversion
            refs = [preprocess_code(str(row['output']))] # Added str() conversion

            # Ensure code_bert_score.score returns four values
            score_results = score(cands, refs, lang='python')
            if len(score_results) == 4:
                P, R, F1, F3 = score_results
                results['P'].append(P.item() if hasattr(P, 'item') else P) # Handle potential tensor output
                results['R'].append(R.item() if hasattr(R, 'item') else R)
                results['F1'].append(F1.item() if hasattr(F1, 'item') else F1)
                results['F3'].append(F3.item() if hasattr(F3, 'item') else F3) # Assuming F3 is returned
            else:
                print(f"Warning: Unexpected number of return values from score function for row {index}. Got {len(score_results)} values.")
                for key in results.keys():
                   results[key].append(np.nan) # Append NaN for unexpected format

        except Exception as e:
            print(f"Error processing row {index}: {e}")
            for key in results.keys():
                results[key].append(np.nan) # Use NaN for errors

    df_metrics = pd.DataFrame(results)
    return df_metrics

def evaluate_dataframe_multiple_runs(df, runs=3):
    all_results = []
    print(f"Starting evaluation for {runs} runs...")
    for run in range(runs):
        print(f"Run {run + 1}/{runs}")
        df_metrics = evaluate_dataframe(df.copy()) # Use a copy to avoid side effects if df is modified
        all_results.append(df_metrics)
        print(f"Run {run + 1} completed.")

    if not all_results:
        print("No results collected.")
        return pd.DataFrame(), pd.DataFrame()

    # Concatenate results and calculate statistics
    try:
        concatenated_results = pd.concat(all_results)
        df_metrics_mean = concatenated_results.groupby(level=0).mean()
        df_metrics_std = concatenated_results.groupby(level=0).std()
        print("Mean and standard deviation calculated.")
    except Exception as e:
        print(f"Error calculating statistics: {e}")
        # Return empty DataFrames or handle as appropriate
        return pd.DataFrame(), pd.DataFrame()

    return df_metrics_mean, df_metrics_std"""]
]

# --- Core Logic (Modified for Streaming) ---
def gen_solution_stream(prompt):
    """
    Generates a solution for a given problem prompt by calling the LLM via RunPod
    and yielding the response chunks as they arrive (streaming).

    Parameters:
    - prompt (str): The problem prompt including the system message and user input.

    Yields:
    - str: Chunks of the generated solution text.
    - str: An error message if an exception occurs.
    """
    try:
        # Call the OpenAI compatible endpoint on RunPod with streaming enabled
        stream = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.1, # Keep temperature low for deterministic refactoring
            top_p=1.0,
            max_tokens=MAX_TOKENS,
            stream=True # Enable streaming
        )
        # Yield content chunks from the stream
        for chunk in stream:
            if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
                content = chunk.choices[0].delta.content
                yield content
            # Optional: Handle finish reason if needed
            # if chunk.choices and chunk.choices[0].finish_reason:
            #    print(f"\nStream finished with reason: {chunk.choices[0].finish_reason}")

    except Exception as e:
        error_message = f"Error: Could not get streaming response from the model. Details: {str(e)}"
        print(error_message, file=sys.stderr) # Log error to stderr
        yield error_message # Yield the error message to be displayed in the UI

# --- Gradio Interface Function (Modified for Streaming) ---
def predict(message, history):
    """
    Handles the user input, calls the backend model stream, and yields the response chunks.
    'history' parameter is required by gr.ChatInterface but might not be used here.
    """
    # Construct the full prompt
    input_prompt = system_prompt + str(message)

    # Get the refactored code stream from the backend
    response_stream = gen_solution_stream(input_prompt)

    # Yield each chunk received from the stream generator
    # Gradio's ChatInterface handles accumulating these yields into the chatbot output
    buffer = ""
    for chunk in response_stream:
        buffer += chunk
        yield buffer # Yield the accumulated buffer to update the UI incrementally

# --- Launch Gradio Interface ---
# Use gr.ChatInterface for a chat-like experience
gr.ChatInterface(
    predict, # Pass the generator function
    chatbot=gr.Chatbot(height=500, label="Refactored Code and Explanation", show_copy_button=True), # Added copy button
    textbox=gr.Textbox(lines=10, label="Python Code", placeholder="Enter or Paste your Python code here..."),
    title=title,
    description=description,
    theme="abidlabs/Lime", # Or choose another theme e.g., gr.themes.Default()
    examples=examples,
    cache_examples=False, # Consider enabling caching if examples don't change often
    submit_btn="Submit Code",
    retry_btn="Retry",
    undo_btn="Undo",
    clear_btn="Clear",
    css=css # Apply custom CSS if needed
).queue().launch(share=True) # share=True creates a public link (use with caution)