Skip to content

Commit

Permalink
Adding padding option to autoencoder (#7068)
Browse files Browse the repository at this point in the history
Fixes #7045  .

### Description

Added "padding" option to `monai/network/nets/autoencoder.py` such that
the conv and residual units will be passed the padding option.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

### Notes
I haven't been able to run the `./runtests.sh`, even tried running
`./runtests.sh -h` and got no output (no error or anything). I guess I
don't have permissions to run `.sh` files on this machine. However the
changes are very small and default to previous functionality so unless
somebody passes a padding argument, this should not break existing usage
of the function.

---------

Signed-off-by: Jupilogy <[email protected]>
  • Loading branch information
JupiLogy authored Sep 30, 2023
1 parent 317ef1f commit 2b39067
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions monai/networks/nets/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class AutoEncoder(nn.Module):
bias: whether to have a bias term in convolution blocks. Defaults to True.
According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
if a conv layer is directly followed by a batch norm layer, bias should be False.
padding: controls the amount of implicit zero-paddings on both sides for padding number of points
for each dimension in convolution blocks. Defaults to None.
Examples::
Expand Down Expand Up @@ -104,6 +106,7 @@ def __init__(
norm: tuple | str = Norm.INSTANCE,
dropout: tuple | str | float | None = None,
bias: bool = True,
padding: Sequence[int] | int | None = None,
) -> None:
super().__init__()
self.dimensions = spatial_dims
Expand All @@ -118,6 +121,7 @@ def __init__(
self.norm = norm
self.dropout = dropout
self.bias = bias
self.padding = padding
self.num_inter_units = num_inter_units
self.inter_channels = inter_channels if inter_channels is not None else []
self.inter_dilations = list(inter_dilations or [1] * len(self.inter_channels))
Expand Down Expand Up @@ -178,6 +182,7 @@ def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> tu
dropout=self.dropout,
dilation=di,
bias=self.bias,
padding=self.padding,
)
else:
unit = Convolution(
Expand All @@ -191,6 +196,7 @@ def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> tu
dropout=self.dropout,
dilation=di,
bias=self.bias,
padding=self.padding,
)

intermediate.add_module("inter_%i" % i, unit)
Expand Down Expand Up @@ -231,6 +237,7 @@ def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, i
norm=self.norm,
dropout=self.dropout,
bias=self.bias,
padding=self.padding,
last_conv_only=is_last,
)
return mod
Expand All @@ -244,6 +251,7 @@ def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, i
norm=self.norm,
dropout=self.dropout,
bias=self.bias,
padding=self.padding,
conv_only=is_last,
)
return mod
Expand All @@ -264,6 +272,7 @@ def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, i
norm=self.norm,
dropout=self.dropout,
bias=self.bias,
padding=self.padding,
conv_only=is_last and self.num_res_units == 0,
is_transposed=True,
)
Expand All @@ -282,6 +291,7 @@ def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, i
norm=self.norm,
dropout=self.dropout,
bias=self.bias,
padding=self.padding,
last_conv_only=is_last,
)

Expand Down

0 comments on commit 2b39067

Please sign in to comment.