|
import json |
|
import operator |
|
from pathlib import Path |
|
from dotenv import load_dotenv |
|
from datetime import datetime |
|
from typing import List, Dict, Any, TypedDict, Annotated, Optional |
|
|
|
from langgraph.graph import StateGraph, END |
|
from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage |
|
from langchain_core.language_models import BaseLanguageModel |
|
from langchain_core.tools import BaseTool |
|
|
|
_ = load_dotenv() |
|
|
|
|
|
class ToolCallLog(TypedDict): |
|
""" |
|
A TypedDict representing a log entry for a tool call. |
|
|
|
Attributes: |
|
timestamp (str): The timestamp of when the tool call was made. |
|
tool_call_id (str): The unique identifier for the tool call. |
|
name (str): The name of the tool that was called. |
|
args (Any): The arguments passed to the tool. |
|
content (str): The content or result of the tool call. |
|
""" |
|
|
|
timestamp: str |
|
tool_call_id: str |
|
name: str |
|
args: Any |
|
content: str |
|
|
|
|
|
class AgentState(TypedDict): |
|
""" |
|
A TypedDict representing the state of an agent. |
|
|
|
Attributes: |
|
messages (Annotated[List[AnyMessage], operator.add]): A list of messages |
|
representing the conversation history. The operator.add annotation |
|
indicates that new messages should be appended to this list. |
|
""" |
|
|
|
messages: Annotated[List[AnyMessage], operator.add] |
|
|
|
|
|
class Agent: |
|
""" |
|
A class representing an agent that processes requests and executes tools based on |
|
language model responses. |
|
|
|
Attributes: |
|
model (BaseLanguageModel): The language model used for processing. |
|
tools (Dict[str, BaseTool]): A dictionary of available tools. |
|
checkpointer (Any): Manages and persists the agent's state. |
|
system_prompt (str): The system instructions for the agent. |
|
workflow (StateGraph): The compiled workflow for the agent's processing. |
|
log_tools (bool): Whether to log tool calls. |
|
log_path (Path): Path to save tool call logs. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model: BaseLanguageModel, |
|
tools: List[BaseTool], |
|
checkpointer: Any = None, |
|
system_prompt: str = "", |
|
log_tools: bool = True, |
|
log_dir: Optional[str] = "logs", |
|
): |
|
""" |
|
Initialize the Agent. |
|
|
|
Args: |
|
model (BaseLanguageModel): The language model to use. |
|
tools (List[BaseTool]): A list of available tools. |
|
checkpointer (Any, optional): State persistence manager. Defaults to None. |
|
system_prompt (str, optional): System instructions. Defaults to "". |
|
log_tools (bool, optional): Whether to log tool calls. Defaults to True. |
|
log_dir (str, optional): Directory to save logs. Defaults to 'logs'. |
|
""" |
|
self.system_prompt = system_prompt |
|
self.log_tools = log_tools |
|
|
|
if self.log_tools: |
|
self.log_path = Path(log_dir or "logs") |
|
self.log_path.mkdir(exist_ok=True) |
|
|
|
|
|
workflow = StateGraph(AgentState) |
|
workflow.add_node("process", self.process_request) |
|
workflow.add_node("execute", self.execute_tools) |
|
workflow.add_conditional_edges( |
|
"process", self.has_tool_calls, {True: "execute", False: END} |
|
) |
|
workflow.add_edge("execute", "process") |
|
workflow.set_entry_point("process") |
|
|
|
self.workflow = workflow.compile(checkpointer=checkpointer) |
|
self.tools = {t.name: t for t in tools} |
|
self.model = model.bind_tools(tools) |
|
|
|
def process_request(self, state: AgentState) -> Dict[str, List[AnyMessage]]: |
|
""" |
|
Process the request using the language model. |
|
|
|
Args: |
|
state (AgentState): The current state of the agent. |
|
|
|
Returns: |
|
Dict[str, List[AnyMessage]]: A dictionary containing the model's response. |
|
""" |
|
messages = state["messages"] |
|
if self.system_prompt: |
|
messages = [SystemMessage(content=self.system_prompt)] + messages |
|
response = self.model.invoke(messages) |
|
return {"messages": [response]} |
|
|
|
def has_tool_calls(self, state: AgentState) -> bool: |
|
""" |
|
Check if the response contains any tool calls. |
|
|
|
Args: |
|
state (AgentState): The current state of the agent. |
|
|
|
Returns: |
|
bool: True if tool calls exist, False otherwise. |
|
""" |
|
response = state["messages"][-1] |
|
return len(response.tool_calls) > 0 |
|
|
|
def execute_tools(self, state: AgentState) -> Dict[str, List[ToolMessage]]: |
|
""" |
|
Execute tool calls from the model's response. |
|
|
|
Args: |
|
state (AgentState): The current state of the agent. |
|
|
|
Returns: |
|
Dict[str, List[ToolMessage]]: A dictionary containing tool execution results. |
|
""" |
|
tool_calls = state["messages"][-1].tool_calls |
|
results = [] |
|
|
|
for call in tool_calls: |
|
print(f"Executing tool: {call}") |
|
if call["name"] not in self.tools: |
|
print("\n....invalid tool....") |
|
result = "invalid tool, please retry" |
|
else: |
|
result = self.tools[call["name"]].invoke(call["args"]) |
|
|
|
results.append( |
|
ToolMessage( |
|
tool_call_id=call["id"], |
|
name=call["name"], |
|
args=call["args"], |
|
content=str(result), |
|
) |
|
) |
|
|
|
self._save_tool_calls(results) |
|
print("Returning to model processing!") |
|
|
|
return {"messages": results} |
|
|
|
def _save_tool_calls(self, tool_calls: List[ToolMessage]) -> None: |
|
""" |
|
Save tool calls to a JSON file with timestamp-based naming. |
|
|
|
Args: |
|
tool_calls (List[ToolMessage]): List of tool calls to save. |
|
""" |
|
if not self.log_tools: |
|
return |
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
filename = self.log_path / f"tool_calls_{timestamp}.json" |
|
|
|
logs: List[ToolCallLog] = [] |
|
for call in tool_calls: |
|
log_entry = { |
|
"tool_call_id": call.tool_call_id, |
|
"name": call.name, |
|
"args": call.args, |
|
"content": call.content, |
|
"timestamp": datetime.now().isoformat(), |
|
} |
|
logs.append(log_entry) |
|
|
|
with open(filename, "w") as f: |
|
json.dump(logs, f, indent=4) |
|
|