Spaces:
Running
Running
import gradio as gr | |
import os | |
import json | |
import uuid | |
import threading | |
import time | |
import re | |
from dotenv import load_dotenv | |
from openai import OpenAI | |
from realtime_transcriber import WebSocketClient, connections, WEBSOCKET_URI, WEBSOCKET_HEADERS | |
# ------------------ Load Secrets ------------------ | |
load_dotenv() | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
ASSISTANT_ID = os.getenv("ASSISTANT_ID") | |
if not OPENAI_API_KEY or not ASSISTANT_ID: | |
raise ValueError("Missing OPENAI_API_KEY or ASSISTANT_ID") | |
client = OpenAI(api_key=OPENAI_API_KEY) | |
session_threads = {} | |
# ------------------ Chat Logic ------------------ | |
def reset_session(): | |
session_id = str(uuid.uuid4()) | |
session_threads[session_id] = client.beta.threads.create().id | |
return session_id | |
def process_chat(message, history, session_id): | |
thread_id = session_threads.get(session_id) | |
if not thread_id: | |
thread_id = client.beta.threads.create().id | |
session_threads[session_id] = thread_id | |
client.beta.threads.messages.create(thread_id=thread_id, role="user", content=message) | |
run = client.beta.threads.runs.create(thread_id=thread_id, assistant_id=ASSISTANT_ID) | |
while client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id).status != "completed": | |
time.sleep(1) | |
messages = client.beta.threads.messages.list(thread_id=thread_id) | |
for msg in reversed(messages.data): | |
if msg.role == "assistant": | |
return msg.content[0].text.value | |
return "β οΈ Assistant did not respond." | |
def extract_image_url(text): | |
match = re.search(r'https://raw\.githubusercontent\.com/[^\s"]+\.png', text) | |
return match.group(0) if match else None | |
def handle_chat(message, history, session_id): | |
response = process_chat(message, history, session_id) | |
history.append((message, response)) | |
image = extract_image_url(response) | |
return history, image | |
# ------------------ Voice Logic ------------------ | |
def create_websocket_client(): | |
client_id = str(uuid.uuid4()) | |
connections[client_id] = WebSocketClient(WEBSOCKET_URI, WEBSOCKET_HEADERS, client_id) | |
threading.Thread(target=connections[client_id].run, daemon=True).start() | |
return client_id | |
def clear_transcript(client_id): | |
if client_id in connections: | |
connections[client_id].transcript = "" | |
return "" | |
def send_audio_chunk(audio, client_id): | |
if client_id not in connections: | |
return "Initializing connection..." | |
sr, y = audio | |
connections[client_id].enqueue_audio_chunk(sr, y) | |
return connections[client_id].transcript | |
# ------------------ UI ------------------ | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π§ Document AI + ποΈ Voice Assistant") | |
session_id = gr.State(value=reset_session()) | |
client_id = gr.State() | |
with gr.Row(): | |
image_display = gr.Image(label="π Extracted Document Image", show_label=True, height=360) | |
with gr.Column(): | |
chatbot = gr.Chatbot(label="π¬ Document Assistant", height=360) | |
text_input = gr.Textbox(label="Ask about the document", placeholder="e.g. What is clause 3.2?") | |
send_btn = gr.Button("Send") | |
send_btn.click(handle_chat, inputs=[text_input, chatbot, session_id], outputs=[chatbot, image_display]) | |
text_input.submit(handle_chat, inputs=[text_input, chatbot, session_id], outputs=[chatbot, image_display]) | |
# Toggle Section | |
with gr.Accordion("π€ Or Use Voice Instead", open=False): | |
with gr.Row(): | |
transcript_box = gr.Textbox(label="Live Transcript", lines=7, interactive=False, autoscroll=True) | |
with gr.Row(): | |
mic_input = gr.Audio(streaming=True) | |
clear_button = gr.Button("Clear Transcript") | |
mic_input.stream(fn=send_audio_chunk, inputs=[mic_input, client_id], outputs=transcript_box) | |
clear_button.click(fn=clear_transcript, inputs=[client_id], outputs=transcript_box) | |
demo.load(fn=create_websocket_client, outputs=client_id) | |
demo.launch() |