Step-by-Step Training Tutorial¶
This tutorial will walk you through the complete process of training a pose estimation model from scratch.
📋 Prerequisites¶
Before starting, make sure you have sleap-nn
installed (Refer Installation docs
)
API-based Tutorial
In this tutorial, we use the Python API for all steps, which is ideal for running in a notebook or Python script. The uvx
workflow will not work with the API-based approach.
-
Installation: Make sure you have installed
sleap-nn
using either pip or the uv sync workflow. -
Command Line Interface (CLI): If you prefer using the CLI, or want to see all available CLI options, refer to the Training Guide and Inference Guide.
🚀 Step 1: Configuration Setup¶
The first step is to set-up our configuration file, which configures the parameters required to train a pose estimation model with sleap-nn.
1.1 Load a Sample Configuration¶
Start by loading a sample (.yaml
) configuration:
data_config:
train_labels_path:
- path/to/your/training_data.slp
val_labels_path:
- path/to/your/validation_data.slp
validation_fraction: 0.1
user_instances_only: true
data_pipeline_fw: torch_dataset
preprocessing:
ensure_rgb: false
ensure_grayscale: false
scale: 1.0
crop_size: null
min_crop_size: 100
use_augmentations_train: true
augmentation_config:
intensity:
contrast_p: 0.5
brightness_p: 0.5
geometric:
rotation_min: -15.0
rotation_max: 15.0
scale_min: 0.9
scale_max: 1.1
affine_p: 1.0
model_config:
pretrained_backbone_weights: null
pretrained_head_weights: null
backbone_config:
unet:
in_channels: 1
kernel_size: 3
filters: 16
filters_rate: 2.0
max_stride: 16
middle_block: true
up_interpolate: true
stacks: 1
convs_per_block: 2
output_stride: 2
convnext: null
swint: null
head_configs:
single_instance:
confmaps:
part_names: null
sigma: 2.5
output_stride: 2
centroid: null
centered_instance: null
bottomup: null
multi_class_bottomup: null
multi_class_topdown: null
trainer_config:
train_data_loader:
batch_size: 4
shuffle: true
num_workers: 0
val_data_loader:
batch_size: 4
shuffle: false
num_workers: 0
model_ckpt:
save_top_k: 1
save_last: false
trainer_devices:
trainer_device_indices:
trainer_accelerator: auto
min_train_steps_per_epoch: 200
visualize_preds_during_training: true
keep_viz: false
max_epochs: 200
use_wandb: false
save_ckpt: true
ckpt_dir: <ckpt_dir>
run_name: <run_name>
optimizer_name: Adam
optimizer:
lr: 0.0001
amsgrad: false
lr_scheduler:
step_lr: null
reduce_lr_on_plateau:
threshold: 1.0e-06
threshold_mode: rel
cooldown: 3
patience: 5
factor: 0.5
min_lr: 1.0e-08
early_stopping:
min_delta: 1.0e-08
patience: 10
stop_training_on_plateau: true
1.2 Understanding the Configuration Structure¶
Your config file has three main sections:
data_config: # How to load and process your data
model_config: # What model architecture to use
trainer_config: # How to train the model, setup hyparameters
1.3 Key Parameters to Modify¶
Data Configuration (data_config
)¶
Download sample
train.pkg.slp
andval.pkg.slp
.
Set the train_labels_path
to the path of your training .slp
file, or a list of .slp
files if you have multiple datasets. The val_labels_path
is optional—if you leave it out, the training data will be automatically split into training and validation sets based on the validation_fraction
parameter.
Choose the appropriate data_pipeline_fw
based on your dataset size and hardware:
- Use
torch_dataset_cache_img_memory
for small datasets that fit comfortably in RAM. This will cache all source images in memory for faster training. - Use
torch_dataset_cache_img_disk
for larger datasets that don't fit in memory. This caches images to disk, enabling efficient loading even for very large datasets. You can reuse the disk cache across different model types, since only the raw source images are cached (not model-specific data).
You can customize data loading, preprocessing, and augmentation options in this section. For a full explanation of all available parameters and augmentation settings, see the Data config section of the Configuration Guide.
data_config:
train_labels_path:
- path/to/your/training_data.slp
val_labels_path:
- path/to/your/validation_data.slp
validation_fraction: 0.1
data_pipeline_fw: torch_dataset
preprocessing:
ensure_rgb: false
ensure_grayscale: false
scale: 1.0
crop_size: null # only for centered-instance model
min_crop_size: 100 # only for centered-instance model
use_augmentations_train: true
augmentation_config:
intensity:
contrast_p: 0.5
brightness_p: 0.5
geometric:
rotation_min: -15.0
rotation_max: 15.0
scale_min: 0.9
scale_max: 1.1
affine_p: 1.0
Model Configuration (model_config
)¶
When configuring your model, you’ll need to select both a backbone architecture and a model type:
- Backbone options:
unet
,swint
, orconvnext
- Model type options:
single_instance
,centroid
,centered_instance
, orbottomup
For a detailed explanation of each backbone and model type, see the Model Architectures Guide.
Tips for configuring your model:
- Input channels (
in_channels
): Set this to match your input image format (e.g., 1 for grayscale, 3 for RGB). The training pipeline will also infer and adjust this automatically. - Max stride (
max_stride
): This parameter controls the number of downsampling (encoder) blocks in the backbone, which directly affects the receptive field size. For a deeper dive into how receptive field is affected, check out the Receptive Field Guide. - Special note for
convnext
andswint
: For these backbones,max_stride
is determined bystem_patch_stride * 16
and cannot be set arbitrarily.
For ready-to-use configuration examples for each backbone and model type, see the Model Config Guide.
model_config:
pretrained_backbone_weights: null
pretrained_head_weights: null
backbone_config:
unet:
in_channels: 1
kernel_size: 3
filters: 16
filters_rate: 2.0
max_stride: 16
middle_block: true
up_interpolate: true
stacks: 1
convs_per_block: 2
output_stride: 2
convnext: null
swint: null
head_configs:
single_instance:
confmaps:
part_names: null
sigma: 2.5
output_stride: 2
centroid: null
centered_instance: null
bottomup: null
multi_class_bottomup: null
multi_class_topdown: null
Trainer Configuration (trainer_config
)¶
The trainer_config
section controls the training process, including key hyperparameters and device settings.
Key tips for configuring trainer_config
:
-
Data Loader Workers (
num_workers
):- For the default data pipeline (
torch_dataset
), setnum_workers: 0
because.slp
video objects cannot be pickled for multiprocessing. - If you use a caching data pipeline (e.g.,
torch_dataset_cache_img_memory
ortorch_dataset_cache_img_disk
fordata_config.data_pipeline_fw
), you can increasenum_workers
(>0) to speed up data loading.
- For the default data pipeline (
-
Epochs and Checkpoints:
- Set
max_epochs
to control how many epochs to train for. - Use
ckpt_dir
andrun_name
to specify where model checkpoints are saved. If both areNone
, a default folder will be created in the working directory using a timestamp and model type. - For multi-GPU training, always set a static
run_name
so all workers write to the same location.
- Set
-
Device and Accelerator:
trainer_accelerator
can be"cpu"
,"gpu"
,"mps"
, or"auto"
.
-"auto"
lets Lightning choose the best device based on your hardware.trainer_device_indices
is a list of ints used to set the device indices.trainer_devices
can be set to specify the number of devices (e.g., GPUs) to use. IfNone
, the number of devices is inferred from the underlying hardware in the training workflow.
-
Other Tips:
- Adjust
batch_size
and learning rate (optimizer.lr
) as needed for your dataset and hardware. - Enable
visualize_preds_during_training
to see predictions during training. - Use
use_wandb: true
to log training metrics to Weights & Biases (optional).
- Adjust
For a full list of options and explanations for the trainer_config
parameters, see the Config Guide.
trainer_config:
train_data_loader:
batch_size: 4
shuffle: true
num_workers: 0
val_data_loader:
batch_size: 4
shuffle: false
num_workers: 0
model_ckpt:
save_top_k: 1
save_last: false
trainer_devices:
trainer_device_indices:
trainer_accelerator: auto
min_train_steps_per_epoch: 200
visualize_preds_during_training: true
keep_viz: false
max_epochs: 200
use_wandb: false
save_ckpt: true
ckpt_dir: my_model_ckpt_dir
run_name: my_run_1
optimizer_name: Adam
optimizer:
lr: 0.0001
amsgrad: false
lr_scheduler:
step_lr: null
reduce_lr_on_plateau:
threshold: 1.0e-06
threshold_mode: rel
cooldown: 3
patience: 5
factor: 0.5
min_lr: 1.0e-08
early_stopping:
min_delta: 1.0e-08
patience: 10
stop_training_on_plateau: true
🤖 Step 2: Training Your Model¶
Now that you have your configuration file, let's train your model!
2.1 Training with Python API¶
If you want to use custom sleap_io.Labels
objects,
For more details and advanced training options, see the Training Guide.
2.2 Training Output¶
After training, you'll find:
my_model_ckpt_dir/my_run_1
├── best.ckpt # Best model weights
├── initial_config.yaml # Initial training configuration
├── training_config.yaml # Final training configuration
├── labels_train_gt_0.slp # Ground-truth train data split
├── labels_val_gt_0.slp # Ground-truth val data split
├── pred_train_0.slp # Predictions on training data
├── pred_val_0.slp # Predictions on validation data
├── train_0_pred_metrics.npz # Metrics on train preds
├── val_0_pred_metrics.npz # Metrics on val preds
└── training_log.csv # CSV that tracks the train/ val losses and epoch time
🔍 Step 3: Running Inference¶
Now that you have a trained model, let's use it to make predictions on new data!
3.1 Inference¶
To run inference on a .slp
file,
To run inference on a video on specific frames,
To run inference on a video with tracking,
For more details and advanced inference options, see the Inference Guide.
3.2 Inference Parameters¶
Essential Parameters:¶
--data_path
: Input video or labels file--model_paths
: Path to your trained model directory--output_path
: Where to save predictions--batch_size
: Number of frames to process at once--device
: Hardware to use (cpu, cuda, mps, auto)--peak_threshold
: Confidence threshold for detections--frames
: Specific frame ranges (e.g., "1-100,200-300")--tracking
: To enable tracking
📊 Step 4: Evaluation and Visualization¶
Let's evaluate how well your model performed and visualize the results!
4.1 Evaluating Model Performance¶
from sleap_nn.evaluation import Evaluator
import sleap_io as sio
# Load labels
ground_truth = sio.load_slp("ground_truth.slp")
predictions = sio.load_slp("predictions.slp")
# Create evaluator
evaluator = Evaluator(ground_truth, predictions)
# Run evaluation
metrics = evaluator.evaluate()
# Print results
print(f"OKS mAP: {metrics['voc_metrics']['oks_voc.mAP']:.3f}")
print(f"Dist p90: {metrics['distance_metrics']['p90']:.3f}")
4.2 Visualizing Results¶
import sleap_io as sio
import matplotlib.pyplot as plt
def plot_preds(gt_labels, pred_labels, lf_index):
_fig, _ax = plt.subplots(1, 1, figsize=(5 * 1, 5 * 1))
# Plot each frame
gt_lf = gt_labels[lf_index]
pred_lf = pred_labels[lf_index]
# Ensure we're plotting keypoints for the same frame
assert (
gt_lf.frame_idx == pred_lf.frame_idx
), f"Frame mismatch at {lf_index}: GT={gt_lf.frame_idx}, Pred={pred_lf.frame_idx}"
_ax.imshow(gt_lf.image, cmap="gray")
_ax.set_title(
f"Frame {gt_lf.frame_idx} (lf idx: {lf_index})",
fontsize=12,
fontweight="bold",
)
# Plot ground truth instances
for idx, instance in enumerate(gt_lf.instances):
if not instance.is_empty:
gt_pts = instance.numpy()
_ax.plot(
gt_pts[:, 0],
gt_pts[:, 1],
"go",
markersize=6,
alpha=0.8,
label="GT" if idx == 0 else "",
)
# Plot predicted instances
for idx, instance in enumerate(pred_lf.instances):
if not instance.is_empty:
pred_pts = instance.numpy()
_ax.plot(
pred_pts[:, 0],
pred_pts[:, 1],
"rx",
markersize=6,
alpha=0.8,
label="Pred" if idx == 0 else "",
)
# Add legend
_ax.legend(loc="upper right", fontsize=8)
_ax.axis("off")
plt.suptitle(f"Ground Truth vs Predictions", fontsize=16, fontweight="bold", y=0.98)
plt.tight_layout()
plt.show()
return
# Overlay results
gt_labels = sio.load_slp("groundtruth.slp")
pred_labels = sio.load_slp("my_predictions.slp")
plot_preds(gt_labels, pred_labels, lf_index=0)
4.3 Metrics Interpretation¶
Key Metrics to Understand:¶
- PCK (Percentage of Correct Keypoints): How many keypoints are within a certain distance threshold
- OKS (Object Keypoint Similarity): How similar are the predicted keypoints to the ground-truth
- mAP (mean Average Precision): Mean of average precisions across match thresholds (where OKS or PCK could be the matching score).
- Distance Metrics: Average euclidean distance between predicted and true keypoints
✨ Next Steps¶
Now that you have the basics, you can:
- Experiment with different model architectures (UNet, ConvNeXt, SwinT)
- Try different detection methods (single instance, bottom-up, top-down)
- Optimize hyperparameters for better performance
- Use data augmentation to improve model robustness
📚 Additional Resources¶
- Configuration Guide: Detailed configuration options
- Training Documentation: Advanced training features
- Inference Guide: Complete inference options
- Model Architectures: Available model types
- Example Notebooks: Interactive tutorials
Happy SLEAPiNNg! 🐭🐭