Skip to content

Utilities

Image processing, the vertebra-labeling path solver, disc labeling and other helpers.

spineps.utils.proc_functions

spineps.utils.proc_functions

Segmentation post-processing helpers: n4 bias correction, connected-component cleaning and instance fixes.

n4_bias

n4_bias(
    nii: NII,
    threshold: int = 60,
    spline_param: int = 100,
    dtype2nii: bool = False,
    norm: int = -1,
) -> tuple[NII, NII]

Apply N4 bias field correction to a NIfTI image.

Builds a foreground mask by thresholding (and filling its bounding box), runs N4 correction restricted to that mask, optionally rescales the result to a target maximum and optionally casts back to the input dtype.

Parameters:

Name Type Description Default
nii NII

Input image to correct.

required
threshold int

Intensity threshold for the foreground mask; voxels below it are excluded. Defaults to 60.

60
spline_param int

Spline distance parameter passed to the N4 correction. Defaults to 100.

100
dtype2nii bool

If True, cast the corrected image back to the input image's dtype. Defaults to False.

False
norm int

If not -1, rescale the corrected image so its maximum equals this value. Defaults to -1.

-1

Returns:

Type Description
tuple[NII, NII]

tuple[NII, NII]: The bias-corrected image and the binary foreground mask used for correction.

Source code in spineps/utils/proc_functions.py
def n4_bias(
    nii: NII,
    threshold: int = 60,
    spline_param: int = 100,
    dtype2nii: bool = False,
    norm: int = -1,
) -> tuple[NII, NII]:
    """Apply N4 bias field correction to a NIfTI image.

    Builds a foreground mask by thresholding (and filling its bounding box), runs N4 correction restricted to
    that mask, optionally rescales the result to a target maximum and optionally casts back to the input dtype.

    Args:
        nii (NII): Input image to correct.
        threshold (int, optional): Intensity threshold for the foreground mask; voxels below it are excluded.
            Defaults to 60.
        spline_param (int, optional): Spline distance parameter passed to the N4 correction. Defaults to 100.
        dtype2nii (bool, optional): If True, cast the corrected image back to the input image's dtype. Defaults
            to False.
        norm (int, optional): If not -1, rescale the corrected image so its maximum equals this value. Defaults
            to -1.

    Returns:
        tuple[NII, NII]: The bias-corrected image and the binary foreground mask used for correction.
    """
    from ants.utils.convert_nibabel import from_nibabel  # they keep renaming that thing. (version 0.4.2)

    # print("n4 bias", nii.dtype)
    mask = nii.get_array()
    mask[mask < threshold] = 0
    mask[mask != 0] = 1
    slices = np_bbox_binary(mask)
    mask[slices] = 1
    mask_nii = nii.set_array(mask)
    mask_nii.seg = True
    n4: NII = nii.n4_bias_field_correction(threshold=0, mask=from_nibabel(mask_nii.nii), spline_param=spline_param)
    if norm != -1:
        n4 *= norm / n4.max()
    if dtype2nii:
        n4.set_dtype_(nii.dtype)
    return n4, mask_nii

clean_cc_artifacts

clean_cc_artifacts(
    mask: NII | ndarray,
    logger: Logger_Interface,
    labels: list[int] = [1, 2, 3],
    cc_size_threshold: int | list[int] = 100,
    neighbor_factor_2_delete: float = 0.1,
    verbose: bool = True,
    only_delete: bool = False,
    ignore_missing_labels: bool = False,
) -> np.ndarray

Clean small connected-component artifacts in a segmentation mask.

For each requested label, finds connected components below the size threshold and either deletes them or, if they border enough other foreground voxels, relabels them by majority vote of their dilated neighborhood.

Parameters:

Name Type Description Default
mask NII | ndarray

Input segmentation mask.

required
logger Logger_Interface

Logger for progress and cleaning reports.

required
labels list[int]

Labels to analyze. Defaults to [1, 2, 3].

[1, 2, 3]
cc_size_threshold int | list[int]

Minimum component size in voxels; a single value applies to all labels, or one value per label. Defaults to 100.

100
neighbor_factor_2_delete float

Fraction of neighboring foreground voxels below which a component is deleted instead of relabeled. Defaults to 0.1.

0.1
verbose bool

If True, log per-component details and show a progress bar. Defaults to True.

True
only_delete bool

If True, delete every analyzed component without majority-vote relabeling. Defaults to False.

False
ignore_missing_labels bool

If True, skip labels not present instead of asserting. Defaults to False.

False

Returns:

Type Description
ndarray

np.ndarray: The cleaned segmentation array.

Raises:

Type Description
AssertionError

If requested labels are missing (when ignore_missing_labels is False) or the length of cc_size_threshold does not match the number of labels.

Source code in spineps/utils/proc_functions.py
def clean_cc_artifacts(
    mask: NII | np.ndarray,
    logger: Logger_Interface,
    labels: list[int] = [1, 2, 3],  # noqa: B006
    cc_size_threshold: int | list[int] = 100,
    neighbor_factor_2_delete: float = 0.1,
    verbose: bool = True,
    only_delete: bool = False,
    ignore_missing_labels: bool = False,
) -> np.ndarray:
    """Clean small connected-component artifacts in a segmentation mask.

    For each requested label, finds connected components below the size threshold and either deletes them or, if
    they border enough other foreground voxels, relabels them by majority vote of their dilated neighborhood.

    Args:
        mask (NII | np.ndarray): Input segmentation mask.
        logger (Logger_Interface): Logger for progress and cleaning reports.
        labels (list[int], optional): Labels to analyze. Defaults to [1, 2, 3].
        cc_size_threshold (int | list[int], optional): Minimum component size in voxels; a single value applies to
            all labels, or one value per label. Defaults to 100.
        neighbor_factor_2_delete (float, optional): Fraction of neighboring foreground voxels below which a
            component is deleted instead of relabeled. Defaults to 0.1.
        verbose (bool, optional): If True, log per-component details and show a progress bar. Defaults to True.
        only_delete (bool, optional): If True, delete every analyzed component without majority-vote relabeling.
            Defaults to False.
        ignore_missing_labels (bool, optional): If True, skip labels not present instead of asserting. Defaults to
            False.

    Returns:
        np.ndarray: The cleaned segmentation array.

    Raises:
        AssertionError: If requested labels are missing (when ``ignore_missing_labels`` is False) or the length of
            ``cc_size_threshold`` does not match the number of labels.
    """
    mask_arr = mask.get_seg_array() if isinstance(mask, NII) else mask.copy()
    result_arr = mask_arr.copy()

    mask_labels = np_unique(result_arr)
    if 0 not in mask_labels:
        logger.print("No zero in mask? Skip cleaning")
        return mask_arr

    if not ignore_missing_labels:
        assert np.all([l in mask_labels for l in labels]), (
            f"specified labels not found in mask, got labels {labels} and mask labels {mask_labels}"
        )
    else:
        labelsnew = []
        sizes = []
        for idx, l in enumerate(labels):
            if l in mask_labels:
                labelsnew.append(l)
                sizes.append(cc_size_threshold[idx] if isinstance(cc_size_threshold, list) else cc_size_threshold)
        labels = labelsnew
        cc_size_threshold = sizes

    if not isinstance(cc_size_threshold, list):
        cc_size_threshold = [cc_size_threshold for i in range(len(labels))]
    assert len(cc_size_threshold) == len(labels), (
        f"cc_size_threshold size does not match number of given labels to clean, got {len(labels)} and {len(cc_size_threshold)}. Specifiy only an int for cc_size_threshold to use it for all labels"
    )

    subreg_cc, subreg_cc_stats = connected_components_3d(result_arr, connectivity=1)

    cc_to_clean = {}
    for lidx, label in enumerate(tqdm(labels, desc=f"{logger._get_logger_prefix()} cleaning...", disable=not verbose)):
        # print(l, subreg_cc_stats[l]["voxel_counts"])
        idx = [i for i, v in enumerate(subreg_cc_stats[label]["voxel_counts"]) if v < cc_size_threshold[lidx] and v > 0]
        if len(idx) > 0:
            cc_to_clean[label] = idx

        for cc_idx in idx:
            # extract cc label
            mask_cc = subreg_cc[label]
            mask_cc_l = mask_cc.copy()
            mask_cc_l[mask_cc_l != cc_idx] = 0
            log_string = ""
            if verbose:
                cc_volume = np_count_nonzero(mask_cc_l)
                cc_centroid = center_of_mass(mask_cc_l)
                cc_centroid = [int(c) + 1 for c in cc_centroid]  # type: ignore
                log_string = f"Label {label}, cc{cc_idx}, at {cc_centroid}, volume {cc_volume}: "
            if only_delete:
                logger.print(log_string + "deleted") if verbose else None
                # dilated mask nothing in original mask, just delete it
                result_arr[mask_cc_l != 0] = 0
                continue
            dilated_m = np_dilate_msk(mask_cc_l, n_pixel=1)
            dilated_m[mask_cc_l != 0] = 0
            neighbor_voxel_count = np_count_nonzero(dilated_m)
            # print(subreg_cc_stats[label])

            mult = mask_arr * dilated_m
            if np_count_nonzero(mult) <= int(neighbor_voxel_count * neighbor_factor_2_delete):
                logger.print(log_string + "deleted") if verbose else None
                # dilated mask nothing in original mask, just delete it
                result_arr[mask_cc_l != 0] = 0
            else:
                # majority voting
                dilated_m[dilated_m != 0] = 1
                mult = mask_arr * dilated_m
                volumes = np_volume(mult)
                nlabels = list(volumes.keys())
                volumes_values = list(volumes.values())
                newlabel = nlabels[np.argmax(volumes_values)]  # type: ignore
                result_arr[mask_cc_l != 0] = newlabel
                logger.print(log_string + f"labeled as {newlabel}") if verbose else None
                # print(labels, count)
    n_to_clean = {k: len(v) for k, v in cc_to_clean.items()}
    # By clearning: look at surrounding neighbor pixels. If too few, remove cc. otherwise, do majority voting
    if len(n_to_clean) != 0:
        logger.print(f"Cleaned (label, n_components) {n_to_clean}")
    return result_arr

connected_components_3d

connected_components_3d(
    mask_image: ndarray,
    connectivity: int = 3,
    verbose: bool = False,
) -> tuple[dict, dict]

Compute 3D connected components per label together with their statistics.

Parameters:

Name Type Description Default
mask_image ndarray

Input (multi-label) mask.

required
connectivity int

Voxel connectivity in range [1, 3]. For 2D images 2 and 3 are equivalent. Defaults to 3.

3
verbose bool

Currently unused. Defaults to False.

False

Returns:

Type Description
dict

tuple[dict, dict]: A dict mapping each label to its connected-component array, and a dict mapping each

dict

label to its cc3d component statistics.

Source code in spineps/utils/proc_functions.py
def connected_components_3d(mask_image: np.ndarray, connectivity: int = 3, verbose: bool = False) -> tuple[dict, dict]:  # noqa: ARG001
    """Compute 3D connected components per label together with their statistics.

    Args:
        mask_image (np.ndarray): Input (multi-label) mask.
        connectivity (int, optional): Voxel connectivity in range [1, 3]. For 2D images 2 and 3 are equivalent.
            Defaults to 3.
        verbose (bool, optional): Currently unused. Defaults to False.

    Returns:
        tuple[dict, dict]: A dict mapping each label to its connected-component array, and a dict mapping each
        label to its ``cc3d`` component statistics.
    """
    subreg_cc = np_connected_components_per_label(
        mask_image,
        connectivity=connectivity,
    )
    subreg_cc_stats = {k: cc3d.statistics(v) for k, v in subreg_cc.items()}
    return subreg_cc, subreg_cc_stats

fix_wrong_posterior_instance_label

fix_wrong_posterior_instance_label(
    seg_sem: NII, seg_inst: NII, logger: Logger_Interface
) -> NII

Reassign misattributed posterior vertebra fragments to the correct instance label.

For every vertebra instance that splits into multiple connected components, each extra component consisting only of posterior elements (arcus vertebrae and/or spinous process) is relabeled to the single neighboring instance it touches, if any. Operates on copies and restores the original orientation before returning.

Parameters:

Name Type Description Default
seg_sem NII

Semantic segmentation (subregion labels) aligned with seg_inst.

required
seg_inst NII

Vertebra instance segmentation to correct.

required
logger Logger_Interface

Logger used to report each relabeling decision.

required

Returns:

Name Type Description
NII NII

The corrected instance segmentation in the original orientation.

Raises:

Type Description
AssertionError

If seg_sem and seg_inst do not share the same affine.

Source code in spineps/utils/proc_functions.py
def fix_wrong_posterior_instance_label(seg_sem: NII, seg_inst: NII, logger: Logger_Interface) -> NII:
    """Reassign misattributed posterior vertebra fragments to the correct instance label.

    For every vertebra instance that splits into multiple connected components, each extra component consisting
    only of posterior elements (arcus vertebrae and/or spinous process) is relabeled to the single neighboring
    instance it touches, if any. Operates on copies and restores the original orientation before returning.

    Args:
        seg_sem (NII): Semantic segmentation (subregion labels) aligned with ``seg_inst``.
        seg_inst (NII): Vertebra instance segmentation to correct.
        logger: Logger used to report each relabeling decision.

    Returns:
        NII: The corrected instance segmentation in the original orientation.

    Raises:
        AssertionError: If ``seg_sem`` and ``seg_inst`` do not share the same affine.
    """
    seg_sem = seg_sem.copy()
    seg_inst = seg_inst.copy()
    orientation = seg_sem.orientation
    seg_sem.assert_affine(other=seg_inst)
    seg_sem.reorient_()
    seg_inst.reorient_()

    seg_inst_arr_proc = seg_inst.get_seg_array()

    instance_labels = [i for i in seg_inst.unique() if 1 <= i <= MAX_VERTEBRA_INSTANCE_LABEL]

    for vert in instance_labels:
        inst_vert = seg_inst.extract_label(vert)
        # sem_vert = seg_sem.apply_mask(inst_vert)

        # Check if multiple CC exist
        inst_vert_cc: NII = inst_vert.filter_connected_components(max_count_component=3, keep_label=False)
        inst_vert_cc_n = int(inst_vert_cc.max())
        if inst_vert_cc_n == 1:
            continue
        #
        # inst_vert_cc is labeled 1 to 3
        for i in range(2, inst_vert_cc_n + 1):
            inst_vert_cc_i = inst_vert_cc.extract_label(i)

            crop = inst_vert_cc_i.compute_crop(dist=1)
            inst_vert_cc_i_c = inst_vert_cc_i.apply_crop(crop)

            cc_sem_vert = seg_sem.apply_crop(crop).apply_mask(inst_vert_cc_i_c)
            # cc_vert is semantic mask of only that cc of instance

            cc_sem_vert_labels = cc_sem_vert.unique()
            # is that cc only arcus and spinosus?
            if len(cc_sem_vert_labels) <= 2 and np.all(
                [i in [Location.Arcus_Vertebrae.value, Location.Spinosus_Process.value] for i in cc_sem_vert_labels]
            ):
                # neighbor that have non arcus/spinosus label?
                neighbor_instance_labels = seg_inst.apply_crop(crop).get_seg_array()
                neighbor_instance_labels[inst_vert_cc_i_c.get_seg_array() == 1] = 0
                neighbor_instance_labels = np_unique_withoutzero(neighbor_instance_labels)
                # which instance labels does it touch
                logger.print(f"vert {vert}, cc_k {i} has instance neighbors {neighbor_instance_labels}")
                # is it touching only one other instance label?
                if len(neighbor_instance_labels) == 1 and neighbor_instance_labels[0] != vert:
                    to_label = neighbor_instance_labels[0]
                    logger.print(f"vert {vert}, cc_k {i} relabel to instance {to_label}")
                    seg_inst_arr_proc[inst_vert_cc_i.get_seg_array() == 1] = to_label

    seg_inst_proc = seg_inst.set_array(seg_inst_arr_proc).reorient_(orientation)
    return seg_inst_proc

spineps.utils.find_min_cost_path

spineps.utils.find_min_cost_path

Min-cost path solver that assigns the most probable vertebra label sequence from a per-vertebra cost matrix.

argmin

argmin(lst: list) -> tuple[int, ...]

Return the index and value of the smallest element in a list.

Parameters:

Name Type Description Default
lst list

A non-empty sequence supporting min and index.

required

Returns:

Name Type Description
tuple tuple[int, ...]

(index, value) of the minimum element.

Source code in spineps/utils/find_min_cost_path.py
def argmin(lst: list) -> tuple[int, ...]:
    """Return the index and value of the smallest element in a list.

    Args:
        lst: A non-empty sequence supporting ``min`` and ``index``.

    Returns:
        tuple: ``(index, value)`` of the minimum element.
    """
    m = min(lst)
    return lst.index(m), m

softmax_T

softmax_T(x, temp)

Compute softmax values for each sets of scores in x.

Source code in spineps/utils/find_min_cost_path.py
def softmax_T(x, temp):
    """Compute softmax values for each sets of scores in x."""
    return np.exp(np.divide(x, temp)) / np.sum(np.exp(np.divide(x, temp)), axis=0)

c_to_region_idx

c_to_region_idx(c: int, regions: list[int]) -> int

Map a class index to the index of the spinal region it falls into.

Parameters:

Name Type Description Default
c int

Class (label) index along the cost matrix's class axis.

required
regions list[int]

Sorted region start indices (e.g. cervical, thoracic, lumbar).

required

Returns:

Name Type Description
int int

Index of the region containing class c.

Source code in spineps/utils/find_min_cost_path.py
def c_to_region_idx(c: int, regions: list[int]) -> int:
    """Map a class index to the index of the spinal region it falls into.

    Args:
        c (int): Class (label) index along the cost matrix's class axis.
        regions (list[int]): Sorted region start indices (e.g. cervical, thoracic, lumbar).

    Returns:
        int: Index of the region containing class ``c``.
    """
    for idx, r in enumerate(regions):
        if c < r:
            return idx - 1
    return len(regions) - 1

internal_to_real_path

internal_to_real_path(p: list) -> list

Convert an internal (row, class) path into the ordered list of class indices.

Parameters:

Name Type Description Default
p list

Iterable of (row, class) tuples representing path nodes.

required

Returns:

Name Type Description
list list

Class indices ordered by ascending row (vertebra) index.

Source code in spineps/utils/find_min_cost_path.py
def internal_to_real_path(p: list) -> list:
    """Convert an internal ``(row, class)`` path into the ordered list of class indices.

    Args:
        p: Iterable of ``(row, class)`` tuples representing path nodes.

    Returns:
        list: Class indices ordered by ascending row (vertebra) index.
    """
    pat = sorted(p, key=lambda x: x[0])
    pat = [x[1] for x in pat]
    return pat

find_most_probably_sequence

find_most_probably_sequence(
    cost: ndarray | list[int],
    min_start_class: int = 0,
    region_rel_cost: ndarray | list[int] | None = None,
    vertt13_cost: ndarray | list[int] | None = None,
    regions: list[int] | None = None,
    invert_cost: bool = True,
    softmax_cost: bool = False,
    softmax_temp: float = DEFAULT_SOFTMAX_TEMP,
    allow_multiple_at_class: list[int] | None = None,
    punish_multiple_sequence: float = 0.0,
    allow_skip_at_class: list[int] | None = None,
    punish_skip_sequence: float = 0.0,
    allow_skip_at_region: list[int] | None = None,
    punish_skip_at_region_sequence: float = 0.2,
    verbose: bool = False,
) -> tuple[float, list[int], list]

Find the most probable vertebra-label sequence as a min-cost monotone path through a cost matrix.

Each matrix row corresponds to a detected vertebra (top to bottom) and each column to a candidate label class. The path moves one row down per step, normally advancing one class (diagonal). Special constraints model spinal anatomy: certain transitional classes (e.g. T12, L5) may repeat, certain classes/regions allow a single skip, and optional region- and T13-related transition costs adjust the path. Extra moves incur the configured penalties; classes flagged as repeatable may appear at most MAX_REPEATS_PER_CLASS times.

Parameters:

Name Type Description Default
cost ndarray | list[int]

2D cost matrix of shape (n_vertebrae, n_classes).

required
min_start_class int

Smallest class index the path may start at. Defaults to 0.

0
region_rel_cost ndarray | list[int] | None

Per-vertebra costs for being the first/last vertebra of each region; enables region-transition costs when given. Defaults to None.

None
vertt13_cost ndarray | list[int] | None

Per-vertebra cost contribution for the T13/T12 (class 18) repeat case. Defaults to None.

None
regions list[int] | None

Region start indices along the class axis. Defaults to DEFAULT_REGION_STARTS.

None
invert_cost bool

Negate the cost so that high input scores are preferred. Defaults to True.

True
softmax_cost bool

Apply a softmax over the cost columns (deprecated path). Defaults to False.

False
softmax_temp float

Temperature for the softmax. Defaults to DEFAULT_SOFTMAX_TEMP.

DEFAULT_SOFTMAX_TEMP
allow_multiple_at_class list[int] | None

Classes allowed to repeat (e.g. T12 and L5). Defaults to [T12_CLASS_IDX, L5_CLASS_IDX].

None
punish_multiple_sequence float

Extra cost added for repeating a class. Defaults to 0.0.

0.0
allow_skip_at_class list[int] | None

Classes after which a single class may be skipped (e.g. T11). Defaults to [T11_CLASS_IDX].

None
punish_skip_sequence float

Extra cost added for a class-level skip. Defaults to 0.0.

0.0
allow_skip_at_region list[int] | None

Regions in which a single skip is permitted. Defaults to [0].

None
punish_skip_at_region_sequence float

Extra cost added for a region-level skip. Defaults to 0.2.

0.2
verbose bool

Enable verbose logging of the recursion. Defaults to False.

False

Returns:

Type Description
float

tuple[float, list[int], list]: The total path cost, the chosen class index per vertebra (top to bottom),

list[int]

and the internal memoization table of best (cost, path) per (row, class) cell.

Raises:

Type Description
AssertionError

If min_start_class is not less than the number of classes, or if a provided region_rel_cost does not have the expected number of columns.

Source code in spineps/utils/find_min_cost_path.py
def find_most_probably_sequence(  # noqa: C901
    cost: np.ndarray | list[int],
    #
    min_start_class: int = 0,
    region_rel_cost: np.ndarray | list[int] | None = None,
    vertt13_cost: np.ndarray | list[int] | None = None,
    regions: list[int] | None = None,
    #
    invert_cost: bool = True,
    #
    softmax_cost: bool = False,
    softmax_temp: float = DEFAULT_SOFTMAX_TEMP,
    #
    allow_multiple_at_class: list[int] | None = None,  # T12 and L5
    punish_multiple_sequence: float = 0.0,
    #
    allow_skip_at_class: list[int] | None = None,  # T11
    punish_skip_sequence: float = 0.0,
    #
    allow_skip_at_region: list[int] | None = None,
    punish_skip_at_region_sequence: float = 0.2,
    #
    verbose: bool = False,
) -> tuple[float, list[int], list]:
    """Find the most probable vertebra-label sequence as a min-cost monotone path through a cost matrix.

    Each matrix row corresponds to a detected vertebra (top to bottom) and each column to a candidate label
    class. The path moves one row down per step, normally advancing one class (diagonal). Special constraints
    model spinal anatomy: certain transitional classes (e.g. T12, L5) may repeat, certain classes/regions allow
    a single skip, and optional region- and T13-related transition costs adjust the path. Extra moves incur the
    configured penalties; classes flagged as repeatable may appear at most ``MAX_REPEATS_PER_CLASS`` times.

    Args:
        cost (np.ndarray | list[int]): 2D cost matrix of shape ``(n_vertebrae, n_classes)``.
        min_start_class (int, optional): Smallest class index the path may start at. Defaults to 0.
        region_rel_cost (np.ndarray | list[int] | None, optional): Per-vertebra costs for being the first/last
            vertebra of each region; enables region-transition costs when given. Defaults to None.
        vertt13_cost (np.ndarray | list[int] | None, optional): Per-vertebra cost contribution for the T13/T12
            (class 18) repeat case. Defaults to None.
        regions (list[int] | None, optional): Region start indices along the class axis. Defaults to
            ``DEFAULT_REGION_STARTS``.
        invert_cost (bool, optional): Negate the cost so that high input scores are preferred. Defaults to True.
        softmax_cost (bool, optional): Apply a softmax over the cost columns (deprecated path). Defaults to False.
        softmax_temp (float, optional): Temperature for the softmax. Defaults to ``DEFAULT_SOFTMAX_TEMP``.
        allow_multiple_at_class (list[int] | None, optional): Classes allowed to repeat (e.g. T12 and L5).
            Defaults to ``[T12_CLASS_IDX, L5_CLASS_IDX]``.
        punish_multiple_sequence (float, optional): Extra cost added for repeating a class. Defaults to 0.0.
        allow_skip_at_class (list[int] | None, optional): Classes after which a single class may be skipped (e.g.
            T11). Defaults to ``[T11_CLASS_IDX]``.
        punish_skip_sequence (float, optional): Extra cost added for a class-level skip. Defaults to 0.0.
        allow_skip_at_region (list[int] | None, optional): Regions in which a single skip is permitted. Defaults
            to ``[0]``.
        punish_skip_at_region_sequence (float, optional): Extra cost added for a region-level skip. Defaults to 0.2.
        verbose (bool, optional): Enable verbose logging of the recursion. Defaults to False.

    Returns:
        tuple[float, list[int], list]: The total path cost, the chosen class index per vertebra (top to bottom),
        and the internal memoization table of best ``(cost, path)`` per ``(row, class)`` cell.

    Raises:
        AssertionError: If ``min_start_class`` is not less than the number of classes, or if a provided
            ``region_rel_cost`` does not have the expected number of columns.
    """
    logger = No_Logger()
    logger.default_verbose = verbose
    # default mutable arguments
    if allow_skip_at_region is None:
        allow_skip_at_region = [0]
    if allow_skip_at_class is None:
        allow_skip_at_class = [T11_CLASS_IDX]
    if allow_multiple_at_class is None:
        allow_multiple_at_class = [T12_CLASS_IDX, L5_CLASS_IDX]
    if regions is None:
        regions = list(DEFAULT_REGION_STARTS)
    # convert to np arrays
    if isinstance(cost, list):
        cost = np.asarray(cost)
    if region_rel_cost is not None and isinstance(region_rel_cost, list):
        region_rel_cost = np.asanyarray(region_rel_cost)
    if vertt13_cost is not None and isinstance(vertt13_cost, list):
        vertt13_cost = np.asanyarray(vertt13_cost)
    # safety assert
    assert isinstance(cost, np.ndarray)
    shape = cost.shape

    # define regions
    n_classes = shape[1]
    assert min_start_class < n_classes
    regions_ranges = None
    if region_rel_cost is not None:
        if n_classes < regions[-1]:
            warn(f"n_classes < defined regions, got {n_classes} and {regions}", stacklevel=3)
        regions.append(n_classes)
        regions_ranges = [(regions[i], regions[i + 1] - 1) for i in range(len(regions) - 1)]
        region_rel_shape = region_rel_cost.shape
        assert region_rel_shape[1] == ((len(regions) - 1) * 2), (
            f"expected region_rel_cost with shape {((len(regions) - 1) * 2)}, but got {region_rel_shape[1]}"
        )

    # softmax (deprecated, handled elsewhere)
    if softmax_cost:
        cost = softmax_T(cost, softmax_temp)
    # invert cost so high numbers are actually preferred instead of repelled
    if invert_cost:
        cost = -cost

    # make costs a list
    costlist = cost.tolist()
    # init memory
    min_costs_path = [[(None, None) for y in range(shape[1])] for x in range(shape[0])]

    # Adds edges with a extra cost beyond the cost matrix
    def add_option_path(options, r, c, extracost):
        options.append(minCostAlgo(r, c))
        options[-1] = (
            options[-1][0] + extracost,
            options[-1][1],
        )
        return options

    # main recursive loop
    def minCostAlgo(r, c):
        logger.print(f"Called vert {r}, label {c}")
        # get current region
        region_cur = c_to_region_idx(c, regions)
        # start point
        if c == -1 and r == -1:
            # go over each possible start column
            options = []
            for cc in range(min_start_class, n_classes):
                with logger:
                    # logger.default_verbose = cc in [7, 8, 9]
                    add_option_path(options, 0, cc, 0)
                # options.append(minCostAlgo(r=0, c=cc))
            minidx, minval = argmin([o[0] for o in options])
            return minval, options[minidx][1]
        # stepped over the line
        elif c < 0 or r < 0 or c >= shape[1] or r >= shape[0]:
            logger.print(f"Out of bounds vert {r}, label {c}")
            return sys.maxsize, [(r, c)]
        # last row, path end
        elif r == shape[0] - 1:
            # logger.print(f"End of path vert {r}, label {c}")
            # path_tothis.append((r, c))
            cost_value = costlist[r][c]
            p = [(r, c)]
            # transition cost of vertrel
            cost_value += rel_cost(r, c, p, region_cur)
            if cost_value < 0:
                logger.print(f"End of path vert {r}, label {c} to {cost_value}, {internal_to_real_path(p)}")
            return (cost_value, p)
        # check min of move directions
        else:
            if min_costs_path[r][c][0] is not None:
                return min_costs_path[r][c]

            # rel_costadd = rel_cost(r, c, [(r, c)], region_cur)
            options = []
            # normal diagonal edge
            with logger:
                add_option_path(options, r + 1, c + 1, 0)
            # allow two subsequent of same class
            if c in allow_multiple_at_class:
                cost_add = punish_multiple_sequence
                if c == T12_CLASS_IDX:
                    cost_add += t13_cost_single(r + 1, c)
                with logger:
                    add_option_path(options, r + 1, c, cost_add)
            # Allow skips at certain classes
            if c in allow_skip_at_class:
                cost_add = punish_skip_sequence
                with logger:
                    add_option_path(options, r + 1, c + 2, cost_add)
            # Allow skips in certain regions
            if region_cur in allow_skip_at_region and c != regions_ranges[region_cur][1] - 1:
                cost_add = punish_skip_at_region_sequence
                with logger:
                    add_option_path(options, r + 1, c + 2, punish_skip_at_region_sequence)
            # find min
            minidx, minval = argmin([o[0] for o in options])
            pnext = options[minidx][1]
            p = [*pnext, (r, c)]
            cnt = Counter([l[1] for l in p])
            #
            cost_value = minval + costlist[r][c]
            # transition cost of vertrel
            cost_value += rel_cost(r, c, pnext, region_cur)
            # constraint: cannot have more than 2 T12 and L5
            for amac in allow_multiple_at_class:
                if amac in cnt and cnt[amac] > MAX_REPEATS_PER_CLASS:
                    cost_value = sys.maxsize
                    break
            # setting to memory
            min_costs_path[r][c] = (cost_value, p)
            if cost_value < 0:
                logger.print(f"Setting vert {r}, label {c} to {cost_value}, {internal_to_real_path(p)}")
            return min_costs_path[r][c]

    # def t13_cost(r, c, pnext, p, region_cur):
    #    cost_add = 0
    #    if vertt13_cost is not None:
    #        vt13_cost = vertt13_cost[r][1]
    #        # print(r, c, p[-1][1], p[-2][1], internal_to_real_path(p))
    #        if p[-1][1] == 18 and p[-2][1] == 18:
    #            print(f"Added F {vt13_cost} to {r}, {c}, {internal_to_real_path(p)}")
    #            cost_add += vt13_cost
    #    return cost_add

    def t13_cost_single(r, c):
        cost_add = 0
        if vertt13_cost is not None:
            vt13_cost = vertt13_cost[r][1]
            if c == 18:
                # print(f"Added F {vt13_cost} to {r}, {c}")
                cost_add += vt13_cost
        return cost_add

    def rel_cost(r, c, pnext, region_cur):
        # transition cost of vertrel
        # first is just equal to that specific vertebra
        # last is dependant on next in path
        # classes are always first, last in order of regions
        cost_add = 0
        if region_rel_cost is not None:
            # for ridx in range(len(regions) - 1):
            for last in [0, 1]:
                if region_cur + last == 0:
                    continue
                region_cls = (region_cur * 2) + last  # 0 is nothing
                rel_cost = region_rel_cost[r][region_cls]
                if rel_cost == 0:
                    continue
                if last == 0 and c == regions_ranges[region_cur][0]:
                    logger.print(f"Added F {rel_cost} to vert {r}, label {c}, {internal_to_real_path(pnext)}")
                    cost_add += rel_cost
                    # break
                elif last == 1 and (c_to_region_idx(pnext[-1][1], regions) >= region_cur + 1):  # or pnext[-1][1] == c):
                    logger.print(f"Added L {rel_cost} to vert {r}, label {c}, {internal_to_real_path(pnext)}")
                    cost_add += rel_cost
        return cost_add

    fcost, fpath = minCostAlgo(-1, -1)
    fpath.reverse()
    fpath = [f[1] for f in fpath]
    return fcost, fpath, min_costs_path

spineps.utils.generate_disc_labels

spineps.utils.generate_disc_labels

This script generates discs labels using SPINEPS' vertebrae segmentation

Author: Nathan Molinier

get_parser

get_parser() -> argparse.ArgumentParser

Build the command-line argument parser for disc-label generation.

Returns:

Type Description
ArgumentParser

argparse.ArgumentParser: Parser accepting the input vertebrae label path and the optional output path.

Source code in spineps/utils/generate_disc_labels.py
def get_parser() -> argparse.ArgumentParser:
    """Build the command-line argument parser for disc-label generation.

    Returns:
        argparse.ArgumentParser: Parser accepting the input vertebrae label path and the optional output path.
    """
    # parse command line arguments
    parser = argparse.ArgumentParser(description="Generate discs labels from spineps' vertebrae segmentation.")
    parser.add_argument(
        "--path-vert",
        type=str,
        required=True,
        help='Path to the SPINEPS vertebrae labels. Example: "/<data_path>/sub-amuALT_T2w_label-vert_dseg.nii.gz" (Required)',
    )
    parser.add_argument(
        "--path-out",
        type=str,
        default="",
        help="Output path of the discs label. "
        'Example: "/<data_path>/sub-amuALT_T2w_label-discs_dlabel.nii.gz". '
        'By default, the structure "_label-discs_dlabel" will be used.',
    )
    return parser

main

main()

Run the disc-label generation CLI.

Parses arguments, loads the SPINEPS vertebrae segmentation, derives single-voxel disc labels from it and writes the result to the chosen (or default) output path.

Source code in spineps/utils/generate_disc_labels.py
def main():
    """Run the disc-label generation CLI.

    Parses arguments, loads the SPINEPS vertebrae segmentation, derives single-voxel disc labels from it and
    writes the result to the chosen (or default) output path.
    """
    # Load parser
    parser = get_parser()
    args = parser.parse_args()

    # Fetch paths
    path_in = Path(args.path_vert).resolve()
    path_out = Path(args.path_out).resolve() if args.path_out else default_name_discs(path_in)

    # Check if output folder exists
    if not path_out.parent.exists():
        path_out.parent.mkdir(parents=True)

    # Extract discs labels
    vert_image = Image(str(path_in))
    print("-" * 80)
    print(f"Creating discs label using SPINEPS prediction: {path_in}")
    print("-" * 80)
    discs_nii_clean = extract_discs_label(vert_image, mapping=DISCS_MAP)

    # Save discs labels
    discs_nii_clean.save(str(path_out))
    print("-" * 80)
    print(f"Discs label: {path_out} was created.")
    print("-" * 80)

default_name_discs

default_name_discs(
    path_in: Path | str, suffix="_label-discs_dlabel"
) -> Path

Derive the default output path for disc labels by swapping in a disc suffix.

Parameters:

Name Type Description Default
path_in Path | str

Path to the input vertebrae label file (may include compound extensions like .nii.gz).

required
suffix str

Suffix inserted before the extension. Defaults to "_label-discs_dlabel".

'_label-discs_dlabel'

Returns:

Name Type Description
Path Path

The default output path with the disc suffix applied.

Source code in spineps/utils/generate_disc_labels.py
def default_name_discs(path_in: Path | str, suffix="_label-discs_dlabel") -> Path:
    """Derive the default output path for disc labels by swapping in a disc suffix.

    Args:
        path_in: Path to the input vertebrae label file (may include compound extensions like ``.nii.gz``).
        suffix (str, optional): Suffix inserted before the extension. Defaults to ``"_label-discs_dlabel"``.

    Returns:
        Path: The default output path with the disc suffix applied.
    """
    # Fetch suffixes
    path_obj = Path(path_in)
    ext = "".join(path_obj.suffixes)

    # Add suffix
    path_out = Path(str(path_in).replace(ext, suffix + ext))
    return path_out

extract_discs_label

extract_discs_label(label: Image, mapping: dict) -> Image

Derive single-voxel disc labels from a vertebrae segmentation.

Remaps vertebra label values to disc values, locates each disc's posterior tip by shifting a centerline (interpolated through the disc centroids) posteriorly and picking the closest segmented voxel, inserts disc 2 between discs 1 and 3 when both are present, and writes one labeled voxel per disc into the image.

Parameters:

Name Type Description Default
label Image

Vertebrae segmentation image; its data is replaced in place with the disc labels.

required
mapping dict

Mapping from vertebra label values to disc label values.

required

Returns:

Name Type Description
Image Image

The image holding the disc labels, restored to its original orientation.

Source code in spineps/utils/generate_disc_labels.py
def extract_discs_label(label: Image, mapping: dict) -> Image:
    """Derive single-voxel disc labels from a vertebrae segmentation.

    Remaps vertebra label values to disc values, locates each disc's posterior tip by shifting a centerline
    (interpolated through the disc centroids) posteriorly and picking the closest segmented voxel, inserts disc 2
    between discs 1 and 3 when both are present, and writes one labeled voxel per disc into the image.

    Args:
        label (Image): Vertebrae segmentation image; its data is replaced in place with the disc labels.
        mapping (dict): Mapping from vertebra label values to disc label values.

    Returns:
        Image: The image holding the disc labels, restored to its original orientation.
    """
    # Store input orientation
    orig_orientation = label.orientation

    # Use RSP orientation
    label.change_orientation("RSP")

    # Extract only discs segmentations based on mapping
    data = label.data
    data_discs_seg = np.zeros_like(data)
    for seg_value, discs_value in mapping.items():
        data_discs_seg[np.where(data == seg_value)] = discs_value

    # Deal with disc 1 obtained with the first vertebrae (Highest vertical coordinate)
    if 1 in data_discs_seg:
        # If the first vertebrae is present identify label disc 1 at the top
        vert1_seg = np.array(np.where(data_discs_seg == 1))
        disc1_idx = np.argmin(vert1_seg[1])  # find min on the S-I axis
        disc1_coord = vert1_seg[:, disc1_idx]
        data_discs_seg[np.where(data_discs_seg == 1)] = 0
        data_discs_seg[disc1_coord[0], disc1_coord[1], disc1_coord[2]] = 1

    ## Identify the posterior tip of the disc
    # Extract the center of mass of every discs segmentation to create discs labels
    # Centroids are sorted based on the vertical axis
    discs_centroids, discs_bb = extract_centroids_3d(data_discs_seg)

    # Generate a centerline between the discs by doing linear interpolation
    yvals = np.linspace(discs_centroids[0, 1], discs_centroids[-1, 1], round(8 * len(discs_centroids)))
    xvals = np.interp(yvals, discs_centroids[:, 1], discs_centroids[:, 0])
    zvals = np.interp(yvals, discs_centroids[:, 1], discs_centroids[:, 2])
    centerline = np.concatenate((np.expand_dims(xvals, axis=1), np.expand_dims(yvals, axis=1), np.expand_dims(zvals, axis=1)), axis=1)

    # Shift the centerline to the posterior direction until there is no intersection with the
    # discs segmentations
    # Find the min coordinate of the discs segmentation on the A-P axis
    min_seg_ap = np.min(np.where(data_discs_seg > 0)[2])
    max_centroid_ap = np.max(discs_centroids[:, 2])
    offset = 5
    shift = (max_centroid_ap - min_seg_ap + offset) if min_seg_ap >= offset else (max_centroid_ap - min_seg_ap)

    centerline_shifted = np.copy(centerline)
    centerline_shifted[:, 2] = centerline_shifted[:, 2] - shift

    # For each segmented disc, identify the closest voxel to this shifted centerline
    discs_list = closest_point_seg_to_line(data_discs_seg, centerline_shifted, discs_bb)

    # Add disc 2 between disc 1 and 3
    if 1 and 3 in discs_list[:, -1]:
        disc1_coord = discs_list[discs_list[:, -1] == 1]
        disc2_coord = discs_list[discs_list[:, -1] == 3]
        disc2_coord[0, 1] = (disc2_coord[0, 1] + disc1_coord[0, 1]) // 2
        disc2_coord[0, -1] = 2
        discs_list = np.insert(discs_list, 1, disc2_coord, axis=0)

    # Create output Image
    data_discs = np.zeros_like(data)
    for x, y, z, v in discs_list:
        data_discs[x, y, z] = v
    label.data = data_discs
    return label.change_orientation(orig_orientation)

extract_centroids_3d

extract_centroids_3d(
    arr: ndarray,
) -> tuple[np.ndarray, np.ndarray]

Extract connected-component centroids and bounding boxes from a 3D array, sorted along the vertical axis.

Parameters:

Name Type Description Default
arr ndarray

3D label array (assumed RSP orientation, so axis 1 is the superior-inferior axis).

required

Returns:

Type Description
ndarray

tuple[np.ndarray, np.ndarray]: Integer centroid coordinates and the matching bounding boxes, both sorted

ndarray

by the vertical (axis-1) coordinate, with the background component removed.

Source code in spineps/utils/generate_disc_labels.py
def extract_centroids_3d(arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    """Extract connected-component centroids and bounding boxes from a 3D array, sorted along the vertical axis.

    Args:
        arr (np.ndarray): 3D label array (assumed RSP orientation, so axis 1 is the superior-inferior axis).

    Returns:
        tuple[np.ndarray, np.ndarray]: Integer centroid coordinates and the matching bounding boxes, both sorted
        by the vertical (axis-1) coordinate, with the background component removed.
    """
    stats = cc3d.statistics(cc3d.connected_components(arr))
    centroids = stats["centroids"][1:]  # Remove backgroud <0>
    bounding_boxes = stats["bounding_boxes"][1:]

    # Sort according to the vertical axis because RSP orientation
    sort_args = np.argsort(centroids[:, 1])

    centroids_sorted = centroids[sort_args]
    bb_sorted = np.array(bounding_boxes)[sort_args]
    return centroids_sorted.astype(int), bb_sorted

project_point_on_line

project_point_on_line(
    point: ndarray, line: ndarray
) -> tuple[np.ndarray, float]

Project a point onto a polyline by finding the closest line point.

Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox.

Parameters:

Name Type Description Default
point ndarray

Coordinates of the point, numpy.array([x, y, z]).

required
line ndarray

Coordinates of the points composing the line.

required

Returns:

Type Description
tuple[ndarray, float]

tuple[np.ndarray, float]: The closest point on the line and the squared distance to it.

Source code in spineps/utils/generate_disc_labels.py
def project_point_on_line(point: np.ndarray, line: np.ndarray) -> tuple[np.ndarray, float]:
    """Project a point onto a polyline by finding the closest line point.

    Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox.

    Args:
        point (np.ndarray): Coordinates of the point, ``numpy.array([x, y, z])``.
        line (np.ndarray): Coordinates of the points composing the line.

    Returns:
        tuple[np.ndarray, float]: The closest point on the line and the squared distance to it.
    """
    # Calculate distances between the referenced point and the line then keep the closest point
    dist = np.sum((line - point) ** 2, axis=1)

    return line[np.argmin(dist)], np.min(dist)

closest_point_seg_to_line

closest_point_seg_to_line(
    discs_seg: ndarray,
    centerline: ndarray,
    bounding_boxes: ndarray,
) -> np.ndarray

Find, per disc, the segmented voxel closest to a reference centerline.

Parameters:

Name Type Description Default
discs_seg ndarray

Disc-labeled segmentation array.

required
centerline ndarray

Coordinates of the points composing the reference line.

required
bounding_boxes ndarray

Bounding box (slice tuple) for each disc, used to isolate it.

required

Returns:

Type Description
ndarray

np.ndarray: Array of [x, y, z, disc_value] rows, one per disc, giving the closest voxel and its label.

Source code in spineps/utils/generate_disc_labels.py
def closest_point_seg_to_line(discs_seg: np.ndarray, centerline: np.ndarray, bounding_boxes: np.ndarray) -> np.ndarray:
    """Find, per disc, the segmented voxel closest to a reference centerline.

    Args:
        discs_seg (np.ndarray): Disc-labeled segmentation array.
        centerline (np.ndarray): Coordinates of the points composing the reference line.
        bounding_boxes (np.ndarray): Bounding box (slice tuple) for each disc, used to isolate it.

    Returns:
        np.ndarray: Array of ``[x, y, z, disc_value]`` rows, one per disc, giving the closest voxel and its label.
    """
    discs_list = []
    for x, y, z in bounding_boxes:
        zer = np.zeros_like(discs_seg)
        zer[x, y, z] = discs_seg[x, y, z]  # isolate disc
        # Loop on all the pixels of the segmentation
        min_dist = np.inf
        nonzero = np.where(zer > 0)
        for u, v, w in zip_strict(nonzero[0], nonzero[1], nonzero[2]):
            _, dist = project_point_on_line(np.array([u, v, w]), centerline)
            if dist < min_dist:
                min_dist = dist
                min_point = np.array([u, v, w, discs_seg[u, v, w]])
        discs_list.append(min_point)
    return np.array(discs_list)

spineps.utils.filepaths

spineps.utils.filepaths

File-path helpers for locating the SPINEPS model weights directory and individual model folders.

get_mri_segmentor_models_dir

get_mri_segmentor_models_dir() -> Path

Returns the path to the models weight directory, reading from environment variable, specified override or backup

Returns:

Name Type Description
Path Path

Path to the overall models folder

Source code in spineps/utils/filepaths.py
def get_mri_segmentor_models_dir() -> Path:
    """Returns the path to the models weight directory, reading from environment variable, specified override or backup

    Returns:
        Path: Path to the overall models folder
    """
    folder_path = (
        os.environ.get("SPINEPS_SEGMENTOR_MODELS")
        if spineps_environment_path_override is None or not spineps_environment_path_override.exists()
        else spineps_environment_path_override
    )
    if folder_path is None and spineps_environment_path_backup is not None:
        folder_path = spineps_environment_path_backup

    assert folder_path is not None, (
        "Environment variable 'SPINEPS_SEGMENTOR_MODELS' is not defined. Setup the environment variable as stated in the readme or set the override in utils.filepaths.py"
    )
    folder_path = Path(folder_path)
    assert folder_path.exists(), f"Environment variable 'SPINEPS_SEGMENTOR_MODELS' = {folder_path} does not exist"
    return folder_path

filepath_model

filepath_model(
    model_folder_name: str,
    model_dir: str | Path | None = None,
) -> Path

Returns the path to a model folder with specified model id name

Parameters:

Name Type Description Default
model_folder_name str

Name of the model (corresponds to its folder name)

required
model_dir str | Path | None

Base path to the models directory. If none, will calculate that itself. Defaults to None.

None

Returns:

Name Type Description
Path Path

Path to the model specified by name

Source code in spineps/utils/filepaths.py
def filepath_model(model_folder_name: str, model_dir: str | Path | None = None) -> Path:
    """Returns the path to a model folder with specified model id name

    Args:
        model_folder_name (str): Name of the model (corresponds to its folder name)
        model_dir (str | Path | None, optional): Base path to the models directory. If none, will calculate that itself. Defaults to None.

    Returns:
        Path: Path to the model specified by name
    """
    if model_dir is None:
        model_dir = get_mri_segmentor_models_dir()

    if isinstance(model_dir, str):
        model_dir = Path(model_dir)

    path = model_dir.joinpath(model_folder_name)
    if not path.exists():
        paths = search_path(Path(model_dir), query=f"**/{model_folder_name}")
        if len(paths) == 1:
            return paths[0]
    return model_dir.joinpath(model_folder_name)

search_path

search_path(
    basepath: str | Path,
    query: str,
    verbose: bool = False,
    suppress: bool = False,
) -> list[Path]

Searches from basepath with query

Parameters:

Name Type Description Default
basepath str | Path

ground path to look into

required
query str

search query, can contain wildcards like .npz or /.npz

required
verbose bool
False
suppress bool

if true, will not throwing warnings if nothing is found

False

Returns:

Type Description
list[Path]

All found paths

Source code in spineps/utils/filepaths.py
def search_path(basepath: str | Path, query: str, verbose: bool = False, suppress: bool = False) -> list[Path]:
    """Searches from basepath with query

    Args:
        basepath: ground path to look into
        query: search query, can contain wildcards like *.npz or **/*.npz
        verbose:
        suppress: if true, will not throwing warnings if nothing is found

    Returns:
        All found paths
    """
    basepath = str(basepath)
    if not basepath.endswith("/"):
        basepath += "/"
    print(f"search_path: in {basepath}{query}") if verbose else None
    paths = sorted(chain(list(Path(f"{basepath}").glob(f"{query}"))))
    if len(paths) == 0 and not suppress:
        warnings.warn(f"did not find any paths in {basepath}{query}", UserWarning, stacklevel=1)
    return paths

spineps.utils.auto_download

spineps.utils.auto_download

Automatic download and extraction of pretrained SPINEPS model weights from the GitHub releases.

download_if_missing

download_if_missing(
    key: str, url: Union[Path, str], phase: SpinepsPhase
) -> Path

Return the local model folder for a model, downloading and extracting its weights if absent.

The target folder name combines the model's download name with the version resolved for its phase (and the phase/key-specific override when one exists, e.g. CT models).

Parameters:

Name Type Description Default
key str

Model key identifying the model within its phase (e.g. "t2w", "instance").

required
url Union[Path, str]

Release URL of the model's weights zip archive.

required
phase SpinepsPhase

Pipeline phase the model belongs to.

required

Returns:

Name Type Description
Path Path

Path to the local model folder containing the (possibly just downloaded) weights.

Source code in spineps/utils/auto_download.py
def download_if_missing(key: str, url: Union[Path, str], phase: SpinepsPhase) -> Path:
    """Return the local model folder for a model, downloading and extracting its weights if absent.

    The target folder name combines the model's download name with the version resolved for its phase (and the
    phase/key-specific override when one exists, e.g. CT models).

    Args:
        key: Model key identifying the model within its phase (e.g. ``"t2w"``, ``"instance"``).
        url: Release URL of the model's weights zip archive.
        phase (SpinepsPhase): Pipeline phase the model belongs to.

    Returns:
        Path: Path to the local model folder containing the (possibly just downloaded) weights.
    """
    version = phase_to_version.get(f"{phase}_{key}", phase_to_version[phase.name])
    out_path = Path(get_mri_segmentor_models_dir(), download_names[key] + "_" + version)
    if not out_path.exists():
        download_weights(url, out_path)

    return out_path

download_weights

download_weights(
    weights_url: Union[Path, str],
    out_path: Union[Path, str],
) -> None

Download a weights zip archive, extract it into out_path and remove the archive.

Shows a progress bar during download. If the extracted archive nests its contents in an extra subfolder (no inference_config.json at the top level), the inner contents are moved up one level. Returns early without raising if the initial size request fails.

Parameters:

Name Type Description Default
weights_url Union[Path, str]

URL of the weights zip archive to download.

required
out_path Union[Path, str]

Destination folder for the extracted weights (the archive is downloaded next to it as .zip).

required

Raises:

Type Description
AssertionError

If the nested archive layout is detected but the extra entry is not a directory.

Source code in spineps/utils/auto_download.py
def download_weights(weights_url: Union[Path, str], out_path: Union[Path, str]) -> None:
    """Download a weights zip archive, extract it into ``out_path`` and remove the archive.

    Shows a progress bar during download. If the extracted archive nests its contents in an extra subfolder
    (no ``inference_config.json`` at the top level), the inner contents are moved up one level. Returns early
    without raising if the initial size request fails.

    Args:
        weights_url: URL of the weights zip archive to download.
        out_path: Destination folder for the extracted weights (the archive is downloaded next to it as ``.zip``).

    Raises:
        AssertionError: If the nested archive layout is detected but the extra entry is not a directory.
    """
    out_path = Path(out_path)
    logger = Print_Logger()
    try:
        # Retrieve file size
        with urllib.request.urlopen(str(weights_url)) as response:
            file_size = int(response.info().get("Content-Length", -1))
    except Exception:
        logger.on_fail("Download attempt failed:", weights_url)
        return
    logger.print("Downloading pretrained weights...")

    with tqdm(total=file_size, unit="B", unit_scale=True, unit_divisor=1024, desc=Path(weights_url).name) as pbar:

        def update_progress(block_num: int, block_size: int, total_size: int) -> None:
            if pbar.total != total_size:
                pbar.total = total_size
            pbar.update(block_num * block_size - pbar.n)

        zip_path = Path(str(out_path) + ".zip")
        # Download the file
        urllib.request.urlretrieve(str(weights_url), zip_path, reporthook=update_progress)

    logger.print("Extracting pretrained weights...")

    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(out_path)
    # Test if there is an additional folder and move the content on up.
    if not Path(out_path, "inference_config.json").exists():
        source = next(out_path.iterdir())
        assert source.is_dir()
        for i in source.iterdir():
            shutil.move(i, out_path)

    zip_path.unlink()

spineps.utils.citation_reminder

spineps.utils.citation_reminder

Citation reminder utilities that prompt users to cite SPINEPS when the package is used.

citation_reminder

citation_reminder(func)

Decorator to remind users to cite SPINEPS.

Source code in spineps/utils/citation_reminder.py
def citation_reminder(func):
    """Decorator to remind users to cite SPINEPS."""

    def wrapper(*args, **kwargs):
        global has_reminded_citation  # noqa: PLW0603
        if not has_reminded_citation and os.environ.get("SPINEPS_TURN_OF_CITATION_REMINDER", "FALSE") != "TRUE":
            print_citation_reminder()
            has_reminded_citation = True
        return func(*args, **kwargs)

    return wrapper

print_citation_reminder

print_citation_reminder()

Print a formatted reminder with the SPINEPS GitHub and ArXiv links asking users to cite the work.

Source code in spineps/utils/citation_reminder.py
def print_citation_reminder():
    """Print a formatted reminder with the SPINEPS GitHub and ArXiv links asking users to cite the work."""
    console = Console()
    console.rule("Thank you for using [bold]SPINEPS[/bold]")
    console.print(
        "Please support our development by citing",
        justify="center",
    )
    console.print(
        f"GitHub: {GITHUB_LINK}\nArXiv: {ARXIV_LINK}\n Thank you!",
        justify="center",
    )
    console.rule()
    console.line()

spineps.utils.compat

spineps.utils.compat

zip_strict

zip_strict(*iterables: Iterable) -> zip

A strict version of zip that raises a ValueError if the input iterables have different lengths.

Converts each iterable to a list to check lengths. This assumes all iterables are finite.

Parameters:

Name Type Description Default
*iterables Iterable

Finite iterables to be zipped together.

()

Returns:

Type Description
zip

An iterator of tuples, where the i-th tuple contains the i-th element from each iterable.

Raises:

Type Description
ValueError

If the input iterables have different lengths.

Source code in spineps/utils/compat.py
def zip_strict(*iterables: Iterable) -> zip:
    """
    A strict version of zip that raises a ValueError if the input iterables have different lengths.

    Converts each iterable to a list to check lengths. This assumes all iterables are finite.

    Args:
        *iterables: Finite iterables to be zipped together.

    Returns:
        An iterator of tuples, where the i-th tuple contains the i-th element from each iterable.

    Raises:
        ValueError: If the input iterables have different lengths.
    """
    lists = [list(it) for it in iterables]
    lengths = [len(lst) for lst in lists]
    if len(set(lengths)) != 1:
        raise ValueError(f"Length mismatch: {lengths}")
    return zip(*lists)