Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -14,11 +14,11 @@ login(token=HF_TOKEN)
|
|
14 |
# Define models
|
15 |
MODELS = {
|
16 |
"athena-1": {
|
17 |
-
"name": "
|
18 |
"sizes": {
|
19 |
"1.5B": "Spestly/Atlas-R1-1.5B-Preview",
|
20 |
},
|
21 |
-
"emoji": "
|
22 |
"experimental": True,
|
23 |
"is_vision": False, # Enable vision support for this model
|
24 |
},
|
@@ -103,12 +103,9 @@ class AtlasInferenceApp:
|
|
103 |
padding=True
|
104 |
)
|
105 |
|
106 |
-
# Generate response
|
107 |
-
response_container = st.empty() # Placeholder for streaming text
|
108 |
-
full_response = ""
|
109 |
-
generated_tokens = [] # Track generated tokens to avoid duplicates
|
110 |
with torch.no_grad():
|
111 |
-
|
112 |
input_ids=inputs.input_ids,
|
113 |
attention_mask=inputs.attention_mask,
|
114 |
max_new_tokens=max_tokens,
|
@@ -118,30 +115,14 @@ class AtlasInferenceApp:
|
|
118 |
do_sample=True,
|
119 |
pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
|
120 |
eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
|
121 |
-
)
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
chunk_text = st.session_state.current_model["tokenizer"].decode(generated_tokens, skip_special_tokens=True)
|
130 |
-
|
131 |
-
# Remove the prompt from the response
|
132 |
-
if prompt in chunk_text:
|
133 |
-
chunk_text = chunk_text.replace(prompt, "").strip()
|
134 |
-
|
135 |
-
# Update the response
|
136 |
-
full_response = chunk_text
|
137 |
-
response_container.markdown(full_response)
|
138 |
-
|
139 |
-
# Stop if the response is too long or incomplete
|
140 |
-
if len(full_response) >= max_tokens * 4: # Approximate token-to-character ratio
|
141 |
-
st.warning("⚠️ Response truncated due to length limit.")
|
142 |
-
break
|
143 |
-
|
144 |
-
return full_response.strip() # Return the cleaned response
|
145 |
except Exception as e:
|
146 |
return f"⚠️ Generation Error: {str(e)}"
|
147 |
finally:
|
|
|
14 |
# Define models
|
15 |
MODELS = {
|
16 |
"athena-1": {
|
17 |
+
"name": "⚡ Atlas-Flash 1205",
|
18 |
"sizes": {
|
19 |
"1.5B": "Spestly/Atlas-R1-1.5B-Preview",
|
20 |
},
|
21 |
+
"emoji": "⚡",
|
22 |
"experimental": True,
|
23 |
"is_vision": False, # Enable vision support for this model
|
24 |
},
|
|
|
103 |
padding=True
|
104 |
)
|
105 |
|
106 |
+
# Generate response without streaming
|
|
|
|
|
|
|
107 |
with torch.no_grad():
|
108 |
+
output = st.session_state.current_model["model"].generate(
|
109 |
input_ids=inputs.input_ids,
|
110 |
attention_mask=inputs.attention_mask,
|
111 |
max_new_tokens=max_tokens,
|
|
|
115 |
do_sample=True,
|
116 |
pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
|
117 |
eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
|
118 |
+
)
|
119 |
+
response = st.session_state.current_model["tokenizer"].decode(output[0], skip_special_tokens=True)
|
120 |
+
|
121 |
+
# Remove the prompt from the response
|
122 |
+
if prompt in response:
|
123 |
+
response = response.replace(prompt, "").strip()
|
124 |
+
|
125 |
+
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
except Exception as e:
|
127 |
return f"⚠️ Generation Error: {str(e)}"
|
128 |
finally:
|