Spaces:
Sleeping
Sleeping
import os | |
import logging | |
from abc import ABC, abstractmethod | |
from typing import List, Dict, Any | |
from sentence_transformers import SentenceTransformer | |
from pymilvus import MilvusClient, DataType | |
import time | |
import gradio as gr | |
# 配置日志 | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s %(levelname)s %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
models = [ | |
'shibing624/text2vec-base-chinese', | |
'BAAI/bge-small-zh', | |
'BAAI/bge-base-zh', | |
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', | |
'all-MiniLM-L6-v2', | |
'all-MiniLM-L12-v2', | |
'multi-qa-mpnet-base-dot-v1', | |
# 'bge-small-en-v1.5', 不兼容 | |
'all-mpnet-base-v2', | |
'jinaai/jina-embeddings-v3', | |
] | |
searchers = {} | |
class BaseEmbeddingModel(ABC): | |
def encode(self, text: str) -> List[float]: | |
pass | |
def dimension(self) -> int: | |
pass | |
def model_name(self) -> str: | |
pass | |
class SentenceTransformerModel(BaseEmbeddingModel): | |
def __init__(self, model_name: str): | |
self.model = SentenceTransformer(model_name, trust_remote_code=True) | |
self._model_name = model_name | |
def encode(self, text: str) -> List[float]: | |
result = self.model.encode(text).tolist() | |
return result | |
def dimension(self) -> int: | |
return self.model.get_sentence_embedding_dimension() | |
def model_name(self) -> str: | |
return self._model_name | |
class StickerSearcher: | |
def __init__(self, model: BaseEmbeddingModel): | |
self.model = model | |
self.client = MilvusClient(uri='./sticker.db') | |
self.collection_name = f'test_{model.model_name.replace("/", "_").replace("-", "_")}' | |
def init_collection(self) -> bool: | |
try: | |
self.client.drop_collection(collection_name=self.collection_name) | |
self.client.create_collection( | |
collection_name=self.collection_name, | |
dimension=self.model.dimension, | |
primary_field_name='id', | |
auto_id=True | |
) | |
self.client.create_index( | |
collection_name=self.collection_name, | |
index_type='IVF_SQ8', | |
metric_type='COSINE', | |
params={'nlist': 128}, | |
index_params={} | |
) | |
self.client.load_collection(self.collection_name) | |
logger.info(f'Collection initialized: {self.collection_name}') | |
return True | |
except Exception as e: | |
logger.error(f'Collection init failed: {str(e)}') | |
return False | |
def store_vector(self, title: str, description: str, tags: List[str], file_path: str): | |
vector = self.model.encode(description) | |
data = [{ | |
'vector': vector, | |
'title': title, | |
'description': description, | |
'tags': tags, | |
'file_name': file_path | |
}] | |
self.client.insert(self.collection_name, data) | |
def search(self, query: str, limit: int = 5) -> List[Dict[str, Any]]: | |
start_time = time.time() | |
query_vector = self.model.encode(query) | |
encode_time = time.time() - start_time | |
start_search_time = time.time() | |
results = self.client.search( | |
collection_name=self.collection_name, | |
data=[query_vector], | |
limit=limit, | |
output_fields=['title', 'description', 'tags', 'file_name'] | |
) | |
search_time = time.time() - start_search_time | |
total_time = encode_time + search_time | |
logger.info(f'模型 {self.model.model_name} Encoding耗时: ${encode_time:.4f},搜索耗时: {search_time:.4f} 秒, 总耗时: {total_time:.4f} 秒') | |
return results[0] | |
def create_gradio_ui(): | |
async def search_model(model_name: str, query: str): | |
try: | |
if model_name in searchers: | |
return searchers[model_name].search(query) | |
logger.error(f'Model not loaded: {model_name}') | |
return [] | |
except Exception as e: | |
logger.error(f'Search failed: {model_name} | Error: {str(e)}') | |
return [] | |
async def search_all_models(query): | |
if not query: | |
return [] | |
print(f'>>>> Searching From Models {query}') | |
results = [] | |
for model_name in models: | |
result = await search_model(model_name, query) | |
results.append(result) | |
formatted_results = [] | |
max_results = max(len(r) for r in results) | |
for i in range(max_results): | |
row = [i + 1] | |
for model_results in results: | |
if i < len(model_results): | |
result = model_results[i] | |
image_url = f'https://huggingface.co./datasets/Nekoko/StickerSet/resolve/main/{result["entity"]["file_name"]}' | |
row.append(f'\n相似度: {result["distance"]:.4f}') | |
else: | |
row.append('-') | |
formatted_results.append(row) | |
return formatted_results | |
def init_collections(): | |
try: | |
client = MilvusClient(uri='./sticker.db') | |
stickers = client.query( | |
collection_name='stickers', | |
filter='', | |
limit=1000, | |
output_fields=['title', 'description', 'tags', 'file_name'] | |
) | |
logger.info(f'Stickers loaded: {len(stickers)}') | |
def init_model(model_name): | |
try: | |
searcher = StickerSearcher(SentenceTransformerModel(model_name)) | |
if searcher.init_collection(): | |
searchers[model_name] = searcher | |
for sticker in stickers: | |
searcher.store_vector( | |
sticker.get('title'), | |
sticker.get('description'), | |
sticker.get('tags'), | |
sticker.get('file_name') | |
) | |
logger.info(f'Model initialized: {model_name}') | |
except Exception as e: | |
logger.error(f'Model init failed: {model_name} | Error: {str(e)}') | |
for model_name in models: | |
print(f'>>>> 初始化模型 {model_name}') | |
start_time = time.time() | |
init_model(model_name) | |
print(f'>>>> 初始化模型 {model_name} 完成 ✅,耗时 {time.time() - start_time:.4f} 秒') | |
print(f'>>>> 初始化所有模型完成 ✅') | |
return '初始化成功!' | |
except Exception as e: | |
logger.error(f'Data init failed: {str(e)}') | |
return f'初始化失败: {str(e)}' | |
with gr.Blocks(title='Neko Sticker Search 🔍', css='.gradio-container img { width: 200px !important; height: 200px !important; object-fit: contain; }') as demo: | |
with gr.Row(): | |
search_input = gr.Textbox(label='搜索关键词') | |
search_button = gr.Button('搜索') | |
headers = ['序号'] + [f'🧊{model.split("/")[-1]}' for i, model in enumerate(models)] | |
results_table = gr.Dataframe( | |
headers=headers, | |
datatype=['number'] + ['markdown'] * len(models), | |
row_count=5, | |
col_count=len(models) + 1 | |
) | |
status_box = gr.Textbox(label='状态', interactive=False) | |
refresh_button = gr.Button('刷新数据') | |
refresh_button.click(fn=init_collections, outputs=status_box) | |
# 由于这里只是简单的搜索操作,可以直接使用同步方式调用 | |
search_button.click( | |
fn=search_all_models, | |
inputs=[search_input], | |
outputs=results_table | |
) | |
return demo | |
if __name__ == '__main__': | |
# 提前加载所有模型 | |
start_time = time.time() | |
for index, model_name in enumerate(models): | |
try: | |
start_time = time.time() | |
searchers[model_name] = StickerSearcher(SentenceTransformerModel(model_name)) | |
print(f'>>>> 预加载模型 {model_name} 完成 ✅, 耗时 {time.time() - start_time:.4f} 秒') | |
except Exception as e: | |
logger.error(f'Model preload failed: {model_name} | Error: {str(e)}') | |
logger.info(f'>>>> 预加载模型完成 ✅: {models}, 耗时 {time.time() - start_time:.4f} 秒') | |
demo = create_gradio_ui() | |
demo.launch() |