NekoAI-Lab / embedding_test.py
nekoko
feat: Sticker DB
1c2b077
import os
import logging
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from sentence_transformers import SentenceTransformer
from pymilvus import MilvusClient, DataType
import time
import gradio as gr
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(levelname)s %(message)s'
)
logger = logging.getLogger(__name__)
models = [
'shibing624/text2vec-base-chinese',
'BAAI/bge-small-zh',
'BAAI/bge-base-zh',
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
'all-MiniLM-L6-v2',
'all-MiniLM-L12-v2',
'multi-qa-mpnet-base-dot-v1',
# 'bge-small-en-v1.5', 不兼容
'all-mpnet-base-v2',
'jinaai/jina-embeddings-v3',
]
searchers = {}
class BaseEmbeddingModel(ABC):
@abstractmethod
def encode(self, text: str) -> List[float]:
pass
@property
@abstractmethod
def dimension(self) -> int:
pass
@property
@abstractmethod
def model_name(self) -> str:
pass
class SentenceTransformerModel(BaseEmbeddingModel):
def __init__(self, model_name: str):
self.model = SentenceTransformer(model_name, trust_remote_code=True)
self._model_name = model_name
def encode(self, text: str) -> List[float]:
result = self.model.encode(text).tolist()
return result
@property
def dimension(self) -> int:
return self.model.get_sentence_embedding_dimension()
@property
def model_name(self) -> str:
return self._model_name
class StickerSearcher:
def __init__(self, model: BaseEmbeddingModel):
self.model = model
self.client = MilvusClient(uri='./sticker.db')
self.collection_name = f'test_{model.model_name.replace("/", "_").replace("-", "_")}'
def init_collection(self) -> bool:
try:
self.client.drop_collection(collection_name=self.collection_name)
self.client.create_collection(
collection_name=self.collection_name,
dimension=self.model.dimension,
primary_field_name='id',
auto_id=True
)
self.client.create_index(
collection_name=self.collection_name,
index_type='IVF_SQ8',
metric_type='COSINE',
params={'nlist': 128},
index_params={}
)
self.client.load_collection(self.collection_name)
logger.info(f'Collection initialized: {self.collection_name}')
return True
except Exception as e:
logger.error(f'Collection init failed: {str(e)}')
return False
def store_vector(self, title: str, description: str, tags: List[str], file_path: str):
vector = self.model.encode(description)
data = [{
'vector': vector,
'title': title,
'description': description,
'tags': tags,
'file_name': file_path
}]
self.client.insert(self.collection_name, data)
def search(self, query: str, limit: int = 5) -> List[Dict[str, Any]]:
start_time = time.time()
query_vector = self.model.encode(query)
encode_time = time.time() - start_time
start_search_time = time.time()
results = self.client.search(
collection_name=self.collection_name,
data=[query_vector],
limit=limit,
output_fields=['title', 'description', 'tags', 'file_name']
)
search_time = time.time() - start_search_time
total_time = encode_time + search_time
logger.info(f'模型 {self.model.model_name} Encoding耗时: ${encode_time:.4f},搜索耗时: {search_time:.4f} 秒, 总耗时: {total_time:.4f} 秒')
return results[0]
def create_gradio_ui():
async def search_model(model_name: str, query: str):
try:
if model_name in searchers:
return searchers[model_name].search(query)
logger.error(f'Model not loaded: {model_name}')
return []
except Exception as e:
logger.error(f'Search failed: {model_name} | Error: {str(e)}')
return []
async def search_all_models(query):
if not query:
return []
print(f'>>>> Searching From Models {query}')
results = []
for model_name in models:
result = await search_model(model_name, query)
results.append(result)
formatted_results = []
max_results = max(len(r) for r in results)
for i in range(max_results):
row = [i + 1]
for model_results in results:
if i < len(model_results):
result = model_results[i]
image_url = f'https://huggingface.co./datasets/Nekoko/StickerSet/resolve/main/{result["entity"]["file_name"]}'
row.append(f'![Sticker]({image_url})\n相似度: {result["distance"]:.4f}')
else:
row.append('-')
formatted_results.append(row)
return formatted_results
def init_collections():
try:
client = MilvusClient(uri='./sticker.db')
stickers = client.query(
collection_name='stickers',
filter='',
limit=1000,
output_fields=['title', 'description', 'tags', 'file_name']
)
logger.info(f'Stickers loaded: {len(stickers)}')
def init_model(model_name):
try:
searcher = StickerSearcher(SentenceTransformerModel(model_name))
if searcher.init_collection():
searchers[model_name] = searcher
for sticker in stickers:
searcher.store_vector(
sticker.get('title'),
sticker.get('description'),
sticker.get('tags'),
sticker.get('file_name')
)
logger.info(f'Model initialized: {model_name}')
except Exception as e:
logger.error(f'Model init failed: {model_name} | Error: {str(e)}')
for model_name in models:
print(f'>>>> 初始化模型 {model_name}')
start_time = time.time()
init_model(model_name)
print(f'>>>> 初始化模型 {model_name} 完成 ✅,耗时 {time.time() - start_time:.4f} 秒')
print(f'>>>> 初始化所有模型完成 ✅')
return '初始化成功!'
except Exception as e:
logger.error(f'Data init failed: {str(e)}')
return f'初始化失败: {str(e)}'
with gr.Blocks(title='Neko Sticker Search 🔍', css='.gradio-container img { width: 200px !important; height: 200px !important; object-fit: contain; }') as demo:
with gr.Row():
search_input = gr.Textbox(label='搜索关键词')
search_button = gr.Button('搜索')
headers = ['序号'] + [f'🧊{model.split("/")[-1]}' for i, model in enumerate(models)]
results_table = gr.Dataframe(
headers=headers,
datatype=['number'] + ['markdown'] * len(models),
row_count=5,
col_count=len(models) + 1
)
status_box = gr.Textbox(label='状态', interactive=False)
refresh_button = gr.Button('刷新数据')
refresh_button.click(fn=init_collections, outputs=status_box)
# 由于这里只是简单的搜索操作,可以直接使用同步方式调用
search_button.click(
fn=search_all_models,
inputs=[search_input],
outputs=results_table
)
return demo
if __name__ == '__main__':
# 提前加载所有模型
start_time = time.time()
for index, model_name in enumerate(models):
try:
start_time = time.time()
searchers[model_name] = StickerSearcher(SentenceTransformerModel(model_name))
print(f'>>>> 预加载模型 {model_name} 完成 ✅, 耗时 {time.time() - start_time:.4f} 秒')
except Exception as e:
logger.error(f'Model preload failed: {model_name} | Error: {str(e)}')
logger.info(f'>>>> 预加载模型完成 ✅: {models}, 耗时 {time.time() - start_time:.4f} 秒')
demo = create_gradio_ui()
demo.launch()