remove duplicated qkv computation in na_vit_nested_tensor_3d.py (#341)

This commit is contained in:
JacobLinCool
2025-01-19 21:52:46 +08:00
committed by GitHub
parent c3018d1433
commit ab63fc9cc8

View File

@@ -83,17 +83,6 @@ class Attention(Module):
# split heads
def split_heads(t):
return t.unflatten(-1, (self.heads, self.dim_head)).transpose(1, 2).contiguous()
# queries, keys, values
query = self.to_queries(x)
key = self.to_keys(context)
value = self.to_values(context)
# split heads
def split_heads(t):
return t.unflatten(-1, (self.heads, self.dim_head))