akashmadisetty commited on
Commit
c1d34f4
·
1 Parent(s): 704e1a6
Files changed (1) hide show
  1. app.py +90 -45
app.py CHANGED
@@ -23,26 +23,55 @@ def load_model(hf_token):
23
  global global_model, global_tokenizer, model_loaded
24
 
25
  if not hf_token:
26
- return False, "Please enter your Hugging Face token to use the model."
 
27
 
28
- model_name = "google/gemma-3-4b-pt"
29
  try:
30
- global_tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
31
- global_model = AutoModelForCausalLM.from_pretrained(
32
- model_name,
33
- torch_dtype=torch.float16,
34
- device_map="auto",
35
- token=hf_token
36
- )
37
- model_loaded = True
38
- return True, "Model loaded successfully!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  except Exception as e:
40
  model_loaded = False
41
  error_msg = str(e)
42
  if "401 Client Error" in error_msg:
43
- return False, "Authentication failed. Please check your token and make sure you've accepted the model license on Hugging Face."
44
  else:
45
- return False, f"Error loading model: {error_msg}"
46
 
47
  def generate_prompt(task_type, **kwargs):
48
  """Generate appropriate prompts based on task type and parameters"""
@@ -135,21 +164,37 @@ def generate_text(prompt, max_length=1024, temperature=0.7, top_p=0.95):
135
  try:
136
  inputs = global_tokenizer(prompt, return_tensors="pt").to(global_model.device)
137
 
138
- # Generate text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  outputs = global_model.generate(
140
  **inputs,
141
- max_length=max_length,
142
- temperature=temperature,
143
- top_p=top_p,
144
- do_sample=True,
145
- pad_token_id=global_tokenizer.eos_token_id,
146
  )
147
 
148
  # Decode and return the generated text
149
  generated_text = global_tokenizer.decode(outputs[0], skip_special_tokens=True)
150
  return generated_text
151
  except Exception as e:
152
- return f"Error generating text: {str(e)}"
 
 
 
 
153
 
154
  # Create parameters UI component
155
  def create_parameter_ui():
@@ -162,16 +207,16 @@ def create_parameter_ui():
162
  label="Maximum Length"
163
  )
164
  temperature = gr.Slider(
165
- minimum=0.1,
166
  maximum=1.5,
167
- value=0.7,
168
  step=0.1,
169
  label="Temperature"
170
  )
171
  top_p = gr.Slider(
172
  minimum=0.5,
173
- maximum=1.0,
174
- value=0.95,
175
  step=0.05,
176
  label="Top-p"
177
  )
@@ -330,10 +375,10 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
330
  # Examples for text generation
331
  gr.Examples(
332
  [
333
- ["Creative Writing", "short story", "a robot learning to paint", "article", "artificial intelligence", "", 1024, 0.7, 0.95],
334
- ["Creative Writing", "poem", "the beauty of mathematics", "article", "artificial intelligence", "", 768, 0.8, 0.95],
335
- ["Informational Writing", "short story", "a robot discovering emotions", "article", "quantum computing", "", 1024, 0.5, 0.95],
336
- ["Custom Prompt", "short story", "a robot discovering emotions", "article", "artificial intelligence", "Write a marketing email for a new smartphone with innovative AI features", 1024, 0.7, 0.95]
337
  ],
338
  fn=text_generation_handler,
339
  inputs=[
@@ -390,9 +435,9 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
390
  # Examples for brainstorming
391
  gr.Examples(
392
  [
393
- ["project", "educational app for children", 1024, 0.8, 0.95],
394
- ["business", "eco-friendly food packaging", 1024, 0.8, 0.95],
395
- ["solution", "reducing urban traffic congestion", 1024, 0.8, 0.95],
396
  ],
397
  fn=brainstorm_handler,
398
  inputs=[brainstorm_category, brainstorm_topic, *brainstorm_params],
@@ -452,9 +497,9 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
452
  # Examples for content creation
453
  gr.Examples(
454
  [
455
- ["blog post", "sustainable living tips", "environmentally conscious consumers", 1536, 0.7, 0.95],
456
- ["social media post", "product launch announcement", "existing customers", 512, 0.7, 0.95],
457
- ["marketing copy", "new fitness app", "health-focused individuals", 1024, 0.7, 0.95],
458
  ],
459
  fn=content_creation_handler,
460
  inputs=[content_type, content_topic, content_audience, *content_params],
@@ -509,9 +554,9 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
509
  # Examples for email drafting
510
  gr.Examples(
511
  [
512
- ["job application", "Applying for a marketing specialist position at ABC Marketing. I have 5 years of experience in digital marketing.", 1024, 0.7, 0.95],
513
- ["business proposal", "Proposing a collaboration between our companies for a joint product development effort.", 1024, 0.7, 0.95],
514
- ["follow-up", "Following up after our meeting last Thursday about the project timeline and resources.", 1024, 0.7, 0.95],
515
  ],
516
  fn=email_draft_handler,
517
  inputs=[email_type, email_context, *email_params],
@@ -609,9 +654,9 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
609
  # Examples for explanation
610
  gr.Examples(
611
  [
612
- ["blockchain technology", "beginner", 1024, 0.7, 0.95],
613
- ["photosynthesis", "child", 1024, 0.7, 0.95],
614
- ["machine learning", "college student", 1024, 0.7, 0.95],
615
  ],
616
  fn=explanation_handler,
617
  inputs=[explain_topic, explain_level, *explain_params],
@@ -666,9 +711,9 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
666
  # Examples for classification
667
  gr.Examples(
668
  [
669
- ["The stock market saw significant gains today as tech companies reported strong quarterly earnings.", "technology, health, finance, entertainment, education, sports", 256, 0.1, 0.95],
670
- ["The team scored in the final minutes to secure their victory in the championship game.", "technology, health, finance, entertainment, education, sports", 256, 0.1, 0.95],
671
- ["The new educational app helps students master complex math concepts through interactive exercises.", "technology, health, finance, entertainment, education, sports", 256, 0.1, 0.95],
672
  ],
673
  fn=classification_handler,
674
  inputs=[classify_text, classify_categories, *classify_params],
@@ -723,8 +768,8 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
723
  # Examples for data extraction
724
  gr.Examples(
725
  [
726
- ["Sarah Johnson is the CEO of Green Innovations, founded in 2012. The company reported $8.5 million in revenue for 2023. Contact her at [email protected].", "name, position, company, founding year, revenue, contact", 768, 0.3, 0.95],
727
- ["The new iPhone 15 Pro features a 6.1-inch display, A17 Pro chip, 48MP camera, and starts at $999 for the 128GB model.", "product name, screen size, processor, camera, price, storage capacity", 768, 0.3, 0.95],
728
  ],
729
  fn=data_extraction_handler,
730
  inputs=[extract_text, extract_data_points, *extract_params],
 
23
  global global_model, global_tokenizer, model_loaded
24
 
25
  if not hf_token:
26
+ model_loaded = False
27
+ return "⚠️ Please enter your Hugging Face token to use the model."
28
 
 
29
  try:
30
+ # Try both model versions
31
+ model_options = [
32
+ "google/gemma-3-4b-pt", # Try the quantized PT version first
33
+ "google/gemma-2b", # Fallback to 2b model
34
+ ]
35
+
36
+ # Try to load models in order until one works
37
+ for model_name in model_options:
38
+ try:
39
+ print(f"Attempting to load model: {model_name}")
40
+
41
+ # Load tokenizer
42
+ global_tokenizer = AutoTokenizer.from_pretrained(
43
+ model_name,
44
+ token=hf_token
45
+ )
46
+
47
+ # Load model with safe configuration
48
+ global_model = AutoModelForCausalLM.from_pretrained(
49
+ model_name,
50
+ torch_dtype=torch.float16,
51
+ device_map="auto",
52
+ token=hf_token,
53
+ use_cache=True,
54
+ low_cpu_mem_usage=True,
55
+ attn_implementation="flash_attention_2" if torch.cuda.is_available() else "eager"
56
+ )
57
+
58
+ model_loaded = True
59
+ return f"✅ Model {model_name} loaded successfully!"
60
+ except Exception as specific_e:
61
+ print(f"Failed to load {model_name}: {specific_e}")
62
+ continue
63
+
64
+ # If we get here, all model options failed
65
+ model_loaded = False
66
+ return "❌ Could not load any model version. Please check your token and try again."
67
+
68
  except Exception as e:
69
  model_loaded = False
70
  error_msg = str(e)
71
  if "401 Client Error" in error_msg:
72
+ return "❌ Authentication failed. Please check your token and make sure you've accepted the model license on Hugging Face."
73
  else:
74
+ return f"Error loading model: {error_msg}"
75
 
76
  def generate_prompt(task_type, **kwargs):
77
  """Generate appropriate prompts based on task type and parameters"""
 
164
  try:
165
  inputs = global_tokenizer(prompt, return_tensors="pt").to(global_model.device)
166
 
167
+ generation_config = {
168
+ "max_length": max_length,
169
+ "do_sample": True,
170
+ "pad_token_id": global_tokenizer.eos_token_id,
171
+ }
172
+
173
+ # Only add temperature if it's not too low (can cause probability issues)
174
+ if temperature >= 0.2:
175
+ generation_config["temperature"] = temperature
176
+ else:
177
+ generation_config["temperature"] = 0.2
178
+
179
+ # Only add top_p if it's valid
180
+ if 0 < top_p < 1:
181
+ generation_config["top_p"] = top_p
182
+
183
+ # Generate text with safer parameters
184
  outputs = global_model.generate(
185
  **inputs,
186
+ **generation_config
 
 
 
 
187
  )
188
 
189
  # Decode and return the generated text
190
  generated_text = global_tokenizer.decode(outputs[0], skip_special_tokens=True)
191
  return generated_text
192
  except Exception as e:
193
+ error_msg = str(e)
194
+ if "probability tensor" in error_msg:
195
+ return "Error: There was a problem with the generation parameters. Try using higher temperature (0.5+) and top_p values (0.9+)."
196
+ else:
197
+ return f"Error generating text: {error_msg}"
198
 
199
  # Create parameters UI component
200
  def create_parameter_ui():
 
207
  label="Maximum Length"
208
  )
209
  temperature = gr.Slider(
210
+ minimum=0.3,
211
  maximum=1.5,
212
+ value=0.8,
213
  step=0.1,
214
  label="Temperature"
215
  )
216
  top_p = gr.Slider(
217
  minimum=0.5,
218
+ maximum=0.99,
219
+ value=0.9,
220
  step=0.05,
221
  label="Top-p"
222
  )
 
375
  # Examples for text generation
376
  gr.Examples(
377
  [
378
+ ["Creative Writing", "short story", "a robot learning to paint", "article", "artificial intelligence", "", 1024, 0.8, 0.9],
379
+ ["Creative Writing", "poem", "the beauty of mathematics", "article", "artificial intelligence", "", 768, 0.8, 0.9],
380
+ ["Informational Writing", "short story", "a robot discovering emotions", "article", "quantum computing", "", 1024, 0.7, 0.9],
381
+ ["Custom Prompt", "short story", "a robot discovering emotions", "article", "artificial intelligence", "Write a marketing email for a new smartphone with innovative AI features", 1024, 0.8, 0.9]
382
  ],
383
  fn=text_generation_handler,
384
  inputs=[
 
435
  # Examples for brainstorming
436
  gr.Examples(
437
  [
438
+ ["project", "educational app for children", 1024, 0.8, 0.9],
439
+ ["business", "eco-friendly food packaging", 1024, 0.8, 0.9],
440
+ ["solution", "reducing urban traffic congestion", 1024, 0.8, 0.9],
441
  ],
442
  fn=brainstorm_handler,
443
  inputs=[brainstorm_category, brainstorm_topic, *brainstorm_params],
 
497
  # Examples for content creation
498
  gr.Examples(
499
  [
500
+ ["blog post", "sustainable living tips", "environmentally conscious consumers", 1536, 0.8, 0.9],
501
+ ["social media post", "product launch announcement", "existing customers", 512, 0.8, 0.9],
502
+ ["marketing copy", "new fitness app", "health-focused individuals", 1024, 0.8, 0.9],
503
  ],
504
  fn=content_creation_handler,
505
  inputs=[content_type, content_topic, content_audience, *content_params],
 
554
  # Examples for email drafting
555
  gr.Examples(
556
  [
557
+ ["job application", "Applying for a marketing specialist position at ABC Marketing. I have 5 years of experience in digital marketing.", 1024, 0.8, 0.9],
558
+ ["business proposal", "Proposing a collaboration between our companies for a joint product development effort.", 1024, 0.8, 0.9],
559
+ ["follow-up", "Following up after our meeting last Thursday about the project timeline and resources.", 1024, 0.8, 0.9],
560
  ],
561
  fn=email_draft_handler,
562
  inputs=[email_type, email_context, *email_params],
 
654
  # Examples for explanation
655
  gr.Examples(
656
  [
657
+ ["blockchain technology", "beginner", 1024, 0.8, 0.9],
658
+ ["photosynthesis", "child", 1024, 0.8, 0.9],
659
+ ["machine learning", "college student", 1024, 0.8, 0.9],
660
  ],
661
  fn=explanation_handler,
662
  inputs=[explain_topic, explain_level, *explain_params],
 
711
  # Examples for classification
712
  gr.Examples(
713
  [
714
+ ["The stock market saw significant gains today as tech companies reported strong quarterly earnings.", "technology, health, finance, entertainment, education, sports", 256, 0.5, 0.9],
715
+ ["The team scored in the final minutes to secure their victory in the championship game.", "technology, health, finance, entertainment, education, sports", 256, 0.5, 0.9],
716
+ ["The new educational app helps students master complex math concepts through interactive exercises.", "technology, health, finance, entertainment, education, sports", 256, 0.5, 0.9],
717
  ],
718
  fn=classification_handler,
719
  inputs=[classify_text, classify_categories, *classify_params],
 
768
  # Examples for data extraction
769
  gr.Examples(
770
  [
771
+ ["Sarah Johnson is the CEO of Green Innovations, founded in 2012. The company reported $8.5 million in revenue for 2023. Contact her at [email protected].", "name, position, company, founding year, revenue, contact", 768, 0.5, 0.9],
772
+ ["The new iPhone 15 Pro features a 6.1-inch display, A17 Pro chip, 48MP camera, and starts at $999 for the 128GB model.", "product name, screen size, processor, camera, price, storage capacity", 768, 0.5, 0.9],
773
  ],
774
  fn=data_extraction_handler,
775
  inputs=[extract_text, extract_data_points, *extract_params],