OmniThink / src /graph.py
ZekunXi's picture
push
80a598c
import logging
import os
import urllib.parse
from typing import Callable, Union, List
import json
import dspy
import pandas as pd
import requests
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.documents import Document
from langchain_qdrant import Qdrant
from qdrant_client import QdrantClient, models
from tqdm import tqdm
from utils import WebPageHelper
def clean_text(res):
# 正则表达式:匹配形如 [**](**) 的模式
pattern = r'\[.*?\]\(.*?\)'
# 使用 re.sub() 将匹配的内容替换为空字符
result = re.sub(pattern, '', res)
url_pattern = pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
result = re.sub(url_pattern, '', result)
result = re.sub(r"\n\n+", "\n", result)
return result
def get_jina(url, max_retries=3):
url = "https://r.jina.ai/" + url
response = ""
for _ in range(max_retries):
try:
response = requests.get(url, headers=headers)
if response.ok:
# response = response.text
response = response.json()['data']['content']
response = clean_text(response)
break
except:
print("retrying!")
continue
return {"url": url, "text": str(response)}
class BingSearchAli(dspy.Retrieve):
def __init__(self, bing_search_api_key=None, k=3, is_valid_source: Callable = None,
min_char_count: int = 150, snippet_chunk_size: int = 1000, webpage_helper_max_threads=10,
mkt='en-US', language='en', **kwargs):
"""
Params:
min_char_count: Minimum character count for the article to be considered valid.
snippet_chunk_size: Maximum character count for each snippet.
webpage_helper_max_threads: Maximum number of threads to use for webpage helper.
mkt, language, **kwargs: Bing search API parameters.
- Reference: https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/reference/query-parameters
"""
super().__init__(k=k)
if not bing_search_api_key and not os.environ.get("BING_SEARCH_ALI_API_KEY"):
raise RuntimeError(
"You must supply bing_search_api_key or set environment variable BING_SEARCH_ALI_API_KEY")
elif bing_search_api_key:
self.bing_api_key = bing_search_api_key
else:
self.bing_api_key = os.environ["BING_SEARCH_ALI_API_KEY"]
self.endpoint = "https://idealab.alibaba-inc.com/api/v1/search/search"
# self.endpoint = "http://47.88.77.118:8080/api/search_web"
self.count = k
self.params = {
'mkt': mkt,
"setLang": language,
"count": k,
**kwargs
}
self.webpage_helper = WebPageHelper(
min_char_count=min_char_count,
snippet_chunk_size=snippet_chunk_size,
max_thread_num=webpage_helper_max_threads
)
self.usage = 0
# If not None, is_valid_source shall be a function that takes a URL and returns a boolean.
if is_valid_source:
self.is_valid_source = is_valid_source
else:
self.is_valid_source = lambda x: True
def get_usage_and_reset(self):
usage = self.usage
self.usage = 0
return {'BingSearch': usage}
def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []):
"""Search with Bing for self.k top passages for query or queries
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): A list of urls to exclude from the search results.
Returns:
a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
self.usage += len(queries)
url_to_results = {}
payload_template = {
"query": "pleaceholder",
"num": self.count,
"platformInput": {
"model": "bing-search",
"instanceVersion": "S1"
}
}
header = {"X-AK": self.bing_api_key, "Content-Type": "application/json"}
for query in queries:
try:
payload_template["query"] = query
response = requests.post(
self.endpoint,
headers=header,
json=payload_template,
).json()
print(response)
search_results = response['data']['originalOutput']['webPages']['value']
for result in search_results:
if self.is_valid_source(result['url']) and result['url'] not in exclude_urls:
url = result['url']
encoded_url = urllib.parse.quote(url, safe='')
file_path = os.path.join('/mnt/nas-alinlp/xizekun/project/storm/wikipages', f"{encoded_url}.json")
try:
with open(file_path, 'r') as data_file:
data = json.load(data_file)
result['snippet'] = data.get('snippet', result.get('snippet', ''))
print('Local wiki page found and used.')
except FileNotFoundError:
print('Local wiki page not found, using API result.')
url_to_results[result['url']] = {
'url': result['url'],
'title': result['name'],
'description': result.get('snippet', '')
}
except Exception as e:
logging.error(f'Error occurs when searching query {query}: {e}')
valid_url_to_snippets = self.webpage_helper.urls_to_snippets(list(url_to_results.keys()))
collected_results = []
for url in valid_url_to_snippets:
r = url_to_results[url]
r['snippets'] = valid_url_to_snippets[url]['snippets']
collected_results.append(r)
return collected_results