NekoAI-Lab / app /database.py
nekoko
feat: Sticker DB
1c2b077
from pymilvus import MilvusClient
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any, Optional, Union
import logging
from app.config import MILVUS_DB_URL, MILVUS_DB_TOKEN, EMBEDDING_MODEL, DATASET_ID
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger(__name__)
class Database:
"""数据库操作类,处理与Milvus的交互"""
def __init__(self):
self.client = MilvusClient(
uri = MILVUS_DB_URL,
token= MILVUS_DB_TOKEN)
self.model = SentenceTransformer(EMBEDDING_MODEL, trust_remote_code=True)
print('初始化模型完成',self.model)
self.collection_name = "stickers"
def init_collection(self) -> bool:
"""初始化 Milvus 数据库"""
try:
print('初始化 Milvus 数据库', self.client.list_collections())
if not len(self.client.list_collections()) > 0:
self.client.create_collection(
collection_name=self.collection_name,
dimension=768,
primary_field="id",
auto_id=True
)
self.client.create_index(
collection_name=self.collection_name,
index_type="IVF_SQ8",
metric_type="COSINE",
params={"nlist": 128},
index_params={}
)
logger.info(f"Collection initialized: {self.collection_name}")
print('初始化 Milvus 数据库成功', self.client.list_collections())
return True
except Exception as e:
logger.error(f"Collection initialization failed: {str(e)}")
return False
def encode_text(self, text: str) -> List[float]:
"""将文本编码为向量"""
return self.model.encode(text).tolist()
def store_sticker(self, title: str, description: str, tags: Union[str, List[str]], file_path: str, image_hash: str = None) -> bool:
"""存储贴纸数据到Milvus"""
try:
vector = self.encode_text(description)
# 处理标签格式
if isinstance(tags, str):
tags = tags.split(",")
logger.info(f"Storing to Milvus - title: {title}, description: {description}, file_path: {file_path}, tags: {tags}, image_hash: {image_hash}")
self.client.insert(
collection_name=self.collection_name,
data=[{
"vector": vector,
"title": title,
"description": description,
"tags": tags,
"file_name": file_path,
"image_hash": image_hash
}]
)
logger.info("Storing to Milvus Success ✅")
return True
except Exception as e:
logger.error(f"Failed to store sticker: {str(e)}")
return False
def search_stickers(self, description: str, limit: int = 2) -> List[Dict[str, Any]]:
"""搜索贴纸"""
if not description:
return []
try:
text_vector = self.encode_text(description)
logger.info(f"Searching Milvus - query: {description}, limit: {limit}")
results = self.client.search(
collection_name=self.collection_name,
data=[text_vector],
limit=limit,
search_params={
"metric_type": "COSINE",
},
output_fields=["title", "description", "tags", "file_name"],
)
logger.info(f"Search Result: {results}")
return results[0]
except Exception as e:
logger.error(f"Search failed: {str(e)}")
return []
def get_all_stickers(self, limit: int = 1000) -> List[Dict[str, Any]]:
"""获取所有贴纸"""
try:
results = self.client.query(
collection_name=self.collection_name,
filter="",
limit=limit,
output_fields=["title", "description", "tags", "file_name", "image_hash"]
)
logger.info(f"Query All Stickers - limit: {limit}, results count: {len(results)}")
return results
except Exception as e:
logger.error(f"Failed to get all stickers: {str(e)}")
return []
def check_image_exists(self, image_hash: str) -> bool:
"""检查文件名是否已存在"""
try:
results = self.client.query(
collection_name=self.collection_name,
filter=f"image_hash == '{image_hash}'",
limit=1,
output_fields=["file_name", "image_hash"]
)
exists = len(results) > 0
logger.info(f"Check file exists - hash: {image_hash}, exists: {exists}, results: {results}")
return exists
except Exception as e:
logger.error(f"Failed to check file exists: {str(e)}")
return False
def delete_sticker(self, sticker_id: int) -> str:
"""删除贴纸"""
try:
logger.info(f"Deleting sticker - id: {sticker_id}")
res = self.client.delete(
collection_name=self.collection_name,
ids=[sticker_id]
)
logger.info(f"Deleted sticker - id: {sticker_id}")
print(res)
return f"Sticker with ID {sticker_id} deleted successfully"
except Exception as e:
logger.error(f"Failed to delete sticker: {str(e)}")
return f"Failed to delete sticker: {str(e)}"
# 初始化 Milvus 数据库
# 创建数据库实例
db = Database()