Spaces:
Sleeping
Sleeping
Implement conservative token limits and better context management
Browse files- aimakerspace/openai_utils/chatmodel.py +30 -4
- app.py +3 -3
aimakerspace/openai_utils/chatmodel.py
CHANGED
@@ -11,8 +11,10 @@ class ChatOpenAI:
|
|
11 |
self.openai_api_key = os.getenv("OPENAI_API_KEY")
|
12 |
if self.openai_api_key is None:
|
13 |
raise ValueError("OPENAI_API_KEY is not set")
|
14 |
-
|
15 |
-
self.
|
|
|
|
|
16 |
|
17 |
def run(self, messages, text_only: bool = True, **kwargs):
|
18 |
if not isinstance(messages, list):
|
@@ -20,11 +22,23 @@ class ChatOpenAI:
|
|
20 |
|
21 |
client = OpenAI()
|
22 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
response = client.chat.completions.create(
|
24 |
model=self.model_name,
|
25 |
messages=messages,
|
26 |
max_tokens=self.max_tokens,
|
27 |
-
temperature=
|
28 |
**kwargs
|
29 |
)
|
30 |
if text_only:
|
@@ -41,11 +55,23 @@ class ChatOpenAI:
|
|
41 |
client = AsyncOpenAI()
|
42 |
|
43 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
stream = await client.chat.completions.create(
|
45 |
model=self.model_name,
|
46 |
messages=messages,
|
47 |
max_tokens=self.max_tokens,
|
48 |
-
temperature=
|
49 |
stream=True,
|
50 |
**kwargs
|
51 |
)
|
|
|
11 |
self.openai_api_key = os.getenv("OPENAI_API_KEY")
|
12 |
if self.openai_api_key is None:
|
13 |
raise ValueError("OPENAI_API_KEY is not set")
|
14 |
+
# More conservative token limits
|
15 |
+
self.max_tokens = 4000 # Reduced from 8192 to leave room for context
|
16 |
+
self.max_total_tokens = 8000 # Reduced from 16384 to be safe
|
17 |
+
self.temperature = 0.7
|
18 |
|
19 |
def run(self, messages, text_only: bool = True, **kwargs):
|
20 |
if not isinstance(messages, list):
|
|
|
22 |
|
23 |
client = OpenAI()
|
24 |
try:
|
25 |
+
# Estimate tokens in messages (rough estimate: 1 token β 4 characters)
|
26 |
+
total_chars = sum(len(str(msg.get('content', ''))) for msg in messages)
|
27 |
+
estimated_tokens = total_chars // 4
|
28 |
+
|
29 |
+
if estimated_tokens > 4000: # If messages are too long
|
30 |
+
print(f"Warning: Messages too long ({estimated_tokens} estimated tokens). Truncating...")
|
31 |
+
# Keep only the most recent messages that fit
|
32 |
+
while estimated_tokens > 4000 and len(messages) > 2: # Keep system message and last user message
|
33 |
+
messages.pop(1) # Remove oldest message after system message
|
34 |
+
total_chars = sum(len(str(msg.get('content', ''))) for msg in messages)
|
35 |
+
estimated_tokens = total_chars // 4
|
36 |
+
|
37 |
response = client.chat.completions.create(
|
38 |
model=self.model_name,
|
39 |
messages=messages,
|
40 |
max_tokens=self.max_tokens,
|
41 |
+
temperature=self.temperature,
|
42 |
**kwargs
|
43 |
)
|
44 |
if text_only:
|
|
|
55 |
client = AsyncOpenAI()
|
56 |
|
57 |
try:
|
58 |
+
# Estimate tokens in messages (rough estimate: 1 token β 4 characters)
|
59 |
+
total_chars = sum(len(str(msg.get('content', ''))) for msg in messages)
|
60 |
+
estimated_tokens = total_chars // 4
|
61 |
+
|
62 |
+
if estimated_tokens > 4000: # If messages are too long
|
63 |
+
print(f"Warning: Messages too long ({estimated_tokens} estimated tokens). Truncating...")
|
64 |
+
# Keep only the most recent messages that fit
|
65 |
+
while estimated_tokens > 4000 and len(messages) > 2: # Keep system message and last user message
|
66 |
+
messages.pop(1) # Remove oldest message after system message
|
67 |
+
total_chars = sum(len(str(msg.get('content', ''))) for msg in messages)
|
68 |
+
estimated_tokens = total_chars // 4
|
69 |
+
|
70 |
stream = await client.chat.completions.create(
|
71 |
model=self.model_name,
|
72 |
messages=messages,
|
73 |
max_tokens=self.max_tokens,
|
74 |
+
temperature=self.temperature,
|
75 |
stream=True,
|
76 |
**kwargs
|
77 |
)
|
app.py
CHANGED
@@ -32,12 +32,12 @@ class RetrievalAugmentedQAPipeline:
|
|
32 |
|
33 |
async def arun_pipeline(self, user_query: str):
|
34 |
# Get more contexts but limit the total length
|
35 |
-
context_list = self.vector_db_retriever.search_by_text(user_query, k=6
|
36 |
|
37 |
-
# Limit total context length to approximately
|
38 |
context_prompt = ""
|
39 |
total_length = 0
|
40 |
-
max_length =
|
41 |
|
42 |
for context in context_list:
|
43 |
if total_length + len(context[0]) > max_length:
|
|
|
32 |
|
33 |
async def arun_pipeline(self, user_query: str):
|
34 |
# Get more contexts but limit the total length
|
35 |
+
context_list = self.vector_db_retriever.search_by_text(user_query, k=3) # Reduced from 6 to 3
|
36 |
|
37 |
+
# Limit total context length to approximately 3000 tokens (12000 characters)
|
38 |
context_prompt = ""
|
39 |
total_length = 0
|
40 |
+
max_length = 12000 # Reduced from 24000 to 12000
|
41 |
|
42 |
for context in context_list:
|
43 |
if total_length + len(context[0]) > max_length:
|