mynuddin commited on
Commit
a0ad636
·
verified ·
1 Parent(s): a3d4021

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -9
app.py CHANGED
@@ -1,25 +1,35 @@
1
  import os
 
 
 
 
2
 
3
  # Set writable cache directory inside the container
4
  os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/app/hf_home'
5
  os.environ['TRANSFORMERS_CACHE'] = '/app/hf_home'
6
 
7
- from fastapi import FastAPI
8
- from transformers import AutoModelForCausalLM, AutoTokenizer
9
-
10
  # Ensure the directory exists
11
  os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
12
 
13
- # Load model
14
- model_name = "mynuddin/chatbot"
15
- tokenizer = AutoTokenizer.from_pretrained(model_name)
16
- model = AutoModelForCausalLM.from_pretrained(model_name).to("cpu")
 
 
 
 
 
 
 
 
17
 
18
  app = FastAPI()
19
 
20
  @app.post("/generate")
21
  def generate_text(prompt: str):
22
  inputs = tokenizer(prompt, return_tensors="pt")
23
- output = model.generate(**inputs, max_length=128)
 
24
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
25
- return {"generated_query": generated_text}
 
1
  import os
2
+ from fastapi import FastAPI
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from peft import PeftModel
5
+ import torch
6
 
7
  # Set writable cache directory inside the container
8
  os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/app/hf_home'
9
  os.environ['TRANSFORMERS_CACHE'] = '/app/hf_home'
10
 
 
 
 
11
  # Ensure the directory exists
12
  os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
13
 
14
+ # Define base model and adapter model
15
+ base_model_name = "facebook/opt-2.7b"
16
+ adapter_name = "mynuddin/chatbot"
17
+
18
+ # Load base model and tokenizer
19
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
20
+ base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16)
21
+
22
+ # Load PEFT adapter
23
+ model = PeftModel.from_pretrained(base_model, adapter_name)
24
+ model = model.to("cpu") # Change to "cuda" if running on GPU
25
+ model.eval()
26
 
27
  app = FastAPI()
28
 
29
  @app.post("/generate")
30
  def generate_text(prompt: str):
31
  inputs = tokenizer(prompt, return_tensors="pt")
32
+ with torch.no_grad():
33
+ output = model.generate(**inputs, max_length=128)
34
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
35
+ return {"generated_query": generated_text}