본문 바로가기

미연시리뷰

[ViT-PE] RPE

import torch
import torch.nn as nn
from pdb import set_trace as t

class RelativePosition(nn.Module):
    # num_units: C//head = 2, max_relative_position = 2
    def __init__(self, num_units, max_relative_position):
        super().__init__()
        self.num_units = num_units
        self.max_relative_position = max_relative_position

        # shape : [5, 6] = nn.Parameter(torch.Tensor(2*2+1, C//head))
        self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
        nn.init.xavier_uniform_(self.embeddings_table)

    # length_q, length_k = 패치 개수 = 8
    def forward(self, length_q, length_k):
        # 패치의 개수만큼 arange
        range_vec_q = torch.arange(length_q)
        range_vec_k = torch.arange(length_k)

        # range_vec_k = [[0, 1, ... N-1]], range_vec_q = [[0], [1], ... [N-1]]
        # distance_map = [[0, 1, ... N-1], [-1, 0, ... N-2], ... [-(N-1), -(N-2), ... 0]] 
        distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
        distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
        final_mat = distance_mat_clipped + self.max_relative_position
        final_mat = torch.LongTensor(final_mat).cuda()
        embeddings = self.embeddings_table[final_mat].cuda()

        return embeddings

class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        
        assert hid_dim % n_heads == 0
        
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        self.max_relative_position = 2

        self.relative_position_k = RelativePosition(self.head_dim, self.max_relative_position)
        self.relative_position_v = RelativePosition(self.head_dim, self.max_relative_position)

        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
        
    def forward(self, query, key, value, mask = None):
        #query = [batch size, query len, hid dim]
        #key = [batch size, key len, hid dim]
        #value = [batch size, value len, hid dim]
        batch_size = query.shape[0]
        len_k = key.shape[1]
        len_q = query.shape[1]
        len_v = value.shape[1]

        query = self.fc_q(query)
        key = self.fc_k(key)
        value = self.fc_v(value)

        # B, N, head, C//head --> B, 
        r_q1 = query.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)  # B, head, N, C//head
        r_k1 = key.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        attn1 = torch.matmul(r_q1, r_k1.permute(0, 1, 3, 2))  # B, head, N, N

        # len_q : 패치 개수
        # B,N,C = [2,8,12] -> [8,2,12] -> [8, 2*2, 12//2] = [N, B*head, C//head]
        r_q2 = query.permute(1, 0, 2).contiguous().view(len_q, batch_size*self.n_heads, self.head_dim)
        r_k2 = self.relative_position_k(len_q, len_k)
        t()
        attn2 = torch.matmul(r_q2, r_k2.transpose(1, 2)).transpose(0, 1)
        attn2 = attn2.contiguous().view(batch_size, self.n_heads, len_q, len_k)
        attn = (attn1 + attn2) / self.scale

        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e10)

        attn = self.dropout(torch.softmax(attn, dim = -1))

        #attn = [batch size, n heads, query len, key len]
        r_v1 = value.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        weight1 = torch.matmul(attn, r_v1)

        r_v2 = self.relative_position_v(len_q, len_v)

        weight2 = attn.permute(2, 0, 1, 3).contiguous().view(len_q, batch_size*self.n_heads, len_k)
        weight2 = torch.matmul(weight2, r_v2)
        weight2 = weight2.transpose(0, 1).contiguous().view(batch_size, self.n_heads, len_q, self.head_dim)

        x = weight1 + weight2
        
        #x = [batch size, n heads, query len, head dim]
        
        x = x.permute(0, 2, 1, 3).contiguous()
        
        #x = [batch size, query len, n heads, head dim]
        
        x = x.view(batch_size, -1, self.hid_dim)
        
        #x = [batch size, query len, hid dim]
        
        x = self.fc_o(x)
        
        #x = [batch size, query len, hid dim]

        return x


if __name__ == "__main__":
    # Sample example
    x = torch.randn(2, 8, 12).cuda()  # B, N, C
    model = MultiHeadAttentionLayer(hid_dim=12, n_heads=2, dropout=0.1, device='cuda')
    model.cuda()  
    output = model(x, x, x)  
    print(output.shape)  


'''
0) self.embedings_table 
Parameter containing:
tensor([[-0.1178, -0.1312, -0.2145, -0.1451,  0.3676, -0.7189],
        [ 0.3236,  0.1328, -0.3711,  0.7122, -0.0044,  0.2933],
        [ 0.4080,  0.0636,  0.3981,  0.4614,  0.4127, -0.5199],
        [ 0.1467, -0.5804, -0.4419, -0.2332, -0.3515, -0.4658],
        [-0.5506,  0.5444, -0.0597, -0.6033,  0.2171,  0.0147]],
       device='cuda:0', requires_grad=True)

1) distance map = [N, N]
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [-1,  0,  1,  2,  3,  4,  5,  6],
        [-2, -1,  0,  1,  2,  3,  4,  5],
        [-3, -2, -1,  0,  1,  2,  3,  4],
        [-4, -3, -2, -1,  0,  1,  2,  3],
        [-5, -4, -3, -2, -1,  0,  1,  2],
        [-6, -5, -4, -3, -2, -1,  0,  1],
        [-7, -6, -5, -4, -3, -2, -1,  0]])

2) after torch.clamp(distance_mat, -2, 2)
-self.max_relative_position(=-2) ~ self.max_relative_position(=2) 사이의 값만 남기고 나머지는 -2, 2로 치환
tensor([[ 0,  1,  2,  2,  2,  2,  2,  2],
        [-1,  0,  1,  2,  2,  2,  2,  2],
        [-2, -1,  0,  1,  2,  2,  2,  2],
        [-2, -2, -1,  0,  1,  2,  2,  2],
        [-2, -2, -2, -1,  0,  1,  2,  2],
        [-2, -2, -2, -2, -1,  0,  1,  2],
        [-2, -2, -2, -2, -2, -1,  0,  1],
        [-2, -2, -2, -2, -2, -2, -1,  0]])

3) final_mat = distance_mat_clipped + self.max_relative_position
모든 element에 self.max_relative_position(=2)를 더해줌
tensor([[2, 3, 4, 4, 4, 4, 4, 4],
        [1, 2, 3, 4, 4, 4, 4, 4],
        [0, 1, 2, 3, 4, 4, 4, 4],
        [0, 0, 1, 2, 3, 4, 4, 4],
        [0, 0, 0, 1, 2, 3, 4, 4],
        [0, 0, 0, 0, 1, 2, 3, 4],
        [0, 0, 0, 0, 0, 1, 2, 3],
        [0, 0, 0, 0, 0, 0, 1, 2]])

4) embeddings = self.embeddings_table[final_mat].cuda()
final_mat의 element를 self.embeddings_table의 index로 사용하여 해당하는 값을 가져옴
예를 들어, final_mat의 2번째 row는 [0, 1, 2, 3, 4, 4, 4, 4]이므로, 
[[-0.1178, -0.1312, -0.2145, -0.1451,  0.3676, -0.7189],
 [ 0.3236,  0.1328, -0.3711,  0.7122, -0.0044,  0.2933],
 [ 0.4080,  0.0636,  0.3981,  0.4614,  0.4127, -0.5199],
 [ 0.1467, -0.5804, -0.4419, -0.2332, -0.3515, -0.4658],
 [-0.5506,  0.5444, -0.0597, -0.6033,  0.2171,  0.0147],
 [-0.5506,  0.5444, -0.0597, -0.6033,  0.2171,  0.0147],
 [-0.5506,  0.5444, -0.0597, -0.6033,  0.2171,  0.0147],
 [-0.5506,  0.5444, -0.0597, -0.6033,  0.2171,  0.0147]],
임. 이렇게 최종적으로 final_mat.shape = [8,8] 이므로, [8,8,6]이 생성됨.
'''

 

 

RPE in SwinT

Relative Position Bias (+ PyTorch Implementation) (youtube.com)

'미연시리뷰' 카테고리의 다른 글

재현성을 위해 필요한 것  (0) 2024.04.30
getattr in HNCT  (0) 2024.03.02
[Optical flow] GMFlow  (0) 2023.10.16
[통계 정리] 확통 사이트 및 개념  (0) 2023.06.09
[알고리즘] P & NP  (0) 2023.06.08