Spestly commited on
Commit
15b94f7
·
verified ·
1 Parent(s): e7e2534

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -31
app.py CHANGED
@@ -14,11 +14,11 @@ login(token=HF_TOKEN)
14
  # Define models
15
  MODELS = {
16
  "athena-1": {
17
- "name": "🦁 Atlas-Flash",
18
  "sizes": {
19
  "1.5B": "Spestly/Atlas-R1-1.5B-Preview",
20
  },
21
- "emoji": "🦁",
22
  "experimental": True,
23
  "is_vision": False, # Enable vision support for this model
24
  },
@@ -103,12 +103,9 @@ class AtlasInferenceApp:
103
  padding=True
104
  )
105
 
106
- # Generate response with streaming
107
- response_container = st.empty() # Placeholder for streaming text
108
- full_response = ""
109
- generated_tokens = [] # Track generated tokens to avoid duplicates
110
  with torch.no_grad():
111
- for chunk in st.session_state.current_model["model"].generate(
112
  input_ids=inputs.input_ids,
113
  attention_mask=inputs.attention_mask,
114
  max_new_tokens=max_tokens,
@@ -118,30 +115,14 @@ class AtlasInferenceApp:
118
  do_sample=True,
119
  pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
120
  eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
121
- ):
122
- # Ensure chunk is 2D (batch size × sequence length)
123
- if chunk.dim() == 1:
124
- chunk = chunk.unsqueeze(0) # Add batch dimension
125
-
126
- # Decode only the new tokens
127
- new_tokens = chunk[:, inputs.input_ids.shape[1]:] # Exclude input tokens
128
- generated_tokens.extend(new_tokens[0].tolist()) # Add new tokens to the list
129
- chunk_text = st.session_state.current_model["tokenizer"].decode(generated_tokens, skip_special_tokens=True)
130
-
131
- # Remove the prompt from the response
132
- if prompt in chunk_text:
133
- chunk_text = chunk_text.replace(prompt, "").strip()
134
-
135
- # Update the response
136
- full_response = chunk_text
137
- response_container.markdown(full_response)
138
-
139
- # Stop if the response is too long or incomplete
140
- if len(full_response) >= max_tokens * 4: # Approximate token-to-character ratio
141
- st.warning("⚠️ Response truncated due to length limit.")
142
- break
143
-
144
- return full_response.strip() # Return the cleaned response
145
  except Exception as e:
146
  return f"⚠️ Generation Error: {str(e)}"
147
  finally:
 
14
  # Define models
15
  MODELS = {
16
  "athena-1": {
17
+ "name": " Atlas-Flash 1205",
18
  "sizes": {
19
  "1.5B": "Spestly/Atlas-R1-1.5B-Preview",
20
  },
21
+ "emoji": "",
22
  "experimental": True,
23
  "is_vision": False, # Enable vision support for this model
24
  },
 
103
  padding=True
104
  )
105
 
106
+ # Generate response without streaming
 
 
 
107
  with torch.no_grad():
108
+ output = st.session_state.current_model["model"].generate(
109
  input_ids=inputs.input_ids,
110
  attention_mask=inputs.attention_mask,
111
  max_new_tokens=max_tokens,
 
115
  do_sample=True,
116
  pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
117
  eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
118
+ )
119
+ response = st.session_state.current_model["tokenizer"].decode(output[0], skip_special_tokens=True)
120
+
121
+ # Remove the prompt from the response
122
+ if prompt in response:
123
+ response = response.replace(prompt, "").strip()
124
+
125
+ return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  except Exception as e:
127
  return f"⚠️ Generation Error: {str(e)}"
128
  finally: