jbilcke-hf HF Staff commited on
Commit
f407698
·
1 Parent(s): d36662f

added support for presets + automatic FPS

Browse files
Files changed (3) hide show
  1. app.py +27 -13
  2. config.py +1 -1
  3. training_service.py +13 -3
app.py CHANGED
@@ -661,25 +661,39 @@ class VideoTrainerUI:
661
  training_dataset
662
  )
663
 
664
- def update_training_params(self, preset_name: str) -> Dict:
665
  """Update UI components based on selected preset"""
666
  preset = TRAINING_PRESETS[preset_name]
667
-
 
 
 
 
 
 
668
  # Get preset description for display
669
  description = preset.get("description", "")
670
- bucket_info = f"\nBucket configuration: {len(preset['training_buckets'])} buckets"
 
 
 
 
 
 
 
671
  info_text = f"{description}{bucket_info}"
672
 
673
- return {
674
- "model_type": gr.Dropdown(value=MODEL_TYPES[preset["model_type"]]),
675
- "lora_rank": gr.Dropdown(value=preset["lora_rank"]),
676
- "lora_alpha": gr.Dropdown(value=preset["lora_alpha"]),
677
- "num_epochs": gr.Number(value=preset["num_epochs"]),
678
- "batch_size": gr.Number(value=preset["batch_size"]),
679
- "learning_rate": gr.Number(value=preset["learning_rate"]),
680
- "save_iterations": gr.Number(value=preset["save_iterations"]),
681
- "preset_info": gr.Markdown(value=info_text)
682
- }
 
683
 
684
  def create_ui(self):
685
  """Create Gradio interface"""
 
661
  training_dataset
662
  )
663
 
664
+ def update_training_params(self, preset_name: str) -> Tuple:
665
  """Update UI components based on selected preset"""
666
  preset = TRAINING_PRESETS[preset_name]
667
+
668
+ # Find the display name that maps to our model type
669
+ model_display_name = next(
670
+ key for key, value in MODEL_TYPES.items()
671
+ if value == preset["model_type"]
672
+ )
673
+
674
  # Get preset description for display
675
  description = preset.get("description", "")
676
+
677
+ # Get max values from buckets
678
+ buckets = preset["training_buckets"]
679
+ max_frames = max(frames for frames, _, _ in buckets)
680
+ max_height = max(height for _, height, _ in buckets)
681
+ max_width = max(width for _, _, width in buckets)
682
+ bucket_info = f"\nMaximum video size: {max_frames} frames at {max_width}x{max_height} resolution"
683
+
684
  info_text = f"{description}{bucket_info}"
685
 
686
+ # Return values in the same order as the output components
687
+ return (
688
+ model_display_name,
689
+ preset["lora_rank"],
690
+ preset["lora_alpha"],
691
+ preset["num_epochs"],
692
+ preset["batch_size"],
693
+ preset["learning_rate"],
694
+ preset["save_iterations"],
695
+ info_text
696
+ )
697
 
698
  def create_ui(self):
699
  """Create Gradio interface"""
config.py CHANGED
@@ -262,7 +262,7 @@ class TrainingConfig:
262
  data_root=data_path,
263
  output_dir=output_path,
264
  batch_size=1,
265
- train_epochs=70,
266
  lr=3e-5,
267
  gradient_checkpointing=True,
268
  id_token="BW_STYLE",
 
262
  data_root=data_path,
263
  output_dir=output_path,
264
  batch_size=1,
265
+ train_epochs=40,
266
  lr=3e-5,
267
  gradient_checkpointing=True,
268
  id_token="BW_STYLE",
training_service.py CHANGED
@@ -19,7 +19,7 @@ import select
19
  from typing import Any, Optional, Dict, List, Union, Tuple
20
 
21
  from huggingface_hub import upload_folder, create_repo
22
- from config import TrainingConfig,TRAINING_PRESETS, LOG_FILE_PATH, TRAINING_VIDEOS_PATH, STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES
23
  from utils import make_archive, parse_training_log, is_image_file, is_video_file
24
  from finetrainers_utils import prepare_finetrainers_dataset, copy_files_to_training_dir
25
 
@@ -214,8 +214,18 @@ class TrainingService:
214
  return f"Configuration validation failed: {str(e)}"
215
 
216
 
217
- def start_training(self, model_type: str, lora_rank: str, lora_alpha: str, num_epochs: int, batch_size: int,
218
- learning_rate: float, save_iterations: int, repo_id: str) -> Tuple[str, str]:
 
 
 
 
 
 
 
 
 
 
219
  """Start training with finetrainers"""
220
 
221
  self.clear_logs()
 
19
  from typing import Any, Optional, Dict, List, Union, Tuple
20
 
21
  from huggingface_hub import upload_folder, create_repo
22
+ from config import TrainingConfig, TRAINING_PRESETS, LOG_FILE_PATH, TRAINING_VIDEOS_PATH, STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES
23
  from utils import make_archive, parse_training_log, is_image_file, is_video_file
24
  from finetrainers_utils import prepare_finetrainers_dataset, copy_files_to_training_dir
25
 
 
214
  return f"Configuration validation failed: {str(e)}"
215
 
216
 
217
+ def start_training(
218
+ self,
219
+ model_type: str,
220
+ lora_rank: str,
221
+ lora_alpha: str,
222
+ num_epochs: int,
223
+ batch_size: int,
224
+ learning_rate: float,
225
+ save_iterations: int,
226
+ repo_id: str,
227
+ preset_name: str,
228
+ ) -> Tuple[str, str]:
229
  """Start training with finetrainers"""
230
 
231
  self.clear_logs()