From 1123063a5e6ef7d5cf7556e3d59a2774b256dfc4 Mon Sep 17 00:00:00 2001 From: Baraa sameeh <64686299+Heterochromi@users.noreply.github.com> Date: Mon, 18 Aug 2025 09:07:25 +0800 Subject: [PATCH] Make all CCT regularization parameters user-configurable. (#346) --- vit_pytorch/cct.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vit_pytorch/cct.py b/vit_pytorch/cct.py index 4b37699..4231156 100644 --- a/vit_pytorch/cct.py +++ b/vit_pytorch/cct.py @@ -316,6 +316,9 @@ class CCT(nn.Module): pooling_kernel_size=3, pooling_stride=2, pooling_padding=1, + dropout_rate=0., + attention_dropout=0.1, + stochastic_depth_rate=0.1, *args, **kwargs ): super().__init__() @@ -340,9 +343,9 @@ class CCT(nn.Module): width=img_width), embedding_dim=embedding_dim, seq_pool=True, - dropout_rate=0., - attention_dropout=0.1, - stochastic_depth=0.1, + dropout_rate=dropout_rate, + attention_dropout=attention_dropout, + stochastic_depth_rate=stochastic_depth_rate, *args, **kwargs) def forward(self, x):