from openai import OpenAI, AsyncOpenAI from dotenv import load_dotenv import os load_dotenv() class ChatOpenAI: def __init__(self, model_name: str = "gpt-4"): self.model_name = model_name self.openai_api_key = os.getenv("OPENAI_API_KEY") if self.openai_api_key is None: raise ValueError("OPENAI_API_KEY is not set") # More conservative token limits self.max_tokens = 4000 # Reduced from 8192 to leave room for context self.max_total_tokens = 8000 # Reduced from 16384 to be safe self.temperature = 0.7 def run(self, messages, text_only: bool = True, **kwargs): if not isinstance(messages, list): raise ValueError("messages must be a list") client = OpenAI() try: # Estimate tokens in messages (rough estimate: 1 token ≈ 4 characters) total_chars = sum(len(str(msg.get('content', ''))) for msg in messages) estimated_tokens = total_chars // 4 if estimated_tokens > 4000: # If messages are too long print(f"Warning: Messages too long ({estimated_tokens} estimated tokens). Truncating...") # Keep only the most recent messages that fit while estimated_tokens > 4000 and len(messages) > 2: # Keep system message and last user message messages.pop(1) # Remove oldest message after system message total_chars = sum(len(str(msg.get('content', ''))) for msg in messages) estimated_tokens = total_chars // 4 response = client.chat.completions.create( model=self.model_name, messages=messages, max_tokens=self.max_tokens, temperature=self.temperature, **kwargs ) if text_only: return response.choices[0].message.content return response except Exception as e: print(f"Error in chat completion: {str(e)}") raise async def astream(self, messages, **kwargs): if not isinstance(messages, list): raise ValueError("messages must be a list") client = AsyncOpenAI() try: # Estimate tokens in messages (rough estimate: 1 token ≈ 4 characters) total_chars = sum(len(str(msg.get('content', ''))) for msg in messages) estimated_tokens = total_chars // 4 if estimated_tokens > 4000: # If messages are too long print(f"Warning: Messages too long ({estimated_tokens} estimated tokens). Truncating...") # Keep only the most recent messages that fit while estimated_tokens > 4000 and len(messages) > 2: # Keep system message and last user message messages.pop(1) # Remove oldest message after system message total_chars = sum(len(str(msg.get('content', ''))) for msg in messages) estimated_tokens = total_chars // 4 stream = await client.chat.completions.create( model=self.model_name, messages=messages, max_tokens=self.max_tokens, temperature=self.temperature, stream=True, **kwargs ) async for chunk in stream: content = chunk.choices[0].delta.content if content is not None: yield content except Exception as e: print(f"Error in chat completion stream: {str(e)}") raise