redfernstech commited on
Commit
f572b7e
·
verified ·
1 Parent(s): 1eede41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -6
app.py CHANGED
@@ -229,11 +229,13 @@ from fastapi import FastAPI, Request, HTTPException
229
  from fastapi.responses import HTMLResponse, JSONResponse
230
  from fastapi.staticfiles import StaticFiles
231
  from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, Settings
 
 
232
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
233
- from llama_index.llms.huggingface import HuggingFaceInferenceAPI
234
  from pydantic import BaseModel
235
  from fastapi.middleware.cors import CORSMiddleware
236
  from fastapi.templating import Jinja2Templates
 
237
  import datetime
238
  from simple_salesforce import Salesforce, SalesforceLogin
239
 
@@ -241,6 +243,41 @@ from simple_salesforce import Salesforce, SalesforceLogin
241
  class MessageRequest(BaseModel):
242
  message: str
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  # Validate environment variables
245
  required_env_vars = ["HF_TOKEN", "username", "password", "security_token", "domain"]
246
  for var in required_env_vars:
@@ -288,11 +325,9 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
288
  templates = Jinja2Templates(directory="static")
289
 
290
  # LlamaIndex configuration
291
- Settings.llm = HuggingFaceInferenceAPI(
292
  model_name="meta-llama/Meta-Llama-3-8B-Instruct",
293
- token=os.getenv("HF_TOKEN"),
294
- max_new_tokens=512,
295
- temperature=0.1
296
  )
297
 
298
  Settings.embed_model = HuggingFaceEmbedding(
@@ -404,7 +439,6 @@ async def receive_form_data(request: Request):
404
  'Phone': form_data.get('phone', '').strip(),
405
  'Email': form_data.get('email', ''),
406
  }
407
-
408
  result = sf.Lead.create(data)
409
  return JSONResponse({"id": result['id']})
410
  except Exception as e:
 
229
  from fastapi.responses import HTMLResponse, JSONResponse
230
  from fastapi.staticfiles import StaticFiles
231
  from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, Settings
232
+ from llama_index.core.base.llms.types import ChatMessage, MessageRole
233
+ from llama_index.core.llms import LLM
234
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
 
235
  from pydantic import BaseModel
236
  from fastapi.middleware.cors import CORSMiddleware
237
  from fastapi.templating import Jinja2Templates
238
+ from huggingface_hub import InferenceClient
239
  import datetime
240
  from simple_salesforce import Salesforce, SalesforceLogin
241
 
 
243
  class MessageRequest(BaseModel):
244
  message: str
245
 
246
+ # Custom LLM class for Hugging Face Inference API
247
+ class HuggingFaceInferenceLLM(LLM):
248
+ def __init__(self, model_name: str, token: str):
249
+ super().__init__()
250
+ self.client = InferenceClient(model=model_name, token=token)
251
+ self.model_name = model_name
252
+
253
+ def chat(self, messages: list[ChatMessage], **kwargs) -> str:
254
+ prompt = ""
255
+ for msg in messages:
256
+ role = "user" if msg.role == MessageRole.USER else "assistant"
257
+ prompt += f"{role}: {msg.content}\n"
258
+ try:
259
+ response = self.client.text_generation(
260
+ prompt,
261
+ max_new_tokens=512,
262
+ temperature=0.1,
263
+ do_sample=True,
264
+ stop_sequences=["\n"]
265
+ )
266
+ return response
267
+ except Exception as e:
268
+ return f"Error in API call: {str(e)}"
269
+
270
+ async def achat(self, messages: list[ChatMessage], **kwargs) -> str:
271
+ return self.chat(messages, **kwargs)
272
+
273
+ @property
274
+ def metadata(self):
275
+ return {
276
+ "model_name": self.model_name,
277
+ "context_window": 3000,
278
+ "max_new_tokens": 512
279
+ }
280
+
281
  # Validate environment variables
282
  required_env_vars = ["HF_TOKEN", "username", "password", "security_token", "domain"]
283
  for var in required_env_vars:
 
325
  templates = Jinja2Templates(directory="static")
326
 
327
  # LlamaIndex configuration
328
+ Settings.llm = HuggingFaceInferenceLLM(
329
  model_name="meta-llama/Meta-Llama-3-8B-Instruct",
330
+ token=os.getenv("HF_TOKEN")
 
 
331
  )
332
 
333
  Settings.embed_model = HuggingFaceEmbedding(
 
439
  'Phone': form_data.get('phone', '').strip(),
440
  'Email': form_data.get('email', ''),
441
  }
 
442
  result = sf.Lead.create(data)
443
  return JSONResponse({"id": result['id']})
444
  except Exception as e: