diff --git a/vit_pytorch/cct_3d.py b/vit_pytorch/cct_3d.py index e14fda1..7c93c26 100644 --- a/vit_pytorch/cct_3d.py +++ b/vit_pytorch/cct_3d.py @@ -167,8 +167,10 @@ class Tokenizer(nn.Module): stride, padding, frame_stride=1, + frame_padding=None, frame_pooling_stride=1, frame_pooling_kernel_size=1, + frame_pooling_padding=None, pooling_kernel_size=3, pooling_stride=2, pooling_padding=1, @@ -188,16 +190,22 @@ class Tokenizer(nn.Module): n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:]) + if frame_padding is None: + frame_padding = frame_kernel_size // 2 + + if frame_pooling_padding is None: + frame_pooling_padding = frame_pooling_kernel_size // 2 + self.conv_layers = nn.Sequential( *[nn.Sequential( nn.Conv3d(chan_in, chan_out, kernel_size=(frame_kernel_size, kernel_size, kernel_size), stride=(frame_stride, stride, stride), - padding=(frame_kernel_size // 2, padding, padding), bias=conv_bias), + padding=(frame_padding, padding, padding), bias=conv_bias), nn.Identity() if not exists(activation) else activation(), nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size), stride=(frame_pooling_stride, pooling_stride, pooling_stride), - padding=(frame_pooling_kernel_size // 2, pooling_padding, pooling_padding)) if max_pool else nn.Identity() + padding=(frame_pooling_padding, pooling_padding, pooling_padding)) if max_pool else nn.Identity() ) for chan_in, chan_out in n_filter_list_pairs ]) @@ -324,8 +332,10 @@ class CCT(nn.Module): n_conv_layers=1, frame_stride=1, frame_kernel_size=3, + frame_padding=None, frame_pooling_kernel_size=1, frame_pooling_stride=1, + frame_pooling_padding=None, kernel_size=7, stride=2, padding=3, @@ -342,8 +352,10 @@ class CCT(nn.Module): n_output_channels=embedding_dim, frame_stride=frame_stride, frame_kernel_size=frame_kernel_size, + frame_padding=frame_padding, frame_pooling_stride=frame_pooling_stride, frame_pooling_kernel_size=frame_pooling_kernel_size, + frame_pooling_padding=frame_pooling_padding, kernel_size=kernel_size, stride=stride, padding=padding,