mynuddin commited on
Commit
f31fa73
·
verified ·
1 Parent(s): 191f973

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -5
app.py CHANGED
@@ -3,6 +3,7 @@ from fastapi import FastAPI
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftModel
5
  import torch
 
6
 
7
  # Set writable cache directory inside the container
8
  os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/app/hf_home'
@@ -13,7 +14,7 @@ os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
13
 
14
  # Define base model and adapter model
15
  base_model_name = "facebook/opt-2.7b"
16
- adapter_name = "mynuddin/chatbot"
17
 
18
  # Load base model and tokenizer
19
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
@@ -21,15 +22,27 @@ base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=t
21
 
22
  # Load PEFT adapter
23
  model = PeftModel.from_pretrained(base_model, adapter_name)
24
- model = model.to("cpu") # Change to "cuda" if running on GPU
25
  model.eval()
26
 
27
  app = FastAPI()
28
 
 
 
 
 
29
  @app.post("/generate")
30
- def generate_text(prompt: str):
31
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
32
  with torch.no_grad():
33
  output = model.generate(**inputs, max_length=128)
 
 
34
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
35
- return {"generated_query": generated_text}
 
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftModel
5
  import torch
6
+ from pydantic import BaseModel
7
 
8
  # Set writable cache directory inside the container
9
  os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/app/hf_home'
 
14
 
15
  # Define base model and adapter model
16
  base_model_name = "facebook/opt-2.7b"
17
+ adapter_name = "mynuddin/chatbot" # Adapter model path or name
18
 
19
  # Load base model and tokenizer
20
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
 
22
 
23
  # Load PEFT adapter
24
  model = PeftModel.from_pretrained(base_model, adapter_name)
25
+ model = model.to("cuda" if torch.cuda.is_available() else "cpu") # Use GPU if available
26
  model.eval()
27
 
28
  app = FastAPI()
29
 
30
+ # Define Pydantic model for input
31
+ class PromptInput(BaseModel):
32
+ prompt: str
33
+
34
  @app.post("/generate")
35
+ def generate_text(input: PromptInput):
36
+ prompt = input.prompt # Access prompt from the request body
37
+
38
+ # Format the prompt with specific style for your fine-tuned model
39
+ inputs = tokenizer(f"### Prompt: {prompt}\n### Completion:", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
40
+
41
+ # Generate the output
42
  with torch.no_grad():
43
  output = model.generate(**inputs, max_length=128)
44
+
45
+ # Decode the output and remove special tokens
46
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
47
+
48
+ return {"generated_query": generated_text}