Skip to content

predict

sleap_nn.predict

Entry point for running inference.

Functions:

Name Description
frame_list

Converts 'n-m' string to list of ints.

run_inference

Entry point to run inference on trained SLEAP-NN models.

frame_list(frame_str)

Converts 'n-m' string to list of ints.

Parameters:

Name Type Description Default
frame_str str

string representing range

required

Returns:

Type Description
Optional[List[int]]

List of ints, or None if string does not represent valid range.

Source code in sleap_nn/predict.py
def frame_list(frame_str: str) -> Optional[List[int]]:
    """Converts 'n-m' string to list of ints.

    Args:
        frame_str: string representing range

    Returns:
        List of ints, or None if string does not represent valid range.
    """
    # Handle ranges of frames. Must be of the form "1-200" (or "1,-200")
    if "-" in frame_str:
        min_max = frame_str.split("-")
        min_frame = int(min_max[0].rstrip(","))
        max_frame = int(min_max[1])
        return list(range(min_frame, max_frame + 1))

    return [int(x) for x in frame_str.split(",")] if len(frame_str) else None

run_inference(data_path=None, input_labels=None, input_video=None, model_paths=None, backbone_ckpt_path=None, head_ckpt_path=None, max_instances=None, max_width=None, max_height=None, ensure_rgb=None, input_scale=None, ensure_grayscale=None, anchor_part=None, only_labeled_frames=False, only_suggested_frames=False, batch_size=4, queue_maxsize=8, video_index=None, video_dataset=None, video_input_format='channels_last', frames=None, crop_size=None, peak_threshold=0.2, integral_refinement='integral', integral_patch_size=5, return_confmaps=False, return_pafs=False, return_paf_graph=False, max_edge_length_ratio=0.25, dist_penalty_weight=1.0, n_points=10, min_instance_peaks=0, min_line_scores=0.25, return_class_maps=False, return_class_vectors=False, make_labels=True, output_path=None, device='auto', tracking=False, tracking_window_size=5, min_new_track_points=0, candidates_method='fixed_window', min_match_points=0, features='keypoints', scoring_method='oks', scoring_reduction='mean', robust_best_instance=1.0, track_matching_method='hungarian', max_tracks=None, use_flow=False, of_img_scale=1.0, of_window_size=21, of_max_levels=3, post_connect_single_breaks=False)

Entry point to run inference on trained SLEAP-NN models.

Parameters:

Name Type Description Default
data_path Optional[str]

(str) Path to .slp file or .mp4 to run inference on.

None
input_labels Optional[Labels]

(sio.Labels) Labels object to run inference on. This is an alternative to specifying the data_path.

None
input_video Optional[Video]

(sio.Video) Video to run inference on. This is an alternative to specifying the data_path. If both input_labels and input_video are provided, input_labels are used.

None
model_paths Optional[List[str]]

(List[str]) List of paths to the directory where the best.ckpt and training_config.yaml are saved.

None
backbone_ckpt_path Optional[str]

(str) To run inference on any .ckpt other than best.ckpt from the model_paths dir, the path to the .ckpt file should be passed here.

None
head_ckpt_path Optional[str]

(str) Path to .ckpt file if a different set of head layer weights are to be used. If None, the best.ckpt from model_paths dir is used (or the ckpt from backbone_ckpt_path if provided.)

None
max_instances Optional[int]

(int) Max number of instances to consider from the predictions.

None
max_width Optional[int]

(int) Maximum width the image should be padded to. If not provided, the values from the training config are used. Default: None.

None
max_height Optional[int]

(int) Maximum height the image should be padded to. If not provided, the values from the training config are used. Default: None.

None
input_scale Optional[float]

(float) Scale factor to apply to the input image. If not provided, the values from the training config are used. Default: None.

None
ensure_rgb Optional[bool]

(bool) True if the input image should have 3 channels (RGB image). If input has only one channel when this is set to True, then the images from single-channel is replicated along the channel axis. If the image has three channels and this is set to False, then we retain the three channels. If not provided, the values from the training config are used. Default: None.

None
ensure_grayscale Optional[bool]

(bool) True if the input image should only have a single channel. If input has three channels (RGB) and this is set to True, then we convert the image to grayscale (single-channel) image. If the source image has only one channel and this is set to False, then we retain the single channel input. If not provided, the values from the training config are used. Default: None.

None
anchor_part Optional[str]

(str) The node name to use as the anchor for the centroid. If not provided, the anchor part in the training_config.yaml is used. Default: None.

None
only_labeled_frames bool

(bool) True if inference should be run only on user-labeled frames. Default: False.

False
only_suggested_frames bool

(bool) True if inference should be run only on unlabeled suggested frames. Default: False.

False
batch_size int

(int) Number of samples per batch. Default: 4.

4
queue_maxsize int

(int) Maximum size of the frame buffer queue. Default: 8.

8
video_index Optional[int]

(int) Integer index of video in .slp file to predict on. To be used with an .slp path as an alternative to specifying the video path.

None
video_dataset Optional[str]

(str) The dataset for HDF5 videos.

None
video_input_format str

(str) The input_format for HDF5 videos.

'channels_last'
frames Optional[list]

(list) List of frames indices. If None, all frames in the video are used. Default: None.

None
crop_size Optional[int]

(int) Crop size. If not provided, the crop size from training_config.yaml is used. Default: None.

None
peak_threshold Union[float, List[float]]

(float) Minimum confidence threshold. Peaks with values below this will be ignored. Default: 0.2. This can also be List[float] for topdown centroid and centered-instance model, where the first element corresponds to centroid model peak finding threshold and the second element is for centered-instance model peak finding.

0.2
integral_refinement Optional[str]

(str) If None, returns the grid-aligned peaks with no refinement. If "integral", peaks will be refined with integral regression. Default: "integral".

'integral'
integral_patch_size int

(int) Size of patches to crop around each rough peak as an integer scalar. Default: 5.

5
return_confmaps bool

(bool) If True, predicted confidence maps will be returned along with the predicted peak values and points. Default: False.

False
return_pafs bool

(bool) If True, the part affinity fields will be returned together with the predicted instances. This will result in slower inference times since the data must be copied off of the GPU, but is useful for visualizing the raw output of the model. Default: False.

False
return_class_vectors bool

If True, the classification probabilities will be returned together with the predicted peaks. This will not line up with the grouped instances, for which the associtated class probabilities will always be returned in "instance_scores".

False
return_paf_graph bool

(bool) If True, the part affinity field graph will be returned together with the predicted instances. The graph is obtained by parsing the part affinity fields with the paf_scorer instance and is an intermediate representation used during instance grouping. Default: False.

False
max_edge_length_ratio float

(float) The maximum expected length of a connected pair of points as a fraction of the image size. Candidate connections longer than this length will be penalized during matching. Default: 0.25.

0.25
dist_penalty_weight float

(float) A coefficient to scale weight of the distance penalty as a scalar float. Set to values greater than 1.0 to enforce the distance penalty more strictly.Default: 1.0.

1.0
n_points int

(int) Number of points to sample along the line integral. Default: 10.

10
min_instance_peaks Union[int, float]

Union[int, float] Minimum number of peaks the instance should have to be considered a real instance. Instances with fewer peaks than this will be discarded (useful for filtering spurious detections). Default: 0.

0
min_line_scores float

(float) Minimum line score (between -1 and 1) required to form a match between candidate point pairs. Useful for rejecting spurious detections when there are no better ones. Default: 0.25.

0.25
return_class_maps bool

If True, the class maps will be returned together with the predicted instances. This will result in slower inference times since the data must be copied off of the GPU, but is useful for visualizing the raw output of the model.

False
make_labels bool

(bool) If True (the default), returns a sio.Labels instance with sio.PredictedInstances. If False, just return a list of dictionaries containing the raw arrays returned by the inference model. Default: True.

True
output_path Optional[str]

(str) Path to save the labels file if make_labels is True. Default is current working directory.

None
device str

(str) Device on which torch.Tensor will be allocated. One of the ('cpu', 'cuda', 'mps', 'auto'). Default: "auto" (based on available backend either cuda, mps or cpu is chosen). If cuda is available, you could also use cuda:0 to specify the device.

'auto'
tracking bool

(bool) If True, runs tracking on the predicted instances.

False
tracking_window_size int

Number of frames to look for in the candidate instances to match with the current detections. Default: 5.

5
min_new_track_points int

We won't spawn a new track for an instance with fewer than this many points. Default: 0.

0
candidates_method str

Either of fixed_window or local_queues. In fixed window method, candidates from the last window_size frames. In local queues, last window_size instances for each track ID is considered for matching against the current detection. Default: fixed_window.

'fixed_window'
min_match_points int

Minimum non-NaN points for match candidates. Default: 0.

0
features str

Feature representation for the candidates to update current detections. One of [keypoints, centroids, bboxes, image]. Default: keypoints.

'keypoints'
scoring_method str

Method to compute association score between features from the current frame and the previous tracks. One of [oks, cosine_sim, iou, euclidean_dist]. Default: oks.

'oks'
scoring_reduction str

Method to aggregate and reduce multiple scores if there are several detections associated with the same track. One of [mean, max, robust_quantile]. Default: mean.

'mean'
robust_best_instance float

If the value is between 0 and 1 (excluded), use a robust quantile similarity score for the track. If the value is 1, use the max similarity (non-robust). For selecting a robust score, 0.95 is a good value.

1.0
track_matching_method str

Track matching algorithm. One of hungarian, greedy. Default:hungarian`.

'hungarian'
max_tracks Optional[int]

Meaximum number of new tracks to be created to avoid redundant tracks. (only for local queues candidate) Default: None.

None
use_flow bool

If True, FlowShiftTracker is used, where the poses are matched using

False
optical flow shifts. Default

False.

required
of_img_scale float

Factor to scale the images by when computing optical flow. Decrease this to increase performance at the cost of finer accuracy. Sometimes decreasing the image scale can improve performance with fast movements. Default: 1.0. (only if use_flow is True)

1.0
of_window_size int

Optical flow window size to consider at each pyramid scale level. Default: 21. (only if use_flow is True)

21
of_max_levels int

Number of pyramid scale levels to consider. This is different from the scale parameter, which determines the initial image scaling. Default: 3. (only if use_flow is True).

3
post_connect_single_breaks bool

If True and max_tracks is not None with local queues candidate method, connects track breaks when exactly one track is lost and exactly one new track is spawned in the frame.

False

Returns:

Type Description

Returns sio.Labels object if make_labels is True. Else this function returns a list of Dictionaries with the predictions.

Source code in sleap_nn/predict.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def run_inference(
    data_path: Optional[str] = None,
    input_labels: Optional[sio.Labels] = None,
    input_video: Optional[sio.Video] = None,
    model_paths: Optional[List[str]] = None,
    backbone_ckpt_path: Optional[str] = None,
    head_ckpt_path: Optional[str] = None,
    max_instances: Optional[int] = None,
    max_width: Optional[int] = None,
    max_height: Optional[int] = None,
    ensure_rgb: Optional[bool] = None,
    input_scale: Optional[float] = None,
    ensure_grayscale: Optional[bool] = None,
    anchor_part: Optional[str] = None,
    only_labeled_frames: bool = False,
    only_suggested_frames: bool = False,
    batch_size: int = 4,
    queue_maxsize: int = 8,
    video_index: Optional[int] = None,
    video_dataset: Optional[str] = None,
    video_input_format: str = "channels_last",
    frames: Optional[list] = None,
    crop_size: Optional[int] = None,
    peak_threshold: Union[float, List[float]] = 0.2,
    integral_refinement: Optional[str] = "integral",
    integral_patch_size: int = 5,
    return_confmaps: bool = False,
    return_pafs: bool = False,
    return_paf_graph: bool = False,
    max_edge_length_ratio: float = 0.25,
    dist_penalty_weight: float = 1.0,
    n_points: int = 10,
    min_instance_peaks: Union[int, float] = 0,
    min_line_scores: float = 0.25,
    return_class_maps: bool = False,
    return_class_vectors: bool = False,
    make_labels: bool = True,
    output_path: Optional[str] = None,
    device: str = "auto",
    tracking: bool = False,
    tracking_window_size: int = 5,
    min_new_track_points: int = 0,
    candidates_method: str = "fixed_window",
    min_match_points: int = 0,
    features: str = "keypoints",
    scoring_method: str = "oks",
    scoring_reduction: str = "mean",
    robust_best_instance: float = 1.0,
    track_matching_method: str = "hungarian",
    max_tracks: Optional[int] = None,
    use_flow: bool = False,
    of_img_scale: float = 1.0,
    of_window_size: int = 21,
    of_max_levels: int = 3,
    post_connect_single_breaks: bool = False,
):
    """Entry point to run inference on trained SLEAP-NN models.

    Args:
        data_path: (str) Path to `.slp` file or `.mp4` to run inference on.
        input_labels: (sio.Labels) Labels object to run inference on. This is an alternative to specifying the data_path.
        input_video: (sio.Video) Video to run inference on. This is an alternative to specifying the data_path. If both input_labels and input_video are provided, input_labels are used.
        model_paths: (List[str]) List of paths to the directory where the best.ckpt
                and training_config.yaml are saved.
        backbone_ckpt_path: (str) To run inference on any `.ckpt` other than `best.ckpt`
                from the `model_paths` dir, the path to the `.ckpt` file should be passed here.
        head_ckpt_path: (str) Path to `.ckpt` file if a different set of head layer weights
                are to be used. If `None`, the `best.ckpt` from `model_paths` dir is used (or the ckpt
                from `backbone_ckpt_path` if provided.)
        max_instances: (int) Max number of instances to consider from the predictions.
        max_width: (int) Maximum width the image should be padded to. If not provided, the
                values from the training config are used. Default: None.
        max_height: (int) Maximum height the image should be padded to. If not provided, the
                values from the training config are used. Default: None.
        input_scale: (float) Scale factor to apply to the input image. If not provided, the
                values from the training config are used. Default: None.
        ensure_rgb: (bool) True if the input image should have 3 channels (RGB image). If input has only one
                channel when this is set to `True`, then the images from single-channel
                is replicated along the channel axis. If the image has three channels and this is set to False, then we retain the three channels. If not provided, the
                values from the training config are used. Default: `None`.
        ensure_grayscale: (bool) True if the input image should only have a single channel. If input has three channels (RGB) and this
                is set to True, then we convert the image to grayscale (single-channel)
                image. If the source image has only one channel and this is set to False, then we retain the single channel input. If not provided, the
                values from the training config are used. Default: `None`.
        anchor_part: (str) The node name to use as the anchor for the centroid. If not
                provided, the anchor part in the `training_config.yaml` is used. Default: `None`.
        only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
        only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
        batch_size: (int) Number of samples per batch. Default: 4.
        queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
        video_index: (int) Integer index of video in .slp file to predict on. To be used with
                an .slp path as an alternative to specifying the video path.
        video_dataset: (str) The dataset for HDF5 videos.
        video_input_format: (str) The input_format for HDF5 videos.
        frames: (list) List of frames indices. If `None`, all frames in the video are used. Default: None.
        crop_size: (int) Crop size. If not provided, the crop size from training_config.yaml is used.
                Default: None.
        peak_threshold: (float) Minimum confidence threshold. Peaks with values below
                this will be ignored. Default: 0.2. This can also be `List[float]` for topdown
                centroid and centered-instance model, where the first element corresponds
                to centroid model peak finding threshold and the second element is for
                centered-instance model peak finding.
        integral_refinement: (str) If `None`, returns the grid-aligned peaks with no refinement.
                If `"integral"`, peaks will be refined with integral regression.
                Default: `"integral"`.
        integral_patch_size: (int) Size of patches to crop around each rough peak as an
                integer scalar. Default: 5.
        return_confmaps: (bool) If `True`, predicted confidence maps will be returned
                along with the predicted peak values and points. Default: False.
        return_pafs: (bool) If `True`, the part affinity fields will be returned together with
                the predicted instances. This will result in slower inference times since
                the data must be copied off of the GPU, but is useful for visualizing the
                raw output of the model. Default: False.
        return_class_vectors: If `True`, the classification probabilities will be
                returned together with the predicted peaks. This will not line up with the
                grouped instances, for which the associtated class probabilities will always
                be returned in `"instance_scores"`.
        return_paf_graph: (bool) If `True`, the part affinity field graph will be returned
                together with the predicted instances. The graph is obtained by parsing the
                part affinity fields with the `paf_scorer` instance and is an intermediate
                representation used during instance grouping. Default: False.
        max_edge_length_ratio: (float) The maximum expected length of a connected pair of points
                as a fraction of the image size. Candidate connections longer than this
                length will be penalized during matching. Default: 0.25.
        dist_penalty_weight: (float) A coefficient to scale weight of the distance penalty as
                a scalar float. Set to values greater than 1.0 to enforce the distance
                penalty more strictly.Default: 1.0.
        n_points: (int) Number of points to sample along the line integral. Default: 10.
        min_instance_peaks: Union[int, float] Minimum number of peaks the instance should
                have to be considered a real instance. Instances with fewer peaks than
                this will be discarded (useful for filtering spurious detections).
                Default: 0.
        min_line_scores: (float) Minimum line score (between -1 and 1) required to form a match
                between candidate point pairs. Useful for rejecting spurious detections when
                there are no better ones. Default: 0.25.
        return_class_maps: If `True`, the class maps will be returned together with
            the predicted instances. This will result in slower inference times since
            the data must be copied off of the GPU, but is useful for visualizing the
            raw output of the model.
        make_labels: (bool) If `True` (the default), returns a `sio.Labels` instance with
                `sio.PredictedInstance`s. If `False`, just return a list of
                dictionaries containing the raw arrays returned by the inference model.
                Default: True.
        output_path: (str) Path to save the labels file if `make_labels` is True.
                Default is current working directory.
        device: (str) Device on which torch.Tensor will be allocated. One of the
                ('cpu', 'cuda', 'mps', 'auto').
                Default: "auto" (based on available backend either cuda, mps or cpu is chosen). If `cuda` is available, you could also use `cuda:0` to specify the device.
        tracking: (bool) If True, runs tracking on the predicted instances.
        tracking_window_size: Number of frames to look for in the candidate instances to match
                with the current detections. Default: 5.
        min_new_track_points: We won't spawn a new track for an instance with
            fewer than this many points. Default: 0.
        candidates_method: Either of `fixed_window` or `local_queues`. In fixed window
            method, candidates from the last `window_size` frames. In local queues,
            last `window_size` instances for each track ID is considered for matching
            against the current detection. Default: `fixed_window`.
        min_match_points: Minimum non-NaN points for match candidates. Default: 0.
        features: Feature representation for the candidates to update current detections.
            One of [`keypoints`, `centroids`, `bboxes`, `image`]. Default: `keypoints`.
        scoring_method: Method to compute association score between features from the
            current frame and the previous tracks. One of [`oks`, `cosine_sim`, `iou`,
            `euclidean_dist`]. Default: `oks`.
        scoring_reduction: Method to aggregate and reduce multiple scores if there are
            several detections associated with the same track. One of [`mean`, `max`,
            `robust_quantile`]. Default: `mean`.
        robust_best_instance: If the value is between 0 and 1
            (excluded), use a robust quantile similarity score for the
            track. If the value is 1, use the max similarity (non-robust).
            For selecting a robust score, 0.95 is a good value.
        track_matching_method: Track matching algorithm. One of `hungarian`, `greedy.
            Default: `hungarian`.
        max_tracks: Meaximum number of new tracks to be created to avoid redundant tracks.
            (only for local queues candidate) Default: None.
        use_flow: If True, `FlowShiftTracker` is used, where the poses are matched using
        optical flow shifts. Default: `False`.
        of_img_scale: Factor to scale the images by when computing optical flow. Decrease
            this to increase performance at the cost of finer accuracy. Sometimes
            decreasing the image scale can improve performance with fast movements.
            Default: 1.0. (only if `use_flow` is True)
        of_window_size: Optical flow window size to consider at each pyramid scale
            level. Default: 21. (only if `use_flow` is True)
        of_max_levels: Number of pyramid scale levels to consider. This is different
            from the scale parameter, which determines the initial image scaling.
            Default: 3. (only if `use_flow` is True).
        post_connect_single_breaks: If True and `max_tracks` is not None with local queues candidate method,
            connects track breaks when exactly one track is lost and exactly one new track is spawned in the frame.

    Returns:
        Returns `sio.Labels` object if `make_labels` is True. Else this function returns
            a list of Dictionaries with the predictions.

    """
    preprocess_config = {  # if not given, then use from training config
        "ensure_rgb": ensure_rgb,
        "ensure_grayscale": ensure_grayscale,
        "crop_size": crop_size,
        "max_width": max_width,
        "max_height": max_height,
        "scale": input_scale,
    }

    if model_paths is None or not len(
        model_paths
    ):  # if model paths is not provided, run tracking-only pipeline.
        if not tracking:
            message = """Neither tracker nor path to trained models specified. Use `model_paths` to specify models to use. To retrack on predictions, set `tracking` to True."""
            logger.error(message)
            raise ValueError(message)

        else:
            if (data_path is not None and not data_path.endswith(".slp")) or (
                input_labels is not None and not isinstance(input_labels, sio.Labels)
            ):
                message = "Data path is not a .slp file. To run track-only pipeline, data path must be an .slp file."
                logger.error(message)
                raise ValueError(message)

            start_inf_time = time()
            start_timestamp = str(datetime.now())
            logger.info(f"Started tracking at: {start_timestamp}")

            labels = sio.load_slp(data_path) if input_labels is None else input_labels

            lf_frames = labels.labeled_frames

            # select video if video_index is provided
            if video_index is not None:
                lf_frames = labels.find(video=labels.videos[video_index])

            # sort frames before tracking
            lf_frames = sorted(lf_frames, key=lambda lf: lf.frame_idx)

            if frames is not None:
                filtered_frames = []
                for lf in lf_frames:
                    if lf.frame_idx in frames:
                        filtered_frames.append(lf)
                lf_frames = filtered_frames

            if post_connect_single_breaks:
                if max_tracks is None:
                    max_tracks = max_instances

            logger.info(f"Running tracking on {len(lf_frames)} frames...")

            tracked_frames = run_tracker(
                untracked_frames=lf_frames,
                window_size=tracking_window_size,
                min_new_track_points=min_new_track_points,
                candidates_method=candidates_method,
                min_match_points=min_match_points,
                features=features,
                scoring_method=scoring_method,
                scoring_reduction=scoring_reduction,
                robust_best_instance=robust_best_instance,
                track_matching_method=track_matching_method,
                max_tracks=max_tracks,
                use_flow=use_flow,
                of_img_scale=of_img_scale,
                of_window_size=of_window_size,
                of_max_levels=of_max_levels,
                post_connect_single_breaks=post_connect_single_breaks,
            )

            finish_timestamp = str(datetime.now())
            total_elapsed = time() - start_inf_time
            logger.info(f"Finished tracking at: {finish_timestamp}")
            logger.info(f"Total runtime: {total_elapsed} secs")

            output = sio.Labels(
                labeled_frames=tracked_frames,
                videos=labels.videos,
                skeletons=labels.skeletons,
            )

    else:
        start_inf_time = time()
        start_timestamp = str(datetime.now())
        logger.info(f"Started inference at: {start_timestamp}")

        if device == "auto":
            device = (
                "cuda"
                if torch.cuda.is_available()
                else "mps" if torch.backends.mps.is_available() else "cpu"
            )

        if integral_refinement is not None and device == "mps":  # TODO
            # kornia/geometry/transform/imgwarp.py:382: in get_perspective_transform. NotImplementedError: The operator 'aten::_linalg_solve_ex.result' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
            logger.info(
                "Integral refinement is not supported with MPS device. Using CPU."
            )
            device = "cpu"  # not supported with mps

        logger.info(f"Using device: {device}")

        # initializes the inference model
        predictor = Predictor.from_model_paths(
            model_paths,
            backbone_ckpt_path=backbone_ckpt_path,
            head_ckpt_path=head_ckpt_path,
            peak_threshold=peak_threshold,
            integral_refinement=integral_refinement,
            integral_patch_size=integral_patch_size,
            batch_size=batch_size,
            max_instances=max_instances,
            return_confmaps=return_confmaps,
            device=device,
            preprocess_config=OmegaConf.create(preprocess_config),
            anchor_part=anchor_part,
        )

        if (
            tracking
            and not isinstance(predictor, BottomUpMultiClassPredictor)
            and not isinstance(predictor, TopDownMultiClassPredictor)
        ):
            predictor.tracker = Tracker.from_config(
                candidates_method=candidates_method,
                min_match_points=min_match_points,
                window_size=tracking_window_size,
                min_new_track_points=min_new_track_points,
                features=features,
                scoring_method=scoring_method,
                scoring_reduction=scoring_reduction,
                robust_best_instance=robust_best_instance,
                track_matching_method=track_matching_method,
                max_tracks=max_tracks,
                use_flow=use_flow,
                of_img_scale=of_img_scale,
                of_window_size=of_window_size,
                of_max_levels=of_max_levels,
            )

        if isinstance(predictor, BottomUpPredictor):
            predictor.inference_model.paf_scorer.max_edge_length_ratio = (
                max_edge_length_ratio
            )
            predictor.inference_model.paf_scorer.dist_penalty_weight = (
                dist_penalty_weight
            )
            predictor.inference_model.return_pafs = return_pafs
            predictor.inference_model.return_paf_graph = return_paf_graph
            predictor.inference_model.paf_scorer.max_edge_length_ratio = (
                max_edge_length_ratio
            )
            predictor.inference_model.paf_scorer.min_line_scores = min_line_scores
            predictor.inference_model.paf_scorer.min_instance_peaks = min_instance_peaks
            predictor.inference_model.paf_scorer.n_points = n_points

        if isinstance(predictor, BottomUpMultiClassPredictor):
            predictor.inference_model.return_class_maps = return_class_maps

        if isinstance(predictor, TopDownMultiClassPredictor):
            predictor.inference_model.instance_peaks.return_class_vectors = (
                return_class_vectors
            )

        # initialize make_pipeline function

        predictor.make_pipeline(
            inference_object=(
                input_labels
                if input_labels is not None
                else input_video if input_video is not None else data_path
            ),
            queue_maxsize=queue_maxsize,
            frames=frames,
            only_labeled_frames=only_labeled_frames,
            only_suggested_frames=only_suggested_frames,
            video_index=video_index,
            video_dataset=video_dataset,
            video_input_format=video_input_format,
        )

        # run predict
        output = predictor.predict(
            make_labels=make_labels,
        )

        if tracking and post_connect_single_breaks:
            if max_tracks is None:
                max_tracks = max_instances

            if max_tracks is None:
                message = "Max_tracks (and max instances) is None. To connect single breaks, max_tracks should be set to an integer."
                logger.error(message)
                raise ValueError(message)

            start_final_pass_time = time()
            start_fp_timestamp = str(datetime.now())
            logger.info(
                f"Started final-pass (connecting single breaks) at: {start_fp_timestamp}"
            )
            corrected_lfs = connect_single_breaks(
                lfs=[x for x in output], max_instances=max_tracks
            )
            finish_fp_timestamp = str(datetime.now())
            total_fp_elapsed = time() - start_final_pass_time
            logger.info(
                f"Finished final-pass (connecting single breaks) at: {finish_fp_timestamp}"
            )
            logger.info(f"Total runtime: {total_fp_elapsed} secs")

            output = sio.Labels(
                labeled_frames=corrected_lfs,
                videos=output.videos,
                skeletons=output.skeletons,
            )

        finish_timestamp = str(datetime.now())
        total_elapsed = time() - start_inf_time
        logger.info(f"Finished inference at: {finish_timestamp}")
        logger.info(
            f"Total runtime: {total_elapsed} secs"
        )  # TODO: add number of predicted frames

    if make_labels:
        if output_path is None:
            output_path = Path(
                data_path if data_path is not None else "results"
            ).with_suffix(".predictions.slp")
        output.save(Path(output_path).as_posix(), restore_original_videos=False)
    finish_timestamp = str(datetime.now())
    logger.info(f"Predictions output path: {output_path}")
    logger.info(f"Saved file at: {finish_timestamp}")

    return output