|
import functools |
|
import logging |
|
import time |
|
from abc import ABC, abstractmethod |
|
from collections import OrderedDict |
|
from typing import Dict, List, Optional, Union |
|
|
|
logging.basicConfig(level=logging.INFO, format='%(name)s : %(levelname)-8s : %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class Information(ABC): |
|
"""Abstract base class to represent basic information. |
|
|
|
Attributes: |
|
uuid (str): The unique identifier for the information. |
|
meta (dict): The meta information associated with the information. |
|
""" |
|
|
|
def __init__(self, uuid, meta={}): |
|
self.uuid = uuid |
|
self.meta = meta |
|
|
|
|
|
class InformationTable(ABC): |
|
""" |
|
The InformationTable class serves as data class to store the information |
|
collected during KnowledgeCuration stage. |
|
|
|
Create subclass to incorporate more information as needed. For example, |
|
in STORM paper https://arxiv.org/pdf/2402.14207.pdf, additional information |
|
would be perspective guided dialogue history. |
|
""" |
|
|
|
def __init__(self): |
|
pass |
|
|
|
@abstractmethod |
|
def retrieve_information(**kwargs): |
|
pass |
|
|
|
|
|
class articleSectionNode(ABC): |
|
""" |
|
The articleSectionNode is the dataclass for handling the section of the article. |
|
The content storage, section writing preferences are defined in this node. |
|
""" |
|
|
|
def __init__(self, section_name: str, content=None): |
|
""" |
|
section_name: section heading in string format. E.g. Introduction, History, etc. |
|
content: content of the section. Up to you for design choice of the data structure. |
|
""" |
|
self.section_name = section_name |
|
self.content = content |
|
self.children = [] |
|
self.preference = None |
|
self.keywords = [] |
|
|
|
def add_child(self, new_child_node, insert_to_front=False): |
|
if insert_to_front: |
|
self.children.insert(0, new_child_node) |
|
else: |
|
self.children.append(new_child_node) |
|
|
|
def remove_child(self, child): |
|
self.children.remove(child) |
|
|
|
|
|
class article(ABC): |
|
def __init__(self, topic_name): |
|
self.root = articleSectionNode(topic_name) |
|
|
|
def find_section(self, node: articleSectionNode, name: str) -> Optional[articleSectionNode]: |
|
""" |
|
Return the node of the section given the section name. |
|
|
|
Args: |
|
node: the node as the root to find. |
|
name: the name of node as section name |
|
|
|
Return: |
|
reference of the node or None if section name has no match |
|
""" |
|
if node.section_name == name: |
|
return node |
|
for child in node.children: |
|
result = self.find_section(child, name) |
|
if result: |
|
return result |
|
return None |
|
|
|
@abstractmethod |
|
def to_string(self) -> str: |
|
""" |
|
Export article object into string representation. |
|
""" |
|
|
|
def get_outline_tree(self): |
|
""" |
|
Generates a hierarchical tree structure representing the outline of the document. |
|
|
|
Returns: |
|
Dict[str, Dict]: A nested dictionary representing the hierarchical structure of the document's outline. |
|
Each key is a section name, and the value is another dictionary representing the child sections, |
|
recursively forming the tree structure of the document's outline. If a section has no subsections, |
|
its value is an empty dictionary. |
|
|
|
Example: |
|
Assuming a document with a structure like: |
|
- Introduction |
|
- Background |
|
- Objective |
|
- Methods |
|
- Data Collection |
|
- Analysis |
|
The method would return: |
|
{ |
|
'Introduction': { |
|
'Background': {}, |
|
'Objective': {} |
|
}, |
|
'Methods': { |
|
'Data Collection': {}, |
|
'Analysis': {} |
|
} |
|
} |
|
""" |
|
|
|
def build_tree(node) -> Dict[str, Dict]: |
|
tree = {} |
|
for child in node.children: |
|
tree[child.section_name] = build_tree(child) |
|
return tree if tree else {} |
|
|
|
return build_tree(self.root) |
|
|
|
def get_first_level_section_names(self) -> List[str]: |
|
""" |
|
Get first level section names |
|
""" |
|
return [i.section_name for i in self.root.children] |
|
|
|
@classmethod |
|
@abstractmethod |
|
def from_string(cls, topic_name: str, article_text: str): |
|
""" |
|
Create an instance of the article object from a string |
|
""" |
|
pass |
|
|
|
def prune_empty_nodes(self, node=None): |
|
if node is None: |
|
node = self.root |
|
|
|
node.children[:] = [child for child in node.children if self.prune_empty_nodes(child)] |
|
|
|
if (node.content is None or node.content == "") and not node.children: |
|
return None |
|
else: |
|
return node |
|
|
|
|
|
class Retriever(ABC): |
|
""" |
|
An abstract base class for retriever modules. It provides a template for retrieving information based on a query. |
|
|
|
This class should be extended to implement specific retrieval functionalities. |
|
Users can design their retriever modules as needed by implementing the retrieve method. |
|
The retrieval model/search engine used for each part should be declared with a suffix '_rm' in the attribute name. |
|
""" |
|
|
|
def __init__(self, search_top_k): |
|
self.search_top_k = search_top_k |
|
|
|
def update_search_top_k(self, k): |
|
self.search_top_k = k |
|
|
|
def collect_and_reset_rm_usage(self): |
|
combined_usage = [] |
|
for attr_name in self.__dict__: |
|
if '_rm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'): |
|
combined_usage.append(getattr(self, attr_name).get_usage_and_reset()) |
|
|
|
name_to_usage = {} |
|
for usage in combined_usage: |
|
for model_name, query_cnt in usage.items(): |
|
if model_name not in name_to_usage: |
|
name_to_usage[model_name] = query_cnt |
|
else: |
|
name_to_usage[model_name] += query_cnt |
|
|
|
return name_to_usage |
|
|
|
@abstractmethod |
|
def retrieve(self, query: Union[str, List[str]], **kwargs) -> List[Information]: |
|
""" |
|
Retrieves information based on a query. |
|
|
|
This method must be implemented by subclasses to specify how information is retrieved. |
|
|
|
Args: |
|
query (Union[str, List[str]]): The query or list of queries to retrieve information for. |
|
**kwargs: Additional keyword arguments that might be necessary for the retrieval process. |
|
|
|
Returns: |
|
List[Information]: A list of Information objects retrieved based on the query. |
|
""" |
|
pass |
|
|
|
|
|
class KnowledgeCurationModule(ABC): |
|
""" |
|
The interface for knowledge curation stage. Given topic, return collected information. |
|
""" |
|
|
|
def __init__(self, retriever: Retriever): |
|
""" |
|
Store args and finish initialization. |
|
""" |
|
self.retriever = retriever |
|
|
|
@abstractmethod |
|
def research(self, topic) -> InformationTable: |
|
""" |
|
Curate information and knowledge for the given topic |
|
|
|
Args: |
|
topic: topic of interest in natural language. |
|
|
|
Returns: |
|
collected_information: collected information in InformationTable type. |
|
""" |
|
pass |
|
|
|
|
|
class OutlineGenerationModule(ABC): |
|
""" |
|
The interface for outline generation stage. Given topic, collected information from knowledge |
|
curation stage, generate outline for the article. |
|
""" |
|
|
|
@abstractmethod |
|
def generate_outline(self, topic: str, information_table: InformationTable, **kwargs) -> article: |
|
""" |
|
Generate outline for the article. Required arguments include: |
|
topic: the topic of interest |
|
information_table: knowledge curation data generated from KnowledgeCurationModule |
|
|
|
More arguments could be |
|
1. draft outline |
|
2. user provided outline |
|
|
|
Returns: |
|
article_outline of type articleOutline |
|
""" |
|
pass |
|
|
|
|
|
class articleGenerationModule(ABC): |
|
""" |
|
The interface for article generation stage. Given topic, collected information from |
|
knowledge curation stage, generated outline from outline generation stage, |
|
""" |
|
|
|
@abstractmethod |
|
def generate_article(self, |
|
topic: str, |
|
information_table: InformationTable, |
|
article_with_outline: article, |
|
**kwargs) -> article: |
|
""" |
|
Generate article. Required arguments include: |
|
topic: the topic of interest |
|
information_table: knowledge curation data generated from KnowledgeCurationModule |
|
article_with_outline: article with specified outline from OutlineGenerationModule |
|
""" |
|
pass |
|
|
|
|
|
class articlePolishingModule(ABC): |
|
""" |
|
The interface for article generation stage. Given topic, collected information from |
|
knowledge curation stage, generated outline from outline generation stage, |
|
""" |
|
|
|
@abstractmethod |
|
def polish_article(self, topic: str, draft_article: article, **kwargs) -> article: |
|
""" |
|
Polish article. Required arguments include: |
|
topic: the topic of interest |
|
draft_article: draft article from articleGenerationModule. |
|
""" |
|
pass |
|
|
|
|
|
def log_execution_time(func): |
|
"""Decorator to log the execution time of a function.""" |
|
|
|
@functools.wraps(func) |
|
def wrapper(self, *args, **kwargs): |
|
start_time = time.time() |
|
result = func(self, *args, **kwargs) |
|
end_time = time.time() |
|
execution_time = end_time - start_time |
|
logger.info(f"{func.__name__} executed in {execution_time:.4f} seconds") |
|
self.time[func.__name__] = execution_time |
|
return result |
|
|
|
return wrapper |
|
|
|
|
|
class LMConfigs(ABC): |
|
"""Abstract base class for language model configurations of the knowledge curation engine. |
|
|
|
The language model used for each part should be declared with a suffix '_lm' in the attribute name.""" |
|
|
|
def __init__(self): |
|
pass |
|
|
|
def init_check(self): |
|
for attr_name in self.__dict__: |
|
if '_lm' in attr_name and getattr(self, attr_name) is None: |
|
logging.warning( |
|
f"Language model for {attr_name} is not initialized. Please call set_{attr_name}()" |
|
) |
|
|
|
def collect_and_reset_lm_history(self): |
|
history = [] |
|
for attr_name in self.__dict__: |
|
if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'history'): |
|
history.extend(getattr(self, attr_name).history) |
|
getattr(self, attr_name).history = [] |
|
|
|
return history |
|
|
|
def collect_and_reset_lm_usage(self): |
|
combined_usage = [] |
|
for attr_name in self.__dict__: |
|
if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'): |
|
combined_usage.append(getattr(self, attr_name).get_usage_and_reset()) |
|
|
|
model_name_to_usage = {} |
|
for usage in combined_usage: |
|
for model_name, tokens in usage.items(): |
|
if model_name not in model_name_to_usage: |
|
model_name_to_usage[model_name] = tokens |
|
else: |
|
model_name_to_usage[model_name]['prompt_tokens'] += tokens['prompt_tokens'] |
|
model_name_to_usage[model_name]['completion_tokens'] += tokens['completion_tokens'] |
|
|
|
return model_name_to_usage |
|
|
|
def log(self): |
|
|
|
return OrderedDict( |
|
{ |
|
attr_name: getattr(self, attr_name).kwargs for attr_name in self.__dict__ if |
|
'_lm' in attr_name and hasattr(getattr(self, attr_name), 'kwargs') |
|
} |
|
) |
|
|
|
|
|
class Engine(ABC): |
|
def __init__(self, lm_configs: LMConfigs): |
|
self.lm_configs = lm_configs |
|
self.time = {} |
|
self.lm_cost = {} |
|
self.rm_cost = {} |
|
|
|
def log_execution_time_and_lm_rm_usage(self, func): |
|
"""Decorator to log the execution time, language model usage, and retrieval model usage of a function.""" |
|
|
|
@functools.wraps(func) |
|
def wrapper(*args, **kwargs): |
|
start_time = time.time() |
|
result = func(*args, **kwargs) |
|
end_time = time.time() |
|
execution_time = end_time - start_time |
|
self.time[func.__name__] = execution_time |
|
logger.info(f"{func.__name__} executed in {execution_time:.4f} seconds") |
|
self.lm_cost[func.__name__] = self.lm_configs.collect_and_reset_lm_usage() |
|
if hasattr(self, 'retriever'): |
|
self.rm_cost[func.__name__] = self.retriever.collect_and_reset_rm_usage() |
|
return result |
|
|
|
return wrapper |
|
|
|
def apply_decorators(self): |
|
"""Apply decorators to methods that need them.""" |
|
methods_to_decorate = [method_name for method_name in dir(self) |
|
if callable(getattr(self, method_name)) and method_name.startswith('run_')] |
|
for method_name in methods_to_decorate: |
|
original_method = getattr(self, method_name) |
|
decorated_method = self.log_execution_time_and_lm_rm_usage(original_method) |
|
setattr(self, method_name, decorated_method) |
|
|
|
@abstractmethod |
|
def run_knowledge_curation_module(self, **kwargs) -> Optional[InformationTable]: |
|
pass |
|
|
|
@abstractmethod |
|
def run_outline_generation_module(self, **kwarg) -> article: |
|
pass |
|
|
|
@abstractmethod |
|
def run_article_generation_module(self, **kwarg) -> article: |
|
pass |
|
|
|
@abstractmethod |
|
def run_article_polishing_module(self, **kwarg) -> article: |
|
pass |
|
|
|
@abstractmethod |
|
def run(self, **kwargs): |
|
pass |
|
|
|
def summary(self): |
|
print("***** Execution time *****") |
|
for k, v in self.time.items(): |
|
print(f"{k}: {v:.4f} seconds") |
|
|
|
print("***** Token usage of language models: *****") |
|
for k, v in self.lm_cost.items(): |
|
print(f"{k}") |
|
for model_name, tokens in v.items(): |
|
print(f" {model_name}: {tokens}") |
|
|
|
print("***** Number of queries of retrieval models: *****") |
|
for k, v in self.rm_cost.items(): |
|
print(f"{k}: {v}") |
|
|
|
def reset(self): |
|
self.time = {} |
|
self.lm_cost = {} |
|
self.rm_cost = {} |
|
|