FluentQ / local_llm.py
tommytracx's picture
Update local_llm.py
da518bc verified
"""
LLM implementation using Hugging Face Inference Endpoint with OpenAI compatibility.
"""
import requests
import os
import json
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Endpoint configuration
HF_API_KEY = os.environ.get("HF_API_KEY", "")
ENDPOINT_URL = os.environ.get("ENDPOINT_URL", "https://cg01ow7izccjx1b2.us-east-1.aws.endpoints.huggingface.cloud/v1/chat/completions")
# Verify configuration
if not HF_API_KEY:
logger.warning("HF_API_KEY environment variable not set")
if not ENDPOINT_URL:
logger.warning("ENDPOINT_URL environment variable not set")
# Memory store for conversation history
conversation_memory = {}
def run_llm(input_text, max_tokens=512, temperature=0.7):
"""
Process input text through HF Inference Endpoint.
Args:
input_text: User input to process
max_tokens: Maximum tokens to generate
temperature: Temperature for sampling (higher = more random)
Returns:
Generated response text
"""
headers = {
"Authorization": f"Bearer {HF_API_KEY}",
"Content-Type": "application/json"
}
# Format messages in OpenAI format
messages = [
{"role": "system", "content": "You are a helpful AI assistant for a telecom service. Answer questions clearly and concisely."},
{"role": "user", "content": input_text}
]
payload = {
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature
}
logger.info(f"Sending request to endpoint: {ENDPOINT_URL[:30]}...")
try:
response = requests.post(ENDPOINT_URL, headers=headers, json=payload)
response.raise_for_status()
result = response.json()
response_text = result["choices"][0]["message"]["content"]
return response_text
except requests.exceptions.RequestException as e:
error_msg = f"Error calling endpoint: {str(e)}"
if hasattr(e, 'response') and e.response is not None:
error_msg += f" - Status code: {e.response.status_code}, Response: {e.response.text}"
logger.error(error_msg)
return f"Error generating response: {str(e)}"
def run_llm_with_memory(input_text, session_id="default", max_tokens=512, temperature=0.7):
"""
Process input with conversation memory.
Args:
input_text: User input to process
session_id: Unique identifier for conversation
max_tokens: Maximum tokens to generate
temperature: Temperature for sampling
Returns:
Generated response text
"""
# Initialize memory if needed
if session_id not in conversation_memory:
conversation_memory[session_id] = [
{"role": "system", "content": "You are a helpful AI assistant for a telecom service. Answer questions clearly and concisely."}
]
# Add current input to memory
conversation_memory[session_id].append({"role": "user", "content": input_text})
# Prepare the full conversation history
messages = conversation_memory[session_id].copy()
# Keep only the last 10 messages to avoid context length issues
if len(messages) > 10:
# Always keep the system message
messages = [messages[0]] + messages[-9:]
headers = {
"Authorization": f"Bearer {HF_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature
}
logger.info(f"Sending memory-based request for session {session_id}")
try:
response = requests.post(ENDPOINT_URL, headers=headers, json=payload)
response.raise_for_status()
result = response.json()
response_text = result["choices"][0]["message"]["content"]
# Save response to memory
conversation_memory[session_id].append({"role": "assistant", "content": response_text})
return response_text
except requests.exceptions.RequestException as e:
error_msg = f"Error calling endpoint: {str(e)}"
if hasattr(e, 'response') and e.response is not None:
error_msg += f" - Status code: {e.response.status_code}, Response: {e.response.text}"
logger.error(error_msg)
return f"Error generating response: {str(e)}"
def clear_memory(session_id="default"):
"""
Clear conversation memory for a specific session.
Args:
session_id: Unique identifier for conversation
"""
if session_id in conversation_memory:
conversation_memory[session_id] = [
{"role": "system", "content": "You are a helpful AI assistant for a telecom service. Answer questions clearly and concisely."}
]
return True
return False
def get_memory_sessions():
"""
Get list of active memory sessions.
Returns:
List of session IDs
"""
return list(conversation_memory.keys())
def get_model_info():
"""
Get information about the connected model endpoint.
Returns:
Dictionary with endpoint information
"""
return {
"endpoint_url": ENDPOINT_URL,
"memory_sessions": len(conversation_memory),
"model_type": "Meta-Llama-3.1-8B-Instruct (Inference Endpoint)"
}
def test_endpoint():
"""
Test the endpoint connection.
Returns:
Status information
"""
try:
response = run_llm("Hello, this is a test message. Please respond with a short greeting.")
return {
"status": "connected",
"message": "Successfully connected to endpoint",
"sample_response": response[:50] + "..." if len(response) > 50 else response
}
except Exception as e:
return {
"status": "error",
"message": f"Failed to connect to endpoint: {str(e)}"
}