OmniThink / src /rm.py
ZekunXi's picture
push
80a598c
import logging
import os
import urllib.parse
from typing import Callable, Union, List
import json
import dspy
import pandas as pd
import requests
import re
import uuid
import os
from tqdm import tqdm
from .utils import WebPageHelper
def clean_text(res):
# 正则表达式:匹配形如 [**](**) 的模式
pattern = r'\[.*?\]\(.*?\)'
# 使用 re.sub() 将匹配的内容替换为空字符
result = re.sub(pattern, '', res)
url_pattern = pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
result = re.sub(url_pattern, '', result)
result = re.sub(r"\n\n+", "\n", result)
return result
class GoogleSearchAli_new(dspy.Retrieve):
def __init__(self, bing_search_api_key=None, k=3, is_valid_source: Callable = None,
min_char_count: int = 150, snippet_chunk_size: int = 1000, webpage_helper_max_threads=10,
mkt='en-US', language='en-US', **kwargs):
super().__init__(k=k)
key = os.environ.get('searchkey', 'default_value')
self.header = {
"Content-Type": "application/json",
"Accept-Encoding": "utf-8",
"Authorization": f"Bearer lm-/{key}== ",
}
self.template = {
"rid": str(uuid.uuid4()),
"scene": "dolphin_search_bing_nlp",
"uq": "",
"debug": True,
"fields": [],
"page": 1,
"rows": 10,
"customConfigInfo": {
"multiSearch": False,
"qpMultiQuery": False,
"qpMultiQueryHistory": [],
"qpSpellcheck": False,
"qpEmbedding": False,
"knnWithScript": False,
"qpTermsWeight": False,
"pluginServiceConfig": {"qp": "mvp_search_qp_qwen"}, # v3 rewrite
},
"headers": {"__d_head_qto": 5000},
}
self.webpage_helper = WebPageHelper(
min_char_count=min_char_count,
snippet_chunk_size=snippet_chunk_size,
max_thread_num=webpage_helper_max_threads
)
self.usage = 0
# If not None, is_valid_source shall be a function that takes a URL and returns a boolean.
if is_valid_source:
self.is_valid_source = is_valid_source
else:
self.is_valid_source = lambda x: True
def get_usage_and_reset(self):
usage = self.usage
self.usage = 0
return {'BingSearch': usage}
def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []):
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
self.usage += len(queries)
url_to_results = {}
for query in queries:
try:
self.template["uq"] = query
response = requests.post(
"https://nlp-cn-beijing.aliyuncs.com/gw/v1/api/msearch-sp/qwen-search",
data=json.dumps(self.template),
headers=self.header,
)
response = json.loads(response.text)
search_results = response['data']['docs']
for result in search_results:
url_to_results[result['url']] = {
'url': result['url'],
'title': result['title'],
'description': result.get('snippet', '')
}
except Exception as e:
logging.error(f'Error occurs when searching query {query}: {e}')
valid_url_to_snippets = self.webpage_helper.urls_to_snippets(list(url_to_results.keys()))
collected_results = []
for url in valid_url_to_snippets:
r = url_to_results[url]
r['snippets'] = valid_url_to_snippets[url]['snippets']
collected_results.append(r)
print(f'lengt of collected_results :{len(collected_results)}')
return collected_results
if __name__ == "__main__":
retrieval = GoogleSearchAli_new()
retrieval("中国")