Update app.py
Browse files
app.py
CHANGED
@@ -538,6 +538,7 @@ def format_to_message_dict(history):
|
|
538 |
messages = []
|
539 |
for item in history:
|
540 |
if isinstance(item, dict) and "role" in item and "content" in item:
|
|
|
541 |
messages.append(item)
|
542 |
elif isinstance(item, list) and len(item) == 2:
|
543 |
# Convert from old format [user_msg, ai_msg]
|
@@ -1262,7 +1263,6 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1262 |
model_id, _ = get_model_info(provider, model_choice)
|
1263 |
if not model_id:
|
1264 |
error_message = f"Error: Model '{model_choice}' not found in OpenRouter"
|
1265 |
-
# Use proper message format
|
1266 |
return history + [
|
1267 |
{"role": "user", "content": message},
|
1268 |
{"role": "assistant", "content": error_message}
|
@@ -1312,12 +1312,11 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1312 |
|
1313 |
# Handle streaming response
|
1314 |
if stream_output and response.status_code == 200:
|
1315 |
-
# Add message to history
|
1316 |
-
updated_history = history + [{"role": "user", "content": message}]
|
1317 |
-
|
1318 |
# Set up generator for streaming updates
|
1319 |
def streaming_generator():
|
|
|
1320 |
assistant_response = ""
|
|
|
1321 |
for line in response.iter_lines():
|
1322 |
if not line:
|
1323 |
continue
|
@@ -1337,7 +1336,7 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1337 |
if "content" in delta and delta["content"]:
|
1338 |
# Update the current response
|
1339 |
assistant_response += delta["content"]
|
1340 |
-
#
|
1341 |
yield updated_history + [{"role": "assistant", "content": assistant_response}]
|
1342 |
except json.JSONDecodeError:
|
1343 |
logger.error(f"Failed to parse JSON from chunk: {data}")
|
@@ -1374,6 +1373,7 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1374 |
]
|
1375 |
|
1376 |
elif provider == "OpenAI":
|
|
|
1377 |
# Get model ID from registry
|
1378 |
model_id, _ = get_model_info(provider, model_choice)
|
1379 |
if not model_id:
|
@@ -1402,12 +1402,11 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1402 |
|
1403 |
# Handle streaming response
|
1404 |
if stream_output:
|
1405 |
-
# Add message to history
|
1406 |
-
updated_history = history + [{"role": "user", "content": message}]
|
1407 |
-
|
1408 |
# Set up generator for streaming updates
|
1409 |
def streaming_generator():
|
|
|
1410 |
assistant_response = ""
|
|
|
1411 |
for chunk in response:
|
1412 |
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
1413 |
content = chunk.choices[0].delta.content
|
@@ -1432,7 +1431,6 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1432 |
]
|
1433 |
|
1434 |
elif provider == "HuggingFace":
|
1435 |
-
# Get model ID from registry
|
1436 |
model_id, _ = get_model_info(provider, model_choice)
|
1437 |
if not model_id:
|
1438 |
error_message = f"Error: Model '{model_choice}' not found in HuggingFace"
|
@@ -2392,23 +2390,23 @@ def create_app():
|
|
2392 |
def get_current_model(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, together_model, ovh_model, cerebras_model, googleai_model):
|
2393 |
"""Get the currently selected model based on provider"""
|
2394 |
if provider == "OpenRouter":
|
2395 |
-
return openrouter_model
|
2396 |
elif provider == "OpenAI":
|
2397 |
-
return openai_model
|
2398 |
elif provider == "HuggingFace":
|
2399 |
-
return hf_model
|
2400 |
elif provider == "Groq":
|
2401 |
-
return groq_model
|
2402 |
elif provider == "Cohere":
|
2403 |
-
return cohere_model
|
2404 |
elif provider == "Together":
|
2405 |
-
return together_model
|
2406 |
elif provider == "OVH":
|
2407 |
-
return ovh_model
|
2408 |
elif provider == "Cerebras":
|
2409 |
-
return cerebras_model
|
2410 |
elif provider == "GoogleAI":
|
2411 |
-
return googleai_model
|
2412 |
return None
|
2413 |
|
2414 |
|
|
|
538 |
messages = []
|
539 |
for item in history:
|
540 |
if isinstance(item, dict) and "role" in item and "content" in item:
|
541 |
+
# Already in the correct format
|
542 |
messages.append(item)
|
543 |
elif isinstance(item, list) and len(item) == 2:
|
544 |
# Convert from old format [user_msg, ai_msg]
|
|
|
1263 |
model_id, _ = get_model_info(provider, model_choice)
|
1264 |
if not model_id:
|
1265 |
error_message = f"Error: Model '{model_choice}' not found in OpenRouter"
|
|
|
1266 |
return history + [
|
1267 |
{"role": "user", "content": message},
|
1268 |
{"role": "assistant", "content": error_message}
|
|
|
1312 |
|
1313 |
# Handle streaming response
|
1314 |
if stream_output and response.status_code == 200:
|
|
|
|
|
|
|
1315 |
# Set up generator for streaming updates
|
1316 |
def streaming_generator():
|
1317 |
+
updated_history = history + [{"role": "user", "content": message}]
|
1318 |
assistant_response = ""
|
1319 |
+
|
1320 |
for line in response.iter_lines():
|
1321 |
if not line:
|
1322 |
continue
|
|
|
1336 |
if "content" in delta and delta["content"]:
|
1337 |
# Update the current response
|
1338 |
assistant_response += delta["content"]
|
1339 |
+
# Return updated history with current response
|
1340 |
yield updated_history + [{"role": "assistant", "content": assistant_response}]
|
1341 |
except json.JSONDecodeError:
|
1342 |
logger.error(f"Failed to parse JSON from chunk: {data}")
|
|
|
1373 |
]
|
1374 |
|
1375 |
elif provider == "OpenAI":
|
1376 |
+
# Process OpenAI similarly as above...
|
1377 |
# Get model ID from registry
|
1378 |
model_id, _ = get_model_info(provider, model_choice)
|
1379 |
if not model_id:
|
|
|
1402 |
|
1403 |
# Handle streaming response
|
1404 |
if stream_output:
|
|
|
|
|
|
|
1405 |
# Set up generator for streaming updates
|
1406 |
def streaming_generator():
|
1407 |
+
updated_history = history + [{"role": "user", "content": message}]
|
1408 |
assistant_response = ""
|
1409 |
+
|
1410 |
for chunk in response:
|
1411 |
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
1412 |
content = chunk.choices[0].delta.content
|
|
|
1431 |
]
|
1432 |
|
1433 |
elif provider == "HuggingFace":
|
|
|
1434 |
model_id, _ = get_model_info(provider, model_choice)
|
1435 |
if not model_id:
|
1436 |
error_message = f"Error: Model '{model_choice}' not found in HuggingFace"
|
|
|
2390 |
def get_current_model(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, together_model, ovh_model, cerebras_model, googleai_model):
|
2391 |
"""Get the currently selected model based on provider"""
|
2392 |
if provider == "OpenRouter":
|
2393 |
+
return openrouter_model if openrouter_model else OPENROUTER_ALL_MODELS[0][0] if OPENROUTER_ALL_MODELS else None
|
2394 |
elif provider == "OpenAI":
|
2395 |
+
return openai_model if openai_model else "gpt-3.5-turbo" if "gpt-3.5-turbo" in OPENAI_MODELS else None
|
2396 |
elif provider == "HuggingFace":
|
2397 |
+
return hf_model if hf_model else "mistralai/Mistral-7B-Instruct-v0.3" if "mistralai/Mistral-7B-Instruct-v0.3" in HUGGINGFACE_MODELS else None
|
2398 |
elif provider == "Groq":
|
2399 |
+
return groq_model if groq_model else "llama-3.1-8b-instant" if "llama-3.1-8b-instant" in GROQ_MODELS else None
|
2400 |
elif provider == "Cohere":
|
2401 |
+
return cohere_model if cohere_model else "command-r-plus" if "command-r-plus" in COHERE_MODELS else None
|
2402 |
elif provider == "Together":
|
2403 |
+
return together_model if together_model else "meta-llama/Llama-3.1-8B-Instruct" if "meta-llama/Llama-3.1-8B-Instruct" in TOGETHER_MODELS else None
|
2404 |
elif provider == "OVH":
|
2405 |
+
return ovh_model if ovh_model else "ovh/llama-3.1-8b-instruct" if "ovh/llama-3.1-8b-instruct" in OVH_MODELS else None
|
2406 |
elif provider == "Cerebras":
|
2407 |
+
return cerebras_model if cerebras_model else "cerebras/llama-3.1-8b" if "cerebras/llama-3.1-8b" in CEREBRAS_MODELS else None
|
2408 |
elif provider == "GoogleAI":
|
2409 |
+
return googleai_model if googleai_model else "gemini-1.5-pro" if "gemini-1.5-pro" in GOOGLEAI_MODELS else None
|
2410 |
return None
|
2411 |
|
2412 |
|