Spaces:
Sleeping
Sleeping
from pymilvus import MilvusClient | |
from sentence_transformers import SentenceTransformer | |
from typing import List, Dict, Any, Optional, Union | |
import logging | |
from app.config import MILVUS_DB_URL, MILVUS_DB_TOKEN, EMBEDDING_MODEL, DATASET_ID | |
# 配置日志 | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') | |
logger = logging.getLogger(__name__) | |
class Database: | |
"""数据库操作类,处理与Milvus的交互""" | |
def __init__(self): | |
self.client = MilvusClient( | |
uri = MILVUS_DB_URL, | |
token= MILVUS_DB_TOKEN) | |
self.model = SentenceTransformer(EMBEDDING_MODEL, trust_remote_code=True) | |
print('初始化模型完成',self.model) | |
self.collection_name = "stickers" | |
def init_collection(self) -> bool: | |
"""初始化 Milvus 数据库""" | |
try: | |
print('初始化 Milvus 数据库', self.client.list_collections()) | |
if not len(self.client.list_collections()) > 0: | |
self.client.create_collection( | |
collection_name=self.collection_name, | |
dimension=768, | |
primary_field="id", | |
auto_id=True | |
) | |
self.client.create_index( | |
collection_name=self.collection_name, | |
index_type="IVF_SQ8", | |
metric_type="COSINE", | |
params={"nlist": 128}, | |
index_params={} | |
) | |
logger.info(f"Collection initialized: {self.collection_name}") | |
print('初始化 Milvus 数据库成功', self.client.list_collections()) | |
return True | |
except Exception as e: | |
logger.error(f"Collection initialization failed: {str(e)}") | |
return False | |
def encode_text(self, text: str) -> List[float]: | |
"""将文本编码为向量""" | |
return self.model.encode(text).tolist() | |
def store_sticker(self, title: str, description: str, tags: Union[str, List[str]], file_path: str, image_hash: str = None) -> bool: | |
"""存储贴纸数据到Milvus""" | |
try: | |
vector = self.encode_text(description) | |
# 处理标签格式 | |
if isinstance(tags, str): | |
tags = tags.split(",") | |
logger.info(f"Storing to Milvus - title: {title}, description: {description}, file_path: {file_path}, tags: {tags}, image_hash: {image_hash}") | |
self.client.insert( | |
collection_name=self.collection_name, | |
data=[{ | |
"vector": vector, | |
"title": title, | |
"description": description, | |
"tags": tags, | |
"file_name": file_path, | |
"image_hash": image_hash | |
}] | |
) | |
logger.info("Storing to Milvus Success ✅") | |
return True | |
except Exception as e: | |
logger.error(f"Failed to store sticker: {str(e)}") | |
return False | |
def search_stickers(self, description: str, limit: int = 2) -> List[Dict[str, Any]]: | |
"""搜索贴纸""" | |
if not description: | |
return [] | |
try: | |
text_vector = self.encode_text(description) | |
logger.info(f"Searching Milvus - query: {description}, limit: {limit}") | |
results = self.client.search( | |
collection_name=self.collection_name, | |
data=[text_vector], | |
limit=limit, | |
search_params={ | |
"metric_type": "COSINE", | |
}, | |
output_fields=["title", "description", "tags", "file_name"], | |
) | |
logger.info(f"Search Result: {results}") | |
return results[0] | |
except Exception as e: | |
logger.error(f"Search failed: {str(e)}") | |
return [] | |
def get_all_stickers(self, limit: int = 1000) -> List[Dict[str, Any]]: | |
"""获取所有贴纸""" | |
try: | |
results = self.client.query( | |
collection_name=self.collection_name, | |
filter="", | |
limit=limit, | |
output_fields=["title", "description", "tags", "file_name", "image_hash"] | |
) | |
logger.info(f"Query All Stickers - limit: {limit}, results count: {len(results)}") | |
return results | |
except Exception as e: | |
logger.error(f"Failed to get all stickers: {str(e)}") | |
return [] | |
def check_image_exists(self, image_hash: str) -> bool: | |
"""检查文件名是否已存在""" | |
try: | |
results = self.client.query( | |
collection_name=self.collection_name, | |
filter=f"image_hash == '{image_hash}'", | |
limit=1, | |
output_fields=["file_name", "image_hash"] | |
) | |
exists = len(results) > 0 | |
logger.info(f"Check file exists - hash: {image_hash}, exists: {exists}, results: {results}") | |
return exists | |
except Exception as e: | |
logger.error(f"Failed to check file exists: {str(e)}") | |
return False | |
def delete_sticker(self, sticker_id: int) -> str: | |
"""删除贴纸""" | |
try: | |
logger.info(f"Deleting sticker - id: {sticker_id}") | |
res = self.client.delete( | |
collection_name=self.collection_name, | |
ids=[sticker_id] | |
) | |
logger.info(f"Deleted sticker - id: {sticker_id}") | |
print(res) | |
return f"Sticker with ID {sticker_id} deleted successfully" | |
except Exception as e: | |
logger.error(f"Failed to delete sticker: {str(e)}") | |
return f"Failed to delete sticker: {str(e)}" | |
# 初始化 Milvus 数据库 | |
# 创建数据库实例 | |
db = Database() |