Training Models¶
Overview¶
SLEAP-NN leverages a flexible, configuration-driven training workflow built on Hydra and OmegaConf. This guide will walk you through the essential steps for training pose estimation models using SLEAP-NN, whether you prefer the command-line interface or Python APIs.
Using uv workflow
This section assumes you have sleap-nn
installed. If not, refer to the installation guide.
-
If you're using the
uvx
workflow, you do not need to install anything. (See installation using uvx for more details.) -
If you are using
uv sync
oruv pip
installation methods, adduv run
as a prefix to all CLI commands shown below, for example:uv run sleap-nn train ...
This section explains how to train a model using an existing configuration file. If you need help creating or editing a config, see the configuration guide.
Using CLI¶
To train a model using CLI,
config-name
or-c
: Name of the config fileconfig-dir
or-d
: Path to the config file
If your config file is in the path: /path/to/config_dir/config.yaml
, then config-name
would be config.yaml
and config-dir
would be /path/to/config_dir
.
Override any configuration from command line:
# Train on list of .slp files
sleap-nn train -c config -d /path/to/config_dir/ "data_config.train_labels_path=[labels.pkg.slp,labels.pkg.slp]"
# Change batch size
sleap-nn train -c config -d /path/to/config_dir/ trainer_config.train_data_loader.batch_size=8 trainer_config.val_data_loader.batch_size=8 "data_config.train_labels_path=[labels.pkg.slp]"
# Set number of GPUs to be used
sleap-nn train -c config -d /path/to/config_dir/ trainer_config.trainer_devices=1 "data_config.train_labels_path=[labels.pkg.slp]"
# Change learning rate
sleap-nn train -c config -d /path/to/config_dir/ trainer_config.optimizer.lr=5e-4 "data_config.train_labels_path=[labels.pkg.slp]"
Training TopDown Model
For topdown, we need to train two models (centroid → instance). To know more about topdown models, refer Model types
Using ModelTrainer
API¶
To train a model using the sleap-nn APIs:
If you have a cutsom labels object which is not in a slp file:
Training without Config¶
If you prefer not to create a large custom config file, you can quickly train a model by calling the train()
function directly and passing your desired parameters as arguments.
This approach is much simpler than manually specifying every parameter for each model component. For example, instead of defining all the details for a UNet backbone, you can just set backbone_config="unet_medium_rf"
or "unet_large_rf"
, and the appropriate preset values will be used automatically. The same applies to head configurations—just specify the desired preset (e.g., "bottomup"
), and the defaults are handled for you. To look into the preset values for each of the backbones and heads, refer the model configs.
For a full list of available arguments and their descriptions, see the train()
API reference in the documentation.
Applying data augmentation is also much simpler—you can just specify the augmentation names directly (as a string or list), instead of writing out a full configuration.
Monitoring Training¶
Weights & Biases (WandB) Integration¶
If you set trainer_config.use_wandb = True
and provide a valid trainer_config.wandb_config
, all key training metrics—including losses, training/validation times, and visualizations (if wandb_config.save_viz_imgs_wandb
is set to True)—are automatically logged to your WandB project. This makes it easy to monitor progress and compare runs.
Checkpointing & Artifacts¶
For every training run, a dedicated checkpoint directory is created. This directory contains:
- The original user-provided config (
initial_config.yaml
) - The full training config with computed values (
training_config.yaml
) - The best model weights (
best.ckpt
) whentrainer_config.save_ckpt
is set toTrue
- The training and validation SLP files used
- A CSV log tracking train/validation loss, times, and learning rate across epochs
Visualizing training performance¶
To help understand model performance, SLEAP-NN can generate visualizations of model predictions (e.g., confidence maps) after each epoch when trainer_config.visualize_preds_during_training
is set to True
. By default, these images are saved temporarily (deleted after training is completed), but you can configure the system to keep them by setting trainer_config.keep_viz
to True
.
Advanced Options¶
Fine-tuning / Transfer Learning¶
SLEAP-NN makes it easy to fine-tune or transfer-learn from existing models. To initialize your model with pre-trained weights, simply set the following options in your configuration:
model_config.pretrained_backbone_weights
: Path to a checkpoint file (or.h5
file path from SLEAP <=1.4 - only UNet backbone is supported) containing the backbone weights you want to load. This will initialize the backbone (e.g., UNet, Swin Transformer) with the specified weights.model_config.pretrained_head_weights
: Path to a checkpoint file (or.h5
file from SLEAP ≤1.4 - only UNet backbone is supported) to initialize the model's head weights (e.g., for bottomup or topdown heads). The head and backbone weights are usually the same checkpoint, but you can specify a different file here if you want to use separate weights for the head (for example, when adapting a model to a new output head or architecture).
By specifying these options, your model will be initialized with the provided weights, allowing you to fine-tune on new data or adapt to a new task. You can use this for transfer learning from a model trained on a different dataset, or to continue training with a modified head or backbone.
Resume Training¶
To resume training from a previous checkpoint (restoring both model weights and optimizer state), simply provide the path to your previous checkpoint file using the trainer_config.resume_ckpt_path
option. This allows you to continue training seamlessly from where you left off.
sleap-nn train \
-c config \
-d /path/to/config_dir/ \
trainer_config.resume_ckpt_path=/path/to/prv_trained/checkpoint.ckpt \
"data_config.train_labels_path=[labels.pkg.slp]"
Multi-GPU Training¶
To automatically configure the accelerator and number of devices, set:
trainer_config:
ckpt_dir: models
run_name: multi_gpu_training_1
trainer_accelerator: "auto"
trainer_devices:
trainer_device_indices:
trainer_strategy: "auto"
To set the number of gpus to be used and the accelerator:
trainer_config:
ckpt_dir: models
run_name: multi_gpu_training_1
trainer_accelerator: "gpu"
trainer_devices: 4
trainer_device_indices:
trainer_strategy: "ddp"
To set the devices to use (use first and third gpu):
trainer_config:
ckpt_dir: models
run_name: multi_gpu_training_1
trainer_accelerator: "gpu"
trainer_devices: 2
trainer_device_indices:
- 0
- 2
trainer_strategy: "ddp"
Training steps in multi-gpu setting
- In a multi-gpu training setup, the effective steps during training would be
config.trainer_config.trainer_steps_per_epoch
/config.trainer_config.trainer_devices
. - If validation labels are not provided in a multi-GPU training setup, we now ensure deterministic splitting of labels into train/val sets by seeding with 42 (when no seed is given). This prevents each GPU worker from producing a different split. To generate a different train-val split, set a custom seed via
config.trainer_config.seed
.
Multi-node training
Multi-node trainings have not been validated and should be considered experimental.
Best Practices¶
- Start Simple: Begin with default configurations
- Cache data: If you want to get faster training time, consider caching the images on memory (or disk) by setting the relevant
data_config.data_pipeline_fw
. (num_workers could be set >0 if caching frameworks are used!) - Monitor Overfitting: Watch validation metrics
- Adjust Learning Rate: Use learning rate scheduling
- Data Augmentation: Enable augmentations for better generalization
- Early Stopping: Prevent overfitting with early stopping callback.
Troubleshooting¶
Out of Memory¶
For large models or datasets:
- Reduce
batch_size
- Reduce model size (fewer filters/layers)
- Reduce number of workers
Slow Training¶
- Use caching methods (
data_config.data_pipeline_fw
) - Increase
num_workers
for data loading - Check GPU utilization
Poor Performance¶
- Increase training data
- Adjust augmentation parameters
- Try different architectures
- Tune hyperparameters