tommytracx commited on
Commit
46f013b
·
verified ·
1 Parent(s): db8e1eb

Update models/local_llm.py

Browse files
Files changed (1) hide show
  1. models/local_llm.py +6 -130
models/local_llm.py CHANGED
@@ -5,7 +5,6 @@ import requests
5
  import os
6
  import json
7
  import logging
8
- from typing import Dict, List, Optional, Any
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -21,17 +20,14 @@ if not HF_API_KEY:
21
  if not ENDPOINT_URL:
22
  logger.warning("ENDPOINT_URL environment variable not set")
23
 
24
- # Memory store for conversation history
25
- conversation_memory: Dict[str, List[Dict[str, str]]] = {}
26
-
27
- def run_llm(input_text: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
28
  """
29
  Process input text through HF Inference Endpoint.
30
 
31
  Args:
32
- input_text: User input to process
33
  max_tokens: Maximum tokens to generate
34
- temperature: Temperature for sampling (higher = more random)
35
 
36
  Returns:
37
  Generated response text
@@ -44,7 +40,7 @@ def run_llm(input_text: str, max_tokens: int = 512, temperature: float = 0.7) ->
44
  # Format messages in OpenAI format
45
  messages = [
46
  {"role": "system", "content": "You are a helpful AI assistant for a telecom service. Answer questions clearly and concisely."},
47
- {"role": "user", "content": input_text}
48
  ]
49
 
50
  payload = {
@@ -54,7 +50,7 @@ def run_llm(input_text: str, max_tokens: int = 512, temperature: float = 0.7) ->
54
  "temperature": temperature
55
  }
56
 
57
- logger.debug(f"Sending request to endpoint with temperature={temperature}, max_tokens={max_tokens}")
58
 
59
  try:
60
  response = requests.post(ENDPOINT_URL, headers=headers, json=payload)
@@ -62,7 +58,6 @@ def run_llm(input_text: str, max_tokens: int = 512, temperature: float = 0.7) ->
62
 
63
  result = response.json()
64
  response_text = result["choices"][0]["message"]["content"]
65
- logger.debug(f"Generated response of {len(response_text)} characters")
66
  return response_text
67
 
68
  except requests.exceptions.RequestException as e:
@@ -70,123 +65,4 @@ def run_llm(input_text: str, max_tokens: int = 512, temperature: float = 0.7) ->
70
  if hasattr(e, 'response') and e.response is not None:
71
  error_msg += f" - Status code: {e.response.status_code}, Response: {e.response.text}"
72
  logger.error(error_msg)
73
- return f"Error generating response: {str(e)}"
74
-
75
- def run_llm_with_memory(input_text: str, session_id: str = "default", max_tokens: int = 512, temperature: float = 0.7) -> str:
76
- """
77
- Process input with conversation memory.
78
-
79
- Args:
80
- input_text: User input to process
81
- session_id: Unique identifier for conversation
82
- max_tokens: Maximum tokens to generate
83
- temperature: Temperature for sampling
84
-
85
- Returns:
86
- Generated response text
87
- """
88
- # Initialize memory if needed
89
- if session_id not in conversation_memory:
90
- conversation_memory[session_id] = [
91
- {"role": "system", "content": "You are a helpful AI assistant for a telecom service. Answer questions clearly and concisely."}
92
- ]
93
-
94
- # Add current input to memory
95
- conversation_memory[session_id].append({"role": "user", "content": input_text})
96
-
97
- # Prepare the full conversation history
98
- messages = conversation_memory[session_id].copy()
99
-
100
- # Keep only the last 10 messages to avoid context length issues
101
- if len(messages) > 10:
102
- # Always keep the system message
103
- messages = [messages[0]] + messages[-9:]
104
-
105
- headers = {
106
- "Authorization": f"Bearer {HF_API_KEY}",
107
- "Content-Type": "application/json"
108
- }
109
-
110
- payload = {
111
- "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
112
- "messages": messages,
113
- "max_tokens": max_tokens,
114
- "temperature": temperature
115
- }
116
-
117
- logger.debug(f"Sending memory-based request for session {session_id}")
118
-
119
- try:
120
- response = requests.post(ENDPOINT_URL, headers=headers, json=payload)
121
- response.raise_for_status()
122
-
123
- result = response.json()
124
- response_text = result["choices"][0]["message"]["content"]
125
-
126
- # Save response to memory
127
- conversation_memory[session_id].append({"role": "assistant", "content": response_text})
128
-
129
- return response_text
130
-
131
- except requests.exceptions.RequestException as e:
132
- error_msg = f"Error calling endpoint: {str(e)}"
133
- if hasattr(e, 'response') and e.response is not None:
134
- error_msg += f" - Status code: {e.response.status_code}, Response: {e.response.text}"
135
- logger.error(error_msg)
136
- return f"Error generating response: {str(e)}"
137
-
138
- def clear_memory(session_id: str = "default") -> bool:
139
- """
140
- Clear conversation memory for a specific session.
141
-
142
- Args:
143
- session_id: Unique identifier for conversation
144
- """
145
- if session_id in conversation_memory:
146
- conversation_memory[session_id] = [
147
- {"role": "system", "content": "You are a helpful AI assistant for a telecom service. Answer questions clearly and concisely."}
148
- ]
149
- return True
150
- return False
151
-
152
- def get_memory_sessions() -> List[str]:
153
- """
154
- Get list of active memory sessions.
155
-
156
- Returns:
157
- List of session IDs
158
- """
159
- return list(conversation_memory.keys())
160
-
161
- def get_model_info() -> Dict[str, Any]:
162
- """
163
- Get information about the connected model endpoint.
164
-
165
- Returns:
166
- Dictionary with endpoint information
167
- """
168
- return {
169
- "endpoint_url": ENDPOINT_URL,
170
- "memory_sessions": len(conversation_memory),
171
- "model_type": "Meta-Llama-3.1-8B-Instruct (Inference Endpoint)"
172
- }
173
-
174
- def test_endpoint() -> Dict[str, Any]:
175
- """
176
- Test the endpoint connection.
177
-
178
- Returns:
179
- Status information
180
- """
181
- try:
182
- response = run_llm("Hello, this is a test message. Please respond with a short greeting.")
183
- return {
184
- "status": "connected",
185
- "message": "Successfully connected to endpoint",
186
- "sample_response": response[:50] + "..." if len(response) > 50 else response
187
- }
188
- except Exception as e:
189
- return {
190
- "status": "error",
191
- "message": f"Failed to connect to endpoint: {str(e)}"
192
- }
 
5
  import os
6
  import json
7
  import logging
 
8
 
9
  # Configure logging
10
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
20
  if not ENDPOINT_URL:
21
  logger.warning("ENDPOINT_URL environment variable not set")
22
 
23
+ def run_llm(prompt, max_tokens=512, temperature=0.7):
 
 
 
24
  """
25
  Process input text through HF Inference Endpoint.
26
 
27
  Args:
28
+ prompt: Input prompt to process
29
  max_tokens: Maximum tokens to generate
30
+ temperature: Temperature for sampling
31
 
32
  Returns:
33
  Generated response text
 
40
  # Format messages in OpenAI format
41
  messages = [
42
  {"role": "system", "content": "You are a helpful AI assistant for a telecom service. Answer questions clearly and concisely."},
43
+ {"role": "user", "content": prompt}
44
  ]
45
 
46
  payload = {
 
50
  "temperature": temperature
51
  }
52
 
53
+ logger.info(f"Sending request to endpoint: {ENDPOINT_URL[:30]}...")
54
 
55
  try:
56
  response = requests.post(ENDPOINT_URL, headers=headers, json=payload)
 
58
 
59
  result = response.json()
60
  response_text = result["choices"][0]["message"]["content"]
 
61
  return response_text
62
 
63
  except requests.exceptions.RequestException as e:
 
65
  if hasattr(e, 'response') and e.response is not None:
66
  error_msg += f" - Status code: {e.response.status_code}, Response: {e.response.text}"
67
  logger.error(error_msg)
68
+ return f"Error generating response: {str(e)}"