import torch
import torch.nn as nn
from pdb import set_trace as t
class RelativePosition(nn.Module):
# num_units: C//head = 2, max_relative_position = 2
def __init__(self, num_units, max_relative_position):
super().__init__()
self.num_units = num_units
self.max_relative_position = max_relative_position
# shape : [5, 6] = nn.Parameter(torch.Tensor(2*2+1, C//head))
self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
nn.init.xavier_uniform_(self.embeddings_table)
# length_q, length_k = 패치 개수 = 8
def forward(self, length_q, length_k):
# 패치의 개수만큼 arange
range_vec_q = torch.arange(length_q)
range_vec_k = torch.arange(length_k)
# range_vec_k = [[0, 1, ... N-1]], range_vec_q = [[0], [1], ... [N-1]]
# distance_map = [[0, 1, ... N-1], [-1, 0, ... N-2], ... [-(N-1), -(N-2), ... 0]]
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
final_mat = distance_mat_clipped + self.max_relative_position
final_mat = torch.LongTensor(final_mat).cuda()
embeddings = self.embeddings_table[final_mat].cuda()
return embeddings
class MultiHeadAttentionLayer(nn.Module):
def __init__(self, hid_dim, n_heads, dropout, device):
super().__init__()
assert hid_dim % n_heads == 0
self.hid_dim = hid_dim
self.n_heads = n_heads
self.head_dim = hid_dim // n_heads
self.max_relative_position = 2
self.relative_position_k = RelativePosition(self.head_dim, self.max_relative_position)
self.relative_position_v = RelativePosition(self.head_dim, self.max_relative_position)
self.fc_q = nn.Linear(hid_dim, hid_dim)
self.fc_k = nn.Linear(hid_dim, hid_dim)
self.fc_v = nn.Linear(hid_dim, hid_dim)
self.fc_o = nn.Linear(hid_dim, hid_dim)
self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
def forward(self, query, key, value, mask = None):
#query = [batch size, query len, hid dim]
#key = [batch size, key len, hid dim]
#value = [batch size, value len, hid dim]
batch_size = query.shape[0]
len_k = key.shape[1]
len_q = query.shape[1]
len_v = value.shape[1]
query = self.fc_q(query)
key = self.fc_k(key)
value = self.fc_v(value)
# B, N, head, C//head --> B,
r_q1 = query.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) # B, head, N, C//head
r_k1 = key.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
attn1 = torch.matmul(r_q1, r_k1.permute(0, 1, 3, 2)) # B, head, N, N
# len_q : 패치 개수
# B,N,C = [2,8,12] -> [8,2,12] -> [8, 2*2, 12//2] = [N, B*head, C//head]
r_q2 = query.permute(1, 0, 2).contiguous().view(len_q, batch_size*self.n_heads, self.head_dim)
r_k2 = self.relative_position_k(len_q, len_k)
t()
attn2 = torch.matmul(r_q2, r_k2.transpose(1, 2)).transpose(0, 1)
attn2 = attn2.contiguous().view(batch_size, self.n_heads, len_q, len_k)
attn = (attn1 + attn2) / self.scale
if mask is not None:
attn = attn.masked_fill(mask == 0, -1e10)
attn = self.dropout(torch.softmax(attn, dim = -1))
#attn = [batch size, n heads, query len, key len]
r_v1 = value.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
weight1 = torch.matmul(attn, r_v1)
r_v2 = self.relative_position_v(len_q, len_v)
weight2 = attn.permute(2, 0, 1, 3).contiguous().view(len_q, batch_size*self.n_heads, len_k)
weight2 = torch.matmul(weight2, r_v2)
weight2 = weight2.transpose(0, 1).contiguous().view(batch_size, self.n_heads, len_q, self.head_dim)
x = weight1 + weight2
#x = [batch size, n heads, query len, head dim]
x = x.permute(0, 2, 1, 3).contiguous()
#x = [batch size, query len, n heads, head dim]
x = x.view(batch_size, -1, self.hid_dim)
#x = [batch size, query len, hid dim]
x = self.fc_o(x)
#x = [batch size, query len, hid dim]
return x
if __name__ == "__main__":
# Sample example
x = torch.randn(2, 8, 12).cuda() # B, N, C
model = MultiHeadAttentionLayer(hid_dim=12, n_heads=2, dropout=0.1, device='cuda')
model.cuda()
output = model(x, x, x)
print(output.shape)
'''
0) self.embedings_table
Parameter containing:
tensor([[-0.1178, -0.1312, -0.2145, -0.1451, 0.3676, -0.7189],
[ 0.3236, 0.1328, -0.3711, 0.7122, -0.0044, 0.2933],
[ 0.4080, 0.0636, 0.3981, 0.4614, 0.4127, -0.5199],
[ 0.1467, -0.5804, -0.4419, -0.2332, -0.3515, -0.4658],
[-0.5506, 0.5444, -0.0597, -0.6033, 0.2171, 0.0147]],
device='cuda:0', requires_grad=True)
1) distance map = [N, N]
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[-1, 0, 1, 2, 3, 4, 5, 6],
[-2, -1, 0, 1, 2, 3, 4, 5],
[-3, -2, -1, 0, 1, 2, 3, 4],
[-4, -3, -2, -1, 0, 1, 2, 3],
[-5, -4, -3, -2, -1, 0, 1, 2],
[-6, -5, -4, -3, -2, -1, 0, 1],
[-7, -6, -5, -4, -3, -2, -1, 0]])
2) after torch.clamp(distance_mat, -2, 2)
-self.max_relative_position(=-2) ~ self.max_relative_position(=2) 사이의 값만 남기고 나머지는 -2, 2로 치환
tensor([[ 0, 1, 2, 2, 2, 2, 2, 2],
[-1, 0, 1, 2, 2, 2, 2, 2],
[-2, -1, 0, 1, 2, 2, 2, 2],
[-2, -2, -1, 0, 1, 2, 2, 2],
[-2, -2, -2, -1, 0, 1, 2, 2],
[-2, -2, -2, -2, -1, 0, 1, 2],
[-2, -2, -2, -2, -2, -1, 0, 1],
[-2, -2, -2, -2, -2, -2, -1, 0]])
3) final_mat = distance_mat_clipped + self.max_relative_position
모든 element에 self.max_relative_position(=2)를 더해줌
tensor([[2, 3, 4, 4, 4, 4, 4, 4],
[1, 2, 3, 4, 4, 4, 4, 4],
[0, 1, 2, 3, 4, 4, 4, 4],
[0, 0, 1, 2, 3, 4, 4, 4],
[0, 0, 0, 1, 2, 3, 4, 4],
[0, 0, 0, 0, 1, 2, 3, 4],
[0, 0, 0, 0, 0, 1, 2, 3],
[0, 0, 0, 0, 0, 0, 1, 2]])
4) embeddings = self.embeddings_table[final_mat].cuda()
final_mat의 element를 self.embeddings_table의 index로 사용하여 해당하는 값을 가져옴
예를 들어, final_mat의 2번째 row는 [0, 1, 2, 3, 4, 4, 4, 4]이므로,
[[-0.1178, -0.1312, -0.2145, -0.1451, 0.3676, -0.7189],
[ 0.3236, 0.1328, -0.3711, 0.7122, -0.0044, 0.2933],
[ 0.4080, 0.0636, 0.3981, 0.4614, 0.4127, -0.5199],
[ 0.1467, -0.5804, -0.4419, -0.2332, -0.3515, -0.4658],
[-0.5506, 0.5444, -0.0597, -0.6033, 0.2171, 0.0147],
[-0.5506, 0.5444, -0.0597, -0.6033, 0.2171, 0.0147],
[-0.5506, 0.5444, -0.0597, -0.6033, 0.2171, 0.0147],
[-0.5506, 0.5444, -0.0597, -0.6033, 0.2171, 0.0147]],
임. 이렇게 최종적으로 final_mat.shape = [8,8] 이므로, [8,8,6]이 생성됨.
'''
RPE in SwinT
Relative Position Bias (+ PyTorch Implementation) (youtube.com)
'미연시리뷰' 카테고리의 다른 글
재현성을 위해 필요한 것 (0) | 2024.04.30 |
---|---|
getattr in HNCT (0) | 2024.03.02 |
[Optical flow] GMFlow (0) | 2023.10.16 |
[통계 정리] 확통 사이트 및 개념 (0) | 2023.06.09 |
[알고리즘] P & NP (0) | 2023.06.08 |