DrishtiSharma commited on
Commit
c1786ee
Β·
verified Β·
1 Parent(s): b2219cf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
+ import os
5
+ from threading import Thread
6
+ import time
7
+
8
+ # Load Model and Tokenizer
9
+ token = os.environ.get("HF_TOKEN")
10
+ model_name = "large-traversaal/Phi-4-Hindi"
11
+
12
+ @st.cache_resource()
13
+ def load_model():
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ model_name,
16
+ token=token,
17
+ trust_remote_code=True,
18
+ torch_dtype=torch.bfloat16
19
+ )
20
+ tok = AutoTokenizer.from_pretrained(model_name, token=token)
21
+ return model, tok
22
+
23
+ model, tok = load_model()
24
+ terminators = [tok.eos_token_id]
25
+
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ model = model.to(device)
28
+
29
+ # Initialize session state if not set
30
+ if "chat_history" not in st.session_state:
31
+ st.session_state.chat_history = []
32
+
33
+ # Chat function
34
+ def chat(message, temperature, do_sample, max_tokens):
35
+ chat_log = st.session_state.chat_history.copy()
36
+ chat_log.append({"role": "user", "content": message})
37
+ messages = tok.apply_chat_template(chat_log, tokenize=False, add_generation_prompt=True)
38
+
39
+ model_inputs = tok([messages], return_tensors="pt").to(device)
40
+ streamer = TextIteratorStreamer(tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
41
+
42
+ generate_kwargs = {
43
+ "inputs": model_inputs["input_ids"],
44
+ "streamer": streamer,
45
+ "max_new_tokens": max_tokens,
46
+ "do_sample": do_sample,
47
+ "temperature": temperature,
48
+ "eos_token_id": terminators,
49
+ }
50
+
51
+ if temperature == 0:
52
+ generate_kwargs["do_sample"] = False
53
+
54
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
55
+ t.start()
56
+
57
+ partial_text = ""
58
+ for new_text in streamer:
59
+ partial_text += new_text
60
+ yield partial_text
61
+
62
+ st.session_state.chat_history.append({"role": "assistant", "content": partial_text})
63
+
64
+ # Streamlit UI
65
+ st.title("πŸ’¬ Chat With Phi-4-Hindi")
66
+ st.markdown("Chat with [large-traversaal/Phi-4-Hindi](https://huggingface.co/large-traversaal/Phi-4-Hindi)")
67
+
68
+ # Chat input
69
+ temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.3, 0.1)
70
+ do_sample = st.sidebar.checkbox("Use Sampling", value=True)
71
+ max_tokens = st.sidebar.slider("Max Tokens", 128, 4096, 512, 1)
72
+ text_color = st.sidebar.selectbox("Text Color", ["Red", "Black", "Blue", "Green", "Purple"], index=0)
73
+ dark_mode = st.sidebar.checkbox("πŸŒ™ Dark Mode", value=False)
74
+
75
+ def get_html_text(text, color):
76
+ return f'<p style="color: {color.lower()}; font-size: 16px;">{text}</p>'
77
+
78
+ for msg in st.session_state.chat_history:
79
+ if msg["role"] == "user":
80
+ st.markdown(get_html_text("πŸ‘€ " + msg["content"], "black"), unsafe_allow_html=True)
81
+ else:
82
+ st.markdown(get_html_text("πŸ€– " + msg["content"], text_color), unsafe_allow_html=True)
83
+
84
+ user_input = st.text_input("Type your message:", "")
85
+ if st.button("Send"):
86
+ if user_input.strip():
87
+ st.session_state.chat_history.append({"role": "user", "content": user_input})
88
+ with st.spinner("Generating response..."):
89
+ for output in chat(user_input, temperature, do_sample, max_tokens):
90
+ pass
91
+ st.experimental_rerun()
92
+
93
+ if st.button("🧹 Clear Chat"):
94
+ st.session_state.chat_history = []
95
+ st.experimental_rerun()