GPT-Researcher / gpt_researcher /skills /context_manager.py
Shreyas094's picture
Upload 528 files
372531f verified
raw
history blame contribute delete
3.72 kB
import asyncio
from typing import List, Dict, Optional, Set
from ..context.compression import ContextCompressor, WrittenContentCompressor, VectorstoreCompressor
from ..actions.utils import stream_output
class ContextManager:
"""Manages context for the researcher agent."""
def __init__(self, researcher):
self.researcher = researcher
async def get_similar_content_by_query(self, query, pages):
if self.researcher.verbose:
await stream_output(
"logs",
"fetching_query_content",
f"πŸ“š Getting relevant content based on query: {query}...",
self.researcher.websocket,
)
context_compressor = ContextCompressor(
documents=pages, embeddings=self.researcher.memory.get_embeddings()
)
return await context_compressor.async_get_context(
query=query, max_results=10, cost_callback=self.researcher.add_costs
)
async def get_similar_content_by_query_with_vectorstore(self, query, filter):
if self.researcher.verbose:
await stream_output(
"logs",
"fetching_query_format",
f" Getting relevant content based on query: {query}...",
self.researcher.websocket,
)
vectorstore_compressor = VectorstoreCompressor(self.researcher.vector_store, filter)
return await vectorstore_compressor.async_get_context(query=query, max_results=8)
async def get_similar_written_contents_by_draft_section_titles(
self,
current_subtopic: str,
draft_section_titles: List[str],
written_contents: List[Dict],
max_results: int = 10
) -> List[str]:
all_queries = [current_subtopic] + draft_section_titles
async def process_query(query: str) -> Set[str]:
return set(await self.__get_similar_written_contents_by_query(query, written_contents))
results = await asyncio.gather(*[process_query(query) for query in all_queries])
relevant_contents = set().union(*results)
relevant_contents = list(relevant_contents)[:max_results]
if relevant_contents and self.researcher.verbose:
prettier_contents = "\n".join(relevant_contents)
await stream_output(
"logs", "relevant_contents_context", f"πŸ“ƒ {prettier_contents}", self.researcher.websocket
)
return relevant_contents
async def __get_similar_written_contents_by_query(self,
query: str,
written_contents: List[Dict],
similarity_threshold: float = 0.5,
max_results: int = 10
) -> List[str]:
if self.researcher.verbose:
await stream_output(
"logs",
"fetching_relevant_written_content",
f"πŸ”Ž Getting relevant written content based on query: {query}...",
self.researcher.websocket,
)
written_content_compressor = WrittenContentCompressor(
documents=written_contents,
embeddings=self.researcher.memory.get_embeddings(),
similarity_threshold=similarity_threshold
)
return await written_content_compressor.async_get_context(
query=query, max_results=max_results, cost_callback=self.researcher.add_costs
)