add option to set frame padding for 3D CCT (#339)

This commit is contained in:
Kale Kundert
2025-01-04 10:55:27 -05:00
committed by GitHub
parent e7cba9ba6d
commit b7ed6bad28

View File

@@ -167,8 +167,10 @@ class Tokenizer(nn.Module):
stride, stride,
padding, padding,
frame_stride=1, frame_stride=1,
frame_padding=None,
frame_pooling_stride=1, frame_pooling_stride=1,
frame_pooling_kernel_size=1, frame_pooling_kernel_size=1,
frame_pooling_padding=None,
pooling_kernel_size=3, pooling_kernel_size=3,
pooling_stride=2, pooling_stride=2,
pooling_padding=1, pooling_padding=1,
@@ -188,16 +190,22 @@ class Tokenizer(nn.Module):
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:]) n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
if frame_padding is None:
frame_padding = frame_kernel_size // 2
if frame_pooling_padding is None:
frame_pooling_padding = frame_pooling_kernel_size // 2
self.conv_layers = nn.Sequential( self.conv_layers = nn.Sequential(
*[nn.Sequential( *[nn.Sequential(
nn.Conv3d(chan_in, chan_out, nn.Conv3d(chan_in, chan_out,
kernel_size=(frame_kernel_size, kernel_size, kernel_size), kernel_size=(frame_kernel_size, kernel_size, kernel_size),
stride=(frame_stride, stride, stride), stride=(frame_stride, stride, stride),
padding=(frame_kernel_size // 2, padding, padding), bias=conv_bias), padding=(frame_padding, padding, padding), bias=conv_bias),
nn.Identity() if not exists(activation) else activation(), nn.Identity() if not exists(activation) else activation(),
nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size), nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size),
stride=(frame_pooling_stride, pooling_stride, pooling_stride), stride=(frame_pooling_stride, pooling_stride, pooling_stride),
padding=(frame_pooling_kernel_size // 2, pooling_padding, pooling_padding)) if max_pool else nn.Identity() padding=(frame_pooling_padding, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
) )
for chan_in, chan_out in n_filter_list_pairs for chan_in, chan_out in n_filter_list_pairs
]) ])
@@ -324,8 +332,10 @@ class CCT(nn.Module):
n_conv_layers=1, n_conv_layers=1,
frame_stride=1, frame_stride=1,
frame_kernel_size=3, frame_kernel_size=3,
frame_padding=None,
frame_pooling_kernel_size=1, frame_pooling_kernel_size=1,
frame_pooling_stride=1, frame_pooling_stride=1,
frame_pooling_padding=None,
kernel_size=7, kernel_size=7,
stride=2, stride=2,
padding=3, padding=3,
@@ -342,8 +352,10 @@ class CCT(nn.Module):
n_output_channels=embedding_dim, n_output_channels=embedding_dim,
frame_stride=frame_stride, frame_stride=frame_stride,
frame_kernel_size=frame_kernel_size, frame_kernel_size=frame_kernel_size,
frame_padding=frame_padding,
frame_pooling_stride=frame_pooling_stride, frame_pooling_stride=frame_pooling_stride,
frame_pooling_kernel_size=frame_pooling_kernel_size, frame_pooling_kernel_size=frame_pooling_kernel_size,
frame_pooling_padding=frame_pooling_padding,
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, stride=stride,
padding=padding, padding=padding,