File size: 1,803 Bytes
a96c72f c585309 75eb9ca a96c72f 6d7b830 22986a6 75eb9ca a96c72f 6d7b830 a96c72f 6d7b830 a96c72f 9000ced a96c72f 75eb9ca 56843e1 9000ced 75eb9ca 9000ced 0cf9b7f 9000ced e22ba0b 9000ced 0cf9b7f 9417eab 0cf9b7f 14e602c 0cf9b7f 6d7b830 75eb9ca 56843e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("google/mt5-base", padding_side="left", use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-base")
return tokenizer, model
st.title("Український Чат-бот")
if "history" not in st.session_state:
st.session_state.history = []
if "user_input" not in st.session_state:
st.session_state.user_input = ""
tokenizer, model = load_model()
def send_message():
if st.session_state.user_input:
inputs = tokenizer(st.session_state.history + [st.session_state.user_input], return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model.generate(**inputs, max_length=100)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.session_state.history.extend([st.session_state.user_input, response])
st.session_state.user_input = ""
def update_user_input():
st.session_state.user_input = st.session_state.temp_user_input
st.text_input("Ви:", key="temp_user_input", on_change=update_user_input)
if st.button("Надіслати"):
send_message()
# Обробка натискання Enter
if st.session_state.get("temp_user_input") and st.session_state.get("last_input", "") != st.session_state.get("temp_user_input"):
st.session_state["last_input"] = st.session_state["temp_user_input"]
send_message()
if st.session_state.history:
for i in range(0, len(st.session_state.history), 2):
st.write(f"Ви: {st.session_state.history[i]}")
if i + 1 < len(st.session_state.history):
st.write(f"Бот: {st.session_state.history[i+1]}") |