|
import time |
|
import json |
|
|
|
import pandas as pd |
|
import streamlit as st |
|
from dotenv import load_dotenv |
|
from http import HTTPStatus |
|
|
|
from src.lm import QwenModel |
|
from src.rm import GoogleSearchAli_new |
|
import sys |
|
sys.path.append('./src/DeepThink/modules') |
|
from mindmap import MindMap |
|
from storm_dataclass import Article |
|
|
|
from article_generation import ArticleGenerationModule |
|
from article_polish import ArticlePolishingModule |
|
from outline_generation import OutlineGenerationModule |
|
|
|
import os |
|
|
|
import subprocess |
|
bash_command = "pip install --upgrade pip" |
|
process = subprocess.Popen(bash_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
|
|
|
|
|
|
|
|
|
openai_kwargs = { |
|
'api_key': os.getenv("OPENAI_API_KEY"), |
|
'api_provider': os.getenv('OPENAI_API_TYPE'), |
|
'temperature': 1.0, |
|
'top_p': 0.9, |
|
'api_base': os.getenv('AZURE_API_BASE'), |
|
'api_version': os.getenv('AZURE_API_VERSION'), |
|
} |
|
|
|
|
|
lm = QwenModel(model='qwen-plus', max_tokens=1000, **openai_kwargs) |
|
lm4outline = QwenModel(model='qwen-plus', max_tokens=1000, **openai_kwargs) |
|
lm4gensection = QwenModel(model='qwen-plus', max_tokens=2000, **openai_kwargs) |
|
lm4polish = QwenModel(model='qwen-plus', max_tokens=4000, **openai_kwargs) |
|
|
|
|
|
rm = GoogleSearchAli_new(k=5) |
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title='OmniThink', layout="wide") |
|
|
|
|
|
|
|
|
|
|
|
|
|
st.warning("Announcement: Due to the recent high volume of visitors, search API quota limitations, you may encounter an error: " |
|
"'ValueError: Expected 2D array, got 1D array instead: array=[]. " |
|
"Reshape your data either using array.reshape(-1, 1) if your data has a single feature " |
|
"or array.reshape(1, -1) if it contains a single sample.' " |
|
"If this error occurs, please try again in a few hours.") |
|
|
|
st.title('🤔 OmniThink') |
|
st.markdown('_OmniThink is a tool that helps you think deeply about a topic, generate an outline, and write an article._') |
|
|
|
|
|
|
|
|
|
|
|
|
|
with st.sidebar: |
|
st.header('Configuration') |
|
MAX_ROUNDS = st.number_input('Retrieval Depth', min_value=0, max_value=10, value=2, step=1) |
|
models = ['Qwen-Plus', 'Coming Soon'] |
|
selected_example = st.selectbox('LLM:', models) |
|
searchers = ['GoogleSearch', 'Coming Soon'] |
|
selected_example = st.selectbox('Search engine', searchers) |
|
|
|
n_max_doc = st.number_input('Number of web pages retrievad in single search', min_value=1, max_value=50, value=10, step=5) |
|
st.header('Examples') |
|
examples = ['AlphaFold', '2024 Hualien City Earthquake', 'Taylor Swift', 'Yoon Seok-youl'] |
|
selected_example = st.selectbox('case', examples) |
|
status_placeholder = st.empty() |
|
|
|
mind_map = MindMap( |
|
retriever=rm, |
|
gen_concept_lm = lm4outline, |
|
gen_concept_lm2 = lm4outline, |
|
search_top_k = n_max_doc, |
|
depth= MAX_ROUNDS |
|
) |
|
|
|
def Think(input_topic): |
|
|
|
generator = mind_map.build_map(input_topic) |
|
|
|
st.markdown(f'Performing an in-depth search on the content related to {input_topic}...') |
|
|
|
for idx, layer in enumerate(generator): |
|
print(layer) |
|
print('layer!!!') |
|
st.markdown(f'Deep Thinking Retrieval at Level {idx + 1}...') |
|
status_placeholder.text(f"Currently conducting the {idx + 1}th level deep thinking retrieval, estimated to take {(idx+1)*3} minutes.") |
|
for node in layer: |
|
category = node.category |
|
|
|
print(f'category: {category}') |
|
with st.expander(f'{category}'): |
|
st.markdown(f'### The concept of {node.category}') |
|
print(node.concept) |
|
for concept in node.concept: |
|
st.markdown(f'* {concept}') |
|
st.markdown(f'### The web of {node.category}') |
|
for idx, info in enumerate(node.info): |
|
st.markdown(f'{idx + 1}. {info["title"]} \n {info["snippets"]}') |
|
|
|
st.markdown(f'Constructing an index table for the {mind_map.get_web_number()} retrieved web pages...') |
|
mind_map.prepare_table_for_retrieval() |
|
return '__finish__', '__finish__' |
|
|
|
def GenOutline(input_topic): |
|
status_placeholder.text("The outline writing is in progress and is expected to take 1 minute.") |
|
ogm = OutlineGenerationModule(lm) |
|
outline = ogm.generate_outline(topic= input_topic, mindmap = mind_map) |
|
|
|
return outline |
|
|
|
def GenArticle(input_topic, outline): |
|
status_placeholder.text("The article writing is in progress and is expected to take 3 minutes.") |
|
|
|
article_with_outline = Article.from_outline_str(topic=input_topic, outline_str=outline) |
|
ag = ArticleGenerationModule(retriever = rm, article_gen_lm = lm, retrieve_top_k = 3, max_thread_num = 10) |
|
article = ag.generate_article(topic = topic, mindmap = mind_map, article_with_outline = article_with_outline) |
|
ap = ArticlePolishingModule(article_gen_lm = lm, article_polish_lm = lm) |
|
article = ap.polish_article(topic = topic, draft_article = article) |
|
return article.to_string() |
|
|
|
with st.form('my_form'): |
|
topic = st.text_input('Please enter the topic you are interested in.', value=selected_example, placeholder='Please enter the topic you are interested in.') |
|
submit_button = st.form_submit_button('Generate!') |
|
|
|
if submit_button: |
|
if topic: |
|
st.markdown('### Thought process') |
|
summary, news_timeline = Think(topic) |
|
st.session_state.summary = summary |
|
st.session_state.news_timeline = news_timeline |
|
|
|
st.markdown('### Outline generation') |
|
with st.expander("Outline generation", expanded=True): |
|
outline = GenOutline(topic) |
|
st.text(outline) |
|
|
|
st.markdown('### article generation') |
|
with st.expander("article generation", expanded=True): |
|
article = GenArticle(topic, outline) |
|
st.markdown(article) |
|
else: |
|
st.error('Please enter the subject.') |
|
|
|
|
|
|