callbacks
sleap_nn.training.callbacks
¶
Custom Callback modules for Lightning Trainer.
Classes:
Name | Description |
---|---|
CSVLoggerCallback |
Callback for logging metrics to csv. |
MatplotlibSaver |
Callback for saving images rendered with matplotlib during training. |
ProgressReporterZMQ |
Callback to publish training progress events to a ZMQ PUB socket. |
TrainingControllerZMQ |
Lightning callback to receive control commands during training via ZMQ. |
WandBPredImageLogger |
Callback for writing image predictions to wandb. |
CSVLoggerCallback
¶
Bases: Callback
Callback for logging metrics to csv.
Attributes:
Name | Type | Description |
---|---|---|
filepath |
Path to save the csv file. |
|
keys |
List of field names to be logged in the csv. |
Methods:
Name | Description |
---|---|
__init__ |
Initialize attributes. |
on_validation_epoch_end |
Log metrics to csv at the end of validation epoch. |
Source code in sleap_nn/training/callbacks.py
__init__(filepath, keys=['epoch', 'train_loss', 'val_loss', 'learning_rate'])
¶
Initialize attributes.
on_validation_epoch_end(trainer, pl_module)
¶
Log metrics to csv at the end of validation epoch.
Source code in sleap_nn/training/callbacks.py
MatplotlibSaver
¶
Bases: Callback
Callback for saving images rendered with matplotlib during training.
This is useful for saving visualizations of the training to disk. It will be called at the end of each epoch.
Attributes:
Name | Type | Description |
---|---|---|
plot_fn |
Function with no arguments that returns a matplotlib figure handle. |
|
save_folder |
Path to a directory to save images to. |
|
prefix |
String that will be prepended to the filenames. This is useful for indicating which dataset the visualization was sampled from. |
Notes
This will save images with the naming pattern: "{save_folder}/{prefix}.{epoch}.png" or: "{save_folder}/{epoch}.png" if a prefix is not specified.
Methods:
Name | Description |
---|---|
__init__ |
Initialize callback. |
on_train_epoch_end |
Save figure at the end of each epoch. |
Source code in sleap_nn/training/callbacks.py
__init__(save_folder, plot_fn, prefix=None)
¶
Initialize callback.
Source code in sleap_nn/training/callbacks.py
on_train_epoch_end(trainer, pl_module)
¶
Save figure at the end of each epoch.
Source code in sleap_nn/training/callbacks.py
ProgressReporterZMQ
¶
Bases: Callback
Callback to publish training progress events to a ZMQ PUB socket.
This is used to publish training metrics to the given socket.
Attributes:
Name | Type | Description |
---|---|---|
address |
The ZMQ address to publish to, e.g., "tcp://127.0.0.1:9001". |
|
what |
Identifier tag for the type of training job (e.g., model name or job type). |
Methods:
Name | Description |
---|---|
__del__ |
Close zmq socket and context when callback is destroyed. |
__init__ |
Initialize the progress reporter callback by connecting to the specified ZMQ PUB socket. |
on_train_batch_end |
Called at the end of each training batch. |
on_train_batch_start |
Called at the beginning of each training batch. |
on_train_end |
Called at the end of training process. |
on_train_epoch_end |
Called at the end of each epoch. |
on_train_epoch_start |
Called at the beginning of each epoch. |
on_train_start |
Called at the beginning of training process. |
send |
Send a message over ZMQ. |
Source code in sleap_nn/training/callbacks.py
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 |
|
__del__()
¶
Close zmq socket and context when callback is destroyed.
__init__(address='tcp://127.0.0.1:9001', what='')
¶
Initialize the progress reporter callback by connecting to the specified ZMQ PUB socket.
Source code in sleap_nn/training/callbacks.py
on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
¶
Called at the end of each training batch.
Source code in sleap_nn/training/callbacks.py
on_train_batch_start(trainer, pl_module, batch, batch_idx)
¶
Called at the beginning of each training batch.
on_train_end(trainer, pl_module)
¶
on_train_epoch_end(trainer, pl_module)
¶
Called at the end of each epoch.
Source code in sleap_nn/training/callbacks.py
on_train_epoch_start(trainer, pl_module)
¶
Called at the beginning of each epoch.
on_train_start(trainer, pl_module)
¶
send(event, logs=None, **kwargs)
¶
TrainingControllerZMQ
¶
Bases: Callback
Lightning callback to receive control commands during training via ZMQ.
This is typically used to allow SLEAP GUI interface (SLEAP LossViewer) to dynamically control the training process (stopping early) by publishing commands to a ZMQ socket.
Attributes:
Name | Type | Description |
---|---|---|
address |
ZMQ socket address to subscribe to. |
|
topic |
Topic filter for messages. |
|
timeout |
Poll timeout in milliseconds when checking for new messages. |
Methods:
Name | Description |
---|---|
__del__ |
Close zmq socket and context when callback is destroyed. |
__init__ |
Initialize the controller callback by connecting to the specified ZMQ PUB socket. |
on_train_batch_end |
Called at the end of each training batch. |
Source code in sleap_nn/training/callbacks.py
__del__()
¶
Close zmq socket and context when callback is destroyed.
__init__(address='tcp://127.0.0.1:9000', topic='', poll_timeout=10)
¶
Initialize the controller callback by connecting to the specified ZMQ PUB socket.
Source code in sleap_nn/training/callbacks.py
on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
¶
Called at the end of each training batch.
Source code in sleap_nn/training/callbacks.py
WandBPredImageLogger
¶
Bases: Callback
Callback for writing image predictions to wandb.
Attributes:
Name | Type | Description |
---|---|---|
viz_folder |
Path to viz directory. |
|
wandb_run_name |
WandB run name. |
|
is_bottomup |
If the model type is bottomup or not. |
Methods:
Name | Description |
---|---|
__init__ |
Initialize attributes. |
on_train_epoch_end |
Called at the end of each epoch. |
Source code in sleap_nn/training/callbacks.py
__init__(viz_folder, wandb_run_name, is_bottomup=False)
¶
Initialize attributes.
Source code in sleap_nn/training/callbacks.py
on_train_epoch_end(trainer, pl_module)
¶
Called at the end of each epoch.