File size: 6,449 Bytes
80a598c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import logging
import os
import urllib.parse
from typing import Callable, Union, List
import json
import dspy
import pandas as pd
import requests
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.documents import Document
from langchain_qdrant import Qdrant
from qdrant_client import QdrantClient, models
from tqdm import tqdm
from utils import WebPageHelper
def clean_text(res):
# 正则表达式:匹配形如 [**](**) 的模式
pattern = r'\[.*?\]\(.*?\)'
# 使用 re.sub() 将匹配的内容替换为空字符
result = re.sub(pattern, '', res)
url_pattern = pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
result = re.sub(url_pattern, '', result)
result = re.sub(r"\n\n+", "\n", result)
return result
def get_jina(url, max_retries=3):
url = "https://r.jina.ai/" + url
response = ""
for _ in range(max_retries):
try:
response = requests.get(url, headers=headers)
if response.ok:
# response = response.text
response = response.json()['data']['content']
response = clean_text(response)
break
except:
print("retrying!")
continue
return {"url": url, "text": str(response)}
class BingSearchAli(dspy.Retrieve):
def __init__(self, bing_search_api_key=None, k=3, is_valid_source: Callable = None,
min_char_count: int = 150, snippet_chunk_size: int = 1000, webpage_helper_max_threads=10,
mkt='en-US', language='en', **kwargs):
"""
Params:
min_char_count: Minimum character count for the article to be considered valid.
snippet_chunk_size: Maximum character count for each snippet.
webpage_helper_max_threads: Maximum number of threads to use for webpage helper.
mkt, language, **kwargs: Bing search API parameters.
- Reference: https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/reference/query-parameters
"""
super().__init__(k=k)
if not bing_search_api_key and not os.environ.get("BING_SEARCH_ALI_API_KEY"):
raise RuntimeError(
"You must supply bing_search_api_key or set environment variable BING_SEARCH_ALI_API_KEY")
elif bing_search_api_key:
self.bing_api_key = bing_search_api_key
else:
self.bing_api_key = os.environ["BING_SEARCH_ALI_API_KEY"]
self.endpoint = "https://idealab.alibaba-inc.com/api/v1/search/search"
# self.endpoint = "http://47.88.77.118:8080/api/search_web"
self.count = k
self.params = {
'mkt': mkt,
"setLang": language,
"count": k,
**kwargs
}
self.webpage_helper = WebPageHelper(
min_char_count=min_char_count,
snippet_chunk_size=snippet_chunk_size,
max_thread_num=webpage_helper_max_threads
)
self.usage = 0
# If not None, is_valid_source shall be a function that takes a URL and returns a boolean.
if is_valid_source:
self.is_valid_source = is_valid_source
else:
self.is_valid_source = lambda x: True
def get_usage_and_reset(self):
usage = self.usage
self.usage = 0
return {'BingSearch': usage}
def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []):
"""Search with Bing for self.k top passages for query or queries
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): A list of urls to exclude from the search results.
Returns:
a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
self.usage += len(queries)
url_to_results = {}
payload_template = {
"query": "pleaceholder",
"num": self.count,
"platformInput": {
"model": "bing-search",
"instanceVersion": "S1"
}
}
header = {"X-AK": self.bing_api_key, "Content-Type": "application/json"}
for query in queries:
try:
payload_template["query"] = query
response = requests.post(
self.endpoint,
headers=header,
json=payload_template,
).json()
print(response)
search_results = response['data']['originalOutput']['webPages']['value']
for result in search_results:
if self.is_valid_source(result['url']) and result['url'] not in exclude_urls:
url = result['url']
encoded_url = urllib.parse.quote(url, safe='')
file_path = os.path.join('/mnt/nas-alinlp/xizekun/project/storm/wikipages', f"{encoded_url}.json")
try:
with open(file_path, 'r') as data_file:
data = json.load(data_file)
result['snippet'] = data.get('snippet', result.get('snippet', ''))
print('Local wiki page found and used.')
except FileNotFoundError:
print('Local wiki page not found, using API result.')
url_to_results[result['url']] = {
'url': result['url'],
'title': result['name'],
'description': result.get('snippet', '')
}
except Exception as e:
logging.error(f'Error occurs when searching query {query}: {e}')
valid_url_to_snippets = self.webpage_helper.urls_to_snippets(list(url_to_results.keys()))
collected_results = []
for url in valid_url_to_snippets:
r = url_to_results[url]
r['snippets'] = valid_url_to_snippets[url]['snippets']
collected_results.append(r)
return collected_results
|