Architectures¶
Network architectures and the vertebra label definitions used by the models.
spineps.architectures.read_labels¶
spineps.architectures.read_labels
¶
Vertebra label definitions, enums and mappings used for vertebra classification targets.
VertRegion
¶
VertRel
¶
Bases: Enum
Relative position of a vertebra at a region boundary (e.g. last cervical, first thoracic).
Source code in spineps/architectures/read_labels.py
VertExact
¶
Bases: Enum
Exact vertebra identity from C1 to L5 (T12 absorbs a potential T13), with 24 classes (0-23).
Source code in spineps/architectures/read_labels.py
VertExactClass
¶
Bases: Enum
Exact vertebra identity including an explicit T13 and L6, with 26 classes (0-25).
Source code in spineps/architectures/read_labels.py
VertT13
¶
VertGroup
¶
Bases: Enum
Coarse vertebra grouping that buckets neighbouring vertebrae into shared classes (12 groups).
Source code in spineps/architectures/read_labels.py
LabelType
¶
Bases: ABC
Abstract base for converting one or more columns of a data entry into a model target label.
Source code in spineps/architectures/read_labels.py
number_of_channel
abstractmethod
property
¶
Number of output channels (label vector length) produced by this label type.
__init__
¶
Initialize the label type with the source column name(s).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
column_name
|
str | list[str]
|
Single column name or list of column names to read from each entry; a single string is wrapped into a one-element list. |
required |
Source code in spineps/architectures/read_labels.py
__call__
¶
Read the configured columns from entry_dict and convert them into a label.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
entry_dict
|
dict
|
Mapping of column names to values for a single sample. |
required |
Returns:
| Type | Description |
|---|---|
object
|
The label produced by :meth: |
Source code in spineps/architectures/read_labels.py
get_entry
¶
Extract the configured column value(s) from a data entry.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
entry
|
dict
|
Mapping of column names to values. |
required |
Returns:
| Type | Description |
|---|---|
str | int | list[str | int]
|
str | int | list[str | int]: The single value if only one column is configured, otherwise a list of values. |
Source code in spineps/architectures/read_labels.py
convert_to_label
abstractmethod
¶
Convert an extracted entry value into the label representation for this label type.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
entry
|
str
|
The value extracted from the data entry. |
required |
EnumLabelType
¶
Bases: LabelType
Label type that one-hot encodes an :class:~enum.Enum value into a multi-class target vector.
Source code in spineps/architectures/read_labels.py
number_of_channel
property
¶
Number of channels, equal to the number of members in the configured enum.
__init__
¶
Initialize the enum label type.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
enum
|
Enum
|
Enum class whose members define the classes; its length sets the number of channels. |
required |
column_name
|
str
|
Column name holding the enum value for each entry. |
required |
Source code in spineps/architectures/read_labels.py
convert_to_label
¶
One-hot encode an enum member into a label vector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
entry
|
Enum
|
Enum member whose |
required |
Returns:
| Type | Description |
|---|---|
list[int]
|
list[int]: A list of zeros with a single 1 at the index given by |
Source code in spineps/architectures/read_labels.py
BinaryLabelTypeDummy
¶
Bases: LabelType
Label type for a binary attribute, one-hot encoded into two channels (true/false).
Source code in spineps/architectures/read_labels.py
__init__
¶
Initialize the binary label type.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
column_name
|
str | list[str]
|
Column name(s) holding the binary value. |
required |
Source code in spineps/architectures/read_labels.py
convert_to_label
¶
Convert a truthy/falsy entry into a two-channel one-hot label.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
entry
|
str | int
|
A value contained in |
required |
Returns:
| Type | Description |
|---|---|
int
|
list[int]: |
Raises:
| Type | Description |
|---|---|
AssertionError
|
If |
Source code in spineps/architectures/read_labels.py
Target
¶
Bases: Enum
Available classification targets, each mapping to a (label-type, enum/column, column-name) configuration tuple.
Source code in spineps/architectures/read_labels.py
Objectives
¶
Bundle of classification targets that builds and combines their label vectors for a single data entry.
Source code in spineps/architectures/read_labels.py
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 | |
group_2_n_channel
property
¶
Mapping from each target name to its number of channels.
required_dict_keys
property
¶
Unique set of data-entry column names required to compute all objectives.
__init__
¶
Initialize the objectives and instantiate the label type for each target.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
objectives
|
list[Target]
|
Targets to compute labels for, in order. |
required |
as_group
|
bool
|
If True, |
True
|
Source code in spineps/architectures/read_labels.py
__call__
¶
Compute the label(s) for all objectives from a single data entry.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
entry
|
dict
|
Data entry containing at least every key in :attr: |
required |
Returns:
| Type | Description |
|---|---|
list[int]
|
list[int] | dict | None: A flat concatenated label list when |
list[int]
|
lists when |
Raises:
| Type | Description |
|---|---|
AssertionError
|
If a required key is missing from |
Source code in spineps/architectures/read_labels.py
SubjectInfo
dataclass
¶
Per-subject vertebra labelling metadata, including anomalies, the resolved label map and region boundaries.
Source code in spineps/architectures/read_labels.py
has_tea
property
¶
Whether the subject has a transitional anomaly (a T11 or T13 anomaly entry).
Returns:
| Type | Description |
|---|---|
bool
|
bool | None: True/False based on the T11/T13 anomaly flags, or None if the subject has no anomaly entry. |
block
property
¶
Dataset block identifier, taken from the first three digits of the subject name.
Returns:
| Name | Type | Description |
|---|---|---|
int |
int
|
The integer formed by the first three characters of |
vert_label_to_vertrel
¶
vert_label_to_vertrel(
vertlabel: int,
last_bwk: int | None,
last_lwk: int | None,
last_hwk=7,
first_bwk=8,
first_lwk=20,
) -> VertRel
Map a numeric vertebra label to its region-boundary relation.
Note that last_hwk, first_bwk and first_lwk are reset to their fixed defaults (7, 8, 20) inside the function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vertlabel
|
int
|
Numeric vertebra label to classify. |
required |
last_bwk
|
int | None
|
Label of the last thoracic vertebra, or None if unknown. |
required |
last_lwk
|
int | None
|
Label of the last lumbar vertebra, or None if unknown. |
required |
last_hwk
|
int
|
Label of the last cervical vertebra (overwritten to 7). |
7
|
first_bwk
|
int
|
Label of the first thoracic vertebra (overwritten to 8). |
8
|
first_lwk
|
int
|
Label of the first lumbar vertebra (overwritten to 20). |
20
|
Returns:
| Name | Type | Description |
|---|---|---|
VertRel |
VertRel
|
The boundary relation of the given label (NOTHING if it is not a boundary vertebra). |
Source code in spineps/architectures/read_labels.py
vert_class_to_region
¶
Map an exact vertebra class to its spinal region.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vert_exact
|
VertExact
|
Exact vertebra identity. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
VertRegion |
VertRegion
|
HWS for C1-C7, BWS for T1-T12 and LWS for the lumbar vertebrae. |
Source code in spineps/architectures/read_labels.py
vert_label_to_class
¶
Map a numeric vertebra label to a :class:VertExact class.
Label 28 (T13) is folded into T12; all other labels map to vertlabel - 1 capped at 23 (L5).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vertlabel
|
int
|
Numeric vertebra label. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
VertExact |
VertExact
|
The corresponding exact vertebra class. |
Source code in spineps/architectures/read_labels.py
vert_label_to_exactclass
¶
Map a numeric vertebra label to a :class:VertExactClass class.
Label 28 maps to the explicit T13 class; labels up to 19 map to vertlabel - 1 (capped at 24) and higher labels to
vertlabel (capped at 25), accounting for the extra T13/L6 slots.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vertlabel
|
int
|
Numeric vertebra label. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
VertExactClass |
VertExactClass
|
The corresponding exact vertebra class. |
Source code in spineps/architectures/read_labels.py
vert_class_to_group
¶
Map an exact vertebra class to its coarse :class:VertGroup.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vert_exact
|
VertExact
|
Exact vertebra identity. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
VertGroup |
VertGroup
|
The group that contains the given vertebra. |
Source code in spineps/architectures/read_labels.py
vertgrp_sequence_to_class
¶
Resolve a top-to-bottom sequence of vertebra groups into exact vertebra classes.
For each group, if every member of the group is present the assignment is trivial; otherwise the neighbouring group before or after the partial run determines whether the members align from the top or from the bottom of the group.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vertgrp
|
list[VertGroup]
|
Vertebra groups ordered from top (cranial) to bottom (caudal). |
required |
Returns:
| Type | Description |
|---|---|
list[VertExact]
|
list[VertExact]: Exact vertebra classes for each position in the input sequence. |
Raises:
| Type | Description |
|---|---|
AssertionError
|
If a partial group has neighbours on both sides, which cannot be resolved unambiguously. |
Source code in spineps/architectures/read_labels.py
flatten
¶
Recursively flatten an arbitrarily nested list of strings and integers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a
|
list[str | int | list[str] | list[int]]
|
A value or (nested) list of strings and integers. |
required |
Yields:
| Type | Description |
|---|---|
str | int
|
str | int: Each leaf string or integer in depth-first order. |
Source code in spineps/architectures/read_labels.py
get_subject_info
¶
get_subject_info(
subject_name: str | int,
anomaly_dict: dict,
vert_subfolders_int: list[int],
subject_name_int: bool = True,
) -> SubjectInfo
Build a :class:SubjectInfo from a subject's raw vertebra labels and any anomaly overrides.
Applies anomaly handling (label deletion, removal flags, T11/T13 remapping and explicit label overrides), derives the actual labels, the expected double-entry labels and the last thoracic/lumbar vertebra labels.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
subject_name
|
str | int
|
Subject identifier. |
required |
anomaly_dict
|
dict
|
Mapping of subject names to anomaly entries; empty if no anomalies are known. |
required |
vert_subfolders_int
|
list[int]
|
Raw numeric vertebra labels present for the subject. |
required |
subject_name_int
|
bool
|
If True, cast |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
SubjectInfo |
SubjectInfo
|
The assembled per-subject labelling metadata. |
Raises:
| Type | Description |
|---|---|
AssertionError
|
If a |
Source code in spineps/architectures/read_labels.py
589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 | |
get_vert_entry
¶
Build the per-vertebra target entry dict for a single vertebra label.
Applies the subject's label map to v and fills in all derived targets (relative position, exact class, exact2 class,
group, region and T13 flag).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
v
|
int
|
Raw numeric vertebra label. |
required |
subject_info
|
SubjectInfo
|
Subject metadata providing the label map and region boundaries. |
required |
Returns:
| Type | Description |
|---|---|
tuple[int, dict]
|
tuple[int, dict]: The remapped actual label and a dict of target values keyed by their column names. |
Source code in spineps/architectures/read_labels.py
spineps.architectures.pl_densenet¶
spineps.architectures.pl_densenet
¶
DenseNet/ResNet-based classifier (PLClassifier) and model configuration for vertebra labeling.
MODEL
¶
Bases: Enum
Selectable backbone architectures (DenseNet and ResNet variants) for the vertebra classifier.
Source code in spineps/architectures/pl_densenet.py
__call__
¶
Instantiate the selected backbone network.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
opt
|
ARGS_MODEL
|
Model configuration providing channels, class count and pretraining flag. |
required |
remove_classification_head
|
bool
|
If True, strip the backbone's final classification layer so it acts as a feature extractor. |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
tuple[Module, int]
|
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If the enum member is neither a DenseNet nor a ResNet variant. |
Source code in spineps/architectures/pl_densenet.py
ARGS_MODEL
dataclass
¶
Bases: Class_to_ArgParse
Configuration (and argparse schema) for the vertebra labeling classifier, covering backbone, heads and training options.
Source code in spineps/architectures/pl_densenet.py
PLClassifier
¶
Bases: LightningModule
LightningModule that classifies vertebrae using a shared backbone with one classification head per target group.
The configured backbone (DenseNet/ResNet) acts as a feature extractor, and a separate head is built for each entry in
group_2_n_channel to produce that group's class logits.
Source code in spineps/architectures/pl_densenet.py
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 | |
__init__
¶
Build the backbone, classification heads and loss/activation modules.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
opt
|
ARGS_MODEL
|
Model configuration; |
required |
group_2_n_channel
|
dict[str, int]
|
Mapping from each target group name to its number of output channels. |
required |
Raises:
| Type | Description |
|---|---|
AssertionError
|
If |
Source code in spineps/architectures/pl_densenet.py
forward
¶
Extract features with the backbone and apply every classification head.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input image batch fed to the backbone. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Tensor]
|
dict[str, torch.Tensor]: Mapping from each group name to that head's output logits. |
Source code in spineps/architectures/pl_densenet.py
build_classification_heads
¶
build_classification_heads(
linear_in: int,
convolution_first: bool,
fully_connected: bool,
) -> nn.ModuleDict
Build one classification head per target group as a :class:~torch.nn.ModuleDict.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
linear_in
|
int
|
Number of input features coming from the backbone. |
required |
convolution_first
|
bool
|
If True, prepend a 3x3x3 convolution that halves the channels before the linear layers. |
required |
fully_connected
|
bool
|
If True, insert a hidden linear+ReLU layer (halving channels) before the output layer. |
required |
Returns:
| Type | Description |
|---|---|
ModuleDict
|
nn.ModuleDict: Mapping from each group name to its head, each ending in a linear layer with that group's class count. |
Source code in spineps/architectures/pl_densenet.py
__str__
¶
Return the model name.
Returns:
| Name | Type | Description |
|---|---|---|
str |
str
|
The fixed name |
resnet2
¶
Build a very small 2-stage MONAI ResNet variant ("resnet2").
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
layers
|
list[int] | None
|
Number of blocks per stage; defaults to |
None
|
**kwargs
|
Additional keyword arguments forwarded to the MONAI |
{}
|
Returns:
| Name | Type | Description |
|---|---|---|
ResNet |
ResNet
|
The constructed ResNet model. |
Source code in spineps/architectures/pl_densenet.py
get_densenet_architecture
¶
get_densenet_architecture(
model: object,
in_channel: int = 1,
out_channel: int = 1,
pretrained: bool = True,
remove_classification_head: bool = True,
) -> tuple[nn.Module, int]
Instantiate a 3D MONAI DenseNet and optionally remove its final classification layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
object
|
A MONAI DenseNet constructor (e.g. |
required |
in_channel
|
int
|
Number of input channels. |
1
|
out_channel
|
int
|
Number of output channels for the original classification layer. |
1
|
pretrained
|
bool
|
Whether to load pretrained weights. |
True
|
remove_classification_head
|
bool
|
If True, drop the final classification layer to use the model as a feature extractor. |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
tuple[Module, int]
|
|
Source code in spineps/architectures/pl_densenet.py
get_resnet_architecture
¶
get_resnet_architecture(
model: object, remove_classification_head: bool = True
) -> tuple[nn.Module, int]
Instantiate a 3D MONAI ResNet and optionally remove its fully connected head.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
object
|
A MONAI ResNet constructor (e.g. |
required |
remove_classification_head
|
bool
|
If True, set the final fully connected layer to None to use the model as a feature extractor. |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
tuple[Module, int]
|
|
Source code in spineps/architectures/pl_densenet.py
spineps.architectures.pl_unet¶
spineps.architectures.pl_unet
¶
PyTorch Lightning wrapper around the 3D U-Net used for spine segmentation training and inference.
PLNet
¶
Bases: LightningModule
LightningModule wrapping a :class:~spineps.architectures.unet3D.Unet3D for multi-class segmentation.
Configures a 4-class 3D U-Net with 10 input channels and provides shared loss/metric helpers (Dice scores) and softmax-based class prediction.
Source code in spineps/architectures/pl_unet.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | |
__init__
¶
Build the wrapped U-Net and store training hyperparameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
opt
|
Any
|
Options object; |
None
|
do2D
|
bool
|
Whether the model operates in 2D mode (affects only the string representation). |
False
|
*args
|
Any
|
Ignored extra positional arguments. |
()
|
**kwargs
|
Any
|
Ignored extra keyword arguments. |
{}
|
Source code in spineps/architectures/pl_unet.py
forward
¶
Run the wrapped U-Net on an input batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Raw class logits of shape |
Source code in spineps/architectures/pl_unet.py
__str__
¶
Return a short model name including the spatial mode.
Returns:
| Name | Type | Description |
|---|---|---|
str |
str
|
|
Source code in spineps/architectures/pl_unet.py
softmax_helper_dim1
¶
Apply softmax along dimension 1 (the channel/class dimension).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor with classes on dimension 1. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor of the same shape with a softmax applied over dimension 1. |
Source code in spineps/architectures/pl_unet.py
spineps.architectures.unet3D¶
spineps.architectures.unet3D
¶
3D U-Net architecture with residual blocks used for volumetric spine segmentation.
Unet3D
¶
Bases: Module
A 3D U-Net with residual (ResNet) blocks, a symmetric encoder/decoder and skip connections.
The encoder repeatedly applies two residual blocks followed by a strided convolution that halves each spatial dimension; a bottleneck of two residual blocks follows; the decoder mirrors the encoder with transposed convolutions and averages in the matching encoder skip features before a final residual block and 1x1x1 output convolution.
Source code in spineps/architectures/unet3D.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | |
__init__
¶
__init__(
dim,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=1,
conditional_dimensions=0,
resnet_block_groups=8,
learned_variance=False,
conditional_label_size=0,
)
Build the 3D U-Net layers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Base feature dimension used to derive per-level channel counts. |
required |
init_dim
|
int | None
|
Channels after the initial convolution; defaults to |
None
|
out_dim
|
int | None
|
Number of output channels; defaults to |
None
|
dim_mults
|
tuple[int, ...]
|
Per-resolution multipliers of |
(1, 2, 4, 8)
|
channels
|
int
|
Number of input image channels. |
1
|
conditional_dimensions
|
int
|
Extra conditioning channels concatenated to the input at the first convolution. |
0
|
resnet_block_groups
|
int
|
Number of groups for the GroupNorm inside each residual block. |
8
|
learned_variance
|
bool
|
If True, doubles the default output channels to also predict variance. |
False
|
conditional_label_size
|
int
|
Size of an optional conditional label vector (stored but unused in |
0
|
Source code in spineps/architectures/unet3D.py
forward
¶
forward(
x,
time: Tensor | None = None,
label: Tensor | None = None,
embedding: Tensor | None = None,
) -> torch.Tensor
Run the U-Net forward pass on a 5D input volume.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape |
required |
time
|
Tensor | None
|
Unused timestep input; replaced by a constant if None. |
None
|
label
|
Tensor | None
|
Unused optional conditioning label. |
None
|
embedding
|
Tensor | None
|
Unused optional conditioning embedding. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Output tensor of shape |
Raises:
| Type | Description |
|---|---|
AssertionError
|
If any spatial dimension of |
Source code in spineps/architectures/unet3D.py
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | |
Block3D
¶
Bases: Module
Basic 3D conv block: 3x3x3 convolution, group normalization, optional FiLM-style scale/shift and LeakyReLU.
Source code in spineps/architectures/unet3D.py
__init__
¶
Build the conv block.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Number of input channels. |
required |
dim_out
|
int
|
Number of output channels. |
required |
groups
|
int
|
Number of groups for the GroupNorm. |
8
|
Source code in spineps/architectures/unet3D.py
forward
¶
Apply convolution, normalization, optional scale/shift modulation and activation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape |
required |
scale_shift
|
tuple[Tensor, Tensor] | None
|
Optional |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Output tensor of shape |
Source code in spineps/architectures/unet3D.py
ResnetBlock3D
¶
Bases: Module
Residual block of two :class:Block3D layers with a skip connection and optional time-embedding modulation.
Source code in spineps/architectures/unet3D.py
__init__
¶
Build the residual block.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Number of input channels. |
required |
dim_out
|
int
|
Number of output channels. |
required |
time_emb_dim
|
int | None
|
If given, size of a time embedding mapped to per-channel scale and shift parameters. |
None
|
groups
|
int
|
Number of groups for the GroupNorm in each inner block. |
8
|
Source code in spineps/architectures/unet3D.py
forward
¶
Apply the two conv blocks plus residual connection, optionally modulated by a time embedding.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape |
required |
time_emb
|
Tensor | None
|
Optional time embedding of shape |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Output tensor of shape |
Source code in spineps/architectures/unet3D.py
default
¶
Return val if it is not None, otherwise a default value.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
val
|
object
|
The candidate value. |
required |
d
|
object
|
The fallback value, or a zero-argument callable that produces it. |
required |
Returns:
| Type | Description |
|---|---|
object
|
|