atolat30 commited on
Commit
e15199a
Β·
1 Parent(s): 63073b9

Implement conservative token limits and better context management

Browse files
Files changed (2) hide show
  1. aimakerspace/openai_utils/chatmodel.py +30 -4
  2. app.py +3 -3
aimakerspace/openai_utils/chatmodel.py CHANGED
@@ -11,8 +11,10 @@ class ChatOpenAI:
11
  self.openai_api_key = os.getenv("OPENAI_API_KEY")
12
  if self.openai_api_key is None:
13
  raise ValueError("OPENAI_API_KEY is not set")
14
- self.max_tokens = 8192 # Maximum tokens for response
15
- self.max_total_tokens = 16384 # Maximum total tokens (prompt + response)
 
 
16
 
17
  def run(self, messages, text_only: bool = True, **kwargs):
18
  if not isinstance(messages, list):
@@ -20,11 +22,23 @@ class ChatOpenAI:
20
 
21
  client = OpenAI()
22
  try:
 
 
 
 
 
 
 
 
 
 
 
 
23
  response = client.chat.completions.create(
24
  model=self.model_name,
25
  messages=messages,
26
  max_tokens=self.max_tokens,
27
- temperature=0.7, # Add some creativity while maintaining accuracy
28
  **kwargs
29
  )
30
  if text_only:
@@ -41,11 +55,23 @@ class ChatOpenAI:
41
  client = AsyncOpenAI()
42
 
43
  try:
 
 
 
 
 
 
 
 
 
 
 
 
44
  stream = await client.chat.completions.create(
45
  model=self.model_name,
46
  messages=messages,
47
  max_tokens=self.max_tokens,
48
- temperature=0.7, # Add some creativity while maintaining accuracy
49
  stream=True,
50
  **kwargs
51
  )
 
11
  self.openai_api_key = os.getenv("OPENAI_API_KEY")
12
  if self.openai_api_key is None:
13
  raise ValueError("OPENAI_API_KEY is not set")
14
+ # More conservative token limits
15
+ self.max_tokens = 4000 # Reduced from 8192 to leave room for context
16
+ self.max_total_tokens = 8000 # Reduced from 16384 to be safe
17
+ self.temperature = 0.7
18
 
19
  def run(self, messages, text_only: bool = True, **kwargs):
20
  if not isinstance(messages, list):
 
22
 
23
  client = OpenAI()
24
  try:
25
+ # Estimate tokens in messages (rough estimate: 1 token β‰ˆ 4 characters)
26
+ total_chars = sum(len(str(msg.get('content', ''))) for msg in messages)
27
+ estimated_tokens = total_chars // 4
28
+
29
+ if estimated_tokens > 4000: # If messages are too long
30
+ print(f"Warning: Messages too long ({estimated_tokens} estimated tokens). Truncating...")
31
+ # Keep only the most recent messages that fit
32
+ while estimated_tokens > 4000 and len(messages) > 2: # Keep system message and last user message
33
+ messages.pop(1) # Remove oldest message after system message
34
+ total_chars = sum(len(str(msg.get('content', ''))) for msg in messages)
35
+ estimated_tokens = total_chars // 4
36
+
37
  response = client.chat.completions.create(
38
  model=self.model_name,
39
  messages=messages,
40
  max_tokens=self.max_tokens,
41
+ temperature=self.temperature,
42
  **kwargs
43
  )
44
  if text_only:
 
55
  client = AsyncOpenAI()
56
 
57
  try:
58
+ # Estimate tokens in messages (rough estimate: 1 token β‰ˆ 4 characters)
59
+ total_chars = sum(len(str(msg.get('content', ''))) for msg in messages)
60
+ estimated_tokens = total_chars // 4
61
+
62
+ if estimated_tokens > 4000: # If messages are too long
63
+ print(f"Warning: Messages too long ({estimated_tokens} estimated tokens). Truncating...")
64
+ # Keep only the most recent messages that fit
65
+ while estimated_tokens > 4000 and len(messages) > 2: # Keep system message and last user message
66
+ messages.pop(1) # Remove oldest message after system message
67
+ total_chars = sum(len(str(msg.get('content', ''))) for msg in messages)
68
+ estimated_tokens = total_chars // 4
69
+
70
  stream = await client.chat.completions.create(
71
  model=self.model_name,
72
  messages=messages,
73
  max_tokens=self.max_tokens,
74
+ temperature=self.temperature,
75
  stream=True,
76
  **kwargs
77
  )
app.py CHANGED
@@ -32,12 +32,12 @@ class RetrievalAugmentedQAPipeline:
32
 
33
  async def arun_pipeline(self, user_query: str):
34
  # Get more contexts but limit the total length
35
- context_list = self.vector_db_retriever.search_by_text(user_query, k=6)
36
 
37
- # Limit total context length to approximately 6000 tokens (24000 characters)
38
  context_prompt = ""
39
  total_length = 0
40
- max_length = 24000 # Rough estimate: 1 token β‰ˆ 4 characters
41
 
42
  for context in context_list:
43
  if total_length + len(context[0]) > max_length:
 
32
 
33
  async def arun_pipeline(self, user_query: str):
34
  # Get more contexts but limit the total length
35
+ context_list = self.vector_db_retriever.search_by_text(user_query, k=3) # Reduced from 6 to 3
36
 
37
+ # Limit total context length to approximately 3000 tokens (12000 characters)
38
  context_prompt = ""
39
  total_length = 0
40
+ max_length = 12000 # Reduced from 24000 to 12000
41
 
42
  for context in context_list:
43
  if total_length + len(context[0]) > max_length: