OmniThink / src /interface.py
ZekunXi's picture
push
80a598c
import functools
import logging
import time
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Dict, List, Optional, Union
logging.basicConfig(level=logging.INFO, format='%(name)s : %(levelname)-8s : %(message)s')
logger = logging.getLogger(__name__)
class Information(ABC):
"""Abstract base class to represent basic information.
Attributes:
uuid (str): The unique identifier for the information.
meta (dict): The meta information associated with the information.
"""
def __init__(self, uuid, meta={}):
self.uuid = uuid
self.meta = meta
class InformationTable(ABC):
"""
The InformationTable class serves as data class to store the information
collected during KnowledgeCuration stage.
Create subclass to incorporate more information as needed. For example,
in STORM paper https://arxiv.org/pdf/2402.14207.pdf, additional information
would be perspective guided dialogue history.
"""
def __init__(self):
pass
@abstractmethod
def retrieve_information(**kwargs):
pass
class articleSectionNode(ABC):
"""
The articleSectionNode is the dataclass for handling the section of the article.
The content storage, section writing preferences are defined in this node.
"""
def __init__(self, section_name: str, content=None):
"""
section_name: section heading in string format. E.g. Introduction, History, etc.
content: content of the section. Up to you for design choice of the data structure.
"""
self.section_name = section_name
self.content = content
self.children = []
self.preference = None
self.keywords = []
def add_child(self, new_child_node, insert_to_front=False):
if insert_to_front:
self.children.insert(0, new_child_node)
else:
self.children.append(new_child_node)
def remove_child(self, child):
self.children.remove(child)
class article(ABC):
def __init__(self, topic_name):
self.root = articleSectionNode(topic_name)
def find_section(self, node: articleSectionNode, name: str) -> Optional[articleSectionNode]:
"""
Return the node of the section given the section name.
Args:
node: the node as the root to find.
name: the name of node as section name
Return:
reference of the node or None if section name has no match
"""
if node.section_name == name:
return node
for child in node.children:
result = self.find_section(child, name)
if result:
return result
return None
@abstractmethod
def to_string(self) -> str:
"""
Export article object into string representation.
"""
def get_outline_tree(self):
"""
Generates a hierarchical tree structure representing the outline of the document.
Returns:
Dict[str, Dict]: A nested dictionary representing the hierarchical structure of the document's outline.
Each key is a section name, and the value is another dictionary representing the child sections,
recursively forming the tree structure of the document's outline. If a section has no subsections,
its value is an empty dictionary.
Example:
Assuming a document with a structure like:
- Introduction
- Background
- Objective
- Methods
- Data Collection
- Analysis
The method would return:
{
'Introduction': {
'Background': {},
'Objective': {}
},
'Methods': {
'Data Collection': {},
'Analysis': {}
}
}
"""
def build_tree(node) -> Dict[str, Dict]:
tree = {}
for child in node.children:
tree[child.section_name] = build_tree(child)
return tree if tree else {}
return build_tree(self.root)
def get_first_level_section_names(self) -> List[str]:
"""
Get first level section names
"""
return [i.section_name for i in self.root.children]
@classmethod
@abstractmethod
def from_string(cls, topic_name: str, article_text: str):
"""
Create an instance of the article object from a string
"""
pass
def prune_empty_nodes(self, node=None):
if node is None:
node = self.root
node.children[:] = [child for child in node.children if self.prune_empty_nodes(child)]
if (node.content is None or node.content == "") and not node.children:
return None
else:
return node
class Retriever(ABC):
"""
An abstract base class for retriever modules. It provides a template for retrieving information based on a query.
This class should be extended to implement specific retrieval functionalities.
Users can design their retriever modules as needed by implementing the retrieve method.
The retrieval model/search engine used for each part should be declared with a suffix '_rm' in the attribute name.
"""
def __init__(self, search_top_k):
self.search_top_k = search_top_k
def update_search_top_k(self, k):
self.search_top_k = k
def collect_and_reset_rm_usage(self):
combined_usage = []
for attr_name in self.__dict__:
if '_rm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'):
combined_usage.append(getattr(self, attr_name).get_usage_and_reset())
name_to_usage = {}
for usage in combined_usage:
for model_name, query_cnt in usage.items():
if model_name not in name_to_usage:
name_to_usage[model_name] = query_cnt
else:
name_to_usage[model_name] += query_cnt
return name_to_usage
@abstractmethod
def retrieve(self, query: Union[str, List[str]], **kwargs) -> List[Information]:
"""
Retrieves information based on a query.
This method must be implemented by subclasses to specify how information is retrieved.
Args:
query (Union[str, List[str]]): The query or list of queries to retrieve information for.
**kwargs: Additional keyword arguments that might be necessary for the retrieval process.
Returns:
List[Information]: A list of Information objects retrieved based on the query.
"""
pass
class KnowledgeCurationModule(ABC):
"""
The interface for knowledge curation stage. Given topic, return collected information.
"""
def __init__(self, retriever: Retriever):
"""
Store args and finish initialization.
"""
self.retriever = retriever
@abstractmethod
def research(self, topic) -> InformationTable:
"""
Curate information and knowledge for the given topic
Args:
topic: topic of interest in natural language.
Returns:
collected_information: collected information in InformationTable type.
"""
pass
class OutlineGenerationModule(ABC):
"""
The interface for outline generation stage. Given topic, collected information from knowledge
curation stage, generate outline for the article.
"""
@abstractmethod
def generate_outline(self, topic: str, information_table: InformationTable, **kwargs) -> article:
"""
Generate outline for the article. Required arguments include:
topic: the topic of interest
information_table: knowledge curation data generated from KnowledgeCurationModule
More arguments could be
1. draft outline
2. user provided outline
Returns:
article_outline of type articleOutline
"""
pass
class articleGenerationModule(ABC):
"""
The interface for article generation stage. Given topic, collected information from
knowledge curation stage, generated outline from outline generation stage,
"""
@abstractmethod
def generate_article(self,
topic: str,
information_table: InformationTable,
article_with_outline: article,
**kwargs) -> article:
"""
Generate article. Required arguments include:
topic: the topic of interest
information_table: knowledge curation data generated from KnowledgeCurationModule
article_with_outline: article with specified outline from OutlineGenerationModule
"""
pass
class articlePolishingModule(ABC):
"""
The interface for article generation stage. Given topic, collected information from
knowledge curation stage, generated outline from outline generation stage,
"""
@abstractmethod
def polish_article(self, topic: str, draft_article: article, **kwargs) -> article:
"""
Polish article. Required arguments include:
topic: the topic of interest
draft_article: draft article from articleGenerationModule.
"""
pass
def log_execution_time(func):
"""Decorator to log the execution time of a function."""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
start_time = time.time()
result = func(self, *args, **kwargs)
end_time = time.time()
execution_time = end_time - start_time
logger.info(f"{func.__name__} executed in {execution_time:.4f} seconds")
self.time[func.__name__] = execution_time
return result
return wrapper
class LMConfigs(ABC):
"""Abstract base class for language model configurations of the knowledge curation engine.
The language model used for each part should be declared with a suffix '_lm' in the attribute name."""
def __init__(self):
pass
def init_check(self):
for attr_name in self.__dict__:
if '_lm' in attr_name and getattr(self, attr_name) is None:
logging.warning(
f"Language model for {attr_name} is not initialized. Please call set_{attr_name}()"
)
def collect_and_reset_lm_history(self):
history = []
for attr_name in self.__dict__:
if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'history'):
history.extend(getattr(self, attr_name).history)
getattr(self, attr_name).history = []
return history
def collect_and_reset_lm_usage(self):
combined_usage = []
for attr_name in self.__dict__:
if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'):
combined_usage.append(getattr(self, attr_name).get_usage_and_reset())
model_name_to_usage = {}
for usage in combined_usage:
for model_name, tokens in usage.items():
if model_name not in model_name_to_usage:
model_name_to_usage[model_name] = tokens
else:
model_name_to_usage[model_name]['prompt_tokens'] += tokens['prompt_tokens']
model_name_to_usage[model_name]['completion_tokens'] += tokens['completion_tokens']
return model_name_to_usage
def log(self):
return OrderedDict(
{
attr_name: getattr(self, attr_name).kwargs for attr_name in self.__dict__ if
'_lm' in attr_name and hasattr(getattr(self, attr_name), 'kwargs')
}
)
class Engine(ABC):
def __init__(self, lm_configs: LMConfigs):
self.lm_configs = lm_configs
self.time = {}
self.lm_cost = {} # Cost of language models measured by in/out tokens.
self.rm_cost = {} # Cost of retrievers measured by number of queries.
def log_execution_time_and_lm_rm_usage(self, func):
"""Decorator to log the execution time, language model usage, and retrieval model usage of a function."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
execution_time = end_time - start_time
self.time[func.__name__] = execution_time
logger.info(f"{func.__name__} executed in {execution_time:.4f} seconds")
self.lm_cost[func.__name__] = self.lm_configs.collect_and_reset_lm_usage()
if hasattr(self, 'retriever'):
self.rm_cost[func.__name__] = self.retriever.collect_and_reset_rm_usage()
return result
return wrapper
def apply_decorators(self):
"""Apply decorators to methods that need them."""
methods_to_decorate = [method_name for method_name in dir(self)
if callable(getattr(self, method_name)) and method_name.startswith('run_')]
for method_name in methods_to_decorate:
original_method = getattr(self, method_name)
decorated_method = self.log_execution_time_and_lm_rm_usage(original_method)
setattr(self, method_name, decorated_method)
@abstractmethod
def run_knowledge_curation_module(self, **kwargs) -> Optional[InformationTable]:
pass
@abstractmethod
def run_outline_generation_module(self, **kwarg) -> article:
pass
@abstractmethod
def run_article_generation_module(self, **kwarg) -> article:
pass
@abstractmethod
def run_article_polishing_module(self, **kwarg) -> article:
pass
@abstractmethod
def run(self, **kwargs):
pass
def summary(self):
print("***** Execution time *****")
for k, v in self.time.items():
print(f"{k}: {v:.4f} seconds")
print("***** Token usage of language models: *****")
for k, v in self.lm_cost.items():
print(f"{k}")
for model_name, tokens in v.items():
print(f" {model_name}: {tokens}")
print("***** Number of queries of retrieval models: *****")
for k, v in self.rm_cost.items():
print(f"{k}: {v}")
def reset(self):
self.time = {}
self.lm_cost = {}
self.rm_cost = {}