File size: 1,941 Bytes
38171fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
from cog import BasePredictor, Input
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import argparse
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
# self.model = torch.load("./weights.pth")
model_name = "defog/sqlcoder-34b-alpha"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
use_cache=True,
offload_folder="./.cache",
)
def predict(
self,
prompt: str = Input(description="Prompt to generate from"),
) -> str:
"""Run a single prediction on the model"""
# processed_input = preprocess(image)
# output = self.model(processed_image, scale)
# return postprocess(output)
# make sure the model stops generating at triple ticks
# eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0]
eos_token_id = self.tokenizer.eos_token_id
pipe = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
max_length=300,
do_sample=False,
num_beams=5, # do beam search with 5 beams for high quality results
)
generated_query = (
pipe(
prompt,
num_return_sequences=1,
eos_token_id=eos_token_id,
pad_token_id=eos_token_id,
)[0]["generated_text"]
.split("```sql")[-1]
.split("```")[0]
.split(";")[0]
.strip()
+ ";"
)
return generated_query
|