File size: 6,449 Bytes
80a598c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import logging
import os
import urllib.parse
from typing import Callable, Union, List
import json
import dspy
import pandas as pd
import requests
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.documents import Document
from langchain_qdrant import Qdrant
from qdrant_client import QdrantClient, models
from tqdm import tqdm

from utils import WebPageHelper


def clean_text(res):
    # 正则表达式:匹配形如 [**](**) 的模式
    pattern = r'\[.*?\]\(.*?\)'
    # 使用 re.sub() 将匹配的内容替换为空字符
    result = re.sub(pattern, '', res)
    url_pattern = pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
    result = re.sub(url_pattern, '', result)
    result = re.sub(r"\n\n+", "\n", result)
    return result

def get_jina(url, max_retries=3):
    url = "https://r.jina.ai/" + url
    response = ""
    for _ in range(max_retries):
        try:
            response = requests.get(url, headers=headers)
            if response.ok:
                # response = response.text
                response = response.json()['data']['content']
                response = clean_text(response)
                break
        except:
            print("retrying!")
            continue
    return {"url": url, "text": str(response)}




class BingSearchAli(dspy.Retrieve):
    def __init__(self, bing_search_api_key=None, k=3, is_valid_source: Callable = None,
                 min_char_count: int = 150, snippet_chunk_size: int = 1000, webpage_helper_max_threads=10,
                 mkt='en-US', language='en', **kwargs):
        """
        Params:
            min_char_count: Minimum character count for the article to be considered valid.
            snippet_chunk_size: Maximum character count for each snippet.
            webpage_helper_max_threads: Maximum number of threads to use for webpage helper.
            mkt, language, **kwargs: Bing search API parameters.
            - Reference: https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/reference/query-parameters
        """
        super().__init__(k=k)
        if not bing_search_api_key and not os.environ.get("BING_SEARCH_ALI_API_KEY"):
            raise RuntimeError(
                "You must supply bing_search_api_key or set environment variable BING_SEARCH_ALI_API_KEY")
        elif bing_search_api_key:
            self.bing_api_key = bing_search_api_key
        else:
            self.bing_api_key = os.environ["BING_SEARCH_ALI_API_KEY"]
        self.endpoint = "https://idealab.alibaba-inc.com/api/v1/search/search"
        # self.endpoint = "http://47.88.77.118:8080/api/search_web"

        
        self.count = k
        self.params = {
            'mkt': mkt,
            "setLang": language,
            "count": k,
            **kwargs
        }
        self.webpage_helper = WebPageHelper(
            min_char_count=min_char_count,
            snippet_chunk_size=snippet_chunk_size,
            max_thread_num=webpage_helper_max_threads
        )
        self.usage = 0

        # If not None, is_valid_source shall be a function that takes a URL and returns a boolean.
        if is_valid_source:
            self.is_valid_source = is_valid_source
        else:
            self.is_valid_source = lambda x: True

    def get_usage_and_reset(self):
        usage = self.usage
        self.usage = 0

        return {'BingSearch': usage}

    def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []):
        """Search with Bing for self.k top passages for query or queries

        Args:
            query_or_queries (Union[str, List[str]]): The query or queries to search for.
            exclude_urls (List[str]): A list of urls to exclude from the search results.

        Returns:
            a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'
        """
        queries = (
            [query_or_queries]
            if isinstance(query_or_queries, str)
            else query_or_queries
        )
        self.usage += len(queries)

        url_to_results = {}

        payload_template = {
            "query": "pleaceholder",
            "num": self.count,
            "platformInput": {
                "model": "bing-search",
                "instanceVersion": "S1"
            }
        }
        header = {"X-AK": self.bing_api_key, "Content-Type": "application/json"}

        for query in queries:
            try:
                payload_template["query"] = query
                response = requests.post(
                    self.endpoint,
                    headers=header,
                    json=payload_template,
                ).json()
                print(response)
                search_results = response['data']['originalOutput']['webPages']['value']

                for result in search_results:
                    if self.is_valid_source(result['url']) and result['url'] not in exclude_urls:
                        url = result['url']
                        encoded_url = urllib.parse.quote(url, safe='')
                        file_path = os.path.join('/mnt/nas-alinlp/xizekun/project/storm/wikipages', f"{encoded_url}.json")
                        try:
                            with open(file_path, 'r') as data_file:
                                data = json.load(data_file)
                                result['snippet'] = data.get('snippet', result.get('snippet', ''))
                                print('Local wiki page found and used.')
                        except FileNotFoundError:
                            print('Local wiki page not found, using API result.')
                        
                        url_to_results[result['url']] = {
                            'url': result['url'],
                            'title': result['name'],
                            'description': result.get('snippet', '')
                        }
            except Exception as e:
                logging.error(f'Error occurs when searching query {query}: {e}')

        valid_url_to_snippets = self.webpage_helper.urls_to_snippets(list(url_to_results.keys()))
        collected_results = []
        for url in valid_url_to_snippets:
            r = url_to_results[url]
            r['snippets'] = valid_url_to_snippets[url]['snippets']
            collected_results.append(r)

        return collected_results