Skip to content

Architectures

Network architectures and the vertebra label definitions used by the models.

spineps.architectures.read_labels

spineps.architectures.read_labels

Vertebra label definitions, enums and mappings used for vertebra classification targets.

VertRegion

Bases: Enum

Spinal region of a vertebra: cervical (HWS), thoracic (BWS) or lumbar (LWS).

Source code in spineps/architectures/read_labels.py
class VertRegion(Enum):
    """Spinal region of a vertebra: cervical (HWS), thoracic (BWS) or lumbar (LWS)."""

    HWS = 0
    BWS = 1
    LWS = 2

VertRel

Bases: Enum

Relative position of a vertebra at a region boundary (e.g. last cervical, first thoracic).

Source code in spineps/architectures/read_labels.py
class VertRel(Enum):
    """Relative position of a vertebra at a region boundary (e.g. last cervical, first thoracic)."""

    NOTHING = 0
    LAST_HWK = 1
    #
    FIRST_BWK = 2
    LAST_BWK = 3
    #
    FIRST_LWK = 4
    LAST_LWK = 5

VertExact

Bases: Enum

Exact vertebra identity from C1 to L5 (T12 absorbs a potential T13), with 24 classes (0-23).

Source code in spineps/architectures/read_labels.py
class VertExact(Enum):
    """Exact vertebra identity from C1 to L5 (T12 absorbs a potential T13), with 24 classes (0-23)."""

    C1 = 0
    C2 = 1
    C3 = 2
    C4 = 3
    C5 = 4
    C6 = 5
    C7 = 6
    T1 = 7
    T2 = 8
    T3 = 9
    T4 = 10
    T5 = 11
    T6 = 12
    T7 = 13
    T8 = 14
    T9 = 15
    T10 = 16
    T11 = 17
    T12 = 18
    # T13 = 18
    L1 = 19
    L2 = 20
    L3 = 21
    L4 = 22
    L5 = 23

VertExactClass

Bases: Enum

Exact vertebra identity including an explicit T13 and L6, with 26 classes (0-25).

Source code in spineps/architectures/read_labels.py
class VertExactClass(Enum):
    """Exact vertebra identity including an explicit T13 and L6, with 26 classes (0-25)."""

    C1 = 0
    C2 = 1
    C3 = 2
    C4 = 3
    C5 = 4
    C6 = 5
    C7 = 6
    T1 = 7
    T2 = 8
    T3 = 9
    T4 = 10
    T5 = 11
    T6 = 12
    T7 = 13
    T8 = 14
    T9 = 15
    T10 = 16
    T11 = 17
    T12 = 18
    T13 = 19
    L1 = 20
    L2 = 21
    L3 = 22
    L4 = 23
    L5 = 24
    L6 = 25

VertT13

Bases: Enum

Whether a vertebra is a (rare) supernumerary T13 or a normal vertebra.

Source code in spineps/architectures/read_labels.py
class VertT13(Enum):
    """Whether a vertebra is a (rare) supernumerary T13 or a normal vertebra."""

    Normal = 0
    T13 = 1

VertGroup

Bases: Enum

Coarse vertebra grouping that buckets neighbouring vertebrae into shared classes (12 groups).

Source code in spineps/architectures/read_labels.py
class VertGroup(Enum):
    """Coarse vertebra grouping that buckets neighbouring vertebrae into shared classes (12 groups)."""

    C12 = 0
    C345 = 1
    C67 = 2
    T12 = 3
    T34 = 4
    T567 = 5
    T89 = 6
    T1011 = 7
    T123 = 8
    L12 = 9
    L34 = 10
    L56 = 11

LabelType

Bases: ABC

Abstract base for converting one or more columns of a data entry into a model target label.

Source code in spineps/architectures/read_labels.py
class LabelType(ABC):
    """Abstract base for converting one or more columns of a data entry into a model target label."""

    def __init__(self, column_name: str | list[str], *args, **kwargs) -> None:  # noqa: ARG002
        """Initialize the label type with the source column name(s).

        Args:
            column_name (str | list[str]): Single column name or list of column names to read from each entry; a single
                string is wrapped into a one-element list.
        """
        if not isinstance(column_name, list):
            column_name = [column_name]
        self.column_name = column_name

    def __call__(self, entry_dict: dict) -> object:
        """Read the configured columns from ``entry_dict`` and convert them into a label.

        Args:
            entry_dict (dict): Mapping of column names to values for a single sample.

        Returns:
            The label produced by :meth:`convert_to_label`.
        """
        entry = self.get_entry(entry_dict)
        return self.convert_to_label(entry)

    def get_entry(self, entry: dict) -> str | int | list[str | int]:
        """Extract the configured column value(s) from a data entry.

        Args:
            entry (dict): Mapping of column names to values.

        Returns:
            str | int | list[str | int]: The single value if only one column is configured, otherwise a list of values.
        """
        entries = [entry[c] for c in self.column_name]
        if len(entries) == 1:
            return entries[0]
        return entries

    @property
    @abstractmethod
    def number_of_channel(self) -> str | int | list[str | int]:
        """Number of output channels (label vector length) produced by this label type."""

    @abstractmethod
    def convert_to_label(self, entry: str):
        """Convert an extracted entry value into the label representation for this label type.

        Args:
            entry (str): The value extracted from the data entry.
        """

number_of_channel abstractmethod property

number_of_channel: str | int | list[str | int]

Number of output channels (label vector length) produced by this label type.

__init__

__init__(
    column_name: str | list[str], *args, **kwargs
) -> None

Initialize the label type with the source column name(s).

Parameters:

Name Type Description Default
column_name str | list[str]

Single column name or list of column names to read from each entry; a single string is wrapped into a one-element list.

required
Source code in spineps/architectures/read_labels.py
def __init__(self, column_name: str | list[str], *args, **kwargs) -> None:  # noqa: ARG002
    """Initialize the label type with the source column name(s).

    Args:
        column_name (str | list[str]): Single column name or list of column names to read from each entry; a single
            string is wrapped into a one-element list.
    """
    if not isinstance(column_name, list):
        column_name = [column_name]
    self.column_name = column_name

__call__

__call__(entry_dict: dict) -> object

Read the configured columns from entry_dict and convert them into a label.

Parameters:

Name Type Description Default
entry_dict dict

Mapping of column names to values for a single sample.

required

Returns:

Type Description
object

The label produced by :meth:convert_to_label.

Source code in spineps/architectures/read_labels.py
def __call__(self, entry_dict: dict) -> object:
    """Read the configured columns from ``entry_dict`` and convert them into a label.

    Args:
        entry_dict (dict): Mapping of column names to values for a single sample.

    Returns:
        The label produced by :meth:`convert_to_label`.
    """
    entry = self.get_entry(entry_dict)
    return self.convert_to_label(entry)

get_entry

get_entry(entry: dict) -> str | int | list[str | int]

Extract the configured column value(s) from a data entry.

Parameters:

Name Type Description Default
entry dict

Mapping of column names to values.

required

Returns:

Type Description
str | int | list[str | int]

str | int | list[str | int]: The single value if only one column is configured, otherwise a list of values.

Source code in spineps/architectures/read_labels.py
def get_entry(self, entry: dict) -> str | int | list[str | int]:
    """Extract the configured column value(s) from a data entry.

    Args:
        entry (dict): Mapping of column names to values.

    Returns:
        str | int | list[str | int]: The single value if only one column is configured, otherwise a list of values.
    """
    entries = [entry[c] for c in self.column_name]
    if len(entries) == 1:
        return entries[0]
    return entries

convert_to_label abstractmethod

convert_to_label(entry: str)

Convert an extracted entry value into the label representation for this label type.

Parameters:

Name Type Description Default
entry str

The value extracted from the data entry.

required
Source code in spineps/architectures/read_labels.py
@abstractmethod
def convert_to_label(self, entry: str):
    """Convert an extracted entry value into the label representation for this label type.

    Args:
        entry (str): The value extracted from the data entry.
    """

EnumLabelType

Bases: LabelType

Label type that one-hot encodes an :class:~enum.Enum value into a multi-class target vector.

Source code in spineps/architectures/read_labels.py
class EnumLabelType(LabelType):
    """Label type that one-hot encodes an :class:`~enum.Enum` value into a multi-class target vector."""

    def __init__(self, enum: Enum, column_name: str, *args, **kwargs) -> None:  # noqa: ARG002
        """Initialize the enum label type.

        Args:
            enum (Enum): Enum class whose members define the classes; its length sets the number of channels.
            column_name (str): Column name holding the enum value for each entry.
        """
        super().__init__(column_name)
        self.enum = enum
        self.n_channel = len(enum)

    @property
    def number_of_channel(self) -> int:
        """Number of channels, equal to the number of members in the configured enum."""
        return self.n_channel

    def convert_to_label(self, entry: Enum) -> list[int]:
        """One-hot encode an enum member into a label vector.

        Args:
            entry (Enum): Enum member whose ``value`` indexes the hot position.

        Returns:
            list[int]: A list of zeros with a single 1 at the index given by ``entry.value``.
        """
        label = list(np.zeros(self.number_of_channel, dtype=int))
        idx = entry.value
        label[idx] = 1
        return label

number_of_channel property

number_of_channel: int

Number of channels, equal to the number of members in the configured enum.

__init__

__init__(
    enum: Enum, column_name: str, *args, **kwargs
) -> None

Initialize the enum label type.

Parameters:

Name Type Description Default
enum Enum

Enum class whose members define the classes; its length sets the number of channels.

required
column_name str

Column name holding the enum value for each entry.

required
Source code in spineps/architectures/read_labels.py
def __init__(self, enum: Enum, column_name: str, *args, **kwargs) -> None:  # noqa: ARG002
    """Initialize the enum label type.

    Args:
        enum (Enum): Enum class whose members define the classes; its length sets the number of channels.
        column_name (str): Column name holding the enum value for each entry.
    """
    super().__init__(column_name)
    self.enum = enum
    self.n_channel = len(enum)

convert_to_label

convert_to_label(entry: Enum) -> list[int]

One-hot encode an enum member into a label vector.

Parameters:

Name Type Description Default
entry Enum

Enum member whose value indexes the hot position.

required

Returns:

Type Description
list[int]

list[int]: A list of zeros with a single 1 at the index given by entry.value.

Source code in spineps/architectures/read_labels.py
def convert_to_label(self, entry: Enum) -> list[int]:
    """One-hot encode an enum member into a label vector.

    Args:
        entry (Enum): Enum member whose ``value`` indexes the hot position.

    Returns:
        list[int]: A list of zeros with a single 1 at the index given by ``entry.value``.
    """
    label = list(np.zeros(self.number_of_channel, dtype=int))
    idx = entry.value
    label[idx] = 1
    return label

BinaryLabelTypeDummy

Bases: LabelType

Label type for a binary attribute, one-hot encoded into two channels (true/false).

Source code in spineps/architectures/read_labels.py
class BinaryLabelTypeDummy(LabelType):
    """Label type for a binary attribute, one-hot encoded into two channels (true/false)."""

    def __init__(self, column_name: str | list[str], *args, **kwargs) -> None:
        """Initialize the binary label type.

        Args:
            column_name (str | list[str]): Column name(s) holding the binary value.
        """
        super().__init__(column_name, *args, **kwargs)

    @property
    def number_of_channel(self) -> int:
        """Number of channels, always 2 (true and false)."""
        return 2

    def convert_to_label(self, entry: str | int) -> int:
        """Convert a truthy/falsy entry into a two-channel one-hot label.

        Args:
            entry (str | int): A value contained in ``TRUE_KEYS`` or ``FALSE_KEYS``.

        Returns:
            list[int]: ``[1, 0]`` for true values and ``[0, 1]`` for false values.

        Raises:
            AssertionError: If ``entry`` is a list, or is not recognised as a true or false value.
        """
        assert not isinstance(entry, list), entry
        if entry in TRUE_KEYS:
            return [1, 0]
        elif entry in FALSE_KEYS:
            return [0, 1]
        raise AssertionError(f"entry {entry} not defined as BinaryLabel")

number_of_channel property

number_of_channel: int

Number of channels, always 2 (true and false).

__init__

__init__(
    column_name: str | list[str], *args, **kwargs
) -> None

Initialize the binary label type.

Parameters:

Name Type Description Default
column_name str | list[str]

Column name(s) holding the binary value.

required
Source code in spineps/architectures/read_labels.py
def __init__(self, column_name: str | list[str], *args, **kwargs) -> None:
    """Initialize the binary label type.

    Args:
        column_name (str | list[str]): Column name(s) holding the binary value.
    """
    super().__init__(column_name, *args, **kwargs)

convert_to_label

convert_to_label(entry: str | int) -> int

Convert a truthy/falsy entry into a two-channel one-hot label.

Parameters:

Name Type Description Default
entry str | int

A value contained in TRUE_KEYS or FALSE_KEYS.

required

Returns:

Type Description
int

list[int]: [1, 0] for true values and [0, 1] for false values.

Raises:

Type Description
AssertionError

If entry is a list, or is not recognised as a true or false value.

Source code in spineps/architectures/read_labels.py
def convert_to_label(self, entry: str | int) -> int:
    """Convert a truthy/falsy entry into a two-channel one-hot label.

    Args:
        entry (str | int): A value contained in ``TRUE_KEYS`` or ``FALSE_KEYS``.

    Returns:
        list[int]: ``[1, 0]`` for true values and ``[0, 1]`` for false values.

    Raises:
        AssertionError: If ``entry`` is a list, or is not recognised as a true or false value.
    """
    assert not isinstance(entry, list), entry
    if entry in TRUE_KEYS:
        return [1, 0]
    elif entry in FALSE_KEYS:
        return [0, 1]
    raise AssertionError(f"entry {entry} not defined as BinaryLabel")

Target

Bases: Enum

Available classification targets, each mapping to a (label-type, enum/column, column-name) configuration tuple.

Source code in spineps/architectures/read_labels.py
class Target(Enum):
    """Available classification targets, each mapping to a (label-type, enum/column, column-name) configuration tuple."""

    REGION = EnumLabelType, VertRegion, "vert_region"  # HWS, BWS, LWS
    VERT = EnumLabelType, VertExact, "vert_exact"  # exakt WK
    VERTEX = EnumLabelType, VertExactClass, "vert_exact2"  # exakt WK
    VT13 = EnumLabelType, VertT13, "vert_t13"  # exakt WK
    VERTREL = EnumLabelType, VertRel, "vert_rel"  # relative label (normal, last LWK, first BWK, ...)
    VERTGRP = EnumLabelType, VertGroup, "vert_group"  # exakt WK
    # for each above is alone multiclass, so softmax afterwards target-wise
    #
    FULLYVISIBLE = BinaryLabelTypeDummy, "vert_cut", "vert_cut"

Objectives

Bundle of classification targets that builds and combines their label vectors for a single data entry.

Source code in spineps/architectures/read_labels.py
class Objectives:
    """Bundle of classification targets that builds and combines their label vectors for a single data entry."""

    def __init__(
        self,
        objectives: list[Target],
        as_group: bool = True,
    ) -> None:
        """Initialize the objectives and instantiate the label type for each target.

        Args:
            objectives (list[Target]): Targets to compute labels for, in order.
            as_group (bool): If True, ``__call__`` returns labels grouped per target name; if False, a flat concatenated list.
        """
        self.__as_group = as_group
        self.targets: list[Target] = objectives
        self.__objective_labels: list[LabelType] = []
        #
        for o in objectives:
            # Horizontal_Flip_Dict
            not_flipped_target = o.value[0](o.value[1], o.value[2])
            self.__objective_labels.append(not_flipped_target)

        self.__n_channel_p_group = [o.number_of_channel for o in self.__objective_labels]
        self.__n_channel = sum(self.__n_channel_p_group)
        self.__required_dict_keys = list(set(flatten([o.value[2] for o in objectives])))

    @property
    def n_channel_p_group(self):
        """List of channel counts, one per target objective."""
        return self.__n_channel_p_group

    @property
    def n_channel(self):
        """Total number of channels across all target objectives."""
        return self.__n_channel

    @property
    def group_2_n_channel(self) -> dict[str, int]:
        """Mapping from each target name to its number of channels."""
        return {self.targets[idx].name: self.n_channel_p_group[idx] for idx in range(len(self.targets))}

    @property
    def required_dict_keys(self):
        """Unique set of data-entry column names required to compute all objectives."""
        return self.__required_dict_keys

    def __call__(
        self,
        entry: dict,
    ) -> list[int]:
        """Compute the label(s) for all objectives from a single data entry.

        Args:
            entry (dict): Data entry containing at least every key in :attr:`required_dict_keys`.

        Returns:
            list[int] | dict | None: A flat concatenated label list when ``as_group`` is False, a per-target-name dict of label
            lists when ``as_group`` is True, or None if a label could not be produced (e.g. a NaN binary/pathology value).

        Raises:
            AssertionError: If a required key is missing from ``entry``.
        """
        entry_keys = entry.keys()
        for r in self.required_dict_keys:
            assert r in entry_keys, f"required {r} not in entry_keys, got {entry_keys}"

        #
        labels = []
        labels_grouped = []
        try:
            list_of_ordered_objectives = self.__objective_labels

            for labeltype in list_of_ordered_objectives:
                labeladd = labeltype(entry)
                if not isinstance(labeladd, list):
                    labeladd = [labeladd]
                labels += labeladd
                labels_grouped.append(labeladd)
        except AssertionError:  # nan binary label
            labels = None
        except AttributeError:  # nan Pathology label
            labels = None
        return labels if not self.__as_group else {self.targets[idx].name: labels_grouped[idx] for idx in range(len(self.targets))}

n_channel_p_group property

n_channel_p_group

List of channel counts, one per target objective.

n_channel property

n_channel

Total number of channels across all target objectives.

group_2_n_channel property

group_2_n_channel: dict[str, int]

Mapping from each target name to its number of channels.

required_dict_keys property

required_dict_keys

Unique set of data-entry column names required to compute all objectives.

__init__

__init__(
    objectives: list[Target], as_group: bool = True
) -> None

Initialize the objectives and instantiate the label type for each target.

Parameters:

Name Type Description Default
objectives list[Target]

Targets to compute labels for, in order.

required
as_group bool

If True, __call__ returns labels grouped per target name; if False, a flat concatenated list.

True
Source code in spineps/architectures/read_labels.py
def __init__(
    self,
    objectives: list[Target],
    as_group: bool = True,
) -> None:
    """Initialize the objectives and instantiate the label type for each target.

    Args:
        objectives (list[Target]): Targets to compute labels for, in order.
        as_group (bool): If True, ``__call__`` returns labels grouped per target name; if False, a flat concatenated list.
    """
    self.__as_group = as_group
    self.targets: list[Target] = objectives
    self.__objective_labels: list[LabelType] = []
    #
    for o in objectives:
        # Horizontal_Flip_Dict
        not_flipped_target = o.value[0](o.value[1], o.value[2])
        self.__objective_labels.append(not_flipped_target)

    self.__n_channel_p_group = [o.number_of_channel for o in self.__objective_labels]
    self.__n_channel = sum(self.__n_channel_p_group)
    self.__required_dict_keys = list(set(flatten([o.value[2] for o in objectives])))

__call__

__call__(entry: dict) -> list[int]

Compute the label(s) for all objectives from a single data entry.

Parameters:

Name Type Description Default
entry dict

Data entry containing at least every key in :attr:required_dict_keys.

required

Returns:

Type Description
list[int]

list[int] | dict | None: A flat concatenated label list when as_group is False, a per-target-name dict of label

list[int]

lists when as_group is True, or None if a label could not be produced (e.g. a NaN binary/pathology value).

Raises:

Type Description
AssertionError

If a required key is missing from entry.

Source code in spineps/architectures/read_labels.py
def __call__(
    self,
    entry: dict,
) -> list[int]:
    """Compute the label(s) for all objectives from a single data entry.

    Args:
        entry (dict): Data entry containing at least every key in :attr:`required_dict_keys`.

    Returns:
        list[int] | dict | None: A flat concatenated label list when ``as_group`` is False, a per-target-name dict of label
        lists when ``as_group`` is True, or None if a label could not be produced (e.g. a NaN binary/pathology value).

    Raises:
        AssertionError: If a required key is missing from ``entry``.
    """
    entry_keys = entry.keys()
    for r in self.required_dict_keys:
        assert r in entry_keys, f"required {r} not in entry_keys, got {entry_keys}"

    #
    labels = []
    labels_grouped = []
    try:
        list_of_ordered_objectives = self.__objective_labels

        for labeltype in list_of_ordered_objectives:
            labeladd = labeltype(entry)
            if not isinstance(labeladd, list):
                labeladd = [labeladd]
            labels += labeladd
            labels_grouped.append(labeladd)
    except AssertionError:  # nan binary label
        labels = None
    except AttributeError:  # nan Pathology label
        labels = None
    return labels if not self.__as_group else {self.targets[idx].name: labels_grouped[idx] for idx in range(len(self.targets))}

SubjectInfo dataclass

Per-subject vertebra labelling metadata, including anomalies, the resolved label map and region boundaries.

Source code in spineps/architectures/read_labels.py
@dataclass
class SubjectInfo:
    """Per-subject vertebra labelling metadata, including anomalies, the resolved label map and region boundaries."""

    subject_name: int
    has_anomaly_entry: bool
    anomaly_entry: dict
    deleted_label: list[int]
    labelmap: dict
    is_remove: bool
    actual_labels: list[int]
    last_lwk: int
    last_bwk: int
    last_hwk: int = 7
    first_bwk: int = 8
    first_lwk: int = 20
    double_entries: list[int] = field(default_factory=list)

    @property
    def has_tea(self) -> bool:
        """Whether the subject has a transitional anomaly (a T11 or T13 anomaly entry).

        Returns:
            bool | None: True/False based on the T11/T13 anomaly flags, or None if the subject has no anomaly entry.
        """
        if not self.has_anomaly_entry:
            return None
        return self.anomaly_entry["T11"] or self.anomaly_entry["T13"]

    @property
    def block(self) -> int:
        """Dataset block identifier, taken from the first three digits of the subject name.

        Returns:
            int: The integer formed by the first three characters of ``subject_name``.
        """
        return int(str(self.subject_name)[:3])

has_tea property

has_tea: bool

Whether the subject has a transitional anomaly (a T11 or T13 anomaly entry).

Returns:

Type Description
bool

bool | None: True/False based on the T11/T13 anomaly flags, or None if the subject has no anomaly entry.

block property

block: int

Dataset block identifier, taken from the first three digits of the subject name.

Returns:

Name Type Description
int int

The integer formed by the first three characters of subject_name.

vert_label_to_vertrel

vert_label_to_vertrel(
    vertlabel: int,
    last_bwk: int | None,
    last_lwk: int | None,
    last_hwk=7,
    first_bwk=8,
    first_lwk=20,
) -> VertRel

Map a numeric vertebra label to its region-boundary relation.

Note that last_hwk, first_bwk and first_lwk are reset to their fixed defaults (7, 8, 20) inside the function.

Parameters:

Name Type Description Default
vertlabel int

Numeric vertebra label to classify.

required
last_bwk int | None

Label of the last thoracic vertebra, or None if unknown.

required
last_lwk int | None

Label of the last lumbar vertebra, or None if unknown.

required
last_hwk int

Label of the last cervical vertebra (overwritten to 7).

7
first_bwk int

Label of the first thoracic vertebra (overwritten to 8).

8
first_lwk int

Label of the first lumbar vertebra (overwritten to 20).

20

Returns:

Name Type Description
VertRel VertRel

The boundary relation of the given label (NOTHING if it is not a boundary vertebra).

Source code in spineps/architectures/read_labels.py
def vert_label_to_vertrel(
    vertlabel: int,
    last_bwk: int | None,
    last_lwk: int | None,
    last_hwk=7,
    first_bwk=8,
    first_lwk=20,
) -> VertRel:
    """Map a numeric vertebra label to its region-boundary relation.

    Note that ``last_hwk``, ``first_bwk`` and ``first_lwk`` are reset to their fixed defaults (7, 8, 20) inside the function.

    Args:
        vertlabel (int): Numeric vertebra label to classify.
        last_bwk: Label of the last thoracic vertebra, or None if unknown.
        last_lwk: Label of the last lumbar vertebra, or None if unknown.
        last_hwk (int): Label of the last cervical vertebra (overwritten to 7).
        first_bwk (int): Label of the first thoracic vertebra (overwritten to 8).
        first_lwk (int): Label of the first lumbar vertebra (overwritten to 20).

    Returns:
        VertRel: The boundary relation of the given label (NOTHING if it is not a boundary vertebra).
    """
    last_hwk = 7
    first_bwk = 8
    first_lwk = 20

    l = VertRel.NOTHING
    if vertlabel == last_hwk:
        l = VertRel.LAST_HWK
    elif vertlabel == first_bwk:
        l = VertRel.FIRST_BWK
    elif last_bwk is not None and vertlabel == last_bwk:
        l = VertRel.LAST_BWK
    elif vertlabel == first_lwk:
        l = VertRel.FIRST_LWK
    elif last_lwk is not None and vertlabel == last_lwk:
        l = VertRel.LAST_LWK
    return l

vert_class_to_region

vert_class_to_region(vert_exact: VertExact) -> VertRegion

Map an exact vertebra class to its spinal region.

Parameters:

Name Type Description Default
vert_exact VertExact

Exact vertebra identity.

required

Returns:

Name Type Description
VertRegion VertRegion

HWS for C1-C7, BWS for T1-T12 and LWS for the lumbar vertebrae.

Source code in spineps/architectures/read_labels.py
def vert_class_to_region(vert_exact: VertExact) -> VertRegion:
    """Map an exact vertebra class to its spinal region.

    Args:
        vert_exact (VertExact): Exact vertebra identity.

    Returns:
        VertRegion: HWS for C1-C7, BWS for T1-T12 and LWS for the lumbar vertebrae.
    """
    return VertRegion.HWS if vert_exact.value < 7 else VertRegion.BWS if 7 <= vert_exact.value < 19 else VertRegion.LWS

vert_label_to_class

vert_label_to_class(vertlabel: int) -> VertExact

Map a numeric vertebra label to a :class:VertExact class.

Label 28 (T13) is folded into T12; all other labels map to vertlabel - 1 capped at 23 (L5).

Parameters:

Name Type Description Default
vertlabel int

Numeric vertebra label.

required

Returns:

Name Type Description
VertExact VertExact

The corresponding exact vertebra class.

Source code in spineps/architectures/read_labels.py
def vert_label_to_class(vertlabel: int) -> VertExact:
    """Map a numeric vertebra label to a :class:`VertExact` class.

    Label 28 (T13) is folded into T12; all other labels map to ``vertlabel - 1`` capped at 23 (L5).

    Args:
        vertlabel (int): Numeric vertebra label.

    Returns:
        VertExact: The corresponding exact vertebra class.
    """
    return VertExact.T12 if vertlabel == 28 else VertExact(min(23, vertlabel - 1))

vert_label_to_exactclass

vert_label_to_exactclass(vertlabel: int) -> VertExactClass

Map a numeric vertebra label to a :class:VertExactClass class.

Label 28 maps to the explicit T13 class; labels up to 19 map to vertlabel - 1 (capped at 24) and higher labels to vertlabel (capped at 25), accounting for the extra T13/L6 slots.

Parameters:

Name Type Description Default
vertlabel int

Numeric vertebra label.

required

Returns:

Name Type Description
VertExactClass VertExactClass

The corresponding exact vertebra class.

Source code in spineps/architectures/read_labels.py
def vert_label_to_exactclass(vertlabel: int) -> VertExactClass:
    """Map a numeric vertebra label to a :class:`VertExactClass` class.

    Label 28 maps to the explicit T13 class; labels up to 19 map to ``vertlabel - 1`` (capped at 24) and higher labels to
    ``vertlabel`` (capped at 25), accounting for the extra T13/L6 slots.

    Args:
        vertlabel (int): Numeric vertebra label.

    Returns:
        VertExactClass: The corresponding exact vertebra class.
    """
    return (
        VertExactClass.T13
        if vertlabel == 28
        else VertExactClass(min(24, vertlabel - 1))
        if vertlabel <= 19
        else VertExactClass(min(25, vertlabel))
    )

vert_class_to_group

vert_class_to_group(vert_exact: VertExact) -> VertGroup

Map an exact vertebra class to its coarse :class:VertGroup.

Parameters:

Name Type Description Default
vert_exact VertExact

Exact vertebra identity.

required

Returns:

Name Type Description
VertGroup VertGroup

The group that contains the given vertebra.

Source code in spineps/architectures/read_labels.py
def vert_class_to_group(vert_exact: VertExact) -> VertGroup:
    """Map an exact vertebra class to its coarse :class:`VertGroup`.

    Args:
        vert_exact (VertExact): Exact vertebra identity.

    Returns:
        VertGroup: The group that contains the given vertebra.
    """
    return vert_exact_to_group_dict[vert_exact]

vertgrp_sequence_to_class

vertgrp_sequence_to_class(
    vertgrp: list[VertGroup],
) -> list[VertExact]

Resolve a top-to-bottom sequence of vertebra groups into exact vertebra classes.

For each group, if every member of the group is present the assignment is trivial; otherwise the neighbouring group before or after the partial run determines whether the members align from the top or from the bottom of the group.

Parameters:

Name Type Description Default
vertgrp list[VertGroup]

Vertebra groups ordered from top (cranial) to bottom (caudal).

required

Returns:

Type Description
list[VertExact]

list[VertExact]: Exact vertebra classes for each position in the input sequence.

Raises:

Type Description
AssertionError

If a partial group has neighbours on both sides, which cannot be resolved unambiguously.

Source code in spineps/architectures/read_labels.py
def vertgrp_sequence_to_class(vertgrp: list[VertGroup]) -> list[VertExact]:
    """Resolve a top-to-bottom sequence of vertebra groups into exact vertebra classes.

    For each group, if every member of the group is present the assignment is trivial; otherwise the neighbouring group before or
    after the partial run determines whether the members align from the top or from the bottom of the group.

    Args:
        vertgrp (list[VertGroup]): Vertebra groups ordered from top (cranial) to bottom (caudal).

    Returns:
        list[VertExact]: Exact vertebra classes for each position in the input sequence.

    Raises:
        AssertionError: If a partial group has neighbours on both sides, which cannot be resolved unambiguously.
    """
    # input must be sorted from top to bottom already!
    vert_exact_list: list[VertExact] = [None] * len(vertgrp)  # type: ignore

    for vg, vel in vert_group_to_exact_dict.items():
        if vg not in vertgrp:
            continue
        vertgrp_count = vertgrp.count(vg)
        vertgrp_idx = [i for i, val in enumerate(vertgrp) if val == vg]
        # all vertgrp instances there, trivial resolution
        if vertgrp_count == len(vel):
            for ii, i in enumerate(vertgrp_idx):
                vert_exact_list[i] = vel[ii]
        # only partial there, the group before or after determines exactness
        else:
            idx_before = min(vertgrp_idx) - 1
            idx_after = max(vertgrp_idx) + 1
            assert not (idx_before >= 0 and idx_after < len(vertgrp)), "partial grp sequence not possible"
            if idx_before >= 0:
                for ii, i in enumerate(vertgrp_idx):
                    vert_exact_list[i] = vel[ii]
            elif idx_after <= len(vertgrp) - 1:
                for ii, i in enumerate(vertgrp_idx[::-1]):
                    vert_exact_list[i] = vel[-(ii + 1)]
                # vert_exact_list[idx_before] = vel[0]
                # vert_exact_list[idx_after] = vel[-1]
    return vert_exact_list

flatten

flatten(
    a: list[str | int | list[str] | list[int]],
) -> Iterator[str | int]

Recursively flatten an arbitrarily nested list of strings and integers.

Parameters:

Name Type Description Default
a list[str | int | list[str] | list[int]]

A value or (nested) list of strings and integers.

required

Yields:

Type Description
str | int

str | int: Each leaf string or integer in depth-first order.

Source code in spineps/architectures/read_labels.py
def flatten(a: list[str | int | list[str] | list[int]]) -> Iterator[str | int]:
    """Recursively flatten an arbitrarily nested list of strings and integers.

    Args:
        a (list[str | int | list[str] | list[int]]): A value or (nested) list of strings and integers.

    Yields:
        str | int: Each leaf string or integer in depth-first order.
    """
    # a = itertools.chain(*a)
    if isinstance(a, (str, int)):
        yield a
    else:
        for b in a:
            yield from flatten(b)

get_subject_info

get_subject_info(
    subject_name: str | int,
    anomaly_dict: dict,
    vert_subfolders_int: list[int],
    subject_name_int: bool = True,
) -> SubjectInfo

Build a :class:SubjectInfo from a subject's raw vertebra labels and any anomaly overrides.

Applies anomaly handling (label deletion, removal flags, T11/T13 remapping and explicit label overrides), derives the actual labels, the expected double-entry labels and the last thoracic/lumbar vertebra labels.

Parameters:

Name Type Description Default
subject_name str | int

Subject identifier.

required
anomaly_dict dict

Mapping of subject names to anomaly entries; empty if no anomalies are known.

required
vert_subfolders_int list[int]

Raw numeric vertebra labels present for the subject.

required
subject_name_int bool

If True, cast subject_name to int before lookup.

True

Returns:

Name Type Description
SubjectInfo SubjectInfo

The assembled per-subject labelling metadata.

Raises:

Type Description
AssertionError

If a LabelOverride length does not match the number of vertebra labels.

Source code in spineps/architectures/read_labels.py
def get_subject_info(
    subject_name: str | int,
    anomaly_dict: dict,
    vert_subfolders_int: list[int],
    subject_name_int: bool = True,
) -> SubjectInfo:
    """Build a :class:`SubjectInfo` from a subject's raw vertebra labels and any anomaly overrides.

    Applies anomaly handling (label deletion, removal flags, T11/T13 remapping and explicit label overrides), derives the actual
    labels, the expected double-entry labels and the last thoracic/lumbar vertebra labels.

    Args:
        subject_name (str | int): Subject identifier.
        anomaly_dict (dict): Mapping of subject names to anomaly entries; empty if no anomalies are known.
        vert_subfolders_int (list[int]): Raw numeric vertebra labels present for the subject.
        subject_name_int (bool): If True, cast ``subject_name`` to int before lookup.

    Returns:
        SubjectInfo: The assembled per-subject labelling metadata.

    Raises:
        AssertionError: If a ``LabelOverride`` length does not match the number of vertebra labels.
    """
    if subject_name_int:
        subject_name = int(subject_name)
    double_entries = []
    labelmap = {}
    has_anomaly_entry = False
    anomaly_entry = {}
    deleted_label = []
    is_remove = False
    if subject_name in anomaly_dict:
        anomaly_entry = anomaly_dict[subject_name]
        has_anomaly_entry = True
        if anomaly_entry["DeleteLabel"] is not None:
            deleted_label = [anomaly_entry["DeleteLabel"]]
        if bool(anomaly_entry["Remove"]):
            is_remove = True

        if bool(anomaly_entry["T11"]):
            labelmap = {i: i + 1 for i in range(19, 26)}
        elif bool(anomaly_entry["T13"]):
            labelmap = {20: 28, 21: 20, 22: 21, 23: 22, 24: 23, 25: 24}

    if "LabelOverride" in anomaly_entry and anomaly_entry["LabelOverride"] is not None:
        assert len(anomaly_entry["LabelOverride"]) == len(vert_subfolders_int), (
            f"len({anomaly_entry['LabelOverride']}) != len({vert_subfolders_int})"
        )
        vert_subfolders_sorted = sorted(vert_subfolders_int, key=lambda x: x if x != 28 else 19.5)
        labelmap = {i: k for i, k in zip(vert_subfolders_sorted, anomaly_entry["LabelOverride"], strict=False)}  # noqa: C416

    actual_labels = [labelmap.get(v, v) for v in vert_subfolders_int]

    if 28 in actual_labels and 19 not in actual_labels:
        print(f"{subject_name}: 28 in {actual_labels} but no 19")
        is_remove = True

    # T11
    if 18 in actual_labels and 19 not in actual_labels and 20 in actual_labels:
        double_entries = [17, 18, 20, 21]
    elif 28 in actual_labels:
        double_entries = [19, 28, 20, 21]
    else:
        double_entries = [18, 19, 20, 21]

    if len(anomaly_dict) == 0:
        double_entries = []

    #
    # last_hwk = 7
    # first_bwk = 8
    bwks = [v for v in actual_labels if 7 < v <= 19 or v == 28]
    last_bwk = max(bwks) if max(actual_labels) >= 18 and len(bwks) > 0 else None
    # first_lwk = 20
    lwks = [v for v in actual_labels if 22 < v < 26]
    last_lwk = max(lwks) if max(actual_labels) >= 23 and len(lwks) > 0 else None
    return SubjectInfo(
        subject_name=subject_name,
        has_anomaly_entry=has_anomaly_entry,
        anomaly_entry=anomaly_entry,
        actual_labels=actual_labels,
        deleted_label=deleted_label,
        is_remove=is_remove,
        labelmap=labelmap,
        last_lwk=last_lwk,
        last_bwk=last_bwk,
        double_entries=double_entries,
    )

get_vert_entry

get_vert_entry(
    v: int, subject_info: SubjectInfo
) -> tuple[int, dict]

Build the per-vertebra target entry dict for a single vertebra label.

Applies the subject's label map to v and fills in all derived targets (relative position, exact class, exact2 class, group, region and T13 flag).

Parameters:

Name Type Description Default
v int

Raw numeric vertebra label.

required
subject_info SubjectInfo

Subject metadata providing the label map and region boundaries.

required

Returns:

Type Description
tuple[int, dict]

tuple[int, dict]: The remapped actual label and a dict of target values keyed by their column names.

Source code in spineps/architectures/read_labels.py
def get_vert_entry(v: int, subject_info: SubjectInfo) -> tuple[int, dict]:
    """Build the per-vertebra target entry dict for a single vertebra label.

    Applies the subject's label map to ``v`` and fills in all derived targets (relative position, exact class, exact2 class,
    group, region and T13 flag).

    Args:
        v (int): Raw numeric vertebra label.
        subject_info (SubjectInfo): Subject metadata providing the label map and region boundaries.

    Returns:
        tuple[int, dict]: The remapped actual label and a dict of target values keyed by their column names.
    """
    entry: dict = {}

    v_actual = subject_info.labelmap.get(v, v)
    entry["subject_name"] = subject_info.subject_name
    entry["vert_rel"] = vert_label_to_vertrel(
        v_actual,
        subject_info.last_bwk,
        subject_info.last_lwk,
        last_hwk=subject_info.last_hwk,
        first_bwk=subject_info.first_bwk,
        first_lwk=subject_info.first_lwk,
    )
    entry["vert_cut"] = False
    entry["vert_exact"] = vert_label_to_class(v_actual)
    entry["vert_exact2"] = vert_label_to_exactclass(v_actual)
    entry["vert_group"] = vert_class_to_group(entry["vert_exact"])
    entry["vert_region"] = vert_class_to_region(entry["vert_exact"])
    entry["vert_t13"] = VertT13.T13 if v_actual == 28 else VertT13.Normal
    return v_actual, entry

spineps.architectures.pl_densenet

spineps.architectures.pl_densenet

DenseNet/ResNet-based classifier (PLClassifier) and model configuration for vertebra labeling.

MODEL

Bases: Enum

Selectable backbone architectures (DenseNet and ResNet variants) for the vertebra classifier.

Source code in spineps/architectures/pl_densenet.py
class MODEL(Enum):
    """Selectable backbone architectures (DenseNet and ResNet variants) for the vertebra classifier."""

    DENSENET169 = DenseNet169
    DENSENET121 = DenseNet121
    RESNET10 = 10  # resnet10
    RESNET18 = 18  # resnet18
    RESNET34 = 34  # resnet34
    RESNET50 = 50  # resnet50
    RESNET101 = 101  # resnet101
    RESNET152 = 152  # resnet152
    RESNET2 = 2  # resnet2

    def __call__(
        self,
        opt: ARGS_MODEL,
        remove_classification_head: bool = True,
    ) -> tuple[nn.Module, int]:
        """Instantiate the selected backbone network.

        Args:
            opt (ARGS_MODEL): Model configuration providing channels, class count and pretraining flag.
            remove_classification_head (bool): If True, strip the backbone's final classification layer so it acts as a
                feature extractor.

        Returns:
            tuple: ``(model, linear_in_features)`` where ``linear_in_features`` is the input feature size of the removed head.

        Raises:
            ValueError: If the enum member is neither a DenseNet nor a ResNet variant.
        """
        if "DENSENET" in self.name:
            return get_densenet_architecture(
                self.value,
                in_channel=opt.in_channel,
                out_channel=opt.num_classes,
                pretrained=not opt.not_pretrained,
                remove_classification_head=remove_classification_head,
            )
        elif "RESNET" in self.name:
            d = {
                10: resnet10,
                18: resnet18,
                34: resnet34,
                50: resnet50,
                101: resnet101,
                152: resnet152,
                2: resnet2,
            }
            return get_resnet_architecture(
                d[self.value],
                remove_classification_head=remove_classification_head,
            )
        else:
            raise ValueError(f"Model {self.name} not supported.")

__call__

__call__(
    opt: ARGS_MODEL, remove_classification_head: bool = True
) -> tuple[nn.Module, int]

Instantiate the selected backbone network.

Parameters:

Name Type Description Default
opt ARGS_MODEL

Model configuration providing channels, class count and pretraining flag.

required
remove_classification_head bool

If True, strip the backbone's final classification layer so it acts as a feature extractor.

True

Returns:

Name Type Description
tuple tuple[Module, int]

(model, linear_in_features) where linear_in_features is the input feature size of the removed head.

Raises:

Type Description
ValueError

If the enum member is neither a DenseNet nor a ResNet variant.

Source code in spineps/architectures/pl_densenet.py
def __call__(
    self,
    opt: ARGS_MODEL,
    remove_classification_head: bool = True,
) -> tuple[nn.Module, int]:
    """Instantiate the selected backbone network.

    Args:
        opt (ARGS_MODEL): Model configuration providing channels, class count and pretraining flag.
        remove_classification_head (bool): If True, strip the backbone's final classification layer so it acts as a
            feature extractor.

    Returns:
        tuple: ``(model, linear_in_features)`` where ``linear_in_features`` is the input feature size of the removed head.

    Raises:
        ValueError: If the enum member is neither a DenseNet nor a ResNet variant.
    """
    if "DENSENET" in self.name:
        return get_densenet_architecture(
            self.value,
            in_channel=opt.in_channel,
            out_channel=opt.num_classes,
            pretrained=not opt.not_pretrained,
            remove_classification_head=remove_classification_head,
        )
    elif "RESNET" in self.name:
        d = {
            10: resnet10,
            18: resnet18,
            34: resnet34,
            50: resnet50,
            101: resnet101,
            152: resnet152,
            2: resnet2,
        }
        return get_resnet_architecture(
            d[self.value],
            remove_classification_head=remove_classification_head,
        )
    else:
        raise ValueError(f"Model {self.name} not supported.")

ARGS_MODEL dataclass

Bases: Class_to_ArgParse

Configuration (and argparse schema) for the vertebra labeling classifier, covering backbone, heads and training options.

Source code in spineps/architectures/pl_densenet.py
@dataclass
class ARGS_MODEL(Class_to_ArgParse):
    """Configuration (and argparse schema) for the vertebra labeling classifier, covering backbone, heads and training options."""

    backbone: MODEL = MODEL.DENSENET169.name
    classification_conv: bool = False
    classification_linear: bool = True
    #
    n_epoch: int = 100
    lr: float = 1e-4
    l2_regularization_w: float = 1e-6  # 1e-5 was ok
    scheduler_endfactor: float = 1e-3
    #
    in_channel: int = 1  # 1 for img, will be set elsewhere
    not_pretrained: bool = True
    #
    mse_weighting: float = 0.0
    dropout: float = 0.05
    weight_decay: float = 0  # 1e-4
    #
    num_classes: int | None = None  # Filled elsewhere
    n_channel_p_group: int | None = None  # Filled elsewhere

PLClassifier

Bases: LightningModule

LightningModule that classifies vertebrae using a shared backbone with one classification head per target group.

The configured backbone (DenseNet/ResNet) acts as a feature extractor, and a separate head is built for each entry in group_2_n_channel to produce that group's class logits.

Source code in spineps/architectures/pl_densenet.py
class PLClassifier(pl.LightningModule):
    """LightningModule that classifies vertebrae using a shared backbone with one classification head per target group.

    The configured backbone (DenseNet/ResNet) acts as a feature extractor, and a separate head is built for each entry in
    ``group_2_n_channel`` to produce that group's class logits.
    """

    def __init__(self, opt: ARGS_MODEL, group_2_n_channel: dict[str, int]):
        """Build the backbone, classification heads and loss/activation modules.

        Args:
            opt (ARGS_MODEL): Model configuration; ``opt.num_classes`` must be an int.
            group_2_n_channel (dict[str, int]): Mapping from each target group name to its number of output channels.

        Raises:
            AssertionError: If ``opt.num_classes`` is not an int.
        """
        super().__init__()
        self.opt = opt
        assert isinstance(opt.num_classes, int), opt.num_classes
        self.num_classes: int = opt.num_classes
        self.group_2_n_channel = group_2_n_channel
        # save hyperparameter, everything below not visible
        self.save_hyperparameters()

        self.backbone = MODEL[opt.backbone]
        self.net, linear_in = self.backbone(opt, remove_classification_head=True)
        self.classification_heads = self.build_classification_heads(linear_in, opt.classification_conv, opt.classification_linear)
        self.classification_keys = list(self.classification_heads.keys())
        self.mse_weighting = opt.mse_weighting

        self.metrics_to_log = ["f1", "mcc", "acc", "auroc", "f1_avg"]
        self.metrics_to_log_overall = ["f1", "f1_avg"]

        self.train_step_outputs = []
        self.val_step_outputs = []
        self.softmax = nn.Softmax(dim=1)  # use this group-wise?
        self.sigmoid = nn.Sigmoid()
        self.cross_entropy = nn.CrossEntropyLoss()
        self.mse = nn.MSELoss(reduction="none")
        self.l2_reg_w = opt.l2_regularization_w

    def forward(self, x) -> dict[str, torch.Tensor]:
        """Extract features with the backbone and apply every classification head.

        Args:
            x (torch.Tensor): Input image batch fed to the backbone.

        Returns:
            dict[str, torch.Tensor]: Mapping from each group name to that head's output logits.
        """
        features = self.net(x)
        return {k: v(features) for k, v in self.classification_heads.items()}

    def build_classification_heads(self, linear_in: int, convolution_first: bool, fully_connected: bool) -> nn.ModuleDict:
        """Build one classification head per target group as a :class:`~torch.nn.ModuleDict`.

        Args:
            linear_in (int): Number of input features coming from the backbone.
            convolution_first (bool): If True, prepend a 3x3x3 convolution that halves the channels before the linear layers.
            fully_connected (bool): If True, insert a hidden linear+ReLU layer (halving channels) before the output layer.

        Returns:
            nn.ModuleDict: Mapping from each group name to its head, each ending in a linear layer with that group's class count.
        """

        def construct_one_head(output_classes: int):
            """Build a single classification head producing ``output_classes`` logits.

            Args:
                output_classes (int): Number of output classes for this head.

            Returns:
                nn.Sequential: The assembled head modules.
            """
            modules = []
            n_channel = linear_in
            n_channel_next = linear_in
            if convolution_first:
                n_channel_next = n_channel // 2
                modules.append(nn.Conv3d(n_channel, n_channel_next, kernel_size=(3, 3, 3), device=self.device))
                n_channel = n_channel_next
            if fully_connected:
                n_channel_next = n_channel // 2
                modules.append(nn.Linear(n_channel, n_channel_next, device=self.device))
                modules.append(nn.ReLU())
                n_channel = n_channel_next
            modules.append(nn.Linear(n_channel, output_classes, device=self.device))

            return nn.Sequential(*modules)

        return nn.ModuleDict({k: construct_one_head(v) for k, v in self.group_2_n_channel.items()})

    def __str__(self) -> str:
        """Return the model name.

        Returns:
            str: The fixed name ``"VertebraLabelingModel"``.
        """
        return "VertebraLabelingModel"

__init__

__init__(
    opt: ARGS_MODEL, group_2_n_channel: dict[str, int]
)

Build the backbone, classification heads and loss/activation modules.

Parameters:

Name Type Description Default
opt ARGS_MODEL

Model configuration; opt.num_classes must be an int.

required
group_2_n_channel dict[str, int]

Mapping from each target group name to its number of output channels.

required

Raises:

Type Description
AssertionError

If opt.num_classes is not an int.

Source code in spineps/architectures/pl_densenet.py
def __init__(self, opt: ARGS_MODEL, group_2_n_channel: dict[str, int]):
    """Build the backbone, classification heads and loss/activation modules.

    Args:
        opt (ARGS_MODEL): Model configuration; ``opt.num_classes`` must be an int.
        group_2_n_channel (dict[str, int]): Mapping from each target group name to its number of output channels.

    Raises:
        AssertionError: If ``opt.num_classes`` is not an int.
    """
    super().__init__()
    self.opt = opt
    assert isinstance(opt.num_classes, int), opt.num_classes
    self.num_classes: int = opt.num_classes
    self.group_2_n_channel = group_2_n_channel
    # save hyperparameter, everything below not visible
    self.save_hyperparameters()

    self.backbone = MODEL[opt.backbone]
    self.net, linear_in = self.backbone(opt, remove_classification_head=True)
    self.classification_heads = self.build_classification_heads(linear_in, opt.classification_conv, opt.classification_linear)
    self.classification_keys = list(self.classification_heads.keys())
    self.mse_weighting = opt.mse_weighting

    self.metrics_to_log = ["f1", "mcc", "acc", "auroc", "f1_avg"]
    self.metrics_to_log_overall = ["f1", "f1_avg"]

    self.train_step_outputs = []
    self.val_step_outputs = []
    self.softmax = nn.Softmax(dim=1)  # use this group-wise?
    self.sigmoid = nn.Sigmoid()
    self.cross_entropy = nn.CrossEntropyLoss()
    self.mse = nn.MSELoss(reduction="none")
    self.l2_reg_w = opt.l2_regularization_w

forward

forward(x) -> dict[str, torch.Tensor]

Extract features with the backbone and apply every classification head.

Parameters:

Name Type Description Default
x Tensor

Input image batch fed to the backbone.

required

Returns:

Type Description
dict[str, Tensor]

dict[str, torch.Tensor]: Mapping from each group name to that head's output logits.

Source code in spineps/architectures/pl_densenet.py
def forward(self, x) -> dict[str, torch.Tensor]:
    """Extract features with the backbone and apply every classification head.

    Args:
        x (torch.Tensor): Input image batch fed to the backbone.

    Returns:
        dict[str, torch.Tensor]: Mapping from each group name to that head's output logits.
    """
    features = self.net(x)
    return {k: v(features) for k, v in self.classification_heads.items()}

build_classification_heads

build_classification_heads(
    linear_in: int,
    convolution_first: bool,
    fully_connected: bool,
) -> nn.ModuleDict

Build one classification head per target group as a :class:~torch.nn.ModuleDict.

Parameters:

Name Type Description Default
linear_in int

Number of input features coming from the backbone.

required
convolution_first bool

If True, prepend a 3x3x3 convolution that halves the channels before the linear layers.

required
fully_connected bool

If True, insert a hidden linear+ReLU layer (halving channels) before the output layer.

required

Returns:

Type Description
ModuleDict

nn.ModuleDict: Mapping from each group name to its head, each ending in a linear layer with that group's class count.

Source code in spineps/architectures/pl_densenet.py
def build_classification_heads(self, linear_in: int, convolution_first: bool, fully_connected: bool) -> nn.ModuleDict:
    """Build one classification head per target group as a :class:`~torch.nn.ModuleDict`.

    Args:
        linear_in (int): Number of input features coming from the backbone.
        convolution_first (bool): If True, prepend a 3x3x3 convolution that halves the channels before the linear layers.
        fully_connected (bool): If True, insert a hidden linear+ReLU layer (halving channels) before the output layer.

    Returns:
        nn.ModuleDict: Mapping from each group name to its head, each ending in a linear layer with that group's class count.
    """

    def construct_one_head(output_classes: int):
        """Build a single classification head producing ``output_classes`` logits.

        Args:
            output_classes (int): Number of output classes for this head.

        Returns:
            nn.Sequential: The assembled head modules.
        """
        modules = []
        n_channel = linear_in
        n_channel_next = linear_in
        if convolution_first:
            n_channel_next = n_channel // 2
            modules.append(nn.Conv3d(n_channel, n_channel_next, kernel_size=(3, 3, 3), device=self.device))
            n_channel = n_channel_next
        if fully_connected:
            n_channel_next = n_channel // 2
            modules.append(nn.Linear(n_channel, n_channel_next, device=self.device))
            modules.append(nn.ReLU())
            n_channel = n_channel_next
        modules.append(nn.Linear(n_channel, output_classes, device=self.device))

        return nn.Sequential(*modules)

    return nn.ModuleDict({k: construct_one_head(v) for k, v in self.group_2_n_channel.items()})

__str__

__str__() -> str

Return the model name.

Returns:

Name Type Description
str str

The fixed name "VertebraLabelingModel".

Source code in spineps/architectures/pl_densenet.py
def __str__(self) -> str:
    """Return the model name.

    Returns:
        str: The fixed name ``"VertebraLabelingModel"``.
    """
    return "VertebraLabelingModel"

resnet2

resnet2(
    layers: list[int] | None = None, **kwargs
) -> ResNet

Build a very small 2-stage MONAI ResNet variant ("resnet2").

Parameters:

Name Type Description Default
layers list[int] | None

Number of blocks per stage; defaults to [1, 1].

None
**kwargs

Additional keyword arguments forwarded to the MONAI _resnet factory.

{}

Returns:

Name Type Description
ResNet ResNet

The constructed ResNet model.

Source code in spineps/architectures/pl_densenet.py
def resnet2(
    layers: list[int] | None = None,
    **kwargs,
) -> ResNet:
    """Build a very small 2-stage MONAI ResNet variant ("resnet2").

    Args:
        layers (list[int] | None): Number of blocks per stage; defaults to ``[1, 1]``.
        **kwargs: Additional keyword arguments forwarded to the MONAI ``_resnet`` factory.

    Returns:
        ResNet: The constructed ResNet model.
    """
    if layers is None:
        layers = [1, 1]
    return _resnet("resnet2", ResNetBlock, layers, get_inplanes(), False, False, **kwargs)

get_densenet_architecture

get_densenet_architecture(
    model: object,
    in_channel: int = 1,
    out_channel: int = 1,
    pretrained: bool = True,
    remove_classification_head: bool = True,
) -> tuple[nn.Module, int]

Instantiate a 3D MONAI DenseNet and optionally remove its final classification layer.

Parameters:

Name Type Description Default
model object

A MONAI DenseNet constructor (e.g. DenseNet121 or DenseNet169).

required
in_channel int

Number of input channels.

1
out_channel int

Number of output channels for the original classification layer.

1
pretrained bool

Whether to load pretrained weights.

True
remove_classification_head bool

If True, drop the final classification layer to use the model as a feature extractor.

True

Returns:

Name Type Description
tuple tuple[Module, int]

(model, linear_infeatures) where linear_infeatures is the input feature size of the removed head.

Source code in spineps/architectures/pl_densenet.py
def get_densenet_architecture(
    model: object,
    in_channel: int = 1,
    out_channel: int = 1,
    pretrained: bool = True,
    remove_classification_head: bool = True,
) -> tuple[nn.Module, int]:
    """Instantiate a 3D MONAI DenseNet and optionally remove its final classification layer.

    Args:
        model: A MONAI DenseNet constructor (e.g. ``DenseNet121`` or ``DenseNet169``).
        in_channel (int): Number of input channels.
        out_channel (int): Number of output channels for the original classification layer.
        pretrained (bool): Whether to load pretrained weights.
        remove_classification_head (bool): If True, drop the final classification layer to use the model as a feature extractor.

    Returns:
        tuple: ``(model, linear_infeatures)`` where ``linear_infeatures`` is the input feature size of the removed head.
    """
    model = model(
        spatial_dims=3,
        in_channels=in_channel,
        out_channels=out_channel,
        pretrained=pretrained,
    )
    linear_infeatures = model.class_layers[-1].in_features
    if remove_classification_head:
        model.class_layers = model.class_layers[:-1]
    return model, linear_infeatures

get_resnet_architecture

get_resnet_architecture(
    model: object, remove_classification_head: bool = True
) -> tuple[nn.Module, int]

Instantiate a 3D MONAI ResNet and optionally remove its fully connected head.

Parameters:

Name Type Description Default
model object

A MONAI ResNet constructor (e.g. resnet18 or resnet50).

required
remove_classification_head bool

If True, set the final fully connected layer to None to use the model as a feature extractor.

True

Returns:

Name Type Description
tuple tuple[Module, int]

(model, linear_infeatures) where linear_infeatures is the input feature size of the removed head.

Source code in spineps/architectures/pl_densenet.py
def get_resnet_architecture(
    model: object,
    remove_classification_head: bool = True,
) -> tuple[nn.Module, int]:
    """Instantiate a 3D MONAI ResNet and optionally remove its fully connected head.

    Args:
        model: A MONAI ResNet constructor (e.g. ``resnet18`` or ``resnet50``).
        remove_classification_head (bool): If True, set the final fully connected layer to None to use the model as a
            feature extractor.

    Returns:
        tuple: ``(model, linear_infeatures)`` where ``linear_infeatures`` is the input feature size of the removed head.
    """
    model = model(
        spatial_dims=3,
        n_input_channels=1,
    )
    linear_infeatures = model.fc.in_features
    if remove_classification_head:
        model.fc = None
    return model, linear_infeatures

spineps.architectures.pl_unet

spineps.architectures.pl_unet

PyTorch Lightning wrapper around the 3D U-Net used for spine segmentation training and inference.

PLNet

Bases: LightningModule

LightningModule wrapping a :class:~spineps.architectures.unet3D.Unet3D for multi-class segmentation.

Configures a 4-class 3D U-Net with 10 input channels and provides shared loss/metric helpers (Dice scores) and softmax-based class prediction.

Source code in spineps/architectures/pl_unet.py
class PLNet(pl.LightningModule):
    """LightningModule wrapping a :class:`~spineps.architectures.unet3D.Unet3D` for multi-class segmentation.

    Configures a 4-class 3D U-Net with 10 input channels and provides shared loss/metric helpers (Dice scores) and softmax-based
    class prediction.
    """

    def __init__(self, opt: Any = None, do2D: bool = False, *args: Any, **kwargs: Any) -> None:  # noqa: ARG002
        """Build the wrapped U-Net and store training hyperparameters.

        Args:
            opt: Options object; ``opt.n_epoch`` sets the number of epochs when provided.
            do2D (bool): Whether the model operates in 2D mode (affects only the string representation).
            *args (Any): Ignored extra positional arguments.
            **kwargs (Any): Ignored extra keyword arguments.
        """
        super().__init__()
        self.save_hyperparameters()

        nclass = Unet3D

        dim_mults = (1, 2, 4, 8)
        dim = 16  # 16

        # if opt.high_res:
        #    dim = 16
        #    dim_mults = (2, 4, 8, 8)

        self.network = nclass(
            dim=dim,
            dim_mults=dim_mults,
            out_dim=4,
            channels=10,  # 10,
        )

        self.opt = opt
        self.do2D = do2D
        self.n_epoch = opt.n_epoch if opt is not None else 0
        self.start_lr = 0.0001
        self.linear_end_factor = 0.01
        self.l2_reg_w = 0.0001
        self.n_classes = 4

        self.train_step_outputs = {}
        self.val_step_outputs = {}

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x) -> torch.Tensor:
        """Run the wrapped U-Net on an input batch.

        Args:
            x (torch.Tensor): Input tensor of shape ``(B, 10, D, H, W)``.

        Returns:
            torch.Tensor: Raw class logits of shape ``(B, 4, D, H, W)``.
        """
        return self.network(x)

    def _shared_step(self, target, gt, detach2cpu: bool = False):
        """Run the forward pass and compute loss plus predicted class map for a batch.

        Args:
            target (torch.Tensor): Input batch fed to the network.
            gt (torch.Tensor): Ground-truth class labels.
            detach2cpu (bool): If True, detach ``gt``, ``logits`` and ``pred_cls`` and move them to CPU.

        Returns:
            tuple: ``(loss, logits, gt, pred_cls)`` where ``pred_cls`` is the argmax over the softmax of the logits.
        """
        logits = self.forward(target)
        loss = self.loss(logits, gt)

        with torch.no_grad():
            # pred_cls = torch.max(logits, 1)
            pred_x = self.softmax(logits)  # , dim=1)
            _, pred_cls = torch.max(pred_x, 1)
            del pred_x
            if detach2cpu:
                # From here on CPU
                gt = gt.detach().cpu()
                logits = logits.detach().cpu()
                pred_cls = pred_cls.detach().cpu()
        return loss, logits, gt, pred_cls

    def _shared_metric_step(self, loss, _, gt, pred_cls):
        """Compute segmentation metrics (overall, foreground and per-class Dice) for a batch.

        Args:
            loss (torch.Tensor): The batch loss to record.
            _: Unused logits placeholder.
            gt (torch.Tensor): Ground-truth class labels.
            pred_cls (torch.Tensor): Predicted class labels.

        Returns:
            dict: Metrics with keys ``loss``, ``dice``, ``diceFG`` (Dice ignoring the background class) and ``dice_p_cls``.
        """
        dice = mF.dice(pred_cls, gt, num_classes=self.n_classes)
        diceFG = mF.dice(pred_cls, gt, num_classes=self.n_classes, ignore_index=0)
        dice_p_cls = mF.dice(pred_cls, gt, average=None, num_classes=self.n_classes)
        return {"loss": loss.detach().cpu(), "dice": dice, "diceFG": diceFG, "dice_p_cls": dice_p_cls}

    def _shared_metric_append(self, metrics, outputs):
        """Append each metric value to the per-key list of accumulated outputs (in place).

        Args:
            metrics (dict): Metric name to value mapping for one step.
            outputs (dict): Accumulator mapping each metric name to a list of values.
        """
        for k, v in metrics.items():
            if k not in outputs:
                outputs[k] = []
            outputs[k].append(v)

    def _shared_cat_metrics(self, outputs):
        """Aggregate accumulated per-step metrics into mean values.

        Args:
            outputs (dict): Mapping of metric name to a list of per-step tensors.

        Returns:
            dict: Mean of each metric; ``dice_p_cls`` is averaged along the step dimension to keep per-class values.
        """
        results = {}
        for m, v in outputs.items():
            stacked = torch.stack(v)
            results[m] = torch.mean(stacked) if m != "dice_p_cls" else torch.mean(stacked, dim=0)
        return results

    def __str__(self) -> str:
        """Return a short model name including the spatial mode.

        Returns:
            str: ``"Unet_2D"`` or ``"Unet_3D"`` depending on ``do2D``.
        """
        text = "Unet"
        dim = "2D" if self.do2D else "3D"
        return text + "_" + dim

__init__

__init__(
    opt: Any = None,
    do2D: bool = False,
    *args: Any,
    **kwargs: Any,
) -> None

Build the wrapped U-Net and store training hyperparameters.

Parameters:

Name Type Description Default
opt Any

Options object; opt.n_epoch sets the number of epochs when provided.

None
do2D bool

Whether the model operates in 2D mode (affects only the string representation).

False
*args Any

Ignored extra positional arguments.

()
**kwargs Any

Ignored extra keyword arguments.

{}
Source code in spineps/architectures/pl_unet.py
def __init__(self, opt: Any = None, do2D: bool = False, *args: Any, **kwargs: Any) -> None:  # noqa: ARG002
    """Build the wrapped U-Net and store training hyperparameters.

    Args:
        opt: Options object; ``opt.n_epoch`` sets the number of epochs when provided.
        do2D (bool): Whether the model operates in 2D mode (affects only the string representation).
        *args (Any): Ignored extra positional arguments.
        **kwargs (Any): Ignored extra keyword arguments.
    """
    super().__init__()
    self.save_hyperparameters()

    nclass = Unet3D

    dim_mults = (1, 2, 4, 8)
    dim = 16  # 16

    # if opt.high_res:
    #    dim = 16
    #    dim_mults = (2, 4, 8, 8)

    self.network = nclass(
        dim=dim,
        dim_mults=dim_mults,
        out_dim=4,
        channels=10,  # 10,
    )

    self.opt = opt
    self.do2D = do2D
    self.n_epoch = opt.n_epoch if opt is not None else 0
    self.start_lr = 0.0001
    self.linear_end_factor = 0.01
    self.l2_reg_w = 0.0001
    self.n_classes = 4

    self.train_step_outputs = {}
    self.val_step_outputs = {}

    self.softmax = nn.Softmax(dim=1)

forward

forward(x) -> torch.Tensor

Run the wrapped U-Net on an input batch.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (B, 10, D, H, W).

required

Returns:

Type Description
Tensor

torch.Tensor: Raw class logits of shape (B, 4, D, H, W).

Source code in spineps/architectures/pl_unet.py
def forward(self, x) -> torch.Tensor:
    """Run the wrapped U-Net on an input batch.

    Args:
        x (torch.Tensor): Input tensor of shape ``(B, 10, D, H, W)``.

    Returns:
        torch.Tensor: Raw class logits of shape ``(B, 4, D, H, W)``.
    """
    return self.network(x)

__str__

__str__() -> str

Return a short model name including the spatial mode.

Returns:

Name Type Description
str str

"Unet_2D" or "Unet_3D" depending on do2D.

Source code in spineps/architectures/pl_unet.py
def __str__(self) -> str:
    """Return a short model name including the spatial mode.

    Returns:
        str: ``"Unet_2D"`` or ``"Unet_3D"`` depending on ``do2D``.
    """
    text = "Unet"
    dim = "2D" if self.do2D else "3D"
    return text + "_" + dim

softmax_helper_dim1

softmax_helper_dim1(x: Tensor) -> torch.Tensor

Apply softmax along dimension 1 (the channel/class dimension).

Parameters:

Name Type Description Default
x Tensor

Input tensor with classes on dimension 1.

required

Returns:

Type Description
Tensor

torch.Tensor: Tensor of the same shape with a softmax applied over dimension 1.

Source code in spineps/architectures/pl_unet.py
def softmax_helper_dim1(x: torch.Tensor) -> torch.Tensor:
    """Apply softmax along dimension 1 (the channel/class dimension).

    Args:
        x (torch.Tensor): Input tensor with classes on dimension 1.

    Returns:
        torch.Tensor: Tensor of the same shape with a softmax applied over dimension 1.
    """
    return torch.softmax(x, 1)

spineps.architectures.unet3D

spineps.architectures.unet3D

3D U-Net architecture with residual blocks used for volumetric spine segmentation.

Unet3D

Bases: Module

A 3D U-Net with residual (ResNet) blocks, a symmetric encoder/decoder and skip connections.

The encoder repeatedly applies two residual blocks followed by a strided convolution that halves each spatial dimension; a bottleneck of two residual blocks follows; the decoder mirrors the encoder with transposed convolutions and averages in the matching encoder skip features before a final residual block and 1x1x1 output convolution.

Source code in spineps/architectures/unet3D.py
class Unet3D(nn.Module):
    """A 3D U-Net with residual (ResNet) blocks, a symmetric encoder/decoder and skip connections.

    The encoder repeatedly applies two residual blocks followed by a strided convolution that halves each spatial dimension; a
    bottleneck of two residual blocks follows; the decoder mirrors the encoder with transposed convolutions and averages in the
    matching encoder skip features before a final residual block and 1x1x1 output convolution.
    """

    def __init__(
        self,
        dim,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8),
        channels=1,
        conditional_dimensions=0,
        resnet_block_groups=8,
        learned_variance=False,
        conditional_label_size=0,
    ):
        """Build the 3D U-Net layers.

        Args:
            dim (int): Base feature dimension used to derive per-level channel counts.
            init_dim (int | None): Channels after the initial convolution; defaults to ``dim``.
            out_dim (int | None): Number of output channels; defaults to ``channels`` (doubled if ``learned_variance``).
            dim_mults (tuple[int, ...]): Per-resolution multipliers of ``dim`` defining encoder/decoder depth and widths.
            channels (int): Number of input image channels.
            conditional_dimensions (int): Extra conditioning channels concatenated to the input at the first convolution.
            resnet_block_groups (int): Number of groups for the GroupNorm inside each residual block.
            learned_variance (bool): If True, doubles the default output channels to also predict variance.
            conditional_label_size (int): Size of an optional conditional label vector (stored but unused in ``forward``).
        """
        super().__init__()

        self.learned_variance = learned_variance

        self.conditional_label_size = conditional_label_size
        # determine dimensions

        self.channels = channels

        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv3d((channels + conditional_dimensions), init_dim, 7, padding=3)

        dims = [init_dim, *(int(dim * m) for m in dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))  # noqa: RUF007

        block_klass = partial(ResnetBlock3D, groups=resnet_block_groups)
        time_dim = None

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                        nn.Conv3d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        nn.ConvTranspose3d(dim_in, dim_in, 4, 2, 1) if not is_last else nn.Identity(),
                    ]
                )
            )

        default_out_dim = channels * (1 if not learned_variance else 2)
        self.out_dim = default(out_dim, default_out_dim) * 1

        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
        self.final_conv = nn.Conv3d(dim, self.out_dim, 1)
        self.first_forward = False

    def forward(
        self,
        x,
        time: torch.Tensor | None = None,
        label: torch.Tensor | None = None,  # noqa: ARG002
        embedding: torch.Tensor | None = None,  # noqa: ARG002
    ) -> torch.Tensor:  # time
        """Run the U-Net forward pass on a 5D input volume.

        Args:
            x (torch.Tensor): Input tensor of shape ``(B, channels, D, H, W)``; each spatial dimension must be divisible by
                ``2 ** (num_downsampling_levels)``.
            time: Unused timestep input; replaced by a constant if None.
            label: Unused optional conditioning label.
            embedding: Unused optional conditioning embedding.

        Returns:
            torch.Tensor: Output tensor of shape ``(B, out_dim, D, H, W)`` with the same spatial size as the input.

        Raises:
            AssertionError: If any spatial dimension of ``x`` is not divisible by the total downsampling factor.
        """
        down_factor = 2 ** (len(self.downs) - 1)
        shape = x.shape
        assert shape[-1] % down_factor == 0, f"dimensions are not dividable by {down_factor}, {shape}, {shape[-1]}"
        assert shape[-2] % down_factor == 0, f"dimensions are not dividable by {down_factor}, {shape}, {shape[-2]}"
        assert shape[-3] % down_factor == 0, f"dimensions are not dividable by {down_factor}, {shape}, {shape[-3]}"
        if self.first_forward:
            print("|", x.shape)
        if self.first_forward:
            print("|", x.shape)

        # time = None
        if time is None:
            time = torch.ones((1,), device=x.device)
        x = self.init_conv(x)
        r = x.clone()
        if self.first_forward:
            print("-", x.shape)

        t = None

        h = []
        o = "-"
        for block1, block2, downsample in self.downs:  # type: ignore
            x = block1(x, t)
            x = block2(x, t)
            h.append(x)
            x = downsample(x)
            if self.first_forward:
                o += "-"
                print(o, x.shape, "\t")

        x = self.mid_block1(x, t)
        x = self.mid_block2(x, t)
        if self.first_forward:
            print(o, x.shape)
            o = o[:-1]

        for block1, block2, upsample in self.ups:  # type: ignore
            x = 0.5 * (x + h.pop())
            x = block1(x, t)
            x = block2(x, t)
            x = upsample(x)
            if self.first_forward:
                print(o, x.shape, "\t")
                o = o[:-1]

        x = torch.cat((x, r), dim=1)

        x = self.final_res_block(x, t)
        if self.first_forward:
            print("|", x.shape)

        x = self.final_conv(x)

        if self.first_forward:
            print("|", x.shape)

        if self.first_forward:
            print("|", x.shape)

        self.first_forward = False

        return x

__init__

__init__(
    dim,
    init_dim=None,
    out_dim=None,
    dim_mults=(1, 2, 4, 8),
    channels=1,
    conditional_dimensions=0,
    resnet_block_groups=8,
    learned_variance=False,
    conditional_label_size=0,
)

Build the 3D U-Net layers.

Parameters:

Name Type Description Default
dim int

Base feature dimension used to derive per-level channel counts.

required
init_dim int | None

Channels after the initial convolution; defaults to dim.

None
out_dim int | None

Number of output channels; defaults to channels (doubled if learned_variance).

None
dim_mults tuple[int, ...]

Per-resolution multipliers of dim defining encoder/decoder depth and widths.

(1, 2, 4, 8)
channels int

Number of input image channels.

1
conditional_dimensions int

Extra conditioning channels concatenated to the input at the first convolution.

0
resnet_block_groups int

Number of groups for the GroupNorm inside each residual block.

8
learned_variance bool

If True, doubles the default output channels to also predict variance.

False
conditional_label_size int

Size of an optional conditional label vector (stored but unused in forward).

0
Source code in spineps/architectures/unet3D.py
def __init__(
    self,
    dim,
    init_dim=None,
    out_dim=None,
    dim_mults=(1, 2, 4, 8),
    channels=1,
    conditional_dimensions=0,
    resnet_block_groups=8,
    learned_variance=False,
    conditional_label_size=0,
):
    """Build the 3D U-Net layers.

    Args:
        dim (int): Base feature dimension used to derive per-level channel counts.
        init_dim (int | None): Channels after the initial convolution; defaults to ``dim``.
        out_dim (int | None): Number of output channels; defaults to ``channels`` (doubled if ``learned_variance``).
        dim_mults (tuple[int, ...]): Per-resolution multipliers of ``dim`` defining encoder/decoder depth and widths.
        channels (int): Number of input image channels.
        conditional_dimensions (int): Extra conditioning channels concatenated to the input at the first convolution.
        resnet_block_groups (int): Number of groups for the GroupNorm inside each residual block.
        learned_variance (bool): If True, doubles the default output channels to also predict variance.
        conditional_label_size (int): Size of an optional conditional label vector (stored but unused in ``forward``).
    """
    super().__init__()

    self.learned_variance = learned_variance

    self.conditional_label_size = conditional_label_size
    # determine dimensions

    self.channels = channels

    init_dim = default(init_dim, dim)
    self.init_conv = nn.Conv3d((channels + conditional_dimensions), init_dim, 7, padding=3)

    dims = [init_dim, *(int(dim * m) for m in dim_mults)]
    in_out = list(zip(dims[:-1], dims[1:]))  # noqa: RUF007

    block_klass = partial(ResnetBlock3D, groups=resnet_block_groups)
    time_dim = None

    self.downs = nn.ModuleList([])
    self.ups = nn.ModuleList([])
    num_resolutions = len(in_out)

    for ind, (dim_in, dim_out) in enumerate(in_out):
        is_last = ind >= (num_resolutions - 1)

        self.downs.append(
            nn.ModuleList(
                [
                    block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                    block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                    nn.Conv3d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
                ]
            )
        )

    mid_dim = dims[-1]
    self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
    self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

    for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
        is_last = ind == (len(in_out) - 1)

        self.ups.append(
            nn.ModuleList(
                [
                    block_klass(dim_out, dim_in, time_emb_dim=time_dim),
                    block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                    nn.ConvTranspose3d(dim_in, dim_in, 4, 2, 1) if not is_last else nn.Identity(),
                ]
            )
        )

    default_out_dim = channels * (1 if not learned_variance else 2)
    self.out_dim = default(out_dim, default_out_dim) * 1

    self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
    self.final_conv = nn.Conv3d(dim, self.out_dim, 1)
    self.first_forward = False

forward

forward(
    x,
    time: Tensor | None = None,
    label: Tensor | None = None,
    embedding: Tensor | None = None,
) -> torch.Tensor

Run the U-Net forward pass on a 5D input volume.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (B, channels, D, H, W); each spatial dimension must be divisible by 2 ** (num_downsampling_levels).

required
time Tensor | None

Unused timestep input; replaced by a constant if None.

None
label Tensor | None

Unused optional conditioning label.

None
embedding Tensor | None

Unused optional conditioning embedding.

None

Returns:

Type Description
Tensor

torch.Tensor: Output tensor of shape (B, out_dim, D, H, W) with the same spatial size as the input.

Raises:

Type Description
AssertionError

If any spatial dimension of x is not divisible by the total downsampling factor.

Source code in spineps/architectures/unet3D.py
def forward(
    self,
    x,
    time: torch.Tensor | None = None,
    label: torch.Tensor | None = None,  # noqa: ARG002
    embedding: torch.Tensor | None = None,  # noqa: ARG002
) -> torch.Tensor:  # time
    """Run the U-Net forward pass on a 5D input volume.

    Args:
        x (torch.Tensor): Input tensor of shape ``(B, channels, D, H, W)``; each spatial dimension must be divisible by
            ``2 ** (num_downsampling_levels)``.
        time: Unused timestep input; replaced by a constant if None.
        label: Unused optional conditioning label.
        embedding: Unused optional conditioning embedding.

    Returns:
        torch.Tensor: Output tensor of shape ``(B, out_dim, D, H, W)`` with the same spatial size as the input.

    Raises:
        AssertionError: If any spatial dimension of ``x`` is not divisible by the total downsampling factor.
    """
    down_factor = 2 ** (len(self.downs) - 1)
    shape = x.shape
    assert shape[-1] % down_factor == 0, f"dimensions are not dividable by {down_factor}, {shape}, {shape[-1]}"
    assert shape[-2] % down_factor == 0, f"dimensions are not dividable by {down_factor}, {shape}, {shape[-2]}"
    assert shape[-3] % down_factor == 0, f"dimensions are not dividable by {down_factor}, {shape}, {shape[-3]}"
    if self.first_forward:
        print("|", x.shape)
    if self.first_forward:
        print("|", x.shape)

    # time = None
    if time is None:
        time = torch.ones((1,), device=x.device)
    x = self.init_conv(x)
    r = x.clone()
    if self.first_forward:
        print("-", x.shape)

    t = None

    h = []
    o = "-"
    for block1, block2, downsample in self.downs:  # type: ignore
        x = block1(x, t)
        x = block2(x, t)
        h.append(x)
        x = downsample(x)
        if self.first_forward:
            o += "-"
            print(o, x.shape, "\t")

    x = self.mid_block1(x, t)
    x = self.mid_block2(x, t)
    if self.first_forward:
        print(o, x.shape)
        o = o[:-1]

    for block1, block2, upsample in self.ups:  # type: ignore
        x = 0.5 * (x + h.pop())
        x = block1(x, t)
        x = block2(x, t)
        x = upsample(x)
        if self.first_forward:
            print(o, x.shape, "\t")
            o = o[:-1]

    x = torch.cat((x, r), dim=1)

    x = self.final_res_block(x, t)
    if self.first_forward:
        print("|", x.shape)

    x = self.final_conv(x)

    if self.first_forward:
        print("|", x.shape)

    if self.first_forward:
        print("|", x.shape)

    self.first_forward = False

    return x

Block3D

Bases: Module

Basic 3D conv block: 3x3x3 convolution, group normalization, optional FiLM-style scale/shift and LeakyReLU.

Source code in spineps/architectures/unet3D.py
class Block3D(nn.Module):
    """Basic 3D conv block: 3x3x3 convolution, group normalization, optional FiLM-style scale/shift and LeakyReLU."""

    def __init__(self, dim, dim_out, groups=8):
        """Build the conv block.

        Args:
            dim (int): Number of input channels.
            dim_out (int): Number of output channels.
            groups (int): Number of groups for the GroupNorm.
        """
        super().__init__()
        self.proj = nn.Conv3d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.LeakyReLU()

    def forward(self, x, scale_shift=None) -> torch.Tensor:
        """Apply convolution, normalization, optional scale/shift modulation and activation.

        Args:
            x (torch.Tensor): Input tensor of shape ``(B, dim, D, H, W)``.
            scale_shift (tuple[torch.Tensor, torch.Tensor] | None): Optional ``(scale, shift)`` tensors applied as
                ``x * (scale + 1) + shift`` after normalization.

        Returns:
            torch.Tensor: Output tensor of shape ``(B, dim_out, D, H, W)``.
        """
        x = self.proj(x)
        x = self.norm(x)

        if scale_shift is not None:
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

__init__

__init__(dim, dim_out, groups=8)

Build the conv block.

Parameters:

Name Type Description Default
dim int

Number of input channels.

required
dim_out int

Number of output channels.

required
groups int

Number of groups for the GroupNorm.

8
Source code in spineps/architectures/unet3D.py
def __init__(self, dim, dim_out, groups=8):
    """Build the conv block.

    Args:
        dim (int): Number of input channels.
        dim_out (int): Number of output channels.
        groups (int): Number of groups for the GroupNorm.
    """
    super().__init__()
    self.proj = nn.Conv3d(dim, dim_out, 3, padding=1)
    self.norm = nn.GroupNorm(groups, dim_out)
    self.act = nn.LeakyReLU()

forward

forward(x, scale_shift=None) -> torch.Tensor

Apply convolution, normalization, optional scale/shift modulation and activation.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (B, dim, D, H, W).

required
scale_shift tuple[Tensor, Tensor] | None

Optional (scale, shift) tensors applied as x * (scale + 1) + shift after normalization.

None

Returns:

Type Description
Tensor

torch.Tensor: Output tensor of shape (B, dim_out, D, H, W).

Source code in spineps/architectures/unet3D.py
def forward(self, x, scale_shift=None) -> torch.Tensor:
    """Apply convolution, normalization, optional scale/shift modulation and activation.

    Args:
        x (torch.Tensor): Input tensor of shape ``(B, dim, D, H, W)``.
        scale_shift (tuple[torch.Tensor, torch.Tensor] | None): Optional ``(scale, shift)`` tensors applied as
            ``x * (scale + 1) + shift`` after normalization.

    Returns:
        torch.Tensor: Output tensor of shape ``(B, dim_out, D, H, W)``.
    """
    x = self.proj(x)
    x = self.norm(x)

    if scale_shift is not None:
        scale, shift = scale_shift
        x = x * (scale + 1) + shift

    x = self.act(x)
    return x

ResnetBlock3D

Bases: Module

Residual block of two :class:Block3D layers with a skip connection and optional time-embedding modulation.

Source code in spineps/architectures/unet3D.py
class ResnetBlock3D(nn.Module):
    """Residual block of two :class:`Block3D` layers with a skip connection and optional time-embedding modulation."""

    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        """Build the residual block.

        Args:
            dim (int): Number of input channels.
            dim_out (int): Number of output channels.
            time_emb_dim (int | None): If given, size of a time embedding mapped to per-channel scale and shift parameters.
            groups (int): Number of groups for the GroupNorm in each inner block.
        """
        super().__init__()
        self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if time_emb_dim is not None else None

        self.block1 = Block3D(dim, dim_out, groups=groups)
        self.block2 = Block3D(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None) -> torch.Tensor:
        """Apply the two conv blocks plus residual connection, optionally modulated by a time embedding.

        Args:
            x (torch.Tensor): Input tensor of shape ``(B, dim, D, H, W)``.
            time_emb (torch.Tensor | None): Optional time embedding of shape ``(B, time_emb_dim)`` used to produce the
                scale/shift applied in the first inner block.

        Returns:
            torch.Tensor: Output tensor of shape ``(B, dim_out, D, H, W)``.
        """
        scale_shift = None
        if self.mlp is not None and time_emb is not None:
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, "b c -> b c 1 1 1")
            scale_shift = time_emb.chunk(2, dim=1)

        h = self.block1(x, scale_shift=scale_shift)

        h = self.block2(h)

        return h + self.res_conv(x)

__init__

__init__(dim, dim_out, *, time_emb_dim=None, groups=8)

Build the residual block.

Parameters:

Name Type Description Default
dim int

Number of input channels.

required
dim_out int

Number of output channels.

required
time_emb_dim int | None

If given, size of a time embedding mapped to per-channel scale and shift parameters.

None
groups int

Number of groups for the GroupNorm in each inner block.

8
Source code in spineps/architectures/unet3D.py
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
    """Build the residual block.

    Args:
        dim (int): Number of input channels.
        dim_out (int): Number of output channels.
        time_emb_dim (int | None): If given, size of a time embedding mapped to per-channel scale and shift parameters.
        groups (int): Number of groups for the GroupNorm in each inner block.
    """
    super().__init__()
    self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if time_emb_dim is not None else None

    self.block1 = Block3D(dim, dim_out, groups=groups)
    self.block2 = Block3D(dim_out, dim_out, groups=groups)
    self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

forward

forward(x, time_emb=None) -> torch.Tensor

Apply the two conv blocks plus residual connection, optionally modulated by a time embedding.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (B, dim, D, H, W).

required
time_emb Tensor | None

Optional time embedding of shape (B, time_emb_dim) used to produce the scale/shift applied in the first inner block.

None

Returns:

Type Description
Tensor

torch.Tensor: Output tensor of shape (B, dim_out, D, H, W).

Source code in spineps/architectures/unet3D.py
def forward(self, x, time_emb=None) -> torch.Tensor:
    """Apply the two conv blocks plus residual connection, optionally modulated by a time embedding.

    Args:
        x (torch.Tensor): Input tensor of shape ``(B, dim, D, H, W)``.
        time_emb (torch.Tensor | None): Optional time embedding of shape ``(B, time_emb_dim)`` used to produce the
            scale/shift applied in the first inner block.

    Returns:
        torch.Tensor: Output tensor of shape ``(B, dim_out, D, H, W)``.
    """
    scale_shift = None
    if self.mlp is not None and time_emb is not None:
        time_emb = self.mlp(time_emb)
        time_emb = rearrange(time_emb, "b c -> b c 1 1 1")
        scale_shift = time_emb.chunk(2, dim=1)

    h = self.block1(x, scale_shift=scale_shift)

    h = self.block2(h)

    return h + self.res_conv(x)

default

default(val: object, d: object) -> object

Return val if it is not None, otherwise a default value.

Parameters:

Name Type Description Default
val object

The candidate value.

required
d object

The fallback value, or a zero-argument callable that produces it.

required

Returns:

Type Description
object

val if not None; otherwise d() when d is callable, else d.

Source code in spineps/architectures/unet3D.py
def default(val: object, d: object) -> object:
    """Return ``val`` if it is not None, otherwise a default value.

    Args:
        val: The candidate value.
        d: The fallback value, or a zero-argument callable that produces it.

    Returns:
        ``val`` if not None; otherwise ``d()`` when ``d`` is callable, else ``d``.
    """
    if val is not None:
        return val
    return d() if isfunction(d) else d