Spaces:
Running
Running
File size: 9,374 Bytes
0aa8067 7871ca4 e737a65 7871ca4 0aa8067 4990331 0aa8067 4990331 0aa8067 3c57b86 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
import json
import gradio as gr
import os
import requests
from huggingface_hub import AsyncInferenceClient
HF_TOKEN = os.getenv('HF_TOKEN')
api_url = os.getenv('API_URL')
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
client = AsyncInferenceClient(api_url)
system_message = """
Refactor the provided Python code to improve its maintainability and efficiency and reduce complexity. Include the refactored code along with the comments on the changes made for improving the metrics.
"""
title = "Python Refactoring"
description = """
Please give it 3 to 4 minutes for the model to load and Run , consider using Python code with less than 120 lines of code due to GPU constrainst
"""
css = """.toast-wrap { display: none !important } """
examples=[["""
import pandas as pd
import re
import ast
from code_bert_score import score
import numpy as np
def preprocess_code(source_text):
def remove_comments_and_docstrings(source_code):
source_code = re.sub(r'#.*', '', source_code)
source_code = re.sub(r'(\'\'\'(.*?)\'\'\'|\"\"\"(.*?)\"\"\")', '', source_code, flags=re.DOTALL)
return source_code
pattern = r"```python\s+(.+?)\s+```"
matches = re.findall(pattern, source_text, re.DOTALL)
code_to_process = '\n'.join(matches) if matches else source_text
cleaned_code = remove_comments_and_docstrings(code_to_process)
return cleaned_code
def evaluate_dataframe(df):
results = {'P': [], 'R': [], 'F1': [], 'F3': []}
for index, row in df.iterrows():
try:
cands = [preprocess_code(row['generated_text'])]
refs = [preprocess_code(row['output'])]
P, R, F1, F3 = score(cands, refs, lang='python')
results['P'].append(P[0])
results['R'].append(R[0])
results['F1'].append(F1[0])
results['F3'].append(F3[0])
except Exception as e:
print(f"Error processing row {index}: {e}")
for key in results.keys():
results[key].append(None)
df_metrics = pd.DataFrame(results)
return df_metrics
def evaluate_dataframe_multiple_runs(df, runs=3):
all_results = []
for run in range(runs):
df_metrics = evaluate_dataframe(df)
all_results.append(df_metrics)
# Calculate mean and std deviation of metrics across runs
df_metrics_mean = pd.concat(all_results).groupby(level=0).mean()
df_metrics_std = pd.concat(all_results).groupby(level=0).std()
return df_metrics_mean, df_metrics_std
""" ] ,
["""
def analyze_sales_data(sales_records):
active_sales = filter(lambda record: record['status'] == 'active', sales_records)
sales_by_category = {}
for record in active_sales:
category = record['category']
total_sales = record['units_sold'] * record['price_per_unit']
if category not in sales_by_category:
sales_by_category[category] = {'total_sales': 0, 'total_units': 0}
sales_by_category[category]['total_sales'] += total_sales
sales_by_category[category]['total_units'] += record['units_sold']
average_sales_data = []
for category, data in sales_by_category.items():
average_sales = data['total_sales'] / data['total_units']
sales_by_category[category]['average_sales'] = average_sales
average_sales_data.append((category, average_sales))
average_sales_data.sort(key=lambda x: x[1], reverse=True)
for rank, (category, _) in enumerate(average_sales_data, start=1):
sales_by_category[category]['rank'] = rank
return sales_by_category
"""]]
# Note: We have removed default system prompt as requested by the paper authors [Dated: 13/Oct/2023]
# Prompting style for Llama2 without using system prompt
# <s>[INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s><s>[INST] {{ user_msg_2 }} [/INST]
# Stream text - stream tokens with InferenceClient from TGI
async def predict(message, chatbot, system_prompt="", temperature=0.1, max_new_tokens=4096, top_p=0.6, repetition_penalty=1.1,):
if system_prompt != "":
input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
else:
input_prompt = f"<s>[INST] "
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
for interaction in chatbot:
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] "
input_prompt = input_prompt + str(message) + " [/INST] "
partial_message = ""
async for token in await client.text_generation(prompt=input_prompt,
max_new_tokens=max_new_tokens,
stream=True,
best_of=1,
temperature=temperature,
top_p=top_p,
do_sample=True,
repetition_penalty=repetition_penalty):
partial_message = partial_message + token
yield partial_message
# No Stream - batch produce tokens using TGI inference endpoint
def predict_batch(message, chatbot, system_prompt="", temperature=0.1, max_new_tokens=4096, top_p=0.6, repetition_penalty=1.1):
if system_prompt != "":
input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
else:
input_prompt = f"<s>[INST] "
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
for interaction in chatbot:
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] "
input_prompt = input_prompt + str(message) + " [/INST] "
print(f"input_prompt - {input_prompt}")
data = {
"inputs": input_prompt,
"parameters": {
"max_new_tokens":max_new_tokens,
"temperature":temperature,
"top_p":top_p,
"repetition_penalty":repetition_penalty,
"do_sample":True,
},
}
response = requests.post(api_url, headers=headers, json=data ) #auth=('hf', hf_token)) data=json.dumps(data),
if response.status_code == 200: # check if the request was successful
try:
json_obj = response.json()
if 'generated_text' in json_obj[0] and len(json_obj[0]['generated_text']) > 0:
return json_obj[0]['generated_text']
elif 'error' in json_obj[0]:
return json_obj[0]['error'] + ' Please refresh and try again with smaller input prompt'
else:
print(f"Unexpected response: {json_obj[0]}")
except json.JSONDecodeError:
print(f"Failed to decode response as JSON: {response.text}")
else:
print(f"Request failed with status code {response.status_code}")
def vote(data: gr.LikeData):
if data.liked:
print("You upvoted this response: " + data.value)
else:
print("You downvoted this response: " + data.value)
additional_inputs=[
gr.Textbox("", label="Optional system prompt"),
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=4096,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.6,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
chatbot_stream = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False)
chatbot_batch = gr.Chatbot(avatar_images=('user1.png', 'bot1.png'),bubble_full_width = False)
chat_interface_stream = gr.ChatInterface(predict,
title=title,
description=description,
textbox=gr.Textbox(),
chatbot=chatbot_stream,
css=css,
examples=examples,
#cache_examples=True,
additional_inputs=additional_inputs,)
chat_interface_batch=gr.ChatInterface(predict_batch,
title=title,
description=description,
textbox=gr.Textbox(),
chatbot=chatbot_batch,
css=css,
examples=examples,
#cache_examples=True,
additional_inputs=additional_inputs,)
# Gradio Demo
with gr.Blocks() as demo:
with gr.Tab("Streaming"):
# streaming chatbot
chatbot_stream.like(vote, None, None)
chat_interface_stream.render()
with gr.Tab("Batch"):
# non-streaming chatbot
chatbot_batch.like(vote, None, None)
chat_interface_batch.render()
demo.queue(max_size=2).launch() |