Spaces:
Running
Running
File size: 9,877 Bytes
0e34d9a a94cab5 0e34d9a a94cab5 0e34d9a ddc3bda 0e34d9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
import gradio as gr
import os
import json
from openai import OpenAI
import sys # Added for flushing output in case of direct printing
# Load sensitive information from environment variables
RUNPOD_API_KEY = os.getenv('RUNPOD_API_KEY')
RUNPOD_ENDPOINT_ID = os.getenv('RUNPOD_ENDPOINT_ID')
# --- Basic Input Validation ---
if not RUNPOD_API_KEY:
raise ValueError("RunPod API key not found. Please set the RUNPOD_API_KEY environment variable.")
if not RUNPOD_ENDPOINT_ID:
raise ValueError("RunPod Endpoint ID not found. Please set the RUNPOD_ENDPOINT_ID environment variable.")
BASE_URL = f"https://api.runpod.ai/v2/{RUNPOD_ENDPOINT_ID}/openai/v1"
MODEL_NAME = "karths/coder_commit_32B" # The specific model hosted on RunPod
MAX_TOKENS = 4096 # Max tokens for the model response
# --- OpenAI Client Initialization ---
client = OpenAI(
api_key=RUNPOD_API_KEY,
base_url=BASE_URL,
)
# --- Gradio App Configuration ---
title = "Python Maintainability Refactoring demo"
description = """
## Instructions for Using the Model
### Model Loading Time:
- Please allow time for the model on GPU server to initialize if it's starting fresh ("Cold Start"). The response will appear token by token.
### Code Submission:
- You can enter or paste your Python code you wish to have refactored, or use the provided example.
### Python Code Constraints:
- Keep the code reasonably sized. Large code blocks might face limitations depending on the GPU instance and model constraints. Max response length is set to {} tokens.
### Understanding Changes:
- It's important to read the "Changes made" section (if provided by the model) in the refactored code response. This will help in understanding what modifications have been made.
### Usage Recommendation:
- Intended for research and evaluation purposes.
""".format(MAX_TOKENS)
system_prompt = """### Instruction:
Refactor the provided Python code to improve its maintainability and efficiency and reduce complexity. Include the refactored code along with comments on the changes made for improving the metrics.
### Input:
"""
css = """.toast-wrap { display: none !important } """
examples = [
["""def analyze_sales_data(sales_records):
active_sales = filter(lambda record: record['status'] == 'active', sales_records)
sales_by_category = {}
for record in active_sales:
category = record['category']
total_sales = record['units_sold'] * record['price_per_unit']
if category not in sales_by_category:
sales_by_category[category] = {'total_sales': 0, 'total_units': 0}
sales_by_category[category]['total_sales'] += total_sales
sales_by_category[category]['total_units'] += record['units_sold']
average_sales_data = []
for category, data in sales_by_category.items():
average_sales = data['total_sales'] / data['total_units'] if data['total_units'] > 0 else 0 # Avoid division by zero
sales_by_category[category]['average_sales'] = average_sales
average_sales_data.append((category, average_sales))
average_sales_data.sort(key=lambda x: x[1], reverse=True)
for rank, (category, _) in enumerate(average_sales_data, start=1):
sales_by_category[category]['rank'] = rank
return sales_by_category"""],
["""import pandas as pd
import re
import ast
from code_bert_score import score # Assuming this library is available in the environment
import numpy as np
def preprocess_code(source_text):
def remove_comments_and_docstrings(source_code):
# Remove single-line comments
source_code = re.sub(r'#.*', '', source_code)
# Remove multi-line strings (docstrings)
source_code = re.sub(r'(\'\'\'(.*?)\'\'\'|\"\"\"(.*?)\"\"\")', '', source_code, flags=re.DOTALL)
return source_code.strip() # Added strip
# Pattern to extract code specifically from markdown blocks if present
pattern = r"```python\s+(.+?)\s+```"
matches = re.findall(pattern, source_text, re.DOTALL)
code_to_process = '\n'.join(matches) if matches else source_text
cleaned_code = remove_comments_and_docstrings(code_to_process)
return cleaned_code
def evaluate_dataframe(df):
results = {'P': [], 'R': [], 'F1': [], 'F3': []}
for index, row in df.iterrows():
try:
# Ensure inputs are lists of strings
cands = [preprocess_code(str(row['generated_text']))] # Added str() conversion
refs = [preprocess_code(str(row['output']))] # Added str() conversion
# Ensure code_bert_score.score returns four values
score_results = score(cands, refs, lang='python')
if len(score_results) == 4:
P, R, F1, F3 = score_results
results['P'].append(P.item() if hasattr(P, 'item') else P) # Handle potential tensor output
results['R'].append(R.item() if hasattr(R, 'item') else R)
results['F1'].append(F1.item() if hasattr(F1, 'item') else F1)
results['F3'].append(F3.item() if hasattr(F3, 'item') else F3) # Assuming F3 is returned
else:
print(f"Warning: Unexpected number of return values from score function for row {index}. Got {len(score_results)} values.")
for key in results.keys():
results[key].append(np.nan) # Append NaN for unexpected format
except Exception as e:
print(f"Error processing row {index}: {e}")
for key in results.keys():
results[key].append(np.nan) # Use NaN for errors
df_metrics = pd.DataFrame(results)
return df_metrics
def evaluate_dataframe_multiple_runs(df, runs=3):
all_results = []
print(f"Starting evaluation for {runs} runs...")
for run in range(runs):
print(f"Run {run + 1}/{runs}")
df_metrics = evaluate_dataframe(df.copy()) # Use a copy to avoid side effects if df is modified
all_results.append(df_metrics)
print(f"Run {run + 1} completed.")
if not all_results:
print("No results collected.")
return pd.DataFrame(), pd.DataFrame()
# Concatenate results and calculate statistics
try:
concatenated_results = pd.concat(all_results)
df_metrics_mean = concatenated_results.groupby(level=0).mean()
df_metrics_std = concatenated_results.groupby(level=0).std()
print("Mean and standard deviation calculated.")
except Exception as e:
print(f"Error calculating statistics: {e}")
# Return empty DataFrames or handle as appropriate
return pd.DataFrame(), pd.DataFrame()
return df_metrics_mean, df_metrics_std"""]
]
# --- Core Logic (Modified for Streaming) ---
def gen_solution_stream(prompt):
"""
Generates a solution for a given problem prompt by calling the LLM via RunPod
and yielding the response chunks as they arrive (streaming).
Parameters:
- prompt (str): The problem prompt including the system message and user input.
Yields:
- str: Chunks of the generated solution text.
- str: An error message if an exception occurs.
"""
try:
# Call the OpenAI compatible endpoint on RunPod with streaming enabled
stream = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
temperature=0.1, # Keep temperature low for deterministic refactoring
top_p=1.0,
max_tokens=MAX_TOKENS,
stream=True # Enable streaming
)
# Yield content chunks from the stream
for chunk in stream:
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield content
# Optional: Handle finish reason if needed
# if chunk.choices and chunk.choices[0].finish_reason:
# print(f"\nStream finished with reason: {chunk.choices[0].finish_reason}")
except Exception as e:
error_message = f"Error: Could not get streaming response from the model. Details: {str(e)}"
print(error_message, file=sys.stderr) # Log error to stderr
yield error_message # Yield the error message to be displayed in the UI
# --- Gradio Interface Function (Modified for Streaming) ---
def predict(message, history):
"""
Handles the user input, calls the backend model stream, and yields the response chunks.
'history' parameter is required by gr.ChatInterface but might not be used here.
"""
# Construct the full prompt
input_prompt = system_prompt + str(message)
# Get the refactored code stream from the backend
response_stream = gen_solution_stream(input_prompt)
# Yield each chunk received from the stream generator
# Gradio's ChatInterface handles accumulating these yields into the chatbot output
buffer = ""
for chunk in response_stream:
buffer += chunk
yield buffer # Yield the accumulated buffer to update the UI incrementally
# --- Launch Gradio Interface ---
# Use gr.ChatInterface for a chat-like experience
gr.ChatInterface(
predict, # Pass the generator function
chatbot=gr.Chatbot(height=500, label="Refactored Code and Explanation", show_copy_button=True), # Added copy button
textbox=gr.Textbox(lines=10, label="Python Code", placeholder="Enter or Paste your Python code here..."),
title=title,
description=description,
theme="abidlabs/Lime", # Or choose another theme e.g., gr.themes.Default()
examples=examples,
cache_examples=False, # Consider enabling caching if examples don't change often
submit_btn="Submit Code",
retry_btn="Retry",
undo_btn="Undo",
clear_btn="Clear",
css=css # Apply custom CSS if needed
).queue().launch(share=True) # share=True creates a public link (use with caution) |