tommytracx commited on
Commit
9f5d5d3
·
verified ·
1 Parent(s): 3a0d58c

Update models/local_llm.py

Browse files
Files changed (1) hide show
  1. models/local_llm.py +191 -6
models/local_llm.py CHANGED
@@ -1,7 +1,192 @@
1
- import subprocess
 
 
 
 
 
 
 
2
 
3
- def run_llm(prompt: str) -> str:
4
- result = subprocess.run([
5
- "./main", "-m", "models/ggml-model.bin", "-p", prompt, "-n", "128"
6
- ], capture_output=True, text=True)
7
- return result.stdout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM implementation using Hugging Face Inference Endpoint with OpenAI compatibility.
3
+ """
4
+ import requests
5
+ import os
6
+ import json
7
+ import logging
8
+ from typing import Dict, List, Optional, Any
9
 
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Endpoint configuration
15
+ HF_API_KEY = os.environ.get("HF_API_KEY", "")
16
+ ENDPOINT_URL = os.environ.get("ENDPOINT_URL", "https://cg01ow7izccjx1b2.us-east-1.aws.endpoints.huggingface.cloud/v1/chat/completions")
17
+
18
+ # Verify configuration
19
+ if not HF_API_KEY:
20
+ logger.warning("HF_API_KEY environment variable not set")
21
+ if not ENDPOINT_URL:
22
+ logger.warning("ENDPOINT_URL environment variable not set")
23
+
24
+ # Memory store for conversation history
25
+ conversation_memory: Dict[str, List[Dict[str, str]]] = {}
26
+
27
+ def run_llm(input_text: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
28
+ """
29
+ Process input text through HF Inference Endpoint.
30
+
31
+ Args:
32
+ input_text: User input to process
33
+ max_tokens: Maximum tokens to generate
34
+ temperature: Temperature for sampling (higher = more random)
35
+
36
+ Returns:
37
+ Generated response text
38
+ """
39
+ headers = {
40
+ "Authorization": f"Bearer {HF_API_KEY}",
41
+ "Content-Type": "application/json"
42
+ }
43
+
44
+ # Format messages in OpenAI format
45
+ messages = [
46
+ {"role": "system", "content": "You are a helpful AI assistant for a telecom service. Answer questions clearly and concisely."},
47
+ {"role": "user", "content": input_text}
48
+ ]
49
+
50
+ payload = {
51
+ "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
52
+ "messages": messages,
53
+ "max_tokens": max_tokens,
54
+ "temperature": temperature
55
+ }
56
+
57
+ logger.debug(f"Sending request to endpoint with temperature={temperature}, max_tokens={max_tokens}")
58
+
59
+ try:
60
+ response = requests.post(ENDPOINT_URL, headers=headers, json=payload)
61
+ response.raise_for_status()
62
+
63
+ result = response.json()
64
+ response_text = result["choices"][0]["message"]["content"]
65
+ logger.debug(f"Generated response of {len(response_text)} characters")
66
+ return response_text
67
+
68
+ except requests.exceptions.RequestException as e:
69
+ error_msg = f"Error calling endpoint: {str(e)}"
70
+ if hasattr(e, 'response') and e.response is not None:
71
+ error_msg += f" - Status code: {e.response.status_code}, Response: {e.response.text}"
72
+ logger.error(error_msg)
73
+ return f"Error generating response: {str(e)}"
74
+
75
+ def run_llm_with_memory(input_text: str, session_id: str = "default", max_tokens: int = 512, temperature: float = 0.7) -> str:
76
+ """
77
+ Process input with conversation memory.
78
+
79
+ Args:
80
+ input_text: User input to process
81
+ session_id: Unique identifier for conversation
82
+ max_tokens: Maximum tokens to generate
83
+ temperature: Temperature for sampling
84
+
85
+ Returns:
86
+ Generated response text
87
+ """
88
+ # Initialize memory if needed
89
+ if session_id not in conversation_memory:
90
+ conversation_memory[session_id] = [
91
+ {"role": "system", "content": "You are a helpful AI assistant for a telecom service. Answer questions clearly and concisely."}
92
+ ]
93
+
94
+ # Add current input to memory
95
+ conversation_memory[session_id].append({"role": "user", "content": input_text})
96
+
97
+ # Prepare the full conversation history
98
+ messages = conversation_memory[session_id].copy()
99
+
100
+ # Keep only the last 10 messages to avoid context length issues
101
+ if len(messages) > 10:
102
+ # Always keep the system message
103
+ messages = [messages[0]] + messages[-9:]
104
+
105
+ headers = {
106
+ "Authorization": f"Bearer {HF_API_KEY}",
107
+ "Content-Type": "application/json"
108
+ }
109
+
110
+ payload = {
111
+ "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
112
+ "messages": messages,
113
+ "max_tokens": max_tokens,
114
+ "temperature": temperature
115
+ }
116
+
117
+ logger.debug(f"Sending memory-based request for session {session_id}")
118
+
119
+ try:
120
+ response = requests.post(ENDPOINT_URL, headers=headers, json=payload)
121
+ response.raise_for_status()
122
+
123
+ result = response.json()
124
+ response_text = result["choices"][0]["message"]["content"]
125
+
126
+ # Save response to memory
127
+ conversation_memory[session_id].append({"role": "assistant", "content": response_text})
128
+
129
+ return response_text
130
+
131
+ except requests.exceptions.RequestException as e:
132
+ error_msg = f"Error calling endpoint: {str(e)}"
133
+ if hasattr(e, 'response') and e.response is not None:
134
+ error_msg += f" - Status code: {e.response.status_code}, Response: {e.response.text}"
135
+ logger.error(error_msg)
136
+ return f"Error generating response: {str(e)}"
137
+
138
+ def clear_memory(session_id: str = "default") -> bool:
139
+ """
140
+ Clear conversation memory for a specific session.
141
+
142
+ Args:
143
+ session_id: Unique identifier for conversation
144
+ """
145
+ if session_id in conversation_memory:
146
+ conversation_memory[session_id] = [
147
+ {"role": "system", "content": "You are a helpful AI assistant for a telecom service. Answer questions clearly and concisely."}
148
+ ]
149
+ return True
150
+ return False
151
+
152
+ def get_memory_sessions() -> List[str]:
153
+ """
154
+ Get list of active memory sessions.
155
+
156
+ Returns:
157
+ List of session IDs
158
+ """
159
+ return list(conversation_memory.keys())
160
+
161
+ def get_model_info() -> Dict[str, Any]:
162
+ """
163
+ Get information about the connected model endpoint.
164
+
165
+ Returns:
166
+ Dictionary with endpoint information
167
+ """
168
+ return {
169
+ "endpoint_url": ENDPOINT_URL,
170
+ "memory_sessions": len(conversation_memory),
171
+ "model_type": "Meta-Llama-3.1-8B-Instruct (Inference Endpoint)"
172
+ }
173
+
174
+ def test_endpoint() -> Dict[str, Any]:
175
+ """
176
+ Test the endpoint connection.
177
+
178
+ Returns:
179
+ Status information
180
+ """
181
+ try:
182
+ response = run_llm("Hello, this is a test message. Please respond with a short greeting.")
183
+ return {
184
+ "status": "connected",
185
+ "message": "Successfully connected to endpoint",
186
+ "sample_response": response[:50] + "..." if len(response) > 50 else response
187
+ }
188
+ except Exception as e:
189
+ return {
190
+ "status": "error",
191
+ "message": f"Failed to connect to endpoint: {str(e)}"
192
+ }