mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
handle channel first for accept video wrapper
This commit is contained in:
2
setup.py
2
setup.py
@@ -6,7 +6,7 @@ with open('README.md') as f:
|
|||||||
setup(
|
setup(
|
||||||
name = 'vit-pytorch',
|
name = 'vit-pytorch',
|
||||||
packages = find_packages(exclude=['examples']),
|
packages = find_packages(exclude=['examples']),
|
||||||
version = '1.11.4',
|
version = '1.11.5',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'Vision Transformer (ViT) - Pytorch',
|
description = 'Vision Transformer (ViT) - Pytorch',
|
||||||
long_description = long_description,
|
long_description = long_description,
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ class AcceptVideoWrapper(Module):
|
|||||||
add_time_pos_emb = False,
|
add_time_pos_emb = False,
|
||||||
dim_emb = None,
|
dim_emb = None,
|
||||||
time_seq_len = None,
|
time_seq_len = None,
|
||||||
|
embed_is_channel_first = False,
|
||||||
output_pos_add_pos_emb = 0 # defaults to first output position to add embedding
|
output_pos_add_pos_emb = 0 # defaults to first output position to add embedding
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -40,6 +41,8 @@ class AcceptVideoWrapper(Module):
|
|||||||
|
|
||||||
self.pos_emb = Parameter(randn(time_seq_len, dim_emb) * 1e-2)
|
self.pos_emb = Parameter(randn(time_seq_len, dim_emb) * 1e-2)
|
||||||
|
|
||||||
|
self.embed_is_channel_first = embed_is_channel_first
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
video, # (b c t h w)
|
video, # (b c t h w)
|
||||||
@@ -89,9 +92,16 @@ class AcceptVideoWrapper(Module):
|
|||||||
|
|
||||||
dims_to_unsqueeze = embed.ndim - pos_emb.ndim
|
dims_to_unsqueeze = embed.ndim - pos_emb.ndim
|
||||||
|
|
||||||
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *((1,) * dims_to_unsqueeze) , pos_emb.shape[-1])
|
one_dims = ((1,) * dims_to_unsqueeze)
|
||||||
|
|
||||||
embed = embed + pos_emb[:, :embed.shape[1]]
|
if self.embed_is_channel_first:
|
||||||
|
pos_emb = pos_emb.reshape(*pos_emb.shape, *one_dims)
|
||||||
|
else:
|
||||||
|
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *one_dims, pos_emb.shape[-1])
|
||||||
|
|
||||||
|
pos_emb = pos_emb[:, :embed.shape[1]]
|
||||||
|
|
||||||
|
embed = embed + pos_emb
|
||||||
|
|
||||||
outputs[self.output_pos_add_pos_emb] = embed
|
outputs[self.output_pos_add_pos_emb] = embed
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user