Spaces:
Paused
Paused
""" | |
Main FastAPI application integrating all components with Hugging Face Inference Endpoint. | |
""" | |
import gradio as gr | |
import fastapi | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse | |
from fastapi import FastAPI, Request, Form, UploadFile, File | |
import os | |
import time | |
import logging | |
import json | |
import shutil | |
import uvicorn | |
from pathlib import Path | |
from typing import Dict, List, Optional, Any | |
import io | |
import numpy as np | |
from scipy.io.wavfile import write | |
# Import our modules | |
from local_llm import run_llm, run_llm_with_memory, clear_memory, get_memory_sessions, get_model_info, test_endpoint | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Create the FastAPI app | |
app = FastAPI(title="AGI Telecom POC") | |
# Create static directory if it doesn't exist | |
static_dir = Path("static") | |
static_dir.mkdir(exist_ok=True) | |
# Copy index.html from templates to static if it doesn't exist | |
html_template = Path("templates/index.html") | |
static_html = static_dir / "index.html" | |
if html_template.exists() and not static_html.exists(): | |
shutil.copy(html_template, static_html) | |
# Mount static files | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Helper functions for mock implementations | |
def mock_transcribe(audio_bytes): | |
"""Mock function to simulate speech-to-text.""" | |
logger.info("Transcribing audio...") | |
time.sleep(0.5) # Simulate processing time | |
return "This is a mock transcription of the audio." | |
def mock_synthesize_speech(text): | |
"""Mock function to simulate text-to-speech.""" | |
logger.info("Synthesizing speech...") | |
time.sleep(0.5) # Simulate processing time | |
# Create a dummy audio file | |
sample_rate = 22050 | |
duration = 2 # seconds | |
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) | |
audio = np.sin(2 * np.pi * 440 * t) * 0.3 | |
output_file = "temp_audio.wav" | |
write(output_file, sample_rate, audio.astype(np.float32)) | |
with open(output_file, "rb") as f: | |
audio_bytes = f.read() | |
return audio_bytes | |
# Routes for the API | |
async def root(): | |
"""Serve the main UI.""" | |
return FileResponse("static/index.html") | |
async def health_check(): | |
"""Health check endpoint.""" | |
endpoint_status = test_endpoint() | |
return { | |
"status": "ok", | |
"endpoint": endpoint_status | |
} | |
async def transcribe(file: UploadFile = File(...)): | |
"""Transcribe audio to text.""" | |
try: | |
audio_bytes = await file.read() | |
text = mock_transcribe(audio_bytes) | |
return {"transcription": text} | |
except Exception as e: | |
logger.error(f"Transcription error: {str(e)}") | |
return JSONResponse( | |
status_code=500, | |
content={"error": f"Failed to transcribe audio: {str(e)}"} | |
) | |
async def query_agent(input_text: str = Form(...), session_id: str = Form("default")): | |
"""Process a text query with the agent.""" | |
try: | |
response = run_llm_with_memory(input_text, session_id=session_id) | |
logger.info(f"Query: {input_text[:30]}... Response: {response[:30]}...") | |
return {"response": response} | |
except Exception as e: | |
logger.error(f"Query error: {str(e)}") | |
return JSONResponse( | |
status_code=500, | |
content={"error": f"Failed to process query: {str(e)}"} | |
) | |
async def speak(text: str = Form(...)): | |
"""Convert text to speech.""" | |
try: | |
audio_bytes = mock_synthesize_speech(text) | |
return FileResponse( | |
"temp_audio.wav", | |
media_type="audio/wav", | |
filename="response.wav" | |
) | |
except Exception as e: | |
logger.error(f"Speech synthesis error: {str(e)}") | |
return JSONResponse( | |
status_code=500, | |
content={"error": f"Failed to synthesize speech: {str(e)}"} | |
) | |
async def create_session(): | |
"""Create a new session.""" | |
import uuid | |
session_id = str(uuid.uuid4()) | |
clear_memory(session_id) | |
return {"session_id": session_id} | |
async def delete_session(session_id: str): | |
"""Delete a session.""" | |
success = clear_memory(session_id) | |
if success: | |
return {"message": f"Session {session_id} cleared"} | |
else: | |
return JSONResponse( | |
status_code=404, | |
content={"error": f"Session {session_id} not found"} | |
) | |
async def list_sessions(): | |
"""List all active sessions.""" | |
return {"sessions": get_memory_sessions()} | |
async def model_info(): | |
"""Get information about the model.""" | |
return get_model_info() | |
async def complete_flow( | |
request: Request, | |
audio_file: UploadFile = File(None), | |
text_input: str = Form(None), | |
session_id: str = Form("default") | |
): | |
""" | |
Complete flow: audio to text to agent to speech. | |
""" | |
try: | |
# If audio file provided, transcribe it | |
if audio_file: | |
audio_bytes = await audio_file.read() | |
text_input = mock_transcribe(audio_bytes) | |
logger.info(f"Transcribed input: {text_input[:30]}...") | |
# Process with agent | |
if not text_input: | |
return JSONResponse( | |
status_code=400, | |
content={"error": "No input provided"} | |
) | |
response = run_llm_with_memory(text_input, session_id=session_id) | |
logger.info(f"Agent response: {response[:30]}...") | |
# Synthesize speech | |
audio_bytes = mock_synthesize_speech(response) | |
# Save audio to a temporary file | |
import tempfile | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
temp_file.write(audio_bytes) | |
temp_file.close() | |
# Generate URL for audio | |
host = request.headers.get("host", "localhost") | |
scheme = request.headers.get("x-forwarded-proto", "http") | |
audio_url = f"{scheme}://{host}/audio/{os.path.basename(temp_file.name)}" | |
return { | |
"input": text_input, | |
"response": response, | |
"audio_url": audio_url | |
} | |
except Exception as e: | |
logger.error(f"Complete flow error: {str(e)}") | |
return JSONResponse( | |
status_code=500, | |
content={"error": f"Failed to process: {str(e)}"} | |
) | |
async def get_audio(filename: str): | |
""" | |
Serve temporary audio files. | |
""" | |
try: | |
# Ensure filename only contains safe characters | |
import re | |
if not re.match(r'^[a-zA-Z0-9_.-]+$', filename): | |
return JSONResponse( | |
status_code=400, | |
content={"error": "Invalid filename"} | |
) | |
temp_dir = tempfile.gettempdir() | |
file_path = os.path.join(temp_dir, filename) | |
if not os.path.exists(file_path): | |
return JSONResponse( | |
status_code=404, | |
content={"error": "File not found"} | |
) | |
return FileResponse( | |
file_path, | |
media_type="audio/wav", | |
filename=filename | |
) | |
except Exception as e: | |
logger.error(f"Audio serving error: {str(e)}") | |
return JSONResponse( | |
status_code=500, | |
content={"error": f"Failed to serve audio: {str(e)}"} | |
) | |
# Gradio interface | |
with gr.Blocks(title="AGI Telecom POC", css="footer {visibility: hidden}") as interface: | |
gr.Markdown("# AGI Telecom POC Demo") | |
gr.Markdown("This is a demonstration of the AGI Telecom Proof of Concept using a Hugging Face Inference Endpoint.") | |
with gr.Row(): | |
with gr.Column(): | |
# Input components | |
audio_input = gr.Audio(label="Voice Input", type="filepath") | |
text_input = gr.Textbox(label="Text Input", placeholder="Type your message here...", lines=2) | |
# Session management | |
session_id = gr.Textbox(label="Session ID", value="default") | |
new_session_btn = gr.Button("New Session") | |
# Action buttons | |
with gr.Row(): | |
transcribe_btn = gr.Button("Transcribe Audio") | |
query_btn = gr.Button("Send Query") | |
speak_btn = gr.Button("Speak Response") | |
with gr.Column(): | |
# Output components | |
transcription_output = gr.Textbox(label="Transcription", lines=2) | |
response_output = gr.Textbox(label="Agent Response", lines=5) | |
audio_output = gr.Audio(label="Voice Response", autoplay=True) | |
# Status and info | |
status_output = gr.Textbox(label="Status", value="Ready") | |
endpoint_status = gr.Textbox(label="Endpoint Status", value="Checking endpoint connection...") | |
# Link components with functions | |
def update_session(): | |
import uuid | |
new_id = str(uuid.uuid4()) | |
clear_memory(new_id) | |
status = f"Created new session: {new_id}" | |
return new_id, status | |
new_session_btn.click( | |
update_session, | |
outputs=[session_id, status_output] | |
) | |
def process_audio(audio_path, session): | |
if not audio_path: | |
return "No audio provided", "", None, "Error: No audio input" | |
try: | |
with open(audio_path, "rb") as f: | |
audio_bytes = f.read() | |
# Transcribe | |
text = mock_transcribe(audio_bytes) | |
# Get response | |
response = run_llm_with_memory(text, session) | |
# Synthesize | |
audio_bytes = mock_synthesize_speech(response) | |
temp_file = "temp_response.wav" | |
with open(temp_file, "wb") as f: | |
f.write(audio_bytes) | |
return text, response, temp_file, "Processed successfully" | |
except Exception as e: | |
logger.error(f"Error: {str(e)}") | |
return "", "", None, f"Error: {str(e)}" | |
transcribe_btn.click( | |
lambda audio_path: mock_transcribe(open(audio_path, "rb").read()) if audio_path else "No audio provided", | |
inputs=[audio_input], | |
outputs=[transcription_output] | |
) | |
query_btn.click( | |
lambda text, session: run_llm_with_memory(text, session), | |
inputs=[text_input, session_id], | |
outputs=[response_output] | |
) | |
speak_btn.click( | |
lambda text: "temp_response.wav" if mock_synthesize_speech(text) else None, | |
inputs=[response_output], | |
outputs=[audio_output] | |
) | |
# Full process | |
audio_input.change( | |
process_audio, | |
inputs=[audio_input, session_id], | |
outputs=[transcription_output, response_output, audio_output, status_output] | |
) | |
# Check endpoint on load | |
def check_endpoint(): | |
status = test_endpoint() | |
if status["status"] == "connected": | |
return f"✅ Connected to endpoint: {status['message']}" | |
else: | |
return f"❌ Error connecting to endpoint: {status['message']}" | |
gr.on_load(lambda: gr.update(value=check_endpoint()), outputs=endpoint_status) | |
# Mount Gradio app | |
app = gr.mount_gradio_app(app, interface, path="/gradio") | |
# Run the app | |
if __name__ == "__main__": | |
# Check if running on HF Spaces | |
if os.environ.get("SPACE_ID"): | |
# Running on HF Spaces - use their port | |
port = int(os.environ.get("PORT", 7860)) | |
uvicorn.run(app, host="0.0.0.0", port=port) | |
else: | |
# Running locally | |
uvicorn.run(app, host="0.0.0.0", port=8000) |