EViT
Paper: Liang2022
Code: evit
Scoring Method: Attention between spatial tokens and the [CLS] token.
Overview
cls_attn = attn[:, :, 0, 1:] # [B, H, N-1]
cls_attn = cls_attn.mean(dim=1) # [B, N-1]
_, idx = torch.topk(cls_attn, left_tokens, dim=1, largest=True, sorted=True) # [B, left_tokens]
index = idx.unsqueeze(-1).expand(-1, -1, C) # [B, left_tokens, C]
non_cls = x[:, 1:]
x_others = torch.gather(non_cls, dim=1, index=index) # [B, left_tokens, C]
compl = complement_idx(idx, N - 1) # [B, N-1-left_tokens]
non_topk = torch.gather(non_cls, dim=1, index=compl.unsqueeze(-1).expand(-1, -1, C)) # [B, N-1-left_tokens, C]
non_topk_attn = torch.gather(cls_attn, dim=1, index=compl) # [B, N-1-left_tokens]
extra_token = torch.sum(non_topk * non_topk_attn.unsqueeze(-1), dim=1, keepdim=True) # [B, 1, C]
x = torch.cat([x[:, 0:1], x_others, extra_token], dim=1)
torch.topk: Returns the k largest elements of the given input tensor along a given dimension. The function returns a namedtuple (values, indices) where values is the k largest elements and indices is the indices of the k largest elements in the original input tensor.torch.gather: Gathers values along an axis specified by dim. For a 3-D tensor, if dim = 1, then for each value in the output tensor, it is equal to input[i][index[i][j][k]][k] where i, j, k are the indices of the output tensor.
In addition to TopK method, EViT adds an extra token, which is the weighted average of the non-topk tokens, where the weight is the attention between the non-topk tokens and the [CLS] token.
Formalization for Linearization
The code above can be simplified into the following code, using similar sort_by_key as in TopK.
# N is the number of non-class tokens, C is the feature dimension, and k is the number of tokens to keep.
def topk_pruning(X: Tensor[N+1, C], attn: Tensor[N+1, N+1], k: int) -> Tensor[k+2, C]:
X_non_cls = X[1:] # [N, C]
scores = attn[0, 1:] # [N]
X_new, scores_new = sort_by_key(X_non_cls, scores, descending=True) # [N, C]
extra_token = sum(scores_new[i] * X_new[i] for i in range(k, N)) # [1, C]
return torch.cat([X[0:1], X_new[:k], extra_token], dim=0)
def sort_by_key(X: Tensor[N, C], key: Tensor[N], descending: bool = False) -> (Tensor[N, C], Tensor[N]):
sorted_indices = torch.argsort(key, dim=0, descending=descending) # [N]
return X[sorted_indices], key[sorted_indices] # [N, C], [N]
For linearization, I rewrite the code above into the following code, using the heaviside function.
def heaviside(x: Tensor) -> Tensor:
return float(x > 0)
def topk_pruning(X: Tensor[N+1, C], attn: Tensor[N+1, N+1], k: int) -> Tensor[k+2, C]:
X_non_cls = X[1:] # [N, C]
scores = attn[0, 1:] # [N]
extra_token = sum<i>(heaviside(scores[i] - scores[k]) * X_non_cls[i]) # [1, C]
X_new = sort_by_key(X_non_cls, scores, descending=True) # [N, C]
return torch.cat([X[0:1], X_new[:k], extra_token], dim=0)
def sort_by_key(X: Tensor[N, C], key: Tensor[N], descending: bool = False) -> Tensor[N, C]:
if descending:
return [kth_largest(X, key, i) for i in range(N)]
else:
return [kth_smallest(X, key, i) for i in range(N)]
def kth_largest(X: Tensor[N, C], key: Tensor[N], k: int) -> Tensor[C]:
for i in range(N):
num_greater[i] = sum<j>(heaviside(key[j] - key[i]))
return sum<i>(delta1(num_greater[i] - k) * X[i])