Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
import jwt | |
import time | |
import uuid | |
import requests | |
import os | |
import base64 | |
from functools import wraps | |
import logging | |
app = Flask(__name__) | |
CORS(app, origins=os.getenv('ALLOWED_ORIGINS', 'https://cybercity.top').split(',')) | |
# 环境变量配置 | |
CLIENT_ID = os.getenv('COZE_CLIENT_ID', '1243934778935') | |
KID = os.getenv('COZE_KID', 'tlrohMMZyKMrrpP3GtxF_3_cerDhVIMINs0LOW91m7w') | |
PRIVATE_KEY = os.getenv('COZE_PRIVATE_KEY').replace('\\n', '\n') # 从环境变量获取并格式化 | |
CLIENT_SECRET = os.getenv('COZE_CLIENT_SECRET', 'your_client_secret') | |
# 日志配置 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# JWT缓存机制(简易内存缓存) | |
jwt_cache = {'token': None, 'exp': 0} | |
def validate_basic_auth(auth_header): | |
"""实现RFC6749标准的Basic认证验证[10](@ref)""" | |
if not auth_header or not auth_header.startswith('Basic '): | |
return False | |
try: | |
credentials = base64.b64decode(auth_header[6:]).decode('utf-8') | |
client_id, client_secret = credentials.split(':', 1) | |
return client_id == CLIENT_ID and client_secret == CLIENT_SECRET | |
except Exception as e: | |
logger.error(f"Basic auth validation failed: {str(e)}") | |
return False | |
def generate_jwt(): | |
"""生成符合RFC7519标准的JWT[1,3](@ref)""" | |
current_time = int(time.time()) | |
payload = { | |
"iss": CLIENT_ID, | |
"sub": CLIENT_ID, # 必须包含sub字段[6](@ref) | |
"aud": "https://api.coze.cn", # 精确的URI格式 | |
"iat": current_time, | |
"exp": current_time + 3600, | |
"jti": uuid.uuid4().hex, | |
"connector_id": CLIENT_ID, # 统一使用client_id | |
"user_id": CLIENT_ID | |
} | |
header = { | |
"alg": "RS256", | |
"typ": "JWT", | |
"kid": KID | |
} | |
try: | |
return jwt.encode(payload, PRIVATE_KEY, algorithm="RS256", headers=header) | |
except jwt.PyJWTError as e: | |
logger.error(f"JWT generation failed: {str(e)}") | |
raise | |
def get_access_token(jwt_token): | |
"""获取访问令牌(带重试机制)[3](@ref)""" | |
url = "https://api.coze.cn/api/permission/oauth2/token" | |
data = { | |
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", | |
"duration_seconds": 86399 | |
} | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {jwt_token}" | |
} | |
try: | |
response = requests.post(url, json=data, headers=headers, timeout=10) | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Access token request failed: {str(e)}") | |
return {"error": "coze_api_error"} | |
# 健康检查端点 | |
def health_check(): | |
return jsonify({"status": "healthy", "timestamp": int(time.time())}), 200 | |
# 令牌获取端点 | |
def get_coze_token(): | |
# Basic认证验证 | |
if not validate_basic_auth(request.headers.get('Authorization')): | |
return jsonify({"error": "invalid_client"}), 401 | |
# 检查缓存中的有效JWT | |
current_time = time.time() | |
if jwt_cache['exp'] > current_time + 300: # 有效期剩余超过5分钟时复用 | |
cached_token = jwt_cache['token'] | |
else: | |
try: | |
cached_token = generate_jwt() | |
jwt_cache.update({ | |
'token': cached_token, | |
'exp': current_time + 3600 | |
}) | |
except Exception as e: | |
return jsonify({"error": "jwt_generation_failed"}), 500 | |
# 获取访问令牌 | |
token_response = get_access_token(cached_token) | |
if 'error' in token_response: | |
return jsonify({ | |
"error": "coze_oauth_error", | |
"details": token_response.get('error_description') | |
}), 502 | |
return jsonify({ | |
"access_token": token_response['access_token'], | |
"expires_in": token_response['expires_in'], | |
"token_type": "Bearer" | |
}) | |
# 错误处理 | |
def not_found(error): | |
return jsonify({"error": "endpoint_not_found"}), 404 | |
def internal_error(error): | |
return jsonify({"error": "internal_server_error"}), 500 | |
if __name__ == '__main__': | |
port = int(os.getenv('PORT', 7860)) | |
app.run(host='0.0.0.0', port=port, debug=os.getenv('DEBUG', 'false').lower() == 'true') |