Skip to content

DeepSORT

arXiv colab

Overview

DeepSORT extends the original SORT algorithm by integrating appearance information through a deep association metric. While maintaining the core Kalman filtering and Hungarian algorithm components from SORT, DeepSORT adds a convolutional neural network (CNN) trained on large-scale person re-identification datasets to extract appearance features from detected objects. This integration allows the tracker to maintain object identities through longer periods of occlusion, effectively reducing identity switches compared to the original SORT. DeepSORT operates with a dual-metric approach, combining motion information (Mahalanobis distance) with appearance similarity (cosine distance in feature space) to improve data association decisions. It also introduces a matching cascade that prioritizes recently seen tracks, enhancing robustness during occlusions. Most of the computational complexity is offloaded to an offline pre-training stage, allowing the online tracking component to run efficiently at approximately 20Hz, making it suitable for real-time applications while achieving competitive tracking performance with significantly improved identity preservation.

Examples

import supervision as sv
from trackers import DeepSORTTracker, ReIDModel
from inference import get_model

reid_model = ReIDModel.from_timm("resnetv2_50.a1h_in1k")
tracker = DeepSORTTracker(reid_model=reid_model)
model = get_model(model_id="yolov11m-640")
annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER)

def callback(frame, _):
    result = model.infer(frame)[0]
    detections = sv.Detections.from_inference(result)
    detections = tracker.update(detections, frame)
    return annotator.annotate(frame, detections, labels=detections.tracker_id)

sv.process_video(
    source_path="<INPUT_VIDEO_PATH>",
    target_path="<OUTPUT_VIDEO_PATH>",
    callback=callback,
)
import supervision as sv
from trackers import DeepSORTTracker, ReIDModel
from rfdetr import RFDETRBase

reid_model = ReIDModel.from_timm("resnetv2_50.a1h_in1k")
tracker = DeepSORTTracker(reid_model=reid_model)
model = RFDETRBase()
annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER)

def callback(frame, _):
    detections = model.predict(frame)
    detections = tracker.update(detections, frame)
    return annotator.annotate(frame, detections, labels=detections.tracker_id)

sv.process_video(
    source_path="<INPUT_VIDEO_PATH>",
    target_path="<OUTPUT_VIDEO_PATH>",
    callback=callback,
)
import supervision as sv
from trackers import DeepSORTTracker, ReIDModel
from ultralytics import YOLO

reid_model = ReIDModel.from_timm("resnetv2_50.a1h_in1k")
tracker = DeepSORTTracker(reid_model=reid_model)
model = YOLO("yolo11m.pt")
annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER)

def callback(frame, _):
    result = model(frame)[0]
    detections = sv.Detections.from_ultralytics(result)
    detections = tracker.update(detections, frame)
    return annotator.annotate(frame, detections, labels=detections.tracker_id)

sv.process_video(
    source_path="<INPUT_VIDEO_PATH>",
    target_path="<OUTPUT_VIDEO_PATH>",
    callback=callback,
)
import torch
import supervision as sv
from trackers import DeepSORTTracker, ReIDModel
from transformers import RTDetrV2ForObjectDetection, RTDetrImageProcessor

reid_model = ReIDModel.from_timm("resnetv2_50.a1h_in1k")
tracker = DeepSORTTracker(reid_model=reid_model)
processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_v2_r18vd")
model = RTDetrV2ForObjectDetection.from_pretrained("PekingU/rtdetr_v2_r18vd")
annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER)

def callback(frame, _):
    inputs = processor(images=frame, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)

    h, w, _ = frame.shape
    results = processor.post_process_object_detection(
        outputs,
        target_sizes=torch.tensor([(h, w)]),
        threshold=0.5
    )[0]

    detections = sv.Detections.from_transformers(
        transformers_results=results,
        id2label=model.config.id2label
    )

    detections = tracker.update(detections, frame)
    return annotator.annotate(frame, detections, labels=detections.tracker_id)

sv.process_video(
    source_path="<INPUT_VIDEO_PATH>",
    target_path="<OUTPUT_VIDEO_PATH>",
    callback=callback,
)

API

Install DeepSORT

pip install "trackers[reid,cpu]"
pip install "trackers[reid,cu118]"
pip install "trackers[reid,cu124]"
pip install "trackers[reid,cu126]"
pip install "trackers[reid,rocm61]"
pip install "trackers[reid,rocm624]"

trackers.core.deepsort.tracker.DeepSORTTracker

Bases: BaseTrackerWithFeatures

Implements DeepSORT (Deep Simple Online and Realtime Tracking).

DeepSORT extends SORT by integrating appearance information using a deep learning model, improving tracking through occlusions and reducing ID switches. It combines motion (Kalman filter) and appearance cues for data association.

Parameters:

Name Type Description Default
reid_model ReIDModel

An instance of a ReIDModel to extract appearance features.

required
device Optional[str]

Device to run the feature extraction model on (e.g., 'cpu', 'cuda').

None
lost_track_buffer int

Number of frames to buffer when a track is lost. Enhances occlusion handling but may increase ID switches for similar objects.

30
frame_rate float

Frame rate of the video (frames per second). Used to calculate the maximum time a track can be lost.

30.0
track_activation_threshold float

Detection confidence threshold for track activation. Higher values reduce false positives but might miss objects.

0.25
minimum_consecutive_frames int

Number of consecutive frames an object must be tracked to be considered 'valid'. Prevents spurious tracks but may miss short tracks.

3
minimum_iou_threshold float

IOU threshold for gating in the matching cascade.

0.3
appearance_threshold float

Cosine distance threshold for appearance matching. Only matches below this threshold are considered valid.

0.7
appearance_weight float

Weight (0-1) balancing motion (IOU) and appearance distance in the combined matching cost.

0.5
distance_metric str

Distance metric for appearance features (e.g., 'cosine', 'euclidean'). See scipy.spatial.distance.cdist.

'cosine'

update(detections, frame)

Updates the tracker state with new detections and appearance features.

Extracts appearance features, performs Kalman filter prediction, calculates IOU and appearance distance matrices, associates detections with tracks using a combined metric, updates matched tracks (position and appearance), and initializes new tracks for unmatched high-confidence detections.

Parameters:

Name Type Description Default
detections Detections

The latest set of object detections.

required
frame ndarray

The current video frame, used for extracting appearance features from detections.

required

Returns:

Type Description
Detections

sv.Detections: A copy of the input detections, augmented with assigned tracker_id for each successfully tracked object. Detections not associated with a track will not have a tracker_id.

reset()

Resets the tracker's internal state.

Clears all active tracks and resets the track ID counter.

Comments