File size: 5,323 Bytes
c1786ee
 
 
 
 
930b56e
c4b0ccb
c1786ee
850643b
 
 
fe9ba40
c1786ee
850643b
c1786ee
 
850643b
930b56e
 
 
9e85637
930b56e
 
 
 
 
 
 
 
 
 
c4b0ccb
 
 
850643b
 
c1786ee
850643b
c1786ee
930b56e
 
 
c1786ee
850643b
 
 
 
 
 
 
 
c1786ee
 
 
 
 
 
 
850643b
c1786ee
850643b
 
 
 
 
c1786ee
850643b
 
c1786ee
 
850643b
c1786ee
 
 
 
 
 
 
 
 
 
 
850643b
 
c1786ee
 
 
850643b
 
c1786ee
850643b
 
c1786ee
850643b
 
 
 
c1786ee
850643b
c1786ee
 
850643b
c1786ee
 
 
 
 
 
850643b
c1786ee
 
 
850643b
c1786ee
850643b
 
c1786ee
850643b
c1786ee
850643b
c1786ee
 
 
850643b
 
 
 
 
 
 
 
930b56e
850643b
9e25b25
c1786ee
 
850643b
 
 
9e25b25
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
import requests
import time

# Define model path for caching (Avoids reloading every app restart)
MODEL_PATH = "/mnt/data/Phi-4-Hindi"
TOKEN = os.environ.get("HF_TOKEN")
MODEL_NAME = "DrishtiSharma/Phi-4-Hindi-quantized"

# Load Model & Tokenizer Once
@st.cache_resource()
def load_model():
    with st.spinner("Loading model... Please wait ⏳"):
        try:
            if not os.path.exists(MODEL_PATH):
                model = AutoModelForCausalLM.from_pretrained(
                    MODEL_NAME, token=TOKEN, trust_remote_code=True, torch_dtype=torch.bfloat16
                )
                tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=TOKEN)
                model.save_pretrained(MODEL_PATH)
                tokenizer.save_pretrained(MODEL_PATH)
            else:
                model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
                tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
        except requests.exceptions.ConnectionError:
            st.error("⚠️ Connection error! Unable to download the model. Please check your internet connection and try again.")
            return None, None
        except requests.exceptions.ReadTimeout:
            st.error("⚠️ Read Timeout! The request took too long. Please try again later.")
            return None, None
    
    return model, tokenizer

# Load and move model to appropriate device
model, tok = load_model()
if model is None or tok is None:
    st.stop()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    model = model.to(device)
except torch.cuda.OutOfMemoryError:
    st.error("⚠️ CUDA Out of Memory! Running on CPU instead.")
    device = torch.device("cpu")
    model = model.to(device)

terminators = [tok.eos_token_id]

# Initialize session state if not set
if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

# Chat function
def chat(message, temperature, do_sample, max_tokens):
    """Processes chat input and generates a response using the model."""
    
    # Append new message to history
    st.session_state.chat_history.append({"role": "user", "content": message})

    # Convert chat history into model-friendly format
    messages = tok.apply_chat_template(st.session_state.chat_history, tokenize=False, add_generation_prompt=True)
    model_inputs = tok([messages], return_tensors="pt").to(device)
    
    # Initialize streamer for token-wise response
    streamer = TextIteratorStreamer(tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    
    # Define generation parameters
    generate_kwargs = {
        "inputs": model_inputs["input_ids"],
        "streamer": streamer,
        "max_new_tokens": max_tokens,
        "do_sample": do_sample,
        "temperature": temperature,
        "eos_token_id": terminators,
    }
    
    if temperature == 0:
        generate_kwargs["do_sample"] = False

    # Generate response asynchronously
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    
    # Collect response as it streams
    response_text = ""
    for new_text in streamer:
        response_text += new_text
        yield response_text

    # Save the assistant's response to session history
    st.session_state.chat_history.append({"role": "assistant", "content": response_text})

# UI Setup
st.title("πŸ’¬ Chat With Phi-4-Hindi")
st.success("βœ… Model is READY to chat!")
st.markdown("Chat with [large-traversaal/Phi-4-Hindi](https://huggingface.co./large-traversaal/Phi-4-Hindi)")

# Sidebar Chat Settings
temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.3, 0.1)
do_sample = st.sidebar.checkbox("Use Sampling", value=True)
max_tokens = st.sidebar.slider("Max Tokens", 128, 4096, 512, 1)
text_color = st.sidebar.selectbox("Text Color", ["Red", "Black", "Blue", "Green", "Purple"], index=0)
dark_mode = st.sidebar.checkbox("πŸŒ™ Dark Mode", value=False)

# Function to format chat messages
def get_html_text(text, color):
    return f'<p style="color: {color.lower()}; font-size: 16px;">{text}</p>'

# Display chat history
for msg in st.session_state.chat_history:
    role = "πŸ‘€" if msg["role"] == "user" else "πŸ€–"
    st.markdown(get_html_text(f"**{role}:** {msg['content']}", text_color if role == "πŸ€–" else "black"), unsafe_allow_html=True)

# User Input Handling
user_input = st.text_input("Type your message:", "")

if st.button("Send"):
    if user_input.strip():
        st.session_state.chat_history.append({"role": "user", "content": user_input})
        
        # Display chatbot response
        with st.spinner("Generating response... πŸ€–πŸ’­"):
            response_generator = chat(user_input, temperature, do_sample, max_tokens)
            final_response = ""
            for output in response_generator:
                final_response = output  # Store latest output
        
        st.success("βœ… Response generated!")
        # Add generated response to session state
        st.rerun()

if st.button("🧹 Clear Chat"):
    with st.spinner("Clearing chat history..."):
        st.session_state.chat_history = []
    st.success("βœ… Chat history cleared!")
    st.rerun()