aitube2 / api.py
jbilcke-hf's picture
jbilcke-hf HF Staff
up
59e0f6a
import asyncio
import json
import logging
import os
import pathlib
import time
from aiohttp import web, WSMsgType
from typing import Dict, Any
from api_core import VideoGenerationAPI
from api_config import *
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
async def process_generic_request(data: dict, ws: web.WebSocketResponse, api) -> None:
"""Handle general requests that don't fit into specialized queues"""
try:
request_id = data.get('requestId')
action = data.get('action')
def error_response(message: str):
return {
'action': action,
'requestId': request_id,
'success': False,
'error': message
}
if action == 'heartbeat':
# Include user role info in heartbeat response
user_role = getattr(ws, 'user_role', 'anon')
await ws.send_json({
'action': 'heartbeat',
'requestId': request_id,
'success': True,
'user_role': user_role
})
elif action == 'get_user_role':
# Return the user role information
user_role = getattr(ws, 'user_role', 'anon')
await ws.send_json({
'action': 'get_user_role',
'requestId': request_id,
'success': True,
'user_role': user_role
})
elif action == 'generate_caption':
title = data.get('params', {}).get('title')
description = data.get('params', {}).get('description')
if not title or not description:
await ws.send_json(error_response('Missing title or description'))
return
caption = await api.generate_caption(title, description)
await ws.send_json({
'action': action,
'requestId': request_id,
'success': True,
'caption': caption
})
elif action == 'generate_thumbnail':
title = data.get('params', {}).get('title')
description = data.get('params', {}).get('description')
if not title or not description:
await ws.send_json(error_response('Missing title or description'))
return
thumbnail = await api.generate_thumbnail(title, description)
await ws.send_json({
'action': action,
'requestId': request_id,
'success': True,
'thumbnailUrl': thumbnail
})
else:
await ws.send_json(error_response(f'Unknown action: {action}'))
except Exception as e:
logger.error(f"Error processing generic request: {str(e)}")
try:
await ws.send_json({
'action': data.get('action'),
'requestId': data.get('requestId'),
'success': False,
'error': f'Internal server error: {str(e)}'
})
except Exception as send_error:
logger.error(f"Error sending error response: {send_error}")
async def process_search_queue(queue: asyncio.Queue, ws: web.WebSocketResponse, api):
"""Medium priority queue for search operations"""
while True:
try:
data = await queue.get()
request_id = data.get('requestId')
query = data.get('query', '').strip()
search_count = data.get('searchCount', 0)
attempt_count = data.get('attemptCount', 0)
logger.info(f"Processing search request: query='{query}', search_count={search_count}, attempt={attempt_count}")
if not query:
logger.warning(f"Empty query received in request: {data}")
result = {
'action': 'search',
'requestId': request_id,
'success': False,
'error': 'No search query provided'
}
else:
try:
search_result = await api.search_video(
query,
search_count=search_count,
attempt_count=attempt_count
)
if search_result:
logger.info(f"Search successful for query '{query}' (#{search_count})")
result = {
'action': 'search',
'requestId': request_id,
'success': True,
'result': search_result
}
else:
logger.warning(f"No results found for query '{query}' (#{search_count})")
result = {
'action': 'search',
'requestId': request_id,
'success': False,
'error': 'No results found'
}
except Exception as e:
logger.error(f"Search error for query '{query}' (#{search_count}, attempt {attempt_count}): {str(e)}")
result = {
'action': 'search',
'requestId': request_id,
'success': False,
'error': f'Search error: {str(e)}'
}
await ws.send_json(result)
except Exception as e:
logger.error(f"Error in search queue processor: {str(e)}")
try:
error_response = {
'action': 'search',
'requestId': data.get('requestId') if 'data' in locals() else None,
'success': False,
'error': f'Internal server error: {str(e)}'
}
await ws.send_json(error_response)
except Exception as send_error:
logger.error(f"Error sending error response: {send_error}")
finally:
if 'queue' in locals():
queue.task_done()
async def process_chat_queue(queue: asyncio.Queue, ws: web.WebSocketResponse):
"""High priority queue for chat operations"""
while True:
data = await queue.get()
try:
api = ws.app['api']
if data['action'] == 'join_chat':
result = await api.handle_join_chat(data, ws)
elif data['action'] == 'chat_message':
result = await api.handle_chat_message(data, ws)
elif data['action'] == 'leave_chat':
result = await api.handle_leave_chat(data, ws)
await ws.send_json(result)
except Exception as e:
logger.error(f"Error processing chat request: {e}")
try:
await ws.send_json({
'action': data['action'],
'requestId': data.get('requestId'),
'success': False,
'error': f'Chat error: {str(e)}'
})
except Exception as send_error:
logger.error(f"Error sending error response: {send_error}")
finally:
queue.task_done()
async def process_video_queue(queue: asyncio.Queue, ws: web.WebSocketResponse):
"""Process multiple video generation requests in parallel"""
active_tasks = set()
MAX_CONCURRENT = len(VIDEO_ROUND_ROBIN_ENDPOINT_URLS) # Match client's max concurrent generations
async def process_single_request(data):
try:
api = ws.app['api']
title = data.get('title', '')
description = data.get('description', '')
video_prompt_prefix = data.get('video_prompt_prefix', '')
options = data.get('options', {})
# Get the user role from the websocket
user_role = getattr(ws, 'user_role', 'anon')
# Pass the user role to generate_video
video_data = await api.generate_video(title, description, video_prompt_prefix, options, user_role)
result = {
'action': 'generate_video',
'requestId': data.get('requestId'),
'success': True,
'video': video_data,
}
await ws.send_json(result)
except Exception as e:
logger.error(f"Error processing video request: {e}")
try:
await ws.send_json({
'action': 'generate_video',
'requestId': data.get('requestId'),
'success': False,
'error': f'Video generation error: {str(e)}'
})
except Exception as send_error:
logger.error(f"Error sending error response: {send_error}")
finally:
active_tasks.discard(asyncio.current_task())
while True:
# Clean up completed tasks
active_tasks = {task for task in active_tasks if not task.done()}
# Start new tasks if we have capacity
while len(active_tasks) < MAX_CONCURRENT:
try:
# Use try_get to avoid blocking if queue is empty
data = await asyncio.wait_for(queue.get(), timeout=0.1)
# Create and start new task
task = asyncio.create_task(process_single_request(data))
active_tasks.add(task)
except asyncio.TimeoutError:
# No items in queue, break inner loop
break
except Exception as e:
logger.error(f"Error creating video generation task: {e}")
break
# Wait a short time before checking queue again
await asyncio.sleep(0.1)
# Handle any completed tasks' errors
for task in list(active_tasks):
if task.done():
try:
await task
except Exception as e:
logger.error(f"Task failed with error: {e}")
active_tasks.discard(task)
async def status_handler(request: web.Request) -> web.Response:
"""Handler for API status endpoint"""
api = request.app['api']
# Get current busy status of all endpoints
endpoint_statuses = []
for ep in api.endpoint_manager.endpoints:
endpoint_statuses.append({
'id': ep.id,
'url': ep.url,
'busy': ep.busy,
'last_used': ep.last_used,
'error_count': ep.error_count,
'error_until': ep.error_until
})
return web.json_response({
'product': PRODUCT_NAME,
'version': PRODUCT_VERSION,
'maintenance_mode': MAINTENANCE_MODE,
'available_endpoints': len(VIDEO_ROUND_ROBIN_ENDPOINT_URLS),
'endpoint_status': endpoint_statuses,
'active_endpoints': sum(1 for ep in endpoint_statuses if not ep['busy'] and ('error_until' not in ep or ep['error_until'] < time.time()))
})
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
# Check if maintenance mode is enabled
if MAINTENANCE_MODE:
# Return an error response indicating maintenance mode
return web.json_response({
'error': 'Server is in maintenance mode',
'maintenance': True
}, status=503) # 503 Service Unavailable
ws = web.WebSocketResponse(
max_msg_size=1024*1024*20, # 20MB max message size
timeout=30.0 # we want to keep things tight and short
)
ws.app = request.app
await ws.prepare(request)
api = request.app['api']
# Get the Hugging Face token from query parameters
hf_token = request.query.get('hf_token', '')
# Validate the token and determine the user role
user_role = await api.validate_user_token(hf_token)
logger.info(f"User connected with role: {user_role}")
# Store the user role in the websocket
ws.user_role = user_role
# Create separate queues for different request types
chat_queue = asyncio.Queue()
video_queue = asyncio.Queue()
search_queue = asyncio.Queue()
# Start background tasks for handling different request types
background_tasks = [
asyncio.create_task(process_chat_queue(chat_queue, ws)),
asyncio.create_task(process_video_queue(video_queue, ws)),
asyncio.create_task(process_search_queue(search_queue, ws, api))
]
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
action = data.get('action')
# Route requests to appropriate queues
if action in ['join_chat', 'leave_chat', 'chat_message']:
await chat_queue.put(data)
elif action in ['generate_video']:
await video_queue.put(data)
elif action == 'search':
await search_queue.put(data)
else:
await process_generic_request(data, ws, api)
except Exception as e:
logger.error(f"Error processing WebSocket message: {str(e)}")
await ws.send_json({
'action': data.get('action') if 'data' in locals() else 'unknown',
'success': False,
'error': f'Error processing message: {str(e)}'
})
elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
break
finally:
# Cleanup
for task in background_tasks:
task.cancel()
try:
await asyncio.gather(*background_tasks, return_exceptions=True)
except asyncio.CancelledError:
pass
return ws
async def init_app() -> web.Application:
app = web.Application(
client_max_size=1024**2*20 # 20MB max size
)
# Create API instance
api = VideoGenerationAPI()
app['api'] = api
# Add cleanup logic
async def cleanup_api(app):
# Add any necessary cleanup for the API
pass
app.on_shutdown.append(cleanup_api)
# Add routes
app.router.add_get('/ws', websocket_handler)
app.router.add_get('/api/status', status_handler)
# Set up static file serving
# Define the path to the public directory
public_path = pathlib.Path(__file__).parent / 'build' / 'web'
if not public_path.exists():
public_path.mkdir(parents=True, exist_ok=True)
# Set up static file serving with proper security considerations
async def static_file_handler(request):
# Get the path from the request (removing leading /)
path_parts = request.path.lstrip('/').split('/')
# Convert to safe path to prevent path traversal attacks
safe_path = public_path.joinpath(*path_parts)
# Make sure the path is within the public directory (prevent directory traversal)
try:
safe_path = safe_path.resolve()
if not str(safe_path).startswith(str(public_path.resolve())):
return web.HTTPForbidden(text="Access denied")
except (ValueError, FileNotFoundError):
return web.HTTPNotFound()
# If path is a directory, look for index.html
if safe_path.is_dir():
safe_path = safe_path / 'index.html'
# Check if the file exists
if not safe_path.exists() or not safe_path.is_file():
# If not found, serve index.html (for SPA routing)
safe_path = public_path / 'index.html'
if not safe_path.exists():
return web.HTTPNotFound()
# Determine content type based on file extension
content_type = 'text/plain'
ext = safe_path.suffix.lower()
if ext == '.html':
content_type = 'text/html'
elif ext == '.js':
content_type = 'application/javascript'
elif ext == '.css':
content_type = 'text/css'
elif ext in ('.jpg', '.jpeg'):
content_type = 'image/jpeg'
elif ext == '.png':
content_type = 'image/png'
elif ext == '.gif':
content_type = 'image/gif'
elif ext == '.svg':
content_type = 'image/svg+xml'
elif ext == '.json':
content_type = 'application/json'
# Return the file with appropriate headers
return web.FileResponse(safe_path, headers={'Content-Type': content_type})
# Add catch-all route for static files (lower priority than API routes)
app.router.add_get('/{path:.*}', static_file_handler)
return app
if __name__ == '__main__':
app = asyncio.run(init_app())
web.run_app(app, host='0.0.0.0', port=8080)