allow channels to be customizable for cvt

This commit is contained in:
lucidrains
2023-10-25 14:47:58 -07:00
parent 92b69321f4
commit 0ad09c4cbc
2 changed files with 4 additions and 3 deletions

View File

@@ -140,12 +140,13 @@ class CvT(nn.Module):
s3_heads = 6,
s3_depth = 10,
s3_mlp_mult = 4,
dropout = 0.
dropout = 0.,
channels = 3
):
super().__init__()
kwargs = dict(locals())
dim = 3
dim = channels
layers = []
for prefix in ('s1', 's2', 's3'):