utils
sleap_nn.training.utils
¶
Miscellaneous utility functions for training.
Functions:
Name | Description |
---|---|
get_dist_rank |
Return the rank of the current process if torch.distributed is initialized. |
get_gpu_memory |
Get the available memory on each GPU. |
imgfig |
Create a tight figure for image plotting. |
is_distributed_initialized |
Check if distributed processes are initialized. |
plot_confmaps |
Plot confidence maps reduced over channels. |
plot_img |
Plot an image in a tight figure. |
plot_peaks |
Plot ground truth and detected peaks. |
xavier_init_weights |
Function to initilaise the model weights with Xavier initialization method. |
get_dist_rank()
¶
get_gpu_memory()
¶
Get the available memory on each GPU.
Returns:
Type | Description |
---|---|
List[int]
|
A list of the available memory on each GPU in MiB. |
Source code in sleap_nn/training/utils.py
imgfig(size=6, dpi=72, scale=1.0)
¶
Create a tight figure for image plotting.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
size
|
float | tuple
|
Scalar or 2-tuple specifying the (width, height) of the figure in inches. If scalar, will assume equal width and height. |
6
|
dpi
|
int
|
Dots per inch, controlling the resolution of the image. |
72
|
scale
|
float
|
Factor to scale the size of the figure by. This is a convenience for increasing the size of the plot at the same DPI. |
1.0
|
Returns:
Type | Description |
---|---|
Figure
|
A matplotlib.figure.Figure to use for plotting. |
Source code in sleap_nn/training/utils.py
is_distributed_initialized()
¶
plot_confmaps(confmaps, output_scale=1.0)
¶
Plot confidence maps reduced over channels.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
confmaps
|
ndarray
|
Confidence maps to plot with shape (height, width, channel). |
required |
output_scale
|
float
|
Factor to scale the size of the figure by. |
1.0
|
Returns:
Type | Description |
---|---|
A matplotlib.figure.Figure to use for plotting. |
Source code in sleap_nn/training/utils.py
plot_img(img, dpi=72, scale=1.0)
¶
Plot an image in a tight figure.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
img
|
ndarray
|
Image to plot of shape (height, width, channel). |
required |
dpi
|
int
|
Dots per inch, controlling the resolution of the image. |
72
|
scale
|
float
|
Factor to scale the size of the figure by. This is a convenience for increasing the size of the plot at the same DPI. |
1.0
|
Returns:
Type | Description |
---|---|
Figure
|
A matplotlib.figure.Figure to use for plotting. |
Source code in sleap_nn/training/utils.py
plot_peaks(pts_gt, pts_pr=None, paired=False)
¶
Plot ground truth and detected peaks.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pts_gt
|
ndarray
|
Ground-truth keypoints of shape (num_instances, nodes, 2). To plot centroids, shape: (1, num_instances, 2). |
required |
pts_pr
|
ndarray | None
|
Predicted keypoints of shape (num_instances, nodes, 2). To plot centroids, shape: (1, num_instances, 2) |
None
|
paired
|
bool
|
True if error lines should be plotted else False. |
False
|
Returns:
Type | Description |
---|---|
A matplotlib.figure.Figure to use for plotting. |
Source code in sleap_nn/training/utils.py
xavier_init_weights(x)
¶
Function to initilaise the model weights with Xavier initialization method.