|
from typing import Union, Optional, Tuple |
|
|
|
import dspy |
|
from mindmap import MindMap |
|
from storm_dataclass import Article |
|
import concurrent.futures |
|
import json |
|
import os |
|
import pickle |
|
import re |
|
import sys |
|
from typing import List, Dict |
|
|
|
import httpx |
|
import toml |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from trafilatura import extract |
|
|
|
|
|
class ArticleTextProcessing: |
|
@staticmethod |
|
def limit_word_count_preserve_newline(input_string, max_word_count): |
|
""" |
|
Limit the word count of an input string to a specified maximum, while preserving the integrity of complete lines. |
|
|
|
The function truncates the input string at the nearest word that does not exceed the maximum word count, |
|
ensuring that no partial lines are included in the output. Words are defined as text separated by spaces, |
|
and lines are defined as text separated by newline characters. |
|
|
|
Args: |
|
input_string (str): The string to be truncated. This string may contain multiple lines. |
|
max_word_count (int): The maximum number of words allowed in the truncated string. |
|
|
|
Returns: |
|
str: The truncated string with word count limited to `max_word_count`, preserving complete lines. |
|
""" |
|
|
|
word_count = 0 |
|
limited_string = '' |
|
|
|
for word in input_string.split('\n'): |
|
line_words = word.split() |
|
for lw in line_words: |
|
if word_count < max_word_count: |
|
limited_string += lw + ' ' |
|
word_count += 1 |
|
else: |
|
break |
|
if word_count >= max_word_count: |
|
break |
|
limited_string = limited_string.strip() + '\n' |
|
|
|
return limited_string.strip() |
|
|
|
@staticmethod |
|
def remove_citations(s): |
|
""" |
|
Removes all citations from a given string. Citations are assumed to be in the format |
|
of numbers enclosed in square brackets, such as [1], [2], or [1, 2], etc. This function searches |
|
for all occurrences of such patterns and removes them, returning the cleaned string. |
|
|
|
Args: |
|
s (str): The string from which citations are to be removed. |
|
|
|
Returns: |
|
str: The string with all citation patterns removed. |
|
""" |
|
|
|
return re.sub(r'\[\d+(?:,\s*\d+)*\]', '', s) |
|
|
|
@staticmethod |
|
def get_first_section_dict_and_list(s): |
|
""" |
|
""" |
|
text = s |
|
sections = text.strip().split('\n# ') |
|
titles = [] |
|
content_dict = {} |
|
|
|
for section in sections: |
|
if section: |
|
lines = section.split('\n', 1) |
|
title = lines[0].strip() |
|
content = lines[1].strip() if len(lines) > 1 else "" |
|
|
|
titles.append(title) |
|
content_dict[title] = content |
|
return content_dict, titles |
|
|
|
@staticmethod |
|
def parse_citation_indices(s): |
|
""" |
|
Extracts citation indexes from the provided content string and returns them as a list of integers. |
|
|
|
Args: |
|
content (str): The content string containing citations in the format [number]. |
|
|
|
Returns: |
|
List[int]: A list of unique citation indexes extracted from the content, in the order they appear. |
|
""" |
|
matches = re.findall(r'\[\d+\]', s) |
|
return [int(index[1:-1]) for index in matches] |
|
|
|
@staticmethod |
|
def remove_uncompleted_sentences_with_citations(text): |
|
""" |
|
Removes uncompleted sentences and standalone citations from the input text. Sentences are identified |
|
by their ending punctuation (.!?), optionally followed by a citation in square brackets (e.g., "[1]"). |
|
Grouped citations (e.g., "[1, 2]") are split into individual ones (e.g., "[1] [2]"). Only text up to |
|
and including the last complete sentence and its citation is retained. |
|
|
|
Args: |
|
text (str): The input text from which uncompleted sentences and their citations are to be removed. |
|
|
|
Returns: |
|
str: The processed string with uncompleted sentences and standalone citations removed, leaving only |
|
complete sentences and their associated citations if present. |
|
""" |
|
|
|
|
|
def replace_with_individual_brackets(match): |
|
numbers = match.group(1).split(', ') |
|
return ' '.join(f'[{n}]' for n in numbers) |
|
|
|
|
|
def deduplicate_group(match): |
|
citations = match.group(0) |
|
unique_citations = list(set(re.findall(r'\[\d+\]', citations))) |
|
sorted_citations = sorted(unique_citations, key=lambda x: int(x.strip('[]'))) |
|
|
|
return ''.join(sorted_citations) |
|
|
|
text = re.sub(r'\[([0-9, ]+)\]', replace_with_individual_brackets, text) |
|
text = re.sub(r'(\[\d+\])+', deduplicate_group, text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eos_pattern = r'([.!?])\s*(\[\d+\])?\s*' |
|
matches = list(re.finditer(eos_pattern, text)) |
|
if matches: |
|
last_match = matches[-1] |
|
text = text[:last_match.end()].strip() |
|
|
|
return text |
|
|
|
@staticmethod |
|
def clean_up_citation(conv): |
|
for turn in conv.dlg_history: |
|
turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('References:')] |
|
turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('Sources:')] |
|
turn.agent_utterance = turn.agent_utterance.replace('Answer:', '').strip() |
|
try: |
|
max_ref_num = max([int(x) for x in re.findall(r'\[(\d+)\]', turn.agent_utterance)]) |
|
except Exception as e: |
|
max_ref_num = 0 |
|
if max_ref_num > len(turn.search_results): |
|
for i in range(len(turn.search_results), max_ref_num + 1): |
|
turn.agent_utterance = turn.agent_utterance.replace(f'[{i}]', '') |
|
turn.agent_utterance = ArticleTextProcessing.remove_uncompleted_sentences_with_citations( |
|
turn.agent_utterance) |
|
|
|
return conv |
|
|
|
@staticmethod |
|
def clean_up_outline(outline, topic=""): |
|
output_lines = [] |
|
current_level = 0 |
|
|
|
for line in outline.split('\n'): |
|
stripped_line = line.strip() |
|
|
|
if topic != "" and f"# {topic.lower()}" in stripped_line.lower(): |
|
output_lines = [] |
|
|
|
|
|
if stripped_line.startswith('#') and stripped_line != '#': |
|
current_level = stripped_line.count('#') |
|
output_lines.append(stripped_line) |
|
|
|
|
|
|
|
|
|
|
|
elif stripped_line.startswith('@'): |
|
output_lines.append(stripped_line) |
|
|
|
outline = '\n'.join(output_lines) |
|
|
|
|
|
outline = re.sub(r"#[#]? See also.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? See Also.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? Notes.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? References.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? External links.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? External Links.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? Bibliography.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? Further reading*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? Further Reading*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? Summary.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? Appendices.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
outline = re.sub(r"#[#]? Appendix.*?(?=##|$)", '', outline, flags=re.DOTALL) |
|
|
|
return outline |
|
|
|
|
|
@staticmethod |
|
def clean_up_section(text): |
|
"""Clean up a section: |
|
1. Remove uncompleted sentences (usually due to output token limitation). |
|
2. Deduplicate individual groups of citations. |
|
3. Remove unnecessary summary.""" |
|
|
|
paragraphs = text.split('\n') |
|
output_paragraphs = [] |
|
summary_sec_flag = False |
|
for p in paragraphs: |
|
p = p.strip() |
|
if len(p) == 0: |
|
continue |
|
if not p.startswith('#'): |
|
p = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(p) |
|
if summary_sec_flag: |
|
if p.startswith('#'): |
|
summary_sec_flag = False |
|
else: |
|
continue |
|
if p.startswith('Overall') or p.startswith('In summary') or p.startswith('In conclusion'): |
|
continue |
|
if "# Summary" in p or '# Conclusion' in p: |
|
summary_sec_flag = True |
|
continue |
|
output_paragraphs.append(p) |
|
|
|
return '\n\n'.join(output_paragraphs) |
|
|
|
@staticmethod |
|
def update_citation_index(s, citation_map): |
|
"""Update citation index in the string based on the citation map.""" |
|
for original_citation in citation_map: |
|
s = s.replace(f"[{original_citation}]", f"__PLACEHOLDER_{original_citation}__") |
|
for original_citation, unify_citation in citation_map.items(): |
|
s = s.replace(f"__PLACEHOLDER_{original_citation}__", f"[{unify_citation}]") |
|
|
|
return s |
|
|
|
@staticmethod |
|
def parse_article_into_dict(input_string): |
|
""" |
|
Parses a structured text into a nested dictionary. The structure of the text |
|
is defined by markdown-like headers (using '#' symbols) to denote sections |
|
and subsections. Each section can contain content and further nested subsections. |
|
|
|
The resulting dictionary captures the hierarchical structure of sections, where |
|
each section is represented as a key (the section's title) mapping to a value |
|
that is another dictionary. This dictionary contains two keys: |
|
- 'content': content of the section |
|
- 'subsections': a list of dictionaries, each representing a nested subsection |
|
following the same structure. |
|
|
|
Args: |
|
input_string (str): A string containing the structured text to parse. |
|
|
|
Returns: |
|
A dictionary representing contains the section title as the key, and another dictionary |
|
as the value, which includes the 'content' and 'subsections' keys as described above. |
|
""" |
|
lines = input_string.split('\n') |
|
lines = [line for line in lines if line.strip()] |
|
root = {'content': '', 'subsections': {}} |
|
current_path = [(root, -1)] |
|
|
|
for line in lines: |
|
if line.startswith('#'): |
|
level = line.count('#') |
|
title = line.strip('# ').strip() |
|
new_section = {'content': '', 'subsections': {}} |
|
|
|
|
|
while current_path and current_path[-1][1] >= level: |
|
current_path.pop() |
|
|
|
|
|
current_path[-1][0]['subsections'][title] = new_section |
|
current_path.append((new_section, level)) |
|
else: |
|
current_path[-1][0]['content'] += line + '\n' |
|
|
|
return root['subsections'] |
|
|
|
|
|
class FileIOHelper: |
|
@staticmethod |
|
def dump_json(obj, file_name, encoding="utf-8"): |
|
with open(file_name, 'w', encoding=encoding) as fw: |
|
json.dump(obj, fw, default=FileIOHelper.handle_non_serializable, ensure_ascii=False) |
|
|
|
@staticmethod |
|
def handle_non_serializable(obj): |
|
return "non-serializable contents" |
|
|
|
@staticmethod |
|
def load_json(file_name, encoding="utf-8"): |
|
with open(file_name, 'r', encoding=encoding) as fr: |
|
return json.load(fr) |
|
|
|
@staticmethod |
|
def write_str(s, path): |
|
with open(path, 'w') as f: |
|
f.write(s) |
|
|
|
@staticmethod |
|
def load_str(path): |
|
with open(path, 'r') as f: |
|
return '\n'.join(f.readlines()) |
|
|
|
@staticmethod |
|
def dump_pickle(obj, path): |
|
with open(path, 'wb') as f: |
|
pickle.dump(obj, f) |
|
|
|
@staticmethod |
|
def load_pickle(path): |
|
with open(path, 'rb') as f: |
|
return pickle.load(f) |
|
|
|
|
|
class OutlineGenerationModule(): |
|
""" |
|
根据收集来的信息生成大纲,借鉴一下autosurvey的方法 |
|
""" |
|
|
|
def __init__(self, |
|
outline_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]): |
|
super().__init__() |
|
self.outline_gen_lm = outline_gen_lm |
|
self.write_outline = WriteOutline(engine=self.outline_gen_lm) |
|
|
|
def generate_outline(self, |
|
topic: str, |
|
mindmap: MindMap, |
|
): |
|
|
|
concepts = mindmap.export_categories_and_concepts() |
|
result = self.write_outline(topic=topic, concepts=concepts) |
|
|
|
return result |
|
|
|
class WriteOutline(dspy.Module): |
|
"""Generate the outline for the Wikipedia page.""" |
|
|
|
def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): |
|
super().__init__() |
|
self.draft_page_outline = dspy.Predict(WritePageOutline) |
|
self.polish_page_outline = dspy.Predict(PolishPageOutline) |
|
self.engine = engine |
|
|
|
def forward(self, topic: str, concepts: str): |
|
|
|
with dspy.settings.context(lm=self.engine): |
|
outline = ArticleTextProcessing.clean_up_outline( |
|
self.draft_page_outline(topic=topic).outline) |
|
outline = ArticleTextProcessing.clean_up_outline( |
|
self.polish_page_outline(draft=outline, concepts=concepts).outline) |
|
|
|
return outline |
|
|
|
|
|
class PolishPageOutline(dspy.Signature): |
|
""" |
|
Improve an outline for a Wikipedia page. You already have a draft outline that covers the general information. Now you want to improve it based on the concept learned from an information-seeking to make it more informative. |
|
Here is the format of your writing: |
|
1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. |
|
2. Do not include other information. |
|
3. Do not include topic name itself in the outline. |
|
""" |
|
|
|
draft = dspy.InputField(prefix="Current outline:\n ", format=str) |
|
concepts = dspy.InputField(prefix="The information you learned from the conversation:\n", format=str) |
|
outline = dspy.OutputField(prefix='Write the page outline:\n', format=str) |
|
|
|
|
|
class WritePageOutline(dspy.Signature): |
|
""" |
|
Write an outline for a Wikipedia page. |
|
Here is the format of your writing: |
|
1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. |
|
2. Do not include other information. |
|
3. Do not include topic name itself in the outline. |
|
""" |
|
|
|
topic = dspy.InputField(prefix="The topic you want to write: ", format=str) |
|
outline = dspy.OutputField(prefix="Write the Wikipedia page outline:\n", format=str) |
|
|
|
|
|
if __name__ == "__main__": |
|
import sys |
|
import os |
|
sys.path.append('/mnt/nas-alinlp/xizekun/project/DeepThink/src') |
|
|
|
from lm import OpenAIModel, OpenAIModel_New |
|
from rm import BingSearch, BingSearchAli |
|
from utils import load_api_key |
|
from storm_dataclass import Article |
|
|
|
load_api_key(toml_file_path='/mnt/nas-alinlp/xizekun/project/DeepThink/secrets.toml') |
|
openai_kwargs = { |
|
'api_key': os.getenv("OPENAI_API_KEY"), |
|
'api_provider': os.getenv('OPENAI_API_TYPE'), |
|
'temperature': 1.0, |
|
'top_p': 0.9, |
|
'api_base': os.getenv('AZURE_API_BASE'), |
|
'api_version': os.getenv('AZURE_API_VERSION'), |
|
} |
|
|
|
lm = OpenAIModel(model='gpt-4o', max_tokens=5000, **openai_kwargs) |
|
rm = BingSearchAli(ydc_api_key=os.getenv('BING_SEARCH_ALI_API_KEY'), k=3) |
|
|
|
retriever = rm |
|
gen_concept_lm = lm |
|
|
|
mind_map = MindMap( |
|
retriever=retriever, |
|
gen_concept_lm=lm, |
|
search_top_k=3, |
|
deepth = 3 |
|
) |
|
a = mind_map.load_map('/mnt/nas-alinlp/xizekun/project/DeepThink/src/DeepThink/modules/Taylor.json') |
|
module = OutlineGenerationModule(lm) |
|
outline = module.generate_outline(topic= 'Taylor Hawkins',mindmap = mind_map) |
|
print(outline) |
|
|
|
article = Article.from_outline_str(topic='Taylor Hawkins', outline_str=outline) |
|
print(article.root.keywords) |