identity
sleap_nn.inference.identity
¶
Utilities for models that learn identity.
These functions implement the inference logic for classifying peaks using class maps or classification vectors.
Functions:
Name | Description |
---|---|
classify_peaks_from_maps |
Classify and group local peaks by their class map probability. |
get_class_inds_from_vectors |
Get class indices from the probability scores. |
group_class_peaks |
Group local peaks using class probabilities, matching peaks to classes using the Hungarian algorithm, per (sample, channel) pair. |
classify_peaks_from_maps(class_maps, peak_points, peak_vals, peak_sample_inds, peak_channel_inds, n_channels)
¶
Classify and group local peaks by their class map probability.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
class_maps
|
Tensor
|
Class maps with shape |
required |
peak_points
|
Tensor
|
Local peak coordinates of shape |
required |
peak_vals
|
Tensor
|
Confidence map value with shape |
required |
peak_sample_inds
|
Tensor
|
Sample index for each peak with shape |
required |
peak_channel_inds
|
Tensor
|
Channel index for each peak with shape |
required |
n_channels
|
int
|
Integer number of channels (nodes) the instances should have. |
required |
Returns:
Type | Description |
---|---|
Tuple[Tensor, Tensor, Tensor]
|
A tuple of
|
See also: group_class_peaks
Source code in sleap_nn/inference/identity.py
get_class_inds_from_vectors(peak_class_probs)
¶
Get class indices from the probability scores.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
peak_class_probs
|
Tensor
|
(n_samples, n_classes) softmax output for each sample |
required |
Returns:
Name | Type | Description |
---|---|---|
class_inds |
(n_samples,) class index assigned to each sample class_probs: (n_samples,) the probability of the assigned class |
Source code in sleap_nn/inference/identity.py
group_class_peaks(peak_class_probs, peak_sample_inds, peak_channel_inds, n_samples, n_channels)
¶
Group local peaks using class probabilities, matching peaks to classes using the Hungarian algorithm, per (sample, channel) pair.