OmniThink / app.py
ZekunXi's picture
push
80a598c
import time
import json
import pandas as pd
import streamlit as st
from dotenv import load_dotenv
from http import HTTPStatus
from src.lm import QwenModel
from src.rm import GoogleSearchAli_new
import sys
sys.path.append('./src/DeepThink/modules')
from mindmap import MindMap
from storm_dataclass import Article
from article_generation import ArticleGenerationModule
from article_polish import ArticlePolishingModule
from outline_generation import OutlineGenerationModule
import os
import subprocess
bash_command = "pip install --upgrade pip"
process = subprocess.Popen(bash_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# Load environment variables and API keys
# load_dotenv()
openai_kwargs = {
'api_key': os.getenv("OPENAI_API_KEY"),
'api_provider': os.getenv('OPENAI_API_TYPE'),
'temperature': 1.0,
'top_p': 0.9,
'api_base': os.getenv('AZURE_API_BASE'),
'api_version': os.getenv('AZURE_API_VERSION'),
}
lm = QwenModel(model='qwen-plus', max_tokens=1000, **openai_kwargs)
lm4outline = QwenModel(model='qwen-plus', max_tokens=1000, **openai_kwargs)
lm4gensection = QwenModel(model='qwen-plus', max_tokens=2000, **openai_kwargs)
lm4polish = QwenModel(model='qwen-plus', max_tokens=4000, **openai_kwargs)
rm = GoogleSearchAli_new(k=5)
st.set_page_config(page_title='OmniThink', layout="wide")
st.warning("Announcement: Due to the recent high volume of visitors, search API quota limitations, you may encounter an error: "
"'ValueError: Expected 2D array, got 1D array instead: array=[]. "
"Reshape your data either using array.reshape(-1, 1) if your data has a single feature "
"or array.reshape(1, -1) if it contains a single sample.' "
"If this error occurs, please try again in a few hours.")
st.title('🤔 OmniThink')
st.markdown('_OmniThink is a tool that helps you think deeply about a topic, generate an outline, and write an article._')
# Sidebar for configuration and examples
with st.sidebar:
st.header('Configuration')
MAX_ROUNDS = st.number_input('Retrieval Depth', min_value=0, max_value=10, value=2, step=1)
models = ['Qwen-Plus', 'Coming Soon']
selected_example = st.selectbox('LLM:', models)
searchers = ['GoogleSearch', 'Coming Soon']
selected_example = st.selectbox('Search engine', searchers)
n_max_doc = st.number_input('Number of web pages retrievad in single search', min_value=1, max_value=50, value=10, step=5)
st.header('Examples')
examples = ['AlphaFold', '2024 Hualien City Earthquake', 'Taylor Swift', 'Yoon Seok-youl']
selected_example = st.selectbox('case', examples)
status_placeholder = st.empty()
mind_map = MindMap(
retriever=rm,
gen_concept_lm = lm4outline,
gen_concept_lm2 = lm4outline,
search_top_k = n_max_doc,
depth= MAX_ROUNDS
)
def Think(input_topic):
generator = mind_map.build_map(input_topic)
st.markdown(f'Performing an in-depth search on the content related to {input_topic}...')
for idx, layer in enumerate(generator):
print(layer)
print('layer!!!')
st.markdown(f'Deep Thinking Retrieval at Level {idx + 1}...')
status_placeholder.text(f"Currently conducting the {idx + 1}th level deep thinking retrieval, estimated to take {(idx+1)*3} minutes.")
for node in layer:
category = node.category
print(f'category: {category}')
with st.expander(f'{category}'):
st.markdown(f'### The concept of {node.category}')
print(node.concept)
for concept in node.concept:
st.markdown(f'* {concept}')
st.markdown(f'### The web of {node.category}')
for idx, info in enumerate(node.info):
st.markdown(f'{idx + 1}. {info["title"]} \n {info["snippets"]}')
st.markdown(f'Constructing an index table for the {mind_map.get_web_number()} retrieved web pages...')
mind_map.prepare_table_for_retrieval()
return '__finish__', '__finish__'
def GenOutline(input_topic):
status_placeholder.text("The outline writing is in progress and is expected to take 1 minute.")
ogm = OutlineGenerationModule(lm)
outline = ogm.generate_outline(topic= input_topic, mindmap = mind_map)
return outline
def GenArticle(input_topic, outline):
status_placeholder.text("The article writing is in progress and is expected to take 3 minutes.")
article_with_outline = Article.from_outline_str(topic=input_topic, outline_str=outline)
ag = ArticleGenerationModule(retriever = rm, article_gen_lm = lm, retrieve_top_k = 3, max_thread_num = 10)
article = ag.generate_article(topic = topic, mindmap = mind_map, article_with_outline = article_with_outline)
ap = ArticlePolishingModule(article_gen_lm = lm, article_polish_lm = lm)
article = ap.polish_article(topic = topic, draft_article = article)
return article.to_string()
with st.form('my_form'):
topic = st.text_input('Please enter the topic you are interested in.', value=selected_example, placeholder='Please enter the topic you are interested in.')
submit_button = st.form_submit_button('Generate!')
if submit_button:
if topic:
st.markdown('### Thought process')
summary, news_timeline = Think(topic)
st.session_state.summary = summary
st.session_state.news_timeline = news_timeline
st.markdown('### Outline generation')
with st.expander("Outline generation", expanded=True):
outline = GenOutline(topic)
st.text(outline)
st.markdown('### article generation')
with st.expander("article generation", expanded=True):
article = GenArticle(topic, outline)
st.markdown(article)
else:
st.error('Please enter the subject.')