본문 바로가기

애니리뷰

n-gram swinT

 

swin-T의 attn_mask는...

[DNN] Swin Transformer 리뷰 및 구현 (ICCV 2021) (tistory.com)

 

[DNN] Swin Transformer 리뷰 및 구현 (ICCV 2021)

안녕하세요 pulluper 입니다. 이번 포스팅에서는 ICCV2021 발표 후 많은 비전 모델의 백본으로 사용되고 있는 swin transformer 논문에 대하여 알아보겠습니다. https://arxiv.org/abs/2103.14030 Swin Transformer: Hierar

csm-kr.tistory.com

 

 

example = torch.tensor([[1,3,5],[1,4,7], [1,5,9]])

일 때, 

attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)

은 

tensor([[[ 0,  2,  4],
         [-2,  0,  2],
         [-4, -2,  0]],

        [[ 0,  3,  6],
         [-3,  0,  3],
         [-6, -3,  0]],

        [[ 0,  4,  8],
         [-4,  0,  4],
         [-8, -4,  0]]])

mask_windows가 [3,3]일 경우, 그 output은 [3,3,3]으로 나타남.

[i,j]는 mask_windows의 [i,j] 값에 대해, 이 [i,j]값과 같은 윈도우에 속해 있는 [i]의 element 전부에 대한 값의 차이를 나타냄.

 

예를들어,

example[0]의 1, 3, 5는 첫 element인 1보다 각각 0, 2, 4 크므로, attn_mask[0,0]은 [0,2,4]

두 번째 element인 3에 비해서 1,3,5는 -2, 0, 2만큼 크므로 attn_mask[0,1]은 [-2,0,2]

세 번째 element인 5에 비해서 1,3,5는 -4, -2, 0만큼 크므로 attn_mask[0,2]는 [-4,-2,0] 

 

또는 transpose시켜 보면,

example[0][0]인 1은, 자신과 같은 행과는 0, -2, -4 차이가 나며, 이것이 attn_mask[0,:,0]으로 나타남

마찬가지로 [0][1]인 3은 자신과 같은 행과는 -2, 0, 2 차이가 나며, 이것이 attn_mask[0,:,1]로 나타남.

 

이를 SRFormer의 permuted_windows를 포함할 시,

a = torch.Tensor([[0,1,2,3], [0,1,3,5], [0,1,4,6], [0,1,5,7]])
b = torch.Tensor([[0,1], [0,2], [0,3], [0,4]])
c = a.unsqueeze(2) - b.unsqueeze(1)

의 답은

tensor([[[ 0., -1.],
         [ 1.,  0.],
         [ 2.,  1.],
         [ 3.,  2.]],

        [[ 0., -2.],
         [ 1., -1.],
         [ 3.,  1.],
         [ 5.,  3.]],

        [[ 0., -3.],
         [ 1., -2.],
         [ 4.,  1.],
         [ 6.,  3.]],

        [[ 0., -4.],
         [ 1., -3.],
         [ 5.,  1.],
         [ 7.,  3.]]])

로, c[0]인 [0, -1]부터 [3, 2]는 각각 [0,1,2,3] - [0], [0,1,2,3] - [1]임.

c[3]인 [0, -4]부터 [7,3]까지는 각각 [0,1,5,7] - [0], [0,1,5,7] - [4]임.