Akjava's picture
update
ae274fc
raw
history blame
2.08 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load the model and tokenizer
model_name = "Qwen/Qwen1.5-0.5B-Chat"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
def generate_text(text):
# Tokenize the input text, including attention mas
#input_ids = tokenizer(text, return_tensors="pt", padding=True)
messages = []
use_system_prompt = True
DEFAULT_SYSTEM_PROMPT = "you are helpfull assistant."
if use_system_prompt:
messages = [
{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}
]
user_messages = [
{"role": "user", "content": text}
]
messages += user_messages
prompt = tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=True,
tokenize=False
)
input_datas = tokenizer(
prompt,
add_special_tokens=True,
return_tensors="pt"
)
# Generate text, passing the attention mask
generated_ids = model.generate(input_ids=input_datas.input_ids, attention_mask=input_datas.attention_mask,max_length=10000)
#generated_ids = model.generate(input_ids=input_ids, max_length=100)
# Decode the generated tokens
generated_text = tokenizer.decode(generated_ids[0][input_datas.input_ids.size(1) :], skip_special_tokens=True)
# Print the generated text
#print(generated_text)
return generated_text
from flask import Flask, request, jsonify
app = Flask(__name__)
#app.logger.disabled = True
#log = logging.getLogger('werkzeug')
#log.disabled = True
@app.route('/')
def predict():
param_value = request.args.get('param', '')
# ここにモデルの推論ロジックを追加
#output = pipe(messages, **generation_args)
#text = (output[0]['generated_text'])
#print("hello")
#result = {"prediction": "dummy_result"}
text = generate_text(param_value)
return f"{text}"
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)