tommytracx commited on
Commit
c6552d6
·
verified ·
1 Parent(s): 5b11611

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -34
app.py CHANGED
@@ -1,41 +1,243 @@
1
  import gradio as gr
2
- from fastapi import FastAPI, UploadFile, File, Request
3
- from fastapi.middleware.cors import CORSMiddleware
4
- from fastapi.responses import StreamingResponse
5
  from fastapi.staticfiles import StaticFiles
6
- from app.agent import process_text
7
- from app.speech_to_text import transcribe_audio
8
- from app.text_to_speech import synthesize_speech
9
- import io
 
 
 
 
 
10
 
11
- app = FastAPI()
 
 
12
 
13
- app.add_middleware(
14
- CORSMiddleware,
15
- allow_origins=["*"],
16
- allow_methods=["*"],
17
- allow_headers=["*"],
18
- )
19
 
20
- app.mount("/", StaticFiles(directory="frontend", html=True), name="frontend")
 
 
21
 
22
- @app.post("/transcribe")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  async def transcribe(file: UploadFile = File(...)):
24
- audio_bytes = await file.read()
25
- text = transcribe_audio(audio_bytes)
26
- return {"transcription": text}
27
-
28
- @app.post("/query")
29
- async def query_agent(request: Request):
30
- data = await request.json()
31
- input_text = data.get("input_text", "")
32
- response = process_text(input_text)
33
- return {"response": response}
34
-
35
- @app.get("/speak")
36
- async def speak(text: str):
37
- audio = synthesize_speech(text)
38
- return StreamingResponse(io.BytesIO(audio), media_type="audio/wav")
39
-
40
- # Required for Hugging Face Spaces
41
- gradio_app = gr.mount_gradio_app(app, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import fastapi
 
 
3
  from fastapi.staticfiles import StaticFiles
4
+ from fastapi.responses import HTMLResponse, FileResponse
5
+ from fastapi import FastAPI, Request, Form, UploadFile, File
6
+ import os
7
+ import time
8
+ import logging
9
+ import json
10
+ import shutil
11
+ import uvicorn
12
+ from pathlib import Path
13
 
14
+ # Setup logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
 
18
+ # Create the FastAPI app
19
+ app = FastAPI(title="AGI Telecom POC")
 
 
 
 
20
 
21
+ # Create static directory if it doesn't exist
22
+ static_dir = Path("static")
23
+ static_dir.mkdir(exist_ok=True)
24
 
25
+ # Copy index.html from templates to static if it doesn't exist
26
+ html_template = Path("templates/index.html")
27
+ static_html = static_dir / "index.html"
28
+ if html_template.exists() and not static_html.exists():
29
+ shutil.copy(html_template, static_html)
30
+
31
+ # Mount static files
32
+ app.mount("/static", StaticFiles(directory="static"), name="static")
33
+
34
+ # Mock data and functions to simulate the real implementation
35
+ SESSIONS = {}
36
+
37
+ def generate_session_id():
38
+ """Generate a unique session ID."""
39
+ import uuid
40
+ return str(uuid.uuid4())
41
+
42
+ def mock_transcribe(audio_bytes):
43
+ """Mock function to simulate speech-to-text."""
44
+ # In production, this would use Whisper
45
+ logger.info("Transcribing audio...")
46
+ time.sleep(1) # Simulate processing time
47
+ return "This is a mock transcription of the audio."
48
+
49
+ def mock_agent_response(text, session_id="default"):
50
+ """Mock function to simulate agent reasoning."""
51
+ # In production, this would use a real LLM
52
+ logger.info(f"Processing query: {text}")
53
+ time.sleep(1.5) # Simulate processing time
54
+
55
+ # Simple keyword-based responses
56
+ if "5g" in text.lower():
57
+ return "5G is the fifth generation of cellular networks, offering higher speeds, lower latency, and more capacity than previous generations."
58
+ elif "telecom" in text.lower():
59
+ return "Telecommunications (telecom) refers to the exchange of information over significant distances by electronic means."
60
+ elif "webrtc" in text.lower():
61
+ return "WebRTC (Web Real-Time Communication) is a free, open-source project that enables web browsers and mobile applications to have real-time communication via simple APIs."
62
+ else:
63
+ return "I'm an AI assistant specialized in telecom topics. Feel free to ask me about 5G, network technologies, or telecommunications in general."
64
+
65
+ def mock_synthesize_speech(text):
66
+ """Mock function to simulate text-to-speech."""
67
+ # In production, this would use a real TTS engine
68
+ logger.info("Synthesizing speech...")
69
+ time.sleep(0.5) # Simulate processing time
70
+
71
+ # Create a dummy audio file
72
+ import numpy as np
73
+ from scipy.io.wavfile import write
74
+
75
+ sample_rate = 22050
76
+ duration = 2 # seconds
77
+ t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
78
+ audio = np.sin(2 * np.pi * 440 * t) * 0.3
79
+
80
+ output_file = "temp_audio.wav"
81
+ write(output_file, sample_rate, audio.astype(np.float32))
82
+
83
+ with open(output_file, "rb") as f:
84
+ audio_bytes = f.read()
85
+
86
+ # Clean up
87
+ os.remove(output_file)
88
+
89
+ return audio_bytes
90
+
91
+ # Routes for the API
92
+ @app.get("/", response_class=HTMLResponse)
93
+ async def root():
94
+ """Serve the main UI."""
95
+ return FileResponse("static/index.html")
96
+
97
+ @app.post("/api/transcribe")
98
  async def transcribe(file: UploadFile = File(...)):
99
+ """Transcribe audio to text."""
100
+ try:
101
+ audio_bytes = await file.read()
102
+ text = mock_transcribe(audio_bytes)
103
+ return {"transcription": text}
104
+ except Exception as e:
105
+ logger.error(f"Transcription error: {str(e)}")
106
+ return {"error": f"Failed to transcribe audio: {str(e)}"}
107
+
108
+ @app.post("/api/query")
109
+ async def query_agent(input_text: str = Form(...), session_id: str = Form("default")):
110
+ """Process a text query with the agent."""
111
+ try:
112
+ response = mock_agent_response(input_text, session_id)
113
+ return {"response": response}
114
+ except Exception as e:
115
+ logger.error(f"Query error: {str(e)}")
116
+ return {"error": f"Failed to process query: {str(e)}"}
117
+
118
+ @app.post("/api/speak")
119
+ async def speak(text: str = Form(...)):
120
+ """Convert text to speech."""
121
+ try:
122
+ audio_bytes = mock_synthesize_speech(text)
123
+ return FileResponse(
124
+ "temp_audio.wav",
125
+ media_type="audio/wav",
126
+ filename="response.wav"
127
+ )
128
+ except Exception as e:
129
+ logger.error(f"Speech synthesis error: {str(e)}")
130
+ return {"error": f"Failed to synthesize speech: {str(e)}"}
131
+
132
+ @app.post("/api/session")
133
+ async def create_session():
134
+ """Create a new session."""
135
+ session_id = generate_session_id()
136
+ SESSIONS[session_id] = {"created_at": time.time()}
137
+ return {"session_id": session_id}
138
+
139
+ # Gradio interface
140
+ with gr.Blocks(title="AGI Telecom POC", css="footer {visibility: hidden}") as interface:
141
+ gr.Markdown("# AGI Telecom POC Demo")
142
+ gr.Markdown("This is a demonstration of the AGI Telecom Proof of Concept. The full interface is available via the direct API.")
143
+
144
+ with gr.Row():
145
+ with gr.Column():
146
+ # Input components
147
+ audio_input = gr.Audio(label="Voice Input", type="filepath")
148
+ text_input = gr.Textbox(label="Text Input", placeholder="Type your message here...", lines=2)
149
+
150
+ # Session management
151
+ session_id = gr.Textbox(label="Session ID", value="default")
152
+ new_session_btn = gr.Button("New Session")
153
+
154
+ # Action buttons
155
+ with gr.Row():
156
+ transcribe_btn = gr.Button("Transcribe Audio")
157
+ query_btn = gr.Button("Send Query")
158
+ speak_btn = gr.Button("Speak Response")
159
+
160
+ with gr.Column():
161
+ # Output components
162
+ transcription_output = gr.Textbox(label="Transcription", lines=2)
163
+ response_output = gr.Textbox(label="Agent Response", lines=5)
164
+ audio_output = gr.Audio(label="Voice Response", autoplay=True)
165
+
166
+ # Status and info
167
+ status_output = gr.Textbox(label="Status", value="Ready")
168
+
169
+ # Link components with functions
170
+ def update_session():
171
+ new_id = generate_session_id()
172
+ status = f"Created new session: {new_id}"
173
+ return new_id, status
174
+
175
+ new_session_btn.click(
176
+ update_session,
177
+ outputs=[session_id, status_output]
178
+ )
179
+
180
+ def process_audio(audio_path, session):
181
+ if not audio_path:
182
+ return "No audio provided", "", None, "Error: No audio input"
183
+
184
+ try:
185
+ with open(audio_path, "rb") as f:
186
+ audio_bytes = f.read()
187
+
188
+ # Transcribe
189
+ text = mock_transcribe(audio_bytes)
190
+
191
+ # Get response
192
+ response = mock_agent_response(text, session)
193
+
194
+ # Synthesize
195
+ audio_bytes = mock_synthesize_speech(response)
196
+
197
+ temp_file = "temp_response.wav"
198
+ with open(temp_file, "wb") as f:
199
+ f.write(audio_bytes)
200
+
201
+ return text, response, temp_file, "Processed successfully"
202
+ except Exception as e:
203
+ logger.error(f"Error: {str(e)}")
204
+ return "", "", None, f"Error: {str(e)}"
205
+
206
+ transcribe_btn.click(
207
+ lambda audio_path: mock_transcribe(open(audio_path, "rb").read()) if audio_path else "No audio provided",
208
+ inputs=[audio_input],
209
+ outputs=[transcription_output]
210
+ )
211
+
212
+ query_btn.click(
213
+ lambda text, session: mock_agent_response(text, session),
214
+ inputs=[text_input, session_id],
215
+ outputs=[response_output]
216
+ )
217
+
218
+ speak_btn.click(
219
+ lambda text: "temp_response.wav" if mock_synthesize_speech(text) else None,
220
+ inputs=[response_output],
221
+ outputs=[audio_output]
222
+ )
223
+
224
+ # Full process
225
+ audio_input.change(
226
+ process_audio,
227
+ inputs=[audio_input, session_id],
228
+ outputs=[transcription_output, response_output, audio_output, status_output]
229
+ )
230
+
231
+ # Mount Gradio app
232
+ app = gr.mount_gradio_app(app, interface, path="/gradio")
233
+
234
+ # Run the app
235
+ if __name__ == "__main__":
236
+ # Check if running on HF Spaces
237
+ if os.environ.get("SPACE_ID"):
238
+ # Running on HF Spaces - use their port
239
+ port = int(os.environ.get("PORT", 7860))
240
+ uvicorn.run(app, host="0.0.0.0", port=port)
241
+ else:
242
+ # Running locally
243
+ uvicorn.run(app, host="0.0.0.0", port=8000)