Avinash109 commited on
Commit
088f906
·
verified ·
1 Parent(s): 4897d60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -128
app.py CHANGED
@@ -1,165 +1,166 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
  import torch
 
4
  import datetime
5
 
6
- # Set Streamlit page configuration
7
  st.set_page_config(
8
  page_title="Qwen2.5-Coder Chat",
9
  page_icon="💬",
10
- layout="wide",
11
  )
12
 
13
- # Title of the app
14
- st.title("💬 Qwen2.5-Coder Chat Interface")
 
15
 
16
- # Initialize session state for messages (store conversation history)
17
- st.session_state.setdefault('messages', [])
18
-
19
- # Load the model and tokenizer
20
  @st.cache_resource
21
- def load_model():
22
- model_name = "Qwen/Qwen2.5-Coder-32B-Instruct" # Replace with the correct model path
23
 
24
- # Define BitsAndBytesConfig for 8-bit quantization
25
- quantization_config = BitsAndBytesConfig(
26
- load_in_8bit=True, # Enable 8-bit loading
27
- llm_int8_enable_fp32_cpu_offload=True # Optional: Enables offloading to CPU
 
 
28
  )
29
 
30
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
31
  model = AutoModelForCausalLM.from_pretrained(
32
  model_name,
33
- quantization_config=quantization_config,
34
  torch_dtype=torch.float16,
35
- device_map="auto"
 
36
  )
 
37
  return tokenizer, model
38
 
39
- # Load tokenizer and model with error handling
40
- try:
41
- with st.spinner("Loading model... This may take a while..."):
42
- tokenizer, model = load_model()
43
- except Exception as e:
44
- st.error(f"Error loading model: {e}")
45
- st.stop()
46
-
47
- # Function to generate model response
48
- def generate_response(messages, tokenizer, model, max_tokens=150, temperature=0.7, top_p=0.9):
49
- """
50
- Generates a response from the model based on the conversation history.
51
-
52
- Args:
53
- messages (list): List of message dictionaries containing 'role' and 'content'.
54
- tokenizer: The tokenizer instance.
55
- model: The language model instance.
56
- max_tokens (int): Maximum number of tokens for the response.
57
- temperature (float): Sampling temperature.
58
- top_p (float): Nucleus sampling probability.
59
-
60
- Returns:
61
- str: The generated response text.
62
- """
63
- # Concatenate all previous messages
64
- conversation = ""
65
- for message in messages:
66
- role = "You" if message['role'] == 'user' else "Qwen2.5-Coder"
67
- conversation += f"**{role}:** {message['content']}\n"
68
-
69
- # Append the latest user input
70
- conversation += f"**You:** {messages[-1]['content']}\n**Qwen2.5-Coder:**"
71
-
72
- # Tokenize the conversation
73
- inputs = tokenizer.encode(conversation, return_tensors="pt").to(model.device)
74
-
75
- # Generate a response
76
- with torch.no_grad():
77
- outputs = model.generate(
78
- inputs,
79
- max_length=inputs.shape[1] + max_tokens,
80
- temperature=temperature,
81
- top_p=top_p,
82
- do_sample=True,
83
- num_return_sequences=1,
84
- pad_token_id=tokenizer.eos_token_id
85
- )
86
-
87
- # Decode the response
88
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
89
-
90
- # Extract the generated response after the conversation
91
- generated_response = response.split("Qwen2.5-Coder:")[-1].strip()
92
- return generated_response
93
 
94
- # Layout: Two columns for the main chat and sidebar
95
- chat_col, sidebar_col = st.columns([4, 1])
96
-
97
- with chat_col:
98
- st.markdown("### Chat")
99
- chat_container = st.container()
100
- with chat_container:
101
- for message in st.session_state['messages']:
102
- time = message.get('timestamp', '')
103
- if message['role'] == 'user':
104
- st.markdown(f"**You:** {message['content']} _({time})_")
105
- else:
106
- st.markdown(f"**Qwen2.5-Coder:** {message['content']} _({time})_")
107
-
108
- # Input area for user message
109
- with st.form(key='chat_form', clear_on_submit=True):
110
- user_input = st.text_area("You:", height=100)
111
- submit_button = st.form_submit_button(label='Send')
112
-
113
- if submit_button and user_input.strip():
114
- timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
115
- # Append the user's message to the chat history
116
- st.session_state['messages'].append({'role': 'user', 'content': user_input, 'timestamp': timestamp})
117
-
118
- # Generate and append the model's response
119
- try:
120
- with st.spinner("Qwen2.5-Coder is typing..."):
121
- response = generate_response(
122
- st.session_state['messages'],
123
- tokenizer,
124
- model,
125
- max_tokens=max_tokens,
126
- temperature=temperature,
127
- top_p=top_p
128
- )
129
- timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
130
- st.session_state['messages'].append({'role': 'assistant', 'content': response, 'timestamp': timestamp})
131
- except Exception as e:
132
- st.error(f"Error generating response: {e}")
133
-
134
- with sidebar_col:
135
- st.sidebar.header("Settings")
136
- max_tokens = st.sidebar.slider(
137
- "Maximum Tokens",
138
- min_value=50,
139
  max_value=4096,
140
  value=512,
141
- step=256,
142
- help="Set the maximum number of tokens for the model's response."
143
  )
144
 
145
- temperature = st.sidebar.slider(
146
  "Temperature",
147
  min_value=0.1,
148
- max_value=1.0,
149
  value=0.7,
150
  step=0.1,
151
- help="Controls the randomness of the model's output."
152
  )
153
 
154
- top_p = st.sidebar.slider(
155
- "Top-p (Nucleus Sampling)",
156
  min_value=0.1,
157
  max_value=1.0,
158
  value=0.9,
159
  step=0.1,
160
- help="Controls the diversity of the model's output."
161
  )
 
 
 
 
162
 
163
- if st.sidebar.button("Clear Chat"):
164
- st.session_state['messages'] = []
165
- st.success("Chat history cleared.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
 
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
  import datetime
5
 
6
+ # Page configuration
7
  st.set_page_config(
8
  page_title="Qwen2.5-Coder Chat",
9
  page_icon="💬",
10
+ layout="wide"
11
  )
12
 
13
+ # Initialize session state for conversation history
14
+ if 'messages' not in st.session_state:
15
+ st.session_state.messages = []
16
 
17
+ # Cache the model loading
 
 
 
18
  @st.cache_resource
19
+ def load_model_and_tokenizer():
20
+ model_name = "Qwen/Qwen2.5-Coder-32B-Instruct"
21
 
22
+ # Configure quantization
23
+ bnb_config = BitsAndBytesConfig(
24
+ load_in_8bit=True,
25
+ bnb_4bit_quant_type="nf4",
26
+ bnb_4bit_compute_dtype=torch.float16,
27
+ bnb_4bit_use_double_quant=False,
28
  )
29
 
30
+ # Load tokenizer and model
31
+ tokenizer = AutoTokenizer.from_pretrained(
32
+ model_name,
33
+ trust_remote_code=True
34
+ )
35
  model = AutoModelForCausalLM.from_pretrained(
36
  model_name,
37
+ quantization_config=bnb_config,
38
  torch_dtype=torch.float16,
39
+ device_map="auto",
40
+ trust_remote_code=True
41
  )
42
+
43
  return tokenizer, model
44
 
45
+ # Main title
46
+ st.title("💬 Qwen2.5-Coder Chat")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ # Sidebar settings
49
+ with st.sidebar:
50
+ st.header("Settings")
51
+
52
+ max_length = st.slider(
53
+ "Maximum Length",
54
+ min_value=64,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  max_value=4096,
56
  value=512,
57
+ step=64,
58
+ help="Maximum number of tokens to generate"
59
  )
60
 
61
+ temperature = st.slider(
62
  "Temperature",
63
  min_value=0.1,
64
+ max_value=2.0,
65
  value=0.7,
66
  step=0.1,
67
+ help="Higher values make output more random, lower values more deterministic"
68
  )
69
 
70
+ top_p = st.slider(
71
+ "Top P",
72
  min_value=0.1,
73
  max_value=1.0,
74
  value=0.9,
75
  step=0.1,
76
+ help="Nucleus sampling: higher values consider more tokens, lower values are more focused"
77
  )
78
+
79
+ if st.button("Clear Conversation"):
80
+ st.session_state.messages = []
81
+ st.rerun()
82
 
83
+ # Load model with error handling
84
+ try:
85
+ with st.spinner("Loading model... Please wait..."):
86
+ tokenizer, model = load_model_and_tokenizer()
87
+ except Exception as e:
88
+ st.error(f"Error loading model: {str(e)}")
89
+ st.stop()
90
+
91
+ def generate_response(prompt, max_new_tokens=512, temperature=0.7, top_p=0.9):
92
+ """Generate response from the model"""
93
+ try:
94
+ # Tokenize input
95
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
96
+
97
+ # Generate response
98
+ with torch.no_grad():
99
+ outputs = model.generate(
100
+ **inputs,
101
+ max_new_tokens=max_new_tokens,
102
+ temperature=temperature,
103
+ top_p=top_p,
104
+ do_sample=True,
105
+ pad_token_id=tokenizer.pad_token_id,
106
+ eos_token_id=tokenizer.eos_token_id,
107
+ )
108
+
109
+ # Decode and return response
110
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
111
+ # Extract only the model's response (after the prompt)
112
+ response = response[len(prompt):].strip()
113
+ return response
114
+
115
+ except Exception as e:
116
+ st.error(f"Error generating response: {str(e)}")
117
+ return None
118
+
119
+ # Display chat history
120
+ for message in st.session_state.messages:
121
+ with st.chat_message(message["role"]):
122
+ st.write(f"{message['content']}\n\n_{message['timestamp']}_")
123
+
124
+ # Chat input
125
+ if prompt := st.chat_input("Ask me anything about coding..."):
126
+ # Add user message to chat
127
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
128
+ st.session_state.messages.append({
129
+ "role": "user",
130
+ "content": prompt,
131
+ "timestamp": timestamp
132
+ })
133
+
134
+ # Display user message
135
+ with st.chat_message("user"):
136
+ st.write(f"{prompt}\n\n_{timestamp}_")
137
+
138
+ # Generate and display response
139
+ with st.chat_message("assistant"):
140
+ with st.spinner("Thinking..."):
141
+ # Prepare conversation history
142
+ conversation = ""
143
+ for msg in st.session_state.messages:
144
+ if msg["role"] == "user":
145
+ conversation += f"Human: {msg['content']}\n"
146
+ else:
147
+ conversation += f"Assistant: {msg['content']}\n"
148
+ conversation += "Assistant:"
149
+
150
+ response = generate_response(
151
+ conversation,
152
+ max_new_tokens=max_length,
153
+ temperature=temperature,
154
+ top_p=top_p
155
+ )
156
+
157
+ if response:
158
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
159
+ st.write(f"{response}\n\n_{timestamp}_")
160
+
161
+ # Add assistant response to chat history
162
+ st.session_state.messages.append({
163
+ "role": "assistant",
164
+ "content": response,
165
+ "timestamp": timestamp
166
+ })