|
import concurrent.futures |
|
import logging |
|
import os |
|
import re |
|
import json |
|
import numpy as np |
|
import networkx as nx |
|
import matplotlib.pyplot as plt |
|
from concurrent.futures import as_completed |
|
from typing import Union, List, Tuple, Optional, Dict |
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import sys |
|
import concurrent.futures |
|
import json |
|
import os |
|
import pickle |
|
import re |
|
import sys |
|
from typing import List, Dict |
|
|
|
import httpx |
|
import toml |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from trafilatura import extract |
|
|
|
|
|
|
|
import dspy |
|
from http import HTTPStatus |
|
|
|
import dashscope |
|
|
|
|
|
try: |
|
from streamlit.runtime.scriptrunner import add_script_run_ctx |
|
streamlit_connection = True |
|
except ImportError as err: |
|
streamlit_connection = False |
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
|
|
|
|
class ArticleTextProcessing: |
|
@staticmethod |
|
def limit_word_count_preserve_newline(input_string, max_word_count): |
|
""" |
|
Limit the word count of an input string to a specified maximum, while preserving the integrity of complete lines. |
|
|
|
The function truncates the input string at the nearest word that does not exceed the maximum word count, |
|
ensuring that no partial lines are included in the output. Words are defined as text separated by spaces, |
|
and lines are defined as text separated by newline characters. |
|
|
|
Args: |
|
input_string (str): The string to be truncated. This string may contain multiple lines. |
|
max_word_count (int): The maximum number of words allowed in the truncated string. |
|
|
|
Returns: |
|
str: The truncated string with word count limited to `max_word_count`, preserving complete lines. |
|
""" |
|
|
|
word_count = 0 |
|
limited_string = '' |
|
|
|
for word in input_string.split('\n'): |
|
line_words = word.split() |
|
for lw in line_words: |
|
if word_count < max_word_count: |
|
limited_string += lw + ' ' |
|
word_count += 1 |
|
else: |
|
break |
|
if word_count >= max_word_count: |
|
break |
|
limited_string = limited_string.strip() + '\n' |
|
|
|
return limited_string.strip() |
|
|
|
@staticmethod |
|
def remove_citations(s): |
|
""" |
|
Removes all citations from a given string. Citations are assumed to be in the format |
|
of numbers enclosed in square brackets, such as [1], [2], or [1, 2], etc. This function searches |
|
for all occurrences of such patterns and removes them, returning the cleaned string. |
|
|
|
Args: |
|
s (str): The string from which citations are to be removed. |
|
|
|
Returns: |
|
str: The string with all citation patterns removed. |
|
""" |
|
|
|
return re.sub(r'\[\d+(?:,\s*\d+)*\]', '', s) |
|
|
|
@staticmethod |
|
def get_first_section_dict_and_list(s): |
|
""" |
|
""" |
|
text = s |
|
sections = text.strip().split('\n# ') |
|
titles = [] |
|
content_dict = {} |
|
|
|
for section in sections: |
|
if section: |
|
lines = section.split('\n', 1) |
|
title = lines[0].strip() |
|
content = lines[1].strip() if len(lines) > 1 else "" |
|
|
|
titles.append(title) |
|
content_dict[title] = content |
|
return content_dict, titles |
|
|
|
@staticmethod |
|
def parse_citation_indices(s): |
|
""" |
|
Extracts citation indexes from the provided content string and returns them as a list of integers. |
|
|
|
Args: |
|
content (str): The content string containing citations in the format [number]. |
|
|
|
Returns: |
|
List[int]: A list of unique citation indexes extracted from the content, in the order they appear. |
|
""" |
|
matches = re.findall(r'\[\d+\]', s) |
|
return [int(index[1:-1]) for index in matches] |
|
|
|
@staticmethod |
|
def remove_uncompleted_sentences_with_citations(text): |
|
""" |
|
Removes uncompleted sentences and standalone citations from the input text. Sentences are identified |
|
by their ending punctuation (.!?), optionally followed by a citation in square brackets (e.g., "[1]"). |
|
Grouped citations (e.g., "[1, 2]") are split into individual ones (e.g., "[1] [2]"). Only text up to |
|
and including the last complete sentence and its citation is retained. |
|
|
|
Args: |
|
text (str): The input text from which uncompleted sentences and their citations are to be removed. |
|
|
|
Returns: |
|
str: The processed string with uncompleted sentences and standalone citations removed, leaving only |
|
complete sentences and their associated citations if present. |
|
""" |
|
|
|
|
|
def replace_with_individual_brackets(match): |
|
numbers = match.group(1).split(', ') |
|
return ' '.join(f'[{n}]' for n in numbers) |
|
|
|
|
|
def deduplicate_group(match): |
|
citations = match.group(0) |
|
unique_citations = list(set(re.findall(r'\[\d+\]', citations))) |
|
sorted_citations = sorted(unique_citations, key=lambda x: int(x.strip('[]'))) |
|
|
|
return ''.join(sorted_citations) |
|
|
|
text = re.sub(r'\[([0-9, ]+)\]', replace_with_individual_brackets, text) |
|
text = re.sub(r'(\[\d+\])+', deduplicate_group, text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eos_pattern = r'([.!?])\s*(\[\d+\])?\s*' |
|
matches = list(re.finditer(eos_pattern, text)) |
|
if matches: |
|
last_match = matches[-1] |
|
text = text[:last_match.end()].strip() |
|
|
|
return text |
|
|
|
@staticmethod |
|
def clean_up_citation(conv): |
|
for turn in conv.dlg_history: |
|
turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('References:')] |
|
turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('Sources:')] |
|
turn.agent_utterance = turn.agent_utterance.replace('Answer:', '').strip() |
|
try: |
|
max_ref_num = max([int(x) for x in re.findall(r'\[(\d+)\]', turn.agent_utterance)]) |
|
except Exception as e: |
|
max_ref_num = 0 |
|
if max_ref_num > len(turn.search_results): |
|
for i in range(len(turn.search_results), max_ref_num + 1): |
|
turn.agent_utterance = turn.agent_utterance.replace(f'[{i}]', '') |
|
turn.agent_utterance = ArticleTextProcessing.remove_uncompleted_sentences_with_citations( |
|
turn.agent_utterance) |
|
|
|
return conv |
|
|
|
@staticmethod |
|
def clean_up_outline(outline, topic=""): |
|
output_lines = [] |
|
current_level = 0 |
|
|
|
for line in outline.split('\n'): |
|
stripped_line = line.strip() |
|
|
|
if topic != "" and f"# {topic.lower()}" in stripped_line.lower(): |
|
output_lines = [] |
|
|
|
|
|
if stripped_line.startswith('#') and stripped_line != '#': |
|
current_level = stripped_line.count('#') |
|
output_lines.append(stripped_line) |
|
|
|
|
|
|
|
|
|
|
|
elif stripped_line.startswith('@'): |
|
output_lines.append(stripped_line) |
|
|
|
outline = '\n'.join(output_lines) |
|
|
|
|
|
outline = re.sub(r"#[#]? See also.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? See Also.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? Notes.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? References.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? External links.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? External Links.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? Bibliography.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? Further reading*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? Further Reading*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? Summary.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? Appendices.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? Appendix.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
|
|
return outline |
|
|
|
|
|
@staticmethod |
|
def clean_up_section(text): |
|
"""Clean up a section: |
|
1. Remove uncompleted sentences (usually due to output token limitation). |
|
2. Deduplicate individual groups of citations. |
|
3. Remove unnecessary summary.""" |
|
|
|
paragraphs = text.split('\n') |
|
output_paragraphs = [] |
|
summary_sec_flag = False |
|
for p in paragraphs: |
|
p = p.strip() |
|
if len(p) == 0: |
|
continue |
|
if not p.startswith('#'): |
|
p = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(p) |
|
if summary_sec_flag: |
|
if p.startswith('#'): |
|
summary_sec_flag = False |
|
else: |
|
continue |
|
if p.startswith('Overall') or p.startswith('In summary') or p.startswith('In conclusion'): |
|
continue |
|
if "# Summary" in p or '# Conclusion' in p: |
|
summary_sec_flag = True |
|
continue |
|
output_paragraphs.append(p) |
|
|
|
return '\n\n'.join(output_paragraphs) |
|
|
|
@staticmethod |
|
def update_citation_index(s, citation_map): |
|
"""Update citation index in the string based on the citation map.""" |
|
for original_citation in citation_map: |
|
s = s.replace(f"[{original_citation}]", f"__PLACEHOLDER_{original_citation}__") |
|
for original_citation, unify_citation in citation_map.items(): |
|
s = s.replace(f"__PLACEHOLDER_{original_citation}__", f"[{unify_citation}]") |
|
|
|
return s |
|
|
|
@staticmethod |
|
def parse_article_into_dict(input_string): |
|
""" |
|
Parses a structured text into a nested dictionary. The structure of the text |
|
is defined by markdown-like headers (using '#' symbols) to denote sections |
|
and subsections. Each section can contain content and further nested subsections. |
|
|
|
The resulting dictionary captures the hierarchical structure of sections, where |
|
each section is represented as a key (the section's title) mapping to a value |
|
that is another dictionary. This dictionary contains two keys: |
|
- 'content': content of the section |
|
- 'subsections': a list of dictionaries, each representing a nested subsection |
|
following the same structure. |
|
|
|
Args: |
|
input_string (str): A string containing the structured text to parse. |
|
|
|
Returns: |
|
A dictionary representing contains the section title as the key, and another dictionary |
|
as the value, which includes the 'content' and 'subsections' keys as described above. |
|
""" |
|
lines = input_string.split('\n') |
|
lines = [line for line in lines if line.strip()] |
|
root = {'content': '', 'subsections': {}} |
|
current_path = [(root, -1)] |
|
|
|
for line in lines: |
|
if line.startswith('#'): |
|
level = line.count('#') |
|
title = line.strip('# ').strip() |
|
new_section = {'content': '', 'subsections': {}} |
|
|
|
|
|
while current_path and current_path[-1][1] >= level: |
|
current_path.pop() |
|
|
|
|
|
current_path[-1][0]['subsections'][title] = new_section |
|
current_path.append((new_section, level)) |
|
else: |
|
current_path[-1][0]['content'] += line + '\n' |
|
|
|
return root['subsections'] |
|
|
|
|
|
class FileIOHelper: |
|
@staticmethod |
|
def dump_json(obj, file_name, encoding="utf-8"): |
|
with open(file_name, 'w', encoding=encoding) as fw: |
|
json.dump(obj, fw, default=FileIOHelper.handle_non_serializable, ensure_ascii=False) |
|
|
|
@staticmethod |
|
def handle_non_serializable(obj): |
|
return "non-serializable contents" |
|
|
|
@staticmethod |
|
def load_json(file_name, encoding="utf-8"): |
|
with open(file_name, 'r', encoding=encoding) as fr: |
|
return json.load(fr) |
|
|
|
@staticmethod |
|
def write_str(s, path): |
|
with open(path, 'w') as f: |
|
f.write(s) |
|
|
|
@staticmethod |
|
def load_str(path): |
|
with open(path, 'r') as f: |
|
return '\n'.join(f.readlines()) |
|
|
|
@staticmethod |
|
def dump_pickle(obj, path): |
|
with open(path, 'wb') as f: |
|
pickle.dump(obj, f) |
|
|
|
@staticmethod |
|
def load_pickle(path): |
|
with open(path, 'rb') as f: |
|
return pickle.load(f) |
|
|
|
|
|
|
|
|
|
class ConceptGenerator(dspy.Module): |
|
"""Extract information and generate a list of concepts.""" |
|
def __init__(self, lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]): |
|
super().__init__() |
|
self.lm = lm |
|
self.concept_generator = dspy.Predict(GenConcept) |
|
|
|
def forward(self, infos: List[Dict]): |
|
snippets_list = [] |
|
for info in infos: |
|
snippet = info.get('snippets', []) |
|
snippets_list.extend(snippet) |
|
|
|
snippets_list_str = "\n".join(f"{index + 1}. {snippet}" for index, snippet in enumerate(snippets_list)) |
|
snippets_list_str = ArticleTextProcessing.limit_word_count_preserve_newline(snippets_list_str, 3000) |
|
|
|
with dspy.settings.context(lm=self.lm): |
|
concepts = self.concept_generator(info=snippets_list_str).concepts |
|
|
|
pattern = r"\d+\.\s*(.*)" |
|
matches = re.findall(pattern, concepts) |
|
concept_list = [match.strip() for match in matches] |
|
|
|
return concept_list |
|
|
|
class ExtendConcept(dspy.Signature): |
|
""" |
|
You are an analytical robot. I will provide you with a subject, the information I have searched about it, and our preliminary concept of it. I need you to generate a detailed, in-depth, and insightful report based on it, further exploring our initial ideas. |
|
|
|
First, break down the subject into several broad categories, then create corresponding search engine keywords for each category. |
|
|
|
Note: The new categories should not repeat the previous ones. |
|
|
|
Your output format should be as follows: |
|
-[Category 1] |
|
--{Keyword 1} |
|
--{Keyword 2} |
|
-[Category 2] |
|
--{Keyword 1} |
|
--{Keyword 2} |
|
The number of categories should be less than 5, and the number of keywords for each category should be less than 3. |
|
""" |
|
info = dspy.InputField(prefix='The information you have collected from the webpage:', format=str) |
|
concept = dspy.InputField(prefix='The summary of the previous concepts:', format=str) |
|
category = dspy.InputField(prefix='The broader categories you need to further expand:', format=str) |
|
keywords = dspy.OutputField(format=str) |
|
|
|
|
|
class GenConcept(dspy.Signature): |
|
""" |
|
Please analyze, summarize, and evaluate the following webpage information. |
|
Think like a person, distill the core point of each piece of information, and synthesize them into a comprehensive opinion. |
|
Present your comprehensive opinion in the format of 1. 2. ... |
|
""" |
|
info = dspy.InputField(prefix='The webpage information you have collected:', format=str) |
|
concepts = dspy.OutputField(format=str) |
|
|
|
|
|
class MindPoint(): |
|
def __init__(self, retriever, lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], root: bool = False, |
|
children: Optional[List['MindPoint']] = None, concept: str = '', |
|
info: Optional[List[Dict]] = None, category: str = ''): |
|
self.root = root |
|
self.category = category |
|
self.children = children if children is not None else {} |
|
self.concept = concept |
|
self.info = info if info is not None else [] |
|
self.lm = lm |
|
self.retriever = retriever |
|
self.concept_generator = ConceptGenerator(lm=lm) |
|
|
|
def extend(self): |
|
extend_concept = dspy.Predict(ExtendConcept) |
|
with dspy.settings.context(lm=self.lm): |
|
info='\n'.join([str(i) for i in self.info]) |
|
keywords = extend_concept(info='\n'.join([str(i) for i in self.info]), concept=self.concept, category = self.category).keywords |
|
print(keywords) |
|
print('-----keywords------') |
|
categories = {} |
|
current_category = None |
|
for line in keywords.split('\n'): |
|
line = line.strip() |
|
if (line.startswith('-[') and line.endswith(']')) or (line.startswith('- [') and line.endswith(']')): |
|
current_category = line[2:-1] |
|
categories[current_category] = [] |
|
elif (line.startswith('--{') and current_category) or (line.startswith('-- {') and current_category) or (line.startswith('--') and current_category): |
|
keyword = line[3:-1].strip() |
|
if keyword: |
|
categories[current_category].append(keyword) |
|
|
|
for category, keywords_list in categories.items(): |
|
new_info = self.retriever(keywords_list) |
|
new_concept = self.concept_generator.forward(new_info) |
|
new_node = MindPoint(concept=new_concept, info=new_info, lm=self.lm, retriever=self.retriever, category=category) |
|
self.children[category] = new_node |
|
|
|
|
|
class MindMap(): |
|
def __init__(self, |
|
retriever, |
|
gen_concept_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], |
|
gen_concept_lm2: Union[dspy.dsp.LM, dspy.dsp.HFModel], |
|
search_top_k: int , |
|
depth: int |
|
): |
|
self.retriever = retriever |
|
self.gen_concept_lm = gen_concept_lm |
|
self.gen_concept_lm2 = gen_concept_lm2 |
|
self.search_top_k = search_top_k |
|
self.depth = depth |
|
self.concept_generator = ConceptGenerator(lm=self.gen_concept_lm) |
|
self.root = None |
|
|
|
def build_map(self, topic: str): |
|
root_info = self.retriever(topic) |
|
root_concept = self.concept_generator(root_info) |
|
root = MindPoint(root=True, info=root_info, concept=root_concept, lm=self.gen_concept_lm2, retriever=self.retriever, category=topic) |
|
self.root = root |
|
|
|
current_level = [root] |
|
|
|
for count in range(self.depth): |
|
next_level = [] |
|
|
|
yield current_level |
|
if count == self.depth - 1: |
|
break |
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
futures = {executor.submit(node.extend): node for node in current_level} |
|
|
|
for future in concurrent.futures.as_completed(futures): |
|
node = futures[future] |
|
|
|
next_level.extend(node.children.values()) |
|
|
|
yield current_level |
|
current_level = next_level |
|
|
|
def recursive_extend(self, node: MindPoint, count: int): |
|
if count >= self.depth: |
|
return |
|
node.extend() |
|
count += 1 |
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
futures = [executor.submit(self.recursive_extend, child, count + 1) for child in node.children.values()] |
|
|
|
def save_map(self, root: MindPoint, filename: str): |
|
def serialize_node(node: MindPoint): |
|
return { |
|
'category': node.category, |
|
'concept': node.concept, |
|
'children': {k: serialize_node(v) for k, v in node.children.items()}, |
|
'info':node.info, |
|
} |
|
|
|
mind_map_dict = serialize_node(root) |
|
with open(filename, 'w', encoding='utf-8') as f: |
|
json.dump(mind_map_dict, f, ensure_ascii=False, indent=2) |
|
|
|
def load_map(self, filename: str): |
|
def deserialize_node(node_data): |
|
category = node_data['category'] |
|
concept = node_data['concept'] |
|
info = node_data['info'] |
|
children_data = node_data['children'] |
|
|
|
node = MindPoint(concept=concept, info=info, lm=self.gen_concept_lm, retriever=self.retriever, category=category) |
|
node.children = {k: deserialize_node(v) for k, v in children_data.items()} |
|
return node |
|
|
|
with open(filename, 'r', encoding='utf-8') as f: |
|
mind_map_dict = json.load(f) |
|
|
|
self.root = deserialize_node(mind_map_dict) |
|
return self.root |
|
|
|
def export_categories_and_concepts(self) -> str: |
|
root = self.root |
|
output = [] |
|
|
|
def traverse(node: MindPoint, indent=0): |
|
output.append(" " * indent + node.category) |
|
for concept in node.concept: |
|
output.append(" " * (indent + 2) + concept) |
|
for child in node.children.values(): |
|
traverse(child, indent + 2) |
|
|
|
traverse(root) |
|
return "\n".join(output) |
|
|
|
def get_all_infos(self) -> List[Dict[str, any]]: |
|
""" |
|
Get all unique info from the MindMap, ensuring unique URLs. |
|
""" |
|
all_infos = [] |
|
seen_urls = set() |
|
|
|
def traverse(node: MindPoint): |
|
if node.info: |
|
for info in node.info: |
|
url = info.get('url') |
|
if url and url not in seen_urls: |
|
seen_urls.add(url) |
|
all_infos.append(info) |
|
for child in node.children.values(): |
|
traverse(child) |
|
|
|
traverse(self.root) |
|
self.all_infos = all_infos |
|
return all_infos |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_web_number(self): |
|
self.collected_urls = [] |
|
self.collected_snippets = [] |
|
seen_urls = set() |
|
|
|
for info in self.get_all_infos(): |
|
url = info.get('url') |
|
snippets = info.get('snippets', []) |
|
if url and url not in seen_urls: |
|
seen_urls.add(url) |
|
for snippet in snippets: |
|
self.collected_urls.append(url) |
|
self.collected_snippets.append(snippet) |
|
|
|
return len(self.collected_snippets) |
|
|
|
def prepare_table_for_retrieval(self): |
|
""" |
|
Prepare collected snippets and URLs for retrieval by encoding the snippets using paraphrase-MiniLM-L6-v2. |
|
collected_urls and collected_snippets have corresponding indices. |
|
""" |
|
self.encoder = SentenceTransformer('./model/paraphrase-MiniLM-L6-v2') |
|
self.collected_urls = [] |
|
self.collected_snippets = [] |
|
seen_urls = set() |
|
|
|
for info in self.get_all_infos(): |
|
url = info.get('url') |
|
snippets = info.get('snippets', []) |
|
if url and url not in seen_urls: |
|
seen_urls.add(url) |
|
for snippet in snippets: |
|
self.collected_urls.append(url) |
|
self.collected_snippets.append(snippet) |
|
|
|
self.encoded_snippets = self.encoder.encode(self.collected_snippets, show_progress_bar=True) |
|
|
|
def retrieve_information(self, queries: Union[List[str], str], search_top_k) -> List[Dict[str, any]]: |
|
""" |
|
Retrieve relevant information based on the given queries. |
|
Returns a list of dictionaries containing 'url' and 'snippets'. |
|
""" |
|
selected_urls = [] |
|
selected_snippets = [] |
|
if type(queries) is str: |
|
queries = [queries] |
|
for query in queries: |
|
encoded_query = self.encoder.encode(query, show_progress_bar=True) |
|
sim = cosine_similarity([encoded_query], self.encoded_snippets)[0] |
|
sorted_indices = np.argsort(sim) |
|
for i in sorted_indices[-search_top_k:][::-1]: |
|
selected_urls.append(self.collected_urls[i]) |
|
selected_snippets.append(self.collected_snippets[i]) |
|
|
|
url_to_snippets = {} |
|
for url, snippet in zip(selected_urls, selected_snippets): |
|
if url not in url_to_snippets: |
|
url_to_snippets[url] = set() |
|
url_to_snippets[url].add(snippet) |
|
|
|
result = [] |
|
for url, snippets in url_to_snippets.items(): |
|
result.append({ |
|
'url': url, |
|
'snippets': list(snippets) |
|
}) |
|
|
|
return result |
|
|
|
def visualize_map(self, root: MindPoint): |
|
G = nx.DiGraph() |
|
|
|
def add_edges(node: MindPoint, parent=None): |
|
if parent is not None: |
|
G.add_edge(parent, node.category) |
|
for child in node.children.values(): |
|
add_edges(child, node.category) |
|
|
|
add_edges(root) |
|
|
|
plt.figure(figsize=(12, 8)) |
|
pos = nx.spring_layout(G) |
|
nx.draw(G, pos, with_labels=True, node_size=3000, node_color='skyblue', font_size=10, font_weight='bold', arrows=True) |
|
plt.title("MindMap Visualization", fontsize=15) |
|
plt.show() |
|
|
|
|
|
if __name__ == "__main__": |
|
import sys |
|
sys.path.append('/mnt/nas-alinlp/xizekun/project/DeepThink/src') |
|
|
|
from lm import OpenAIModel, OpenAIModel_New |
|
from rm import BingSearch, BingSearchAli |
|
from utils import load_api_key |
|
|
|
load_api_key(toml_file_path='/mnt/nas-alinlp/xizekun/project/DeepThink/secrets.toml') |
|
openai_kwargs = { |
|
'api_key': os.getenv("OPENAI_API_KEY"), |
|
'api_provider': os.getenv('OPENAI_API_TYPE'), |
|
'temperature': 1.0, |
|
'top_p': 0.9, |
|
'api_base': os.getenv('AZURE_API_BASE'), |
|
'api_version': os.getenv('AZURE_API_VERSION'), |
|
} |
|
|
|
lm = OpenAIModel(model='gpt-4-1106-preview', max_tokens=5000, **openai_kwargs) |
|
rm = BingSearchAli(ydc_api_key=os.getenv('BING_SEARCH_ALI_API_KEY'), k=3) |
|
|
|
retriever = rm |
|
gen_concept_lm = lm |
|
|
|
mind_map = MindMap( |
|
retriever=retriever, |
|
gen_concept_lm=lm, |
|
search_top_k=3, |
|
deepth = 3, |
|
) |
|
|
|
root = mind_map.build_map('Taylor Hawkins') |
|
mind_map.save_map(root, '/mnt/nas-alinlp/xizekun/project/DeepThink/src/DeepThink/modules/Taylor.json') |
|
|
|
b = mind_map.get_all_infos() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|