Make all CCT regularization parameters user-configurable. (#346)

This commit is contained in:
Baraa sameeh
2025-08-18 09:07:25 +08:00
committed by GitHub
parent f8bec5ede2
commit 1123063a5e

View File

@@ -316,6 +316,9 @@ class CCT(nn.Module):
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
dropout_rate=0.,
attention_dropout=0.1,
stochastic_depth_rate=0.1,
*args, **kwargs
):
super().__init__()
@@ -340,9 +343,9 @@ class CCT(nn.Module):
width=img_width),
embedding_dim=embedding_dim,
seq_pool=True,
dropout_rate=0.,
attention_dropout=0.1,
stochastic_depth=0.1,
dropout_rate=dropout_rate,
attention_dropout=attention_dropout,
stochastic_depth_rate=stochastic_depth_rate,
*args, **kwargs)
def forward(self, x):