Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import asyncio | |
import json | |
import logging | |
import os | |
import pathlib | |
import time | |
from aiohttp import web, WSMsgType | |
from typing import Dict, Any | |
from api_core import VideoGenerationAPI | |
from api_config import * | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
async def process_generic_request(data: dict, ws: web.WebSocketResponse, api) -> None: | |
"""Handle general requests that don't fit into specialized queues""" | |
try: | |
request_id = data.get('requestId') | |
action = data.get('action') | |
def error_response(message: str): | |
return { | |
'action': action, | |
'requestId': request_id, | |
'success': False, | |
'error': message | |
} | |
if action == 'heartbeat': | |
# Include user role info in heartbeat response | |
user_role = getattr(ws, 'user_role', 'anon') | |
await ws.send_json({ | |
'action': 'heartbeat', | |
'requestId': request_id, | |
'success': True, | |
'user_role': user_role | |
}) | |
elif action == 'get_user_role': | |
# Return the user role information | |
user_role = getattr(ws, 'user_role', 'anon') | |
await ws.send_json({ | |
'action': 'get_user_role', | |
'requestId': request_id, | |
'success': True, | |
'user_role': user_role | |
}) | |
elif action == 'generate_caption': | |
title = data.get('params', {}).get('title') | |
description = data.get('params', {}).get('description') | |
if not title or not description: | |
await ws.send_json(error_response('Missing title or description')) | |
return | |
caption = await api.generate_caption(title, description) | |
await ws.send_json({ | |
'action': action, | |
'requestId': request_id, | |
'success': True, | |
'caption': caption | |
}) | |
elif action == 'generate_thumbnail': | |
title = data.get('params', {}).get('title') | |
description = data.get('params', {}).get('description') | |
if not title or not description: | |
await ws.send_json(error_response('Missing title or description')) | |
return | |
thumbnail = await api.generate_thumbnail(title, description) | |
await ws.send_json({ | |
'action': action, | |
'requestId': request_id, | |
'success': True, | |
'thumbnailUrl': thumbnail | |
}) | |
else: | |
await ws.send_json(error_response(f'Unknown action: {action}')) | |
except Exception as e: | |
logger.error(f"Error processing generic request: {str(e)}") | |
try: | |
await ws.send_json({ | |
'action': data.get('action'), | |
'requestId': data.get('requestId'), | |
'success': False, | |
'error': f'Internal server error: {str(e)}' | |
}) | |
except Exception as send_error: | |
logger.error(f"Error sending error response: {send_error}") | |
async def process_search_queue(queue: asyncio.Queue, ws: web.WebSocketResponse, api): | |
"""Medium priority queue for search operations""" | |
while True: | |
try: | |
data = await queue.get() | |
request_id = data.get('requestId') | |
query = data.get('query', '').strip() | |
search_count = data.get('searchCount', 0) | |
attempt_count = data.get('attemptCount', 0) | |
logger.info(f"Processing search request: query='{query}', search_count={search_count}, attempt={attempt_count}") | |
if not query: | |
logger.warning(f"Empty query received in request: {data}") | |
result = { | |
'action': 'search', | |
'requestId': request_id, | |
'success': False, | |
'error': 'No search query provided' | |
} | |
else: | |
try: | |
search_result = await api.search_video( | |
query, | |
search_count=search_count, | |
attempt_count=attempt_count | |
) | |
if search_result: | |
logger.info(f"Search successful for query '{query}' (#{search_count})") | |
result = { | |
'action': 'search', | |
'requestId': request_id, | |
'success': True, | |
'result': search_result | |
} | |
else: | |
logger.warning(f"No results found for query '{query}' (#{search_count})") | |
result = { | |
'action': 'search', | |
'requestId': request_id, | |
'success': False, | |
'error': 'No results found' | |
} | |
except Exception as e: | |
logger.error(f"Search error for query '{query}' (#{search_count}, attempt {attempt_count}): {str(e)}") | |
result = { | |
'action': 'search', | |
'requestId': request_id, | |
'success': False, | |
'error': f'Search error: {str(e)}' | |
} | |
await ws.send_json(result) | |
except Exception as e: | |
logger.error(f"Error in search queue processor: {str(e)}") | |
try: | |
error_response = { | |
'action': 'search', | |
'requestId': data.get('requestId') if 'data' in locals() else None, | |
'success': False, | |
'error': f'Internal server error: {str(e)}' | |
} | |
await ws.send_json(error_response) | |
except Exception as send_error: | |
logger.error(f"Error sending error response: {send_error}") | |
finally: | |
if 'queue' in locals(): | |
queue.task_done() | |
async def process_chat_queue(queue: asyncio.Queue, ws: web.WebSocketResponse): | |
"""High priority queue for chat operations""" | |
while True: | |
data = await queue.get() | |
try: | |
api = ws.app['api'] | |
if data['action'] == 'join_chat': | |
result = await api.handle_join_chat(data, ws) | |
elif data['action'] == 'chat_message': | |
result = await api.handle_chat_message(data, ws) | |
elif data['action'] == 'leave_chat': | |
result = await api.handle_leave_chat(data, ws) | |
await ws.send_json(result) | |
except Exception as e: | |
logger.error(f"Error processing chat request: {e}") | |
try: | |
await ws.send_json({ | |
'action': data['action'], | |
'requestId': data.get('requestId'), | |
'success': False, | |
'error': f'Chat error: {str(e)}' | |
}) | |
except Exception as send_error: | |
logger.error(f"Error sending error response: {send_error}") | |
finally: | |
queue.task_done() | |
async def process_video_queue(queue: asyncio.Queue, ws: web.WebSocketResponse): | |
"""Process multiple video generation requests in parallel""" | |
active_tasks = set() | |
MAX_CONCURRENT = len(VIDEO_ROUND_ROBIN_ENDPOINT_URLS) # Match client's max concurrent generations | |
async def process_single_request(data): | |
try: | |
api = ws.app['api'] | |
title = data.get('title', '') | |
description = data.get('description', '') | |
video_prompt_prefix = data.get('video_prompt_prefix', '') | |
options = data.get('options', {}) | |
# Get the user role from the websocket | |
user_role = getattr(ws, 'user_role', 'anon') | |
# Pass the user role to generate_video | |
video_data = await api.generate_video(title, description, video_prompt_prefix, options, user_role) | |
result = { | |
'action': 'generate_video', | |
'requestId': data.get('requestId'), | |
'success': True, | |
'video': video_data, | |
} | |
await ws.send_json(result) | |
except Exception as e: | |
logger.error(f"Error processing video request: {e}") | |
try: | |
await ws.send_json({ | |
'action': 'generate_video', | |
'requestId': data.get('requestId'), | |
'success': False, | |
'error': f'Video generation error: {str(e)}' | |
}) | |
except Exception as send_error: | |
logger.error(f"Error sending error response: {send_error}") | |
finally: | |
active_tasks.discard(asyncio.current_task()) | |
while True: | |
# Clean up completed tasks | |
active_tasks = {task for task in active_tasks if not task.done()} | |
# Start new tasks if we have capacity | |
while len(active_tasks) < MAX_CONCURRENT: | |
try: | |
# Use try_get to avoid blocking if queue is empty | |
data = await asyncio.wait_for(queue.get(), timeout=0.1) | |
# Create and start new task | |
task = asyncio.create_task(process_single_request(data)) | |
active_tasks.add(task) | |
except asyncio.TimeoutError: | |
# No items in queue, break inner loop | |
break | |
except Exception as e: | |
logger.error(f"Error creating video generation task: {e}") | |
break | |
# Wait a short time before checking queue again | |
await asyncio.sleep(0.1) | |
# Handle any completed tasks' errors | |
for task in list(active_tasks): | |
if task.done(): | |
try: | |
await task | |
except Exception as e: | |
logger.error(f"Task failed with error: {e}") | |
active_tasks.discard(task) | |
async def status_handler(request: web.Request) -> web.Response: | |
"""Handler for API status endpoint""" | |
api = request.app['api'] | |
# Get current busy status of all endpoints | |
endpoint_statuses = [] | |
for ep in api.endpoint_manager.endpoints: | |
endpoint_statuses.append({ | |
'id': ep.id, | |
'url': ep.url, | |
'busy': ep.busy, | |
'last_used': ep.last_used, | |
'error_count': ep.error_count, | |
'error_until': ep.error_until | |
}) | |
return web.json_response({ | |
'product': PRODUCT_NAME, | |
'version': PRODUCT_VERSION, | |
'maintenance_mode': MAINTENANCE_MODE, | |
'available_endpoints': len(VIDEO_ROUND_ROBIN_ENDPOINT_URLS), | |
'endpoint_status': endpoint_statuses, | |
'active_endpoints': sum(1 for ep in endpoint_statuses if not ep['busy'] and ('error_until' not in ep or ep['error_until'] < time.time())) | |
}) | |
async def websocket_handler(request: web.Request) -> web.WebSocketResponse: | |
# Check if maintenance mode is enabled | |
if MAINTENANCE_MODE: | |
# Return an error response indicating maintenance mode | |
return web.json_response({ | |
'error': 'Server is in maintenance mode', | |
'maintenance': True | |
}, status=503) # 503 Service Unavailable | |
ws = web.WebSocketResponse( | |
max_msg_size=1024*1024*20, # 20MB max message size | |
timeout=30.0 # we want to keep things tight and short | |
) | |
ws.app = request.app | |
await ws.prepare(request) | |
api = request.app['api'] | |
# Get the Hugging Face token from query parameters | |
hf_token = request.query.get('hf_token', '') | |
# Validate the token and determine the user role | |
user_role = await api.validate_user_token(hf_token) | |
logger.info(f"User connected with role: {user_role}") | |
# Store the user role in the websocket | |
ws.user_role = user_role | |
# Create separate queues for different request types | |
chat_queue = asyncio.Queue() | |
video_queue = asyncio.Queue() | |
search_queue = asyncio.Queue() | |
# Start background tasks for handling different request types | |
background_tasks = [ | |
asyncio.create_task(process_chat_queue(chat_queue, ws)), | |
asyncio.create_task(process_video_queue(video_queue, ws)), | |
asyncio.create_task(process_search_queue(search_queue, ws, api)) | |
] | |
try: | |
async for msg in ws: | |
if msg.type == WSMsgType.TEXT: | |
try: | |
data = json.loads(msg.data) | |
action = data.get('action') | |
# Route requests to appropriate queues | |
if action in ['join_chat', 'leave_chat', 'chat_message']: | |
await chat_queue.put(data) | |
elif action in ['generate_video']: | |
await video_queue.put(data) | |
elif action == 'search': | |
await search_queue.put(data) | |
else: | |
await process_generic_request(data, ws, api) | |
except Exception as e: | |
logger.error(f"Error processing WebSocket message: {str(e)}") | |
await ws.send_json({ | |
'action': data.get('action') if 'data' in locals() else 'unknown', | |
'success': False, | |
'error': f'Error processing message: {str(e)}' | |
}) | |
elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE): | |
break | |
finally: | |
# Cleanup | |
for task in background_tasks: | |
task.cancel() | |
try: | |
await asyncio.gather(*background_tasks, return_exceptions=True) | |
except asyncio.CancelledError: | |
pass | |
return ws | |
async def init_app() -> web.Application: | |
app = web.Application( | |
client_max_size=1024**2*20 # 20MB max size | |
) | |
# Create API instance | |
api = VideoGenerationAPI() | |
app['api'] = api | |
# Add cleanup logic | |
async def cleanup_api(app): | |
# Add any necessary cleanup for the API | |
pass | |
app.on_shutdown.append(cleanup_api) | |
# Add routes | |
app.router.add_get('/ws', websocket_handler) | |
app.router.add_get('/api/status', status_handler) | |
# Set up static file serving | |
# Define the path to the public directory | |
public_path = pathlib.Path(__file__).parent / 'build' / 'web' | |
if not public_path.exists(): | |
public_path.mkdir(parents=True, exist_ok=True) | |
# Set up static file serving with proper security considerations | |
async def static_file_handler(request): | |
# Get the path from the request (removing leading /) | |
path_parts = request.path.lstrip('/').split('/') | |
# Convert to safe path to prevent path traversal attacks | |
safe_path = public_path.joinpath(*path_parts) | |
# Make sure the path is within the public directory (prevent directory traversal) | |
try: | |
safe_path = safe_path.resolve() | |
if not str(safe_path).startswith(str(public_path.resolve())): | |
return web.HTTPForbidden(text="Access denied") | |
except (ValueError, FileNotFoundError): | |
return web.HTTPNotFound() | |
# If path is a directory, look for index.html | |
if safe_path.is_dir(): | |
safe_path = safe_path / 'index.html' | |
# Check if the file exists | |
if not safe_path.exists() or not safe_path.is_file(): | |
# If not found, serve index.html (for SPA routing) | |
safe_path = public_path / 'index.html' | |
if not safe_path.exists(): | |
return web.HTTPNotFound() | |
# Determine content type based on file extension | |
content_type = 'text/plain' | |
ext = safe_path.suffix.lower() | |
if ext == '.html': | |
content_type = 'text/html' | |
elif ext == '.js': | |
content_type = 'application/javascript' | |
elif ext == '.css': | |
content_type = 'text/css' | |
elif ext in ('.jpg', '.jpeg'): | |
content_type = 'image/jpeg' | |
elif ext == '.png': | |
content_type = 'image/png' | |
elif ext == '.gif': | |
content_type = 'image/gif' | |
elif ext == '.svg': | |
content_type = 'image/svg+xml' | |
elif ext == '.json': | |
content_type = 'application/json' | |
# Return the file with appropriate headers | |
return web.FileResponse(safe_path, headers={'Content-Type': content_type}) | |
# Add catch-all route for static files (lower priority than API routes) | |
app.router.add_get('/{path:.*}', static_file_handler) | |
return app | |
if __name__ == '__main__': | |
app = asyncio.run(init_app()) | |
web.run_app(app, host='0.0.0.0', port=8080) |