from __future__ import annotations import json import logging import os from typing import List, Optional, Union import backoff import openai from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.chat_models.base import BaseChatModel from langchain.schema import ( AIMessage, HumanMessage, SystemMessage, messages_from_dict, messages_to_dict, ) from langchain_openai import AzureChatOpenAI, ChatOpenAI from gpt_engineer.core.token_usage import TokenUsageLog # Type hint for a chat message Message = Union[AIMessage, HumanMessage, SystemMessage] # Set up logging logger = logging.getLogger(__name__) class AI: def __init__( self, model_name="gpt-4-1106-preview", temperature=0.1, azure_endpoint="", streaming=True, ): """ Initialize the AI class. Parameters ---------- model_name : str, optional The name of the model to use, by default "gpt-4". temperature : float, optional The temperature to use for the model, by default 0.1. """ self.temperature = temperature self.azure_endpoint = azure_endpoint self.model_name = model_name self.streaming = streaming self.llm = self._create_chat_model() self.token_usage_log = TokenUsageLog(model_name) logger.debug(f"Using model {self.model_name}") def start(self, system: str, user: str, step_name: str) -> List[Message]: """ Start the conversation with a system message and a user message. Parameters ---------- system : str The content of the system message. user : str The content of the user message. step_name : str The name of the step. Returns ------- List[Message] The list of messages in the conversation. """ messages: List[Message] = [ SystemMessage(content=system), HumanMessage(content=user), ] return self.next(messages, step_name=step_name) def next( self, messages: List[Message], prompt: Optional[str] = None, *, step_name: str, ) -> List[Message]: """ Advances the conversation by sending message history to LLM and updating with the response. Parameters ---------- messages : List[Message] The list of messages in the conversation. prompt : Optional[str], optional The prompt to use, by default None. step_name : str The name of the step. Returns ------- List[Message] The updated list of messages in the conversation. """ """ Advances the conversation by sending message history to LLM and updating with the response. """ if prompt: messages.append(HumanMessage(content=prompt)) logger.debug(f"Creating a new chat completion: {messages}") response = self.backoff_inference(messages) self.token_usage_log.update_log( messages=messages, answer=response.content, step_name=step_name ) messages.append(response) logger.debug(f"Chat completion finished: {messages}") return messages @backoff.on_exception(backoff.expo, openai.RateLimitError, max_tries=7, max_time=45) def backoff_inference(self, messages): """ Perform inference using the language model while implementing an exponential backoff strategy. This function will retry the inference in case of a rate limit error from the OpenAI API. It uses an exponential backoff strategy, meaning the wait time between retries increases exponentially. The function will attempt to retry up to 7 times within a span of 45 seconds. Parameters ---------- messages : List[Message] A list of chat messages which will be passed to the language model for processing. callbacks : List[Callable] A list of callback functions that are triggered after each inference. These functions can be used for logging, monitoring, or other auxiliary tasks. Returns ------- Any The output from the language model after processing the provided messages. Raises ------ openai.error.RateLimitError If the number of retries exceeds the maximum or if the rate limit persists beyond the allotted time, the function will ultimately raise a RateLimitError. Example ------- >>> messages = [SystemMessage(content="Hello"), HumanMessage(content="How's the weather?")] >>> response = backoff_inference(messages) """ return self.llm.invoke(messages) # type: ignore @staticmethod def serialize_messages(messages: List[Message]) -> str: """ Serialize a list of messages to a JSON string. Parameters ---------- messages : List[Message] The list of messages to serialize. Returns ------- str The serialized messages as a JSON string. """ return json.dumps(messages_to_dict(messages)) @staticmethod def deserialize_messages(jsondictstr: str) -> List[Message]: """ Deserialize a JSON string to a list of messages. Parameters ---------- jsondictstr : str The JSON string to deserialize. Returns ------- List[Message] The deserialized list of messages. """ data = json.loads(jsondictstr) # Modify implicit is_chunk property to ALWAYS false # since Langchain's Message schema is stricter prevalidated_data = [ {**item, "tools": {**item.get("tools", {}), "is_chunk": False}} for item in data ] return list(messages_from_dict(prevalidated_data)) # type: ignore def _create_chat_model(self) -> BaseChatModel: """ Create a chat model with the specified model name and temperature. Parameters ---------- model : str The name of the model to create. temperature : float The temperature to use for the model. Returns ------- BaseChatModel The created chat model. """ if self.azure_endpoint: return AzureChatOpenAI( openai_api_base=self.azure_endpoint, openai_api_version=os.getenv("OPENAI_API_VERSION", "2023-05-15"), deployment_name=self.model_name, openai_api_type="azure", streaming=self.streaming, callbacks=[StreamingStdOutCallbackHandler()], ) return ChatOpenAI( model=self.model_name, temperature=self.temperature, streaming=self.streaming, callbacks=[StreamingStdOutCallbackHandler()], ) def serialize_messages(messages: List[Message]) -> str: return AI.serialize_messages(messages)