File size: 3,663 Bytes
2b61f9d
 
 
 
 
 
 
 
63073b9
2b61f9d
 
 
 
e15199a
 
 
 
2b61f9d
 
 
 
 
 
63073b9
e15199a
 
 
 
 
 
 
 
 
 
 
 
63073b9
 
 
 
e15199a
63073b9
 
 
 
 
 
 
 
2b61f9d
 
 
 
 
 
 
63073b9
e15199a
 
 
 
 
 
 
 
 
 
 
 
63073b9
 
 
 
e15199a
63073b9
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from openai import OpenAI, AsyncOpenAI
from dotenv import load_dotenv
import os

load_dotenv()


class ChatOpenAI:
    def __init__(self, model_name: str = "gpt-4"):
        self.model_name = model_name
        self.openai_api_key = os.getenv("OPENAI_API_KEY")
        if self.openai_api_key is None:
            raise ValueError("OPENAI_API_KEY is not set")
        # More conservative token limits
        self.max_tokens = 4000  # Reduced from 8192 to leave room for context
        self.max_total_tokens = 8000  # Reduced from 16384 to be safe
        self.temperature = 0.7

    def run(self, messages, text_only: bool = True, **kwargs):
        if not isinstance(messages, list):
            raise ValueError("messages must be a list")

        client = OpenAI()
        try:
            # Estimate tokens in messages (rough estimate: 1 token β‰ˆ 4 characters)
            total_chars = sum(len(str(msg.get('content', ''))) for msg in messages)
            estimated_tokens = total_chars // 4
            
            if estimated_tokens > 4000:  # If messages are too long
                print(f"Warning: Messages too long ({estimated_tokens} estimated tokens). Truncating...")
                # Keep only the most recent messages that fit
                while estimated_tokens > 4000 and len(messages) > 2:  # Keep system message and last user message
                    messages.pop(1)  # Remove oldest message after system message
                    total_chars = sum(len(str(msg.get('content', ''))) for msg in messages)
                    estimated_tokens = total_chars // 4

            response = client.chat.completions.create(
                model=self.model_name,
                messages=messages,
                max_tokens=self.max_tokens,
                temperature=self.temperature,
                **kwargs
            )
            if text_only:
                return response.choices[0].message.content
            return response
        except Exception as e:
            print(f"Error in chat completion: {str(e)}")
            raise
    
    async def astream(self, messages, **kwargs):
        if not isinstance(messages, list):
            raise ValueError("messages must be a list")
        
        client = AsyncOpenAI()

        try:
            # Estimate tokens in messages (rough estimate: 1 token β‰ˆ 4 characters)
            total_chars = sum(len(str(msg.get('content', ''))) for msg in messages)
            estimated_tokens = total_chars // 4
            
            if estimated_tokens > 4000:  # If messages are too long
                print(f"Warning: Messages too long ({estimated_tokens} estimated tokens). Truncating...")
                # Keep only the most recent messages that fit
                while estimated_tokens > 4000 and len(messages) > 2:  # Keep system message and last user message
                    messages.pop(1)  # Remove oldest message after system message
                    total_chars = sum(len(str(msg.get('content', ''))) for msg in messages)
                    estimated_tokens = total_chars // 4

            stream = await client.chat.completions.create(
                model=self.model_name,
                messages=messages,
                max_tokens=self.max_tokens,
                temperature=self.temperature,
                stream=True,
                **kwargs
            )

            async for chunk in stream:
                content = chunk.choices[0].delta.content
                if content is not None:
                    yield content
        except Exception as e:
            print(f"Error in chat completion stream: {str(e)}")
            raise