Spaces:
Sleeping
Sleeping
import os | |
import logging | |
from typing import List, Dict, Any, Optional, Union | |
from PIL import Image | |
from app.api import get_chat_completion | |
import json | |
from app.config import ( | |
STICKER_RERANKING_SYSTEM_PROMPT, | |
PUBLIC_URL | |
) | |
from app.database import db | |
from app.image_utils import ( | |
save_image_temp, | |
upload_to_huggingface, | |
get_image_cdn_url, | |
get_image_description, | |
calculate_image_hash | |
) | |
from app.gradio_formatter import gradio_formatter | |
# 配置日志 | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') | |
logger = logging.getLogger(__name__) | |
class StickerService: | |
"""贴纸服务类,处理贴纸的上传、搜索等业务逻辑""" | |
def upload_sticker(image_file_path: str, title: str, description: str, tags: str) -> str: | |
"""上传贴纸""" | |
try: | |
# 打开图片 | |
image = Image.open(image_file_path) | |
# 检查文件名是否已存在 | |
image_hash = calculate_image_hash(image) | |
if db.check_image_exists(image_hash): | |
print(f"文件已存在", image_hash) | |
raise Exception('File_Exists') | |
# 上传到 HuggingFace | |
file_path, image_filename = upload_to_huggingface(image_file_path) | |
# print('>>>> image_file_path', image_file_path) | |
# print('>>>> image_filename', image_filename) | |
# print('>>>> file_path', file_path) | |
# 如果没有描述,获取图片描述 | |
if not description: | |
image_cdn_url = '' | |
if (PUBLIC_URL): | |
image_cdn_url = f'{PUBLIC_URL}/gradio_api/file={image_file_path}' | |
else: | |
image_cdn_url = get_image_cdn_url(file_path) | |
print('image_cdn_url',image_cdn_url) | |
description = get_image_description(image_cdn_url) | |
# 清理临时文件 | |
# os.unlink(temp_file_path) | |
# 存储到 Milvus | |
db.store_sticker(title, description, tags, file_path, image_hash) | |
return f"Upload successful! {image_filename}" | |
except Exception as e: | |
logger.error(f"Upload failed: {str(e)}") | |
return f"Upload failed: {str(e)}" | |
def search_stickers(description: str, limit: int = 2, reranking : bool = False) -> List[Dict[str, Any]]: | |
"""搜索贴纸""" | |
if not description: | |
return [] | |
try: | |
results = db.search_stickers(description, limit) | |
if (reranking): | |
# 对搜索结果进行重排 | |
results = StickerService.rerank_search_results(description, results, limit) | |
return results | |
except Exception as e: | |
logger.error(f"Search failed: {str(e)}") | |
return [] | |
def get_all_stickers(limit: int = 1000) -> List[List]: | |
"""获取所有贴纸""" | |
try: | |
results = db.get_all_stickers(limit) | |
return gradio_formatter.format_all_stickers(results) | |
except Exception as e: | |
logger.error(f"Failed to get all stickers: {str(e)}") | |
return [] | |
def delete_sticker(sticker_id: str) -> str: | |
"""删除贴纸""" | |
try: | |
# 首先查询贴纸是否存在 | |
result = db.delete_sticker(sticker_id) | |
return f"Sticker with ID {sticker_id} deleted successfully" | |
except Exception as e: | |
logger.error(f"Delete failed: {str(e)}") | |
return f"Delete failed: {str(e)}" | |
def rerank_search_results(query: str, sticker_list: List[Dict[str, Any]], limit: int = 5) -> List[Dict[str, Any]]: | |
## 使用 LLM 模型重新排序搜索结果 | |
try: | |
# 构建提示词 | |
system_prompt = STICKER_RERANKING_SYSTEM_PROMPT | |
# 构建用户提示词,包含查询和表情包信息 | |
_sticker_list = [] | |
for hit in sticker_list: | |
_sticker_list.append({ | |
"id": hit["id"], | |
"description": hit["entity"]["description"] | |
}) | |
user_prompt = f"请分析关键词 '{query}' 与以下表情包的相关性:\n{_sticker_list}" | |
print(f">>> 使用 LLM 模型重新排序....", user_prompt, system_prompt) | |
# 调用 LLM 模型获取重排序结果 | |
response = get_chat_completion(user_prompt, system_prompt) | |
# 解析 LLM 返回的 JSON 结果 | |
reranked_stickers = json.loads(response) | |
# 验证返回结果格式 | |
if not isinstance(reranked_stickers, list): | |
raise ValueError("Invalid response format") | |
# 按分数排序 | |
reranked_stickers.sort(key=lambda x: float(x.get("score", 0)), reverse=True) | |
print(f">>> LLM 排序结果", reranked_stickers) | |
# 将重排序结果与原始结果对应 | |
rerank_results = [] | |
for sticker in reranked_stickers: | |
for hit in sticker_list: | |
if str(hit["id"]) == str(sticker["sticker_id"]): | |
hit["entity"]["score"] = sticker["score"] | |
hit["entity"]["reason"] = sticker["reason"] | |
rerank_results.append(hit) | |
break | |
print(f">>> rerank_results", rerank_results) | |
return rerank_results | |
except Exception as e: | |
logger.error(f"Reranking failed: {str(e)}") | |
return [] | |
# 创建服务实例 | |
sticker_service = StickerService() |