|
import logging |
|
import os |
|
import urllib.parse |
|
from typing import Callable, Union, List |
|
import json |
|
import dspy |
|
import pandas as pd |
|
import requests |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_core.documents import Document |
|
from langchain_qdrant import Qdrant |
|
from qdrant_client import QdrantClient, models |
|
from tqdm import tqdm |
|
|
|
from utils import WebPageHelper |
|
|
|
|
|
def clean_text(res): |
|
|
|
pattern = r'\[.*?\]\(.*?\)' |
|
|
|
result = re.sub(pattern, '', res) |
|
url_pattern = pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+' |
|
result = re.sub(url_pattern, '', result) |
|
result = re.sub(r"\n\n+", "\n", result) |
|
return result |
|
|
|
def get_jina(url, max_retries=3): |
|
url = "https://r.jina.ai/" + url |
|
response = "" |
|
for _ in range(max_retries): |
|
try: |
|
response = requests.get(url, headers=headers) |
|
if response.ok: |
|
|
|
response = response.json()['data']['content'] |
|
response = clean_text(response) |
|
break |
|
except: |
|
print("retrying!") |
|
continue |
|
return {"url": url, "text": str(response)} |
|
|
|
|
|
|
|
|
|
class BingSearchAli(dspy.Retrieve): |
|
def __init__(self, bing_search_api_key=None, k=3, is_valid_source: Callable = None, |
|
min_char_count: int = 150, snippet_chunk_size: int = 1000, webpage_helper_max_threads=10, |
|
mkt='en-US', language='en', **kwargs): |
|
""" |
|
Params: |
|
min_char_count: Minimum character count for the article to be considered valid. |
|
snippet_chunk_size: Maximum character count for each snippet. |
|
webpage_helper_max_threads: Maximum number of threads to use for webpage helper. |
|
mkt, language, **kwargs: Bing search API parameters. |
|
- Reference: https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/reference/query-parameters |
|
""" |
|
super().__init__(k=k) |
|
if not bing_search_api_key and not os.environ.get("BING_SEARCH_ALI_API_KEY"): |
|
raise RuntimeError( |
|
"You must supply bing_search_api_key or set environment variable BING_SEARCH_ALI_API_KEY") |
|
elif bing_search_api_key: |
|
self.bing_api_key = bing_search_api_key |
|
else: |
|
self.bing_api_key = os.environ["BING_SEARCH_ALI_API_KEY"] |
|
self.endpoint = "https://idealab.alibaba-inc.com/api/v1/search/search" |
|
|
|
|
|
|
|
self.count = k |
|
self.params = { |
|
'mkt': mkt, |
|
"setLang": language, |
|
"count": k, |
|
**kwargs |
|
} |
|
self.webpage_helper = WebPageHelper( |
|
min_char_count=min_char_count, |
|
snippet_chunk_size=snippet_chunk_size, |
|
max_thread_num=webpage_helper_max_threads |
|
) |
|
self.usage = 0 |
|
|
|
|
|
if is_valid_source: |
|
self.is_valid_source = is_valid_source |
|
else: |
|
self.is_valid_source = lambda x: True |
|
|
|
def get_usage_and_reset(self): |
|
usage = self.usage |
|
self.usage = 0 |
|
|
|
return {'BingSearch': usage} |
|
|
|
def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []): |
|
"""Search with Bing for self.k top passages for query or queries |
|
|
|
Args: |
|
query_or_queries (Union[str, List[str]]): The query or queries to search for. |
|
exclude_urls (List[str]): A list of urls to exclude from the search results. |
|
|
|
Returns: |
|
a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url' |
|
""" |
|
queries = ( |
|
[query_or_queries] |
|
if isinstance(query_or_queries, str) |
|
else query_or_queries |
|
) |
|
self.usage += len(queries) |
|
|
|
url_to_results = {} |
|
|
|
payload_template = { |
|
"query": "pleaceholder", |
|
"num": self.count, |
|
"platformInput": { |
|
"model": "bing-search", |
|
"instanceVersion": "S1" |
|
} |
|
} |
|
header = {"X-AK": self.bing_api_key, "Content-Type": "application/json"} |
|
|
|
for query in queries: |
|
try: |
|
payload_template["query"] = query |
|
response = requests.post( |
|
self.endpoint, |
|
headers=header, |
|
json=payload_template, |
|
).json() |
|
print(response) |
|
search_results = response['data']['originalOutput']['webPages']['value'] |
|
|
|
for result in search_results: |
|
if self.is_valid_source(result['url']) and result['url'] not in exclude_urls: |
|
url = result['url'] |
|
encoded_url = urllib.parse.quote(url, safe='') |
|
file_path = os.path.join('/mnt/nas-alinlp/xizekun/project/storm/wikipages', f"{encoded_url}.json") |
|
try: |
|
with open(file_path, 'r') as data_file: |
|
data = json.load(data_file) |
|
result['snippet'] = data.get('snippet', result.get('snippet', '')) |
|
print('Local wiki page found and used.') |
|
except FileNotFoundError: |
|
print('Local wiki page not found, using API result.') |
|
|
|
url_to_results[result['url']] = { |
|
'url': result['url'], |
|
'title': result['name'], |
|
'description': result.get('snippet', '') |
|
} |
|
except Exception as e: |
|
logging.error(f'Error occurs when searching query {query}: {e}') |
|
|
|
valid_url_to_snippets = self.webpage_helper.urls_to_snippets(list(url_to_results.keys())) |
|
collected_results = [] |
|
for url in valid_url_to_snippets: |
|
r = url_to_results[url] |
|
r['snippets'] = valid_url_to_snippets[url]['snippets'] |
|
collected_results.append(r) |
|
|
|
return collected_results |
|
|
|
|