NekoAI-Lab / app /services.py
nekoko
feat: Sticker DB
1c2b077
import os
import logging
from typing import List, Dict, Any, Optional, Union
from PIL import Image
from app.api import get_chat_completion
import json
from app.config import (
STICKER_RERANKING_SYSTEM_PROMPT,
PUBLIC_URL
)
from app.database import db
from app.image_utils import (
save_image_temp,
upload_to_huggingface,
get_image_cdn_url,
get_image_description,
calculate_image_hash
)
from app.gradio_formatter import gradio_formatter
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger(__name__)
class StickerService:
"""贴纸服务类,处理贴纸的上传、搜索等业务逻辑"""
@staticmethod
def upload_sticker(image_file_path: str, title: str, description: str, tags: str) -> str:
"""上传贴纸"""
try:
# 打开图片
image = Image.open(image_file_path)
# 检查文件名是否已存在
image_hash = calculate_image_hash(image)
if db.check_image_exists(image_hash):
print(f"文件已存在", image_hash)
raise Exception('File_Exists')
# 上传到 HuggingFace
file_path, image_filename = upload_to_huggingface(image_file_path)
# print('>>>> image_file_path', image_file_path)
# print('>>>> image_filename', image_filename)
# print('>>>> file_path', file_path)
# 如果没有描述,获取图片描述
if not description:
image_cdn_url = ''
if (PUBLIC_URL):
image_cdn_url = f'{PUBLIC_URL}/gradio_api/file={image_file_path}'
else:
image_cdn_url = get_image_cdn_url(file_path)
print('image_cdn_url',image_cdn_url)
description = get_image_description(image_cdn_url)
# 清理临时文件
# os.unlink(temp_file_path)
# 存储到 Milvus
db.store_sticker(title, description, tags, file_path, image_hash)
return f"Upload successful! {image_filename}"
except Exception as e:
logger.error(f"Upload failed: {str(e)}")
return f"Upload failed: {str(e)}"
@staticmethod
def search_stickers(description: str, limit: int = 2, reranking : bool = False) -> List[Dict[str, Any]]:
"""搜索贴纸"""
if not description:
return []
try:
results = db.search_stickers(description, limit)
if (reranking):
# 对搜索结果进行重排
results = StickerService.rerank_search_results(description, results, limit)
return results
except Exception as e:
logger.error(f"Search failed: {str(e)}")
return []
@staticmethod
def get_all_stickers(limit: int = 1000) -> List[List]:
"""获取所有贴纸"""
try:
results = db.get_all_stickers(limit)
return gradio_formatter.format_all_stickers(results)
except Exception as e:
logger.error(f"Failed to get all stickers: {str(e)}")
return []
@staticmethod
def delete_sticker(sticker_id: str) -> str:
"""删除贴纸"""
try:
# 首先查询贴纸是否存在
result = db.delete_sticker(sticker_id)
return f"Sticker with ID {sticker_id} deleted successfully"
except Exception as e:
logger.error(f"Delete failed: {str(e)}")
return f"Delete failed: {str(e)}"
@staticmethod
def rerank_search_results(query: str, sticker_list: List[Dict[str, Any]], limit: int = 5) -> List[Dict[str, Any]]:
## 使用 LLM 模型重新排序搜索结果
try:
# 构建提示词
system_prompt = STICKER_RERANKING_SYSTEM_PROMPT
# 构建用户提示词,包含查询和表情包信息
_sticker_list = []
for hit in sticker_list:
_sticker_list.append({
"id": hit["id"],
"description": hit["entity"]["description"]
})
user_prompt = f"请分析关键词 '{query}' 与以下表情包的相关性:\n{_sticker_list}"
print(f">>> 使用 LLM 模型重新排序....", user_prompt, system_prompt)
# 调用 LLM 模型获取重排序结果
response = get_chat_completion(user_prompt, system_prompt)
# 解析 LLM 返回的 JSON 结果
reranked_stickers = json.loads(response)
# 验证返回结果格式
if not isinstance(reranked_stickers, list):
raise ValueError("Invalid response format")
# 按分数排序
reranked_stickers.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
print(f">>> LLM 排序结果", reranked_stickers)
# 将重排序结果与原始结果对应
rerank_results = []
for sticker in reranked_stickers:
for hit in sticker_list:
if str(hit["id"]) == str(sticker["sticker_id"]):
hit["entity"]["score"] = sticker["score"]
hit["entity"]["reason"] = sticker["reason"]
rerank_results.append(hit)
break
print(f">>> rerank_results", rerank_results)
return rerank_results
except Exception as e:
logger.error(f"Reranking failed: {str(e)}")
return []
# 创建服务实例
sticker_service = StickerService()