Skip to content

identity

sleap_nn.data.identity

Utilities for generating data for track identity models.

Functions:

Name Description
generate_class_maps

Generate class maps from track indices.

make_class_maps

Generate identity class maps using instance-wise confidence maps.

make_class_vectors

Make a binary class vectors from class indices.

generate_class_maps(instances, img_hw, num_instances, class_inds, num_tracks, class_map_threshold=0.2, sigma=1.5, output_stride=2, is_centroids=False)

Generate class maps from track indices.

Parameters:

Name Type Description Default
instances Tensor

Input keypoints. (n_samples=1, n_instances, n_nodes, 2) or for centroids - (n_samples=1, n_instances, 2)

required
img_hw Tuple[int]

Image size as tuple (height, width).

required
num_instances int

Original number of instances in the frame.

required
class_inds Tensor

Class indices as torch.int32 tensor of shape (n_instances).

required
num_tracks int

Total number of tracks in the dataset.

required
class_map_threshold float

Minimum confidence map value below which map values will be replaced with zeros.

0.2
sigma float

The standard deviation of the Gaussian distribution that is used to generate confidence maps. Default: 1.5.

1.5
output_stride int

The relative stride to use when generating confidence maps. A larger stride will generate smaller confidence maps. Default: 2.

2
is_centroids bool

True if confidence maps should be generates for centroids else False. Default: False.

False
Source code in sleap_nn/data/identity.py
def generate_class_maps(
    instances: torch.Tensor,
    img_hw: Tuple[int],
    num_instances: int,
    class_inds: torch.Tensor,
    num_tracks: int,
    class_map_threshold: float = 0.2,
    sigma: float = 1.5,
    output_stride: int = 2,
    is_centroids: bool = False,
):
    """Generate class maps from track indices.

    Args:
        instances: Input keypoints. (n_samples=1, n_instances, n_nodes, 2) or
            for centroids - (n_samples=1, n_instances, 2)
        img_hw: Image size as tuple (height, width).
        num_instances: Original number of instances in the frame.
        class_inds: Class indices as `torch.int32` tensor of shape `(n_instances)`.
        num_tracks: Total number of tracks in the dataset.
        class_map_threshold: Minimum confidence map value below which map values will be
            replaced with zeros.
        sigma: The standard deviation of the Gaussian distribution that is used to
            generate confidence maps. Default: 1.5.
        output_stride: The relative stride to use when generating confidence maps.
            A larger stride will generate smaller confidence maps. Default: 2.
        is_centroids: True if confidence maps should be generates for centroids else False.
            Default: False.

    """
    height, width = img_hw
    xv, yv = make_grid_vectors(height, width, output_stride)

    if is_centroids:
        points = instances[:, :num_instances, :].unsqueeze(dim=-3)
        # (n_samples=1, 1, n_instances, 2)
    else:
        points = instances[:, :num_instances, :, :].permute(
            0, 2, 1, 3
        )  # (n_samples=1, n_nodes, n_instances, 2)

    # Generate confidene maps for masking.
    cms = make_multi_confmaps(
        points, xv, yv, sigma * output_stride
    )  # (n_samples=1, n_instances, height/ output_stride, width/ output_stride).

    class_maps = make_class_maps(
        cms,
        class_inds=class_inds,
        n_classes=num_tracks,
        threshold=class_map_threshold,
    )  # (n_samples=1, n_classes, height/ output_stride, width/ output_stride)
    return class_maps

make_class_maps(confmaps, class_inds, n_classes, threshold=0.2)

Generate identity class maps using instance-wise confidence maps.

This is useful for making class maps defined on local neighborhoods around the peaks.

Parameters:

Name Type Description Default
confmaps Tensor

Confidence maps for the same points as the offset maps as a torch.Tensor of shape (n_samples=1, n_instances, grid_height, grid_width). This can be generated by sleap_nn.data.confidence_maps.make_confmaps.

required
class_inds Tensor

Class indices as torch.int32 tensor of shape (n_instances).

required
n_classes int

Integer number of maximum classes.

required
threshold float

Minimum confidence map value below which map values will be replaced with zeros.

0.2

Returns:

Type Description
Tensor

The class maps with shape (n_samples=1, n_classes, grid_height, grid_width) and dtype torch.float32 where each channel will be a binary mask with 1 where the instance confidence maps were higher than the threshold.

Notes

Pixels that have confidence map values from more than one animal will have the class vectors weighed by the relative contribution of each instance.

Source code in sleap_nn/data/identity.py
def make_class_maps(
    confmaps: torch.Tensor,
    class_inds: torch.Tensor,
    n_classes: int,
    threshold: float = 0.2,
) -> torch.Tensor:
    """Generate identity class maps using instance-wise confidence maps.

    This is useful for making class maps defined on local neighborhoods around the
    peaks.

    Args:
        confmaps: Confidence maps for the same points as the offset maps as a
            `torch.Tensor` of shape `(n_samples=1, n_instances, grid_height, grid_width)`. This can be generated by
            `sleap_nn.data.confidence_maps.make_confmaps`.
        class_inds: Class indices as `torch.int32` tensor of shape `(n_instances)`.
        n_classes: Integer number of maximum classes.
        threshold: Minimum confidence map value below which map values will be replaced
            with zeros.

    Returns:
        The class maps with shape `(n_samples=1, n_classes, grid_height, grid_width)` and dtype
        `torch.float32` where each channel will be a binary mask with 1 where the instance
        confidence maps were higher than the threshold.

    Notes:
        Pixels that have confidence map values from more than one animal will have the
        class vectors weighed by the relative contribution of each instance.

    """
    n_instances = confmaps.shape[-3]
    class_vectors = make_class_vectors(class_inds, n_classes)
    class_vectors = torch.reshape(
        class_vectors.to(torch.float32),
        [n_classes, n_instances, 1, 1],
    )

    # Normalize instance mask
    mask = confmaps / torch.sum(confmaps, dim=-3, keepdim=True)
    mask = torch.where(
        confmaps > threshold,
        mask,
        torch.tensor(0.0, dtype=mask.dtype, device=mask.device),
    )  # (1, num_instances, H, W)

    # Apply mask to vectors and reduce over instances
    class_maps = torch.max(mask * class_vectors, dim=-3).values

    return class_maps.unsqueeze(0)  # (n_samples=1, n_classes, H, W)

make_class_vectors(class_inds, n_classes)

Make a binary class vectors from class indices.

Parameters:

Name Type Description Default
class_inds Tensor

Class indices as torch.Tensor of dtype torch.int32 and shape (n_instances,). Indices of -1 will be interpreted as having no class.

required
n_classes int

Integer number of maximum classes.

required

Returns:

Type Description
Tensor

A tensor with binary class vectors of shape (n_instances, n_classes) of dtype torch.int32. Instances with no class will have all zeros in their row.

Notes: A class index can be used to represent a track index.

Source code in sleap_nn/data/identity.py
def make_class_vectors(class_inds: torch.Tensor, n_classes: int) -> torch.Tensor:
    """Make a binary class vectors from class indices.

    Args:
        class_inds: Class indices as `torch.Tensor` of dtype `torch.int32` and shape
            `(n_instances,)`. Indices of `-1` will be interpreted as having no class.
        n_classes: Integer number of maximum classes.

    Returns:
        A tensor with binary class vectors of shape `(n_instances, n_classes)` of dtype
        `torch.int32`. Instances with no class will have all zeros in their row.

    Notes: A class index can be used to represent a track index.
    """
    # Create mask of valid IDs
    mask = class_inds >= 0
    class_inds_masked = class_inds.clone()
    class_inds_masked[~mask] = 0

    one_hot = F.one_hot(class_inds_masked.long(), num_classes=n_classes)
    one_hot[~mask] = 0  # zero out invalids
    return one_hot.to(torch.int32)