|
import os |
|
import gradio as gr |
|
import time |
|
import uuid |
|
from typing import Dict, List |
|
import markdown |
|
import re |
|
from pygments import highlight |
|
from pygments.lexers import get_lexer_by_name |
|
from pygments.formatters import HtmlFormatter |
|
|
|
|
|
import requests |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
|
|
AI_MODEL_ID = os.environ.get("AI_MODEL_ID", "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ") |
|
API_TOKEN = os.environ.get("HUGGINGFACE_API_TOKEN", None) |
|
MAX_HISTORY_LENGTH = 10 |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
active_sessions: Dict[str, List[Dict]] = {} |
|
|
|
|
|
@torch.inference_mode() |
|
def initialize_model(): |
|
try: |
|
print(f"Loading model {AI_MODEL_ID} on {DEVICE}...") |
|
tokenizer = AutoTokenizer.from_pretrained(AI_MODEL_ID) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
AI_MODEL_ID, |
|
device_map=DEVICE, |
|
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32 |
|
) |
|
print("Model loaded successfully!") |
|
return model, tokenizer |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
return None, None |
|
|
|
model, tokenizer = initialize_model() |
|
|
|
|
|
def format_code_blocks(text): |
|
def replace_code_block(match): |
|
language = match.group(1) or "python" |
|
code = match.group(2) |
|
try: |
|
lexer = get_lexer_by_name(language, stripall=True) |
|
formatter = HtmlFormatter(style="github", cssclass="syntax-highlight") |
|
result = highlight(code, lexer, formatter) |
|
return f'<div class="code-block">{result}</div>' |
|
except: |
|
|
|
return f'<pre><code class="{language}">{code}</code></pre>' |
|
|
|
|
|
pattern = r'```(\w+)?\n([\s\S]+?)\n```' |
|
return re.sub(pattern, replace_code_block, text) |
|
|
|
|
|
@torch.inference_mode() |
|
def process_message(message, history, session_id): |
|
if session_id not in active_sessions: |
|
active_sessions[session_id] = [] |
|
|
|
|
|
if len(active_sessions[session_id]) >= MAX_HISTORY_LENGTH: |
|
active_sessions[session_id].pop(0) |
|
active_sessions[session_id].append({"role": "user", "content": message}) |
|
|
|
|
|
prompt = format_prompt(active_sessions[session_id]) |
|
|
|
|
|
yield "β Thinking..." |
|
time.sleep(0.5) |
|
yield "β Generating response..." |
|
|
|
try: |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
|
output = model.generate( |
|
inputs["input_ids"], |
|
max_length=2048, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
response = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
ai_response = extract_ai_response(response, prompt) |
|
|
|
|
|
active_sessions[session_id].append({"role": "assistant", "content": ai_response}) |
|
|
|
|
|
formatted_response = format_code_blocks(markdown.markdown(ai_response)) |
|
|
|
return formatted_response |
|
except Exception as e: |
|
error_msg = f"Error generating response: {str(e)}" |
|
print(error_msg) |
|
return f"<span style='color: red;'>{error_msg}</span>" |
|
|
|
|
|
def format_prompt(messages): |
|
prompt = "" |
|
for msg in messages: |
|
if msg["role"] == "user": |
|
prompt += f"USER: {msg['content']}\n" |
|
else: |
|
prompt += f"ASSISTANT: {msg['content']}\n" |
|
prompt += "ASSISTANT: " |
|
return prompt |
|
|
|
|
|
def extract_ai_response(full_response, prompt): |
|
|
|
if full_response.startswith(prompt): |
|
return full_response[len(prompt):].strip() |
|
return full_response.strip() |
|
|
|
|
|
def create_new_session(): |
|
session_id = str(uuid.uuid4()) |
|
active_sessions[session_id] = [] |
|
return session_id, [] |
|
|
|
|
|
def debug_code(code, session_id): |
|
try: |
|
|
|
yield "π Analyzing code..." |
|
time.sleep(1) |
|
|
|
|
|
compile(code, '<string>', 'exec') |
|
yield "β
Syntax check passed" |
|
time.sleep(0.5) |
|
|
|
|
|
lines = code.split('\n') |
|
issues = [] |
|
|
|
|
|
for i, line in enumerate(lines): |
|
if 'print(' in line and not line.strip().endswith(')'): |
|
issues.append(f"Line {i+1}: Missing closing parenthesis in print statement") |
|
if '#' not in line and line.strip().endswith(':') and i+1 < len(lines) and not lines[i+1].startswith(' '): |
|
issues.append(f"Line {i+1}: Missing indentation after control statement") |
|
|
|
if issues: |
|
yield "π΄ Found potential issues:\n" + "\n".join(issues) |
|
else: |
|
|
|
yield "π’ No obvious issues detected. Running code..." |
|
time.sleep(1) |
|
|
|
|
|
if session_id in active_sessions: |
|
active_sessions[session_id].append({ |
|
"role": "assistant", |
|
"content": f"I've analyzed your code and it looks good syntactically. Here are some tips for improvement:\n\n```python\n{code}\n```\n\nConsider adding more comments and error handling for better robustness." |
|
}) |
|
|
|
yield "β
Code analysis complete. The code appears to be valid Python code." |
|
except SyntaxError as e: |
|
error_msg = f"π΄ Syntax Error: {str(e)}" |
|
yield error_msg |
|
|
|
|
|
if session_id in active_sessions: |
|
active_sessions[session_id].append({ |
|
"role": "assistant", |
|
"content": f"I found a syntax error in your code:\n\n```python\n{code}\n```\n\nError: {str(e)}\n\nPlease check your syntax and try again." |
|
}) |
|
except Exception as e: |
|
error_msg = f"π΄ Error during analysis: {str(e)}" |
|
yield error_msg |
|
|
|
|
|
custom_css = """ |
|
.container {max-width: 850px; margin: auto;} |
|
.chat-message {padding: 12px; border-radius: 10px; margin-bottom: 10px; position: relative;} |
|
.user-message {background-color: #e6f7ff; text-align: right; margin-left: 20%;} |
|
.bot-message {background-color: #f2f2f2; margin-right: 20%;} |
|
.timestamp {font-size: 0.7em; color: #888; position: absolute; bottom: 2px; right: 10px;} |
|
.syntax-highlight {border-radius: 5px; padding: 10px !important; margin: 15px 0 !important; overflow-x: auto;} |
|
.code-block {border: 1px solid #ddd; border-radius: 5px; margin: 10px 0;} |
|
.typing-indicator {font-style: italic; color: #888;} |
|
""" |
|
|
|
|
|
def build_ui(): |
|
with gr.Blocks(css=custom_css) as demo: |
|
gr.Markdown("# AI Chat with Code Capabilities") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
|
|
chatbot = gr.Chatbot( |
|
label="Conversation", |
|
height=500, |
|
elem_classes="container" |
|
) |
|
|
|
with gr.Row(): |
|
message_input = gr.Textbox( |
|
label="Your message", |
|
placeholder="Ask anything or paste code for debugging...", |
|
lines=3 |
|
) |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("Send", variant="primary") |
|
clear_btn = gr.Button("Clear Chat") |
|
debug_btn = gr.Button("Debug Code", variant="secondary") |
|
|
|
with gr.Column(scale=1): |
|
|
|
new_session_btn = gr.Button("New Session") |
|
session_info = gr.Textbox(label="Current Session", value="", visible=False) |
|
|
|
|
|
gr.Markdown(f"### Model Info\n- Using: {AI_MODEL_ID}\n- Device: {DEVICE}") |
|
|
|
|
|
temperature = gr.Slider( |
|
minimum=0.1, maximum=1.5, value=0.7, step=0.1, |
|
label="Temperature (Creativity)" |
|
) |
|
|
|
|
|
status_box = gr.Textbox(label="Status", value="Ready") |
|
|
|
|
|
session_id = gr.State(str(uuid.uuid4())) |
|
|
|
|
|
def on_submit(message, chat_history, sid): |
|
if not message.strip(): |
|
return "", chat_history |
|
|
|
|
|
chat_history.append([message, None]) |
|
status_box.update(value="Generating response...") |
|
|
|
return "", chat_history |
|
|
|
submit_btn.click( |
|
on_submit, |
|
[message_input, chatbot, session_id], |
|
[message_input, chatbot] |
|
).then( |
|
process_message, |
|
[message_input, chatbot, session_id], |
|
chatbot |
|
).then( |
|
lambda: "Ready", |
|
None, |
|
status_box |
|
) |
|
|
|
|
|
def on_debug(message, chat_history, sid): |
|
if not message.strip(): |
|
return chat_history, "Please enter code to debug" |
|
|
|
chat_history.append([message, None]) |
|
return chat_history, "Debugging code..." |
|
|
|
debug_btn.click( |
|
on_debug, |
|
[message_input, chatbot, session_id], |
|
[chatbot, status_box] |
|
).then( |
|
debug_code, |
|
[message_input, session_id], |
|
chatbot |
|
).then( |
|
lambda: "Ready", |
|
None, |
|
status_box |
|
) |
|
|
|
|
|
def start_new_session(): |
|
new_sid = str(uuid.uuid4()) |
|
active_sessions[new_sid] = [] |
|
return new_sid, [], f"New session started: {new_sid[:8]}...", "Ready" |
|
|
|
new_session_btn.click( |
|
start_new_session, |
|
None, |
|
[session_id, chatbot, session_info, status_box] |
|
) |
|
|
|
|
|
clear_btn.click(lambda sid: ([], f"Session cleared: {sid[:8]}...", "Ready"), |
|
[session_id], |
|
[chatbot, session_info, status_box]) |
|
|
|
return demo |
|
|
|
|
|
def main(): |
|
demo = build_ui() |
|
|
|
|
|
demo.queue(concurrency_count=5).launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False, |
|
debug=False |
|
) |
|
|
|
if __name__ == "__main__": |
|
main() |