mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
add option to set frame padding for 3D CCT (#339)
This commit is contained in:
@@ -167,8 +167,10 @@ class Tokenizer(nn.Module):
|
|||||||
stride,
|
stride,
|
||||||
padding,
|
padding,
|
||||||
frame_stride=1,
|
frame_stride=1,
|
||||||
|
frame_padding=None,
|
||||||
frame_pooling_stride=1,
|
frame_pooling_stride=1,
|
||||||
frame_pooling_kernel_size=1,
|
frame_pooling_kernel_size=1,
|
||||||
|
frame_pooling_padding=None,
|
||||||
pooling_kernel_size=3,
|
pooling_kernel_size=3,
|
||||||
pooling_stride=2,
|
pooling_stride=2,
|
||||||
pooling_padding=1,
|
pooling_padding=1,
|
||||||
@@ -188,16 +190,22 @@ class Tokenizer(nn.Module):
|
|||||||
|
|
||||||
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
|
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
|
||||||
|
|
||||||
|
if frame_padding is None:
|
||||||
|
frame_padding = frame_kernel_size // 2
|
||||||
|
|
||||||
|
if frame_pooling_padding is None:
|
||||||
|
frame_pooling_padding = frame_pooling_kernel_size // 2
|
||||||
|
|
||||||
self.conv_layers = nn.Sequential(
|
self.conv_layers = nn.Sequential(
|
||||||
*[nn.Sequential(
|
*[nn.Sequential(
|
||||||
nn.Conv3d(chan_in, chan_out,
|
nn.Conv3d(chan_in, chan_out,
|
||||||
kernel_size=(frame_kernel_size, kernel_size, kernel_size),
|
kernel_size=(frame_kernel_size, kernel_size, kernel_size),
|
||||||
stride=(frame_stride, stride, stride),
|
stride=(frame_stride, stride, stride),
|
||||||
padding=(frame_kernel_size // 2, padding, padding), bias=conv_bias),
|
padding=(frame_padding, padding, padding), bias=conv_bias),
|
||||||
nn.Identity() if not exists(activation) else activation(),
|
nn.Identity() if not exists(activation) else activation(),
|
||||||
nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size),
|
nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size),
|
||||||
stride=(frame_pooling_stride, pooling_stride, pooling_stride),
|
stride=(frame_pooling_stride, pooling_stride, pooling_stride),
|
||||||
padding=(frame_pooling_kernel_size // 2, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
|
padding=(frame_pooling_padding, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
|
||||||
)
|
)
|
||||||
for chan_in, chan_out in n_filter_list_pairs
|
for chan_in, chan_out in n_filter_list_pairs
|
||||||
])
|
])
|
||||||
@@ -324,8 +332,10 @@ class CCT(nn.Module):
|
|||||||
n_conv_layers=1,
|
n_conv_layers=1,
|
||||||
frame_stride=1,
|
frame_stride=1,
|
||||||
frame_kernel_size=3,
|
frame_kernel_size=3,
|
||||||
|
frame_padding=None,
|
||||||
frame_pooling_kernel_size=1,
|
frame_pooling_kernel_size=1,
|
||||||
frame_pooling_stride=1,
|
frame_pooling_stride=1,
|
||||||
|
frame_pooling_padding=None,
|
||||||
kernel_size=7,
|
kernel_size=7,
|
||||||
stride=2,
|
stride=2,
|
||||||
padding=3,
|
padding=3,
|
||||||
@@ -342,8 +352,10 @@ class CCT(nn.Module):
|
|||||||
n_output_channels=embedding_dim,
|
n_output_channels=embedding_dim,
|
||||||
frame_stride=frame_stride,
|
frame_stride=frame_stride,
|
||||||
frame_kernel_size=frame_kernel_size,
|
frame_kernel_size=frame_kernel_size,
|
||||||
|
frame_padding=frame_padding,
|
||||||
frame_pooling_stride=frame_pooling_stride,
|
frame_pooling_stride=frame_pooling_stride,
|
||||||
frame_pooling_kernel_size=frame_pooling_kernel_size,
|
frame_pooling_kernel_size=frame_pooling_kernel_size,
|
||||||
|
frame_pooling_padding=frame_pooling_padding,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
|
|||||||
Reference in New Issue
Block a user