[Optical Flow] RAFT 이해를 위한 여러 함수 정리글
torch.meshgrid 정리 - raft.py의
coords0, coords1 = self.initialize_flow(image1)
를 실행할 경우 행해지는 함수들
width, height = 128, 64
coords = torch.meshgrid(torch.arange(width//8), torch.arange(height//8)) # 16, 8
print(coords)
print('------------------------------')
print(coords[::-1])
print('------------------------------')
coords = torch.stack(coords[::-1], dim=0).float()
# print(coords)
print('------------------------------')
# print(coords[None].repeat(1, 1, 1, 1))
위를 진행하면, 우선 torch.meshgrid는 다음과 같은 결과를 냄.
(tensor([[ 0, 0, 0, 0, 0, 0, 0, 0],
[ 1, 1, 1, 1, 1, 1, 1, 1],
[ 2, 2, 2, 2, 2, 2, 2, 2],
[ 3, 3, 3, 3, 3, 3, 3, 3],
[ 4, 4, 4, 4, 4, 4, 4, 4],
[ 5, 5, 5, 5, 5, 5, 5, 5],
[ 6, 6, 6, 6, 6, 6, 6, 6],
[ 7, 7, 7, 7, 7, 7, 7, 7],
[ 8, 8, 8, 8, 8, 8, 8, 8],
[ 9, 9, 9, 9, 9, 9, 9, 9],
[10, 10, 10, 10, 10, 10, 10, 10],
[11, 11, 11, 11, 11, 11, 11, 11],
[12, 12, 12, 12, 12, 12, 12, 12],
[13, 13, 13, 13, 13, 13, 13, 13],
[14, 14, 14, 14, 14, 14, 14, 14],
[15, 15, 15, 15, 15, 15, 15, 15]]), tensor([[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7]]))
즉, (16, 8)짜리 행이 같은 숫자인 배열, 열이 같은 숫자인 배열 두 개를 만들어냄. shape는 (16, 8)
torch.stack([::-1], dim=0)은 아래의 결과를 도출.
tensor([[[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.]],
[[ 0., 0., 0., 0., 0., 0., 0., 0.],
[ 1., 1., 1., 1., 1., 1., 1., 1.],
[ 2., 2., 2., 2., 2., 2., 2., 2.],
[ 3., 3., 3., 3., 3., 3., 3., 3.],
[ 4., 4., 4., 4., 4., 4., 4., 4.],
[ 5., 5., 5., 5., 5., 5., 5., 5.],
[ 6., 6., 6., 6., 6., 6., 6., 6.],
[ 7., 7., 7., 7., 7., 7., 7., 7.],
[ 8., 8., 8., 8., 8., 8., 8., 8.],
[ 9., 9., 9., 9., 9., 9., 9., 9.],
[10., 10., 10., 10., 10., 10., 10., 10.],
[11., 11., 11., 11., 11., 11., 11., 11.],
[12., 12., 12., 12., 12., 12., 12., 12.],
[13., 13., 13., 13., 13., 13., 13., 13.],
[14., 14., 14., 14., 14., 14., 14., 14.],
[15., 15., 15., 15., 15., 15., 15., 15.]]])
-1, 즉 역순(reverse)으로 해당 텐서를 stack함.
마지막으로 coords[None].repeat(batch, 1, 1, 1)을 실행하면
tensor([[[[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 0., 1., 2., 3., 4., 5., 6., 7.]],
[[ 0., 0., 0., 0., 0., 0., 0., 0.],
[ 1., 1., 1., 1., 1., 1., 1., 1.],
[ 2., 2., 2., 2., 2., 2., 2., 2.],
[ 3., 3., 3., 3., 3., 3., 3., 3.],
[ 4., 4., 4., 4., 4., 4., 4., 4.],
[ 5., 5., 5., 5., 5., 5., 5., 5.],
[ 6., 6., 6., 6., 6., 6., 6., 6.],
[ 7., 7., 7., 7., 7., 7., 7., 7.],
[ 8., 8., 8., 8., 8., 8., 8., 8.],
[ 9., 9., 9., 9., 9., 9., 9., 9.],
[10., 10., 10., 10., 10., 10., 10., 10.],
[11., 11., 11., 11., 11., 11., 11., 11.],
[12., 12., 12., 12., 12., 12., 12., 12.],
[13., 13., 13., 13., 13., 13., 13., 13.],
[14., 14., 14., 14., 14., 14., 14., 14.],
[15., 15., 15., 15., 15., 15., 15., 15.]]]])
그냥 배치만큼 스택하고 끝임. 이거로 flow map을 초기화하는듯.
[None] == np.newaxis
그럼 여기서 궁금증이 생긴다.
1) coords[None]의 의미 >> 실험결과, coords.repeat(...)와 똑같음.
이 때 주의할 점 :
flow_loss += i_weight * (valid[:, None] * i_loss).mean()
이 경우, 두 번째 차원에 추가됨. 즉,
n[:, None] == n.unsqueeze(dim=1)
2) repeat(a, b, c, d) >> 각 차원을 얼마나 반복하는지. 즉 (1,1,1,3)이면 [0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7]이 됨.
아무튼 이렇게 똑같은 두 개의 배열을 만들어냄(각각 coords0, coords1).
2개를 만드는 이유는, 각각 x방향, y방향이라.
그러면 다음으론, 아래가 실행되는데, 우선 coor_fn(coords1)을 살펴보겠다. 이건 아무래도 논문의 3.3.Upsampling으로 보임.
for itr in range(iters):
coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume, def __call__ in CorrBlock class
flow = coords1 - coords0
with autocast(enabled=self.args.mixed_precision):
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow
# upsample predictions
if up_mask is None:
flow_up = upflow8(coords1 - coords0)
else:
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
flow_predictions.append(flow_up)
if test_mode:
return coords1 - coords0, flow_up
return flow_predictions
coor_fn을 실행하면 아래와 같은 것을 수행한다.
def __call__(self, coords):
r = self.radius # 4
# batch, 2, w//8, h//8 --> batch, w//8, h//8, 2
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i] # get i-th pyramid got from __init__, correlation map
# [-4 ~ 4], size = 9. so, [-4, -3, -2, .... 3, 4]
# maybe it would has concerned with RAFT 3.3.Upsampling(use 9 neighbors to make original size predict map)
dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
coords_lvl = centroid_lvl + delta_lvl
corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
delta는 다음과 같다. axis = -1이기에, 아래와 같은 형태가 됨. shape는 (9, 9, 2)
tensor([[[-4., -4.],
[-4., -3.],
[-4., -2.],
[-4., -1.],
[-4., 0.],
[-4., 1.],
[-4., 2.],
[-4., 3.],
[-4., 4.]],
[[-3., -4.],
[-3., -3.],
[-3., -2.],
[-3., -1.],
[-3., 0.],
[-3., 1.],
[-3., 2.],
[-3., 3.],
[-3., 4.]],
[[-2., -4.],
[-2., -3.],
[-2., -2.],
[-2., -1.],
[-2., 0.],
[-2., 1.],
[-2., 2.],
[-2., 3.],
[-2., 4.]],
[[-1., -4.],
[-1., -3.],
[-1., -2.],
[-1., -1.],
[-1., 0.],
[-1., 1.],
[-1., 2.],
[-1., 3.],
[-1., 4.]],
[[ 0., -4.],
[ 0., -3.],
[ 0., -2.],
[ 0., -1.],
[ 0., 0.],
[ 0., 1.],
[ 0., 2.],
[ 0., 3.],
[ 0., 4.]],
[[ 1., -4.],
[ 1., -3.],
[ 1., -2.],
[ 1., -1.],
[ 1., 0.],
[ 1., 1.],
[ 1., 2.],
[ 1., 3.],
[ 1., 4.]],
[[ 2., -4.],
[ 2., -3.],
[ 2., -2.],
[ 2., -1.],
[ 2., 0.],
[ 2., 1.],
[ 2., 2.],
[ 2., 3.],
[ 2., 4.]],
[[ 3., -4.],
[ 3., -3.],
[ 3., -2.],
[ 3., -1.],
[ 3., 0.],
[ 3., 1.],
[ 3., 2.],
[ 3., 3.],
[ 3., 4.]],
[[ 4., -4.],
[ 4., -3.],
[ 4., -2.],
[ 4., -1.],
[ 4., 0.],
[ 4., 1.],
[ 4., 2.],
[ 4., 3.],
[ 4., 4.]]])
centroid_lvl는 기존의 (batch, w//8, h//8, 2) 배열을 (batch * w//8 * h//8, 1, 1, 2)로 바꿔준다.
다음으로 2**i로 나누는데, 이는 각 피라미드에 적합하게 하기 위함으로 보인다.
corr = pyramid[i]이므로, 이 pyramid에 적합한 사이즈로 만들기 위해, 각 요소를 2**i로 나눠줌.
이를 바탕으로 coords_lvl을 구하면, (1920/8 * 1080/8 * batch, 9, 2)의 shape를 가진 tensor가 나옴.
정확히 어떤 값이 나오는지 알아보기 위해, 아래와 같이 코드를 짜봄.
import torch
import torch.nn.functional as F
# utils.py - coords_grid & raft.py - initialize_flow
H, W = 16, 24
coords = torch.meshgrid(torch.arange(H//8), torch.arange(W//8))
coords = torch.stack(coords[::-1], dim=0).float()
coords = coords[None].repeat(1, 1, 1, 1) # batch = 1이라 가정
print(coords)
print('---------------------------')
# corr_fn - __call__
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
centroid_lvl = coords.reshape(1*h1*w1, 1, 1, 2)
print(centroid_lvl.shape)
print(centroid_lvl)
print('---------------------------')
r = 4
dx = torch.linspace(-r, r, 2*r+1)
dy = torch.linspace(-r, r, 2*r+1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
print(delta_lvl.shape)
print(delta_lvl)
print('---------------------------')
coords_lvl = centroid_lvl + delta_lvl
print(coords_lvl)
print(coords_lvl.shape)
# bilinear_sampler
print('---------------------------')
xgrid, ygrid = coords_lvl.split([1, 1], dim=-1) # xgrid/ygrid : (bhw x 9 x 9 x 1) cause each (bhw x 9 x 9) is x/y coor
xgrid = 2 * xgrid / (W-1) - 1
ygrid = 2 * ygrid / (H-1) - 1
grid = torch.cat([xgrid, ygrid], dim = -1)
# print(xgrid)
# print(ygrid)
# img = F.grid_sample(img, grid, align_corners = True)
그 결과는 아래와 같다.
torch.Size([6, 1, 1, 2])
tensor([[[[0., 0.]]],
[[[1., 0.]]],
[[[2., 0.]]],
[[[0., 1.]]],
[[[1., 1.]]],
[[[2., 1.]]]])
----------------------------------------------------------
torch.Size([1, 9, 9, 2])
tensor([[[[-4., -4.],
[-4., -3.],
[-4., -2.],
[-4., -1.],
[-4., 0.],
[-4., 1.],
[-4., 2.],
[-4., 3.],
[-4., 4.]],
[[-3., -4.],
[-3., -3.],
[-3., -2.],
[-3., -1.],
[-3., 0.],
[-3., 1.],
[-3., 2.],
[-3., 3.],
[-3., 4.]],
[[-2., -4.],
[-2., -3.],
[-2., -2.],
[-2., -1.],
[-2., 0.],
[-2., 1.],
[-2., 2.],
[-2., 3.],
[-2., 4.]],
[[-1., -4.],
[-1., -3.],
[-1., -2.],
[-1., -1.],
[-1., 0.],
[-1., 1.],
[-1., 2.],
[-1., 3.],
[-1., 4.]],
[[ 0., -4.],
[ 0., -3.],
[ 0., -2.],
[ 0., -1.],
[ 0., 0.],
[ 0., 1.],
[ 0., 2.],
[ 0., 3.],
[ 0., 4.]],
[[ 1., -4.],
[ 1., -3.],
[ 1., -2.],
[ 1., -1.],
[ 1., 0.],
[ 1., 1.],
[ 1., 2.],
[ 1., 3.],
[ 1., 4.]],
[[ 2., -4.],
[ 2., -3.],
[ 2., -2.],
[ 2., -1.],
[ 2., 0.],
[ 2., 1.],
[ 2., 2.],
[ 2., 3.],
[ 2., 4.]],
[[ 3., -4.],
[ 3., -3.],
[ 3., -2.],
[ 3., -1.],
[ 3., 0.],
[ 3., 1.],
[ 3., 2.],
[ 3., 3.],
[ 3., 4.]],
[[ 4., -4.],
[ 4., -3.],
[ 4., -2.],
[ 4., -1.],
[ 4., 0.],
[ 4., 1.],
[ 4., 2.],
[ 4., 3.],
[ 4., 4.]]]])
----------------------------------------------------------
tensor([[[[-4., -4.],
[-4., -3.],
[-4., -2.],
[-4., -1.],
[-4., 0.],
[-4., 1.],
[-4., 2.],
[-4., 3.],
[-4., 4.]],
[[-3., -4.],
[-3., -3.],
[-3., -2.],
[-3., -1.],
[-3., 0.],
[-3., 1.],
[-3., 2.],
[-3., 3.],
[-3., 4.]],
[[-2., -4.],
[-2., -3.],
[-2., -2.],
[-2., -1.],
[-2., 0.],
[-2., 1.],
[-2., 2.],
[-2., 3.],
[-2., 4.]],
[[-1., -4.],
[-1., -3.],
[-1., -2.],
[-1., -1.],
[-1., 0.],
[-1., 1.],
[-1., 2.],
[-1., 3.],
[-1., 4.]],
[[ 0., -4.],
[ 0., -3.],
[ 0., -2.],
[ 0., -1.],
[ 0., 0.],
[ 0., 1.],
[ 0., 2.],
[ 0., 3.],
[ 0., 4.]],
[[ 1., -4.],
[ 1., -3.],
[ 1., -2.],
[ 1., -1.],
[ 1., 0.],
[ 1., 1.],
[ 1., 2.],
[ 1., 3.],
[ 1., 4.]],
[[ 2., -4.],
[ 2., -3.],
[ 2., -2.],
[ 2., -1.],
[ 2., 0.],
[ 2., 1.],
[ 2., 2.],
[ 2., 3.],
[ 2., 4.]],
[[ 3., -4.],
[ 3., -3.],
[ 3., -2.],
[ 3., -1.],
[ 3., 0.],
[ 3., 1.],
[ 3., 2.],
[ 3., 3.],
[ 3., 4.]],
[[ 4., -4.],
[ 4., -3.],
[ 4., -2.],
[ 4., -1.],
[ 4., 0.],
[ 4., 1.],
[ 4., 2.],
[ 4., 3.],
[ 4., 4.]]],
[[[-3., -4.],
[-3., -3.],
[-3., -2.],
[-3., -1.],
[-3., 0.],
[-3., 1.],
[-3., 2.],
[-3., 3.],
[-3., 4.]],
[[-2., -4.],
[-2., -3.],
[-2., -2.],
[-2., -1.],
[-2., 0.],
[-2., 1.],
[-2., 2.],
[-2., 3.],
[-2., 4.]],
[[-1., -4.],
[-1., -3.],
[-1., -2.],
[-1., -1.],
[-1., 0.],
[-1., 1.],
[-1., 2.],
[-1., 3.],
[-1., 4.]],
[[ 0., -4.],
[ 0., -3.],
[ 0., -2.],
[ 0., -1.],
[ 0., 0.],
[ 0., 1.],
[ 0., 2.],
[ 0., 3.],
[ 0., 4.]],
[[ 1., -4.],
[ 1., -3.],
[ 1., -2.],
[ 1., -1.],
[ 1., 0.],
[ 1., 1.],
[ 1., 2.],
[ 1., 3.],
[ 1., 4.]],
[[ 2., -4.],
[ 2., -3.],
[ 2., -2.],
[ 2., -1.],
[ 2., 0.],
[ 2., 1.],
[ 2., 2.],
[ 2., 3.],
[ 2., 4.]],
[[ 3., -4.],
[ 3., -3.],
[ 3., -2.],
[ 3., -1.],
[ 3., 0.],
[ 3., 1.],
[ 3., 2.],
[ 3., 3.],
[ 3., 4.]],
[[ 4., -4.],
[ 4., -3.],
[ 4., -2.],
[ 4., -1.],
[ 4., 0.],
[ 4., 1.],
[ 4., 2.],
[ 4., 3.],
[ 4., 4.]],
[[ 5., -4.],
[ 5., -3.],
[ 5., -2.],
[ 5., -1.],
[ 5., 0.],
[ 5., 1.],
[ 5., 2.],
[ 5., 3.],
[ 5., 4.]]],
[[[-2., -4.],
[-2., -3.],
[-2., -2.],
[-2., -1.],
[-2., 0.],
[-2., 1.],
[-2., 2.],
[-2., 3.],
[-2., 4.]],
[[-1., -4.],
[-1., -3.],
[-1., -2.],
[-1., -1.],
[-1., 0.],
[-1., 1.],
[-1., 2.],
[-1., 3.],
[-1., 4.]],
[[ 0., -4.],
[ 0., -3.],
[ 0., -2.],
[ 0., -1.],
[ 0., 0.],
[ 0., 1.],
[ 0., 2.],
[ 0., 3.],
[ 0., 4.]],
[[ 1., -4.],
[ 1., -3.],
[ 1., -2.],
[ 1., -1.],
[ 1., 0.],
[ 1., 1.],
[ 1., 2.],
[ 1., 3.],
[ 1., 4.]],
[[ 2., -4.],
[ 2., -3.],
[ 2., -2.],
[ 2., -1.],
[ 2., 0.],
[ 2., 1.],
[ 2., 2.],
[ 2., 3.],
[ 2., 4.]],
[[ 3., -4.],
[ 3., -3.],
[ 3., -2.],
[ 3., -1.],
[ 3., 0.],
[ 3., 1.],
[ 3., 2.],
[ 3., 3.],
[ 3., 4.]],
[[ 4., -4.],
[ 4., -3.],
[ 4., -2.],
[ 4., -1.],
[ 4., 0.],
[ 4., 1.],
[ 4., 2.],
[ 4., 3.],
[ 4., 4.]],
[[ 5., -4.],
[ 5., -3.],
[ 5., -2.],
[ 5., -1.],
[ 5., 0.],
[ 5., 1.],
[ 5., 2.],
[ 5., 3.],
[ 5., 4.]],
[[ 6., -4.],
[ 6., -3.],
[ 6., -2.],
[ 6., -1.],
[ 6., 0.],
[ 6., 1.],
[ 6., 2.],
[ 6., 3.],
[ 6., 4.]]],
[[[-4., -3.],
[-4., -2.],
[-4., -1.],
[-4., 0.],
[-4., 1.],
[-4., 2.],
[-4., 3.],
[-4., 4.],
[-4., 5.]],
[[-3., -3.],
[-3., -2.],
[-3., -1.],
[-3., 0.],
[-3., 1.],
[-3., 2.],
[-3., 3.],
[-3., 4.],
[-3., 5.]],
[[-2., -3.],
[-2., -2.],
[-2., -1.],
[-2., 0.],
[-2., 1.],
[-2., 2.],
[-2., 3.],
[-2., 4.],
[-2., 5.]],
[[-1., -3.],
[-1., -2.],
[-1., -1.],
[-1., 0.],
[-1., 1.],
[-1., 2.],
[-1., 3.],
[-1., 4.],
[-1., 5.]],
[[ 0., -3.],
[ 0., -2.],
[ 0., -1.],
[ 0., 0.],
[ 0., 1.],
[ 0., 2.],
[ 0., 3.],
[ 0., 4.],
[ 0., 5.]],
[[ 1., -3.],
[ 1., -2.],
[ 1., -1.],
[ 1., 0.],
[ 1., 1.],
[ 1., 2.],
[ 1., 3.],
[ 1., 4.],
[ 1., 5.]],
[[ 2., -3.],
[ 2., -2.],
[ 2., -1.],
[ 2., 0.],
[ 2., 1.],
[ 2., 2.],
[ 2., 3.],
[ 2., 4.],
[ 2., 5.]],
[[ 3., -3.],
[ 3., -2.],
[ 3., -1.],
[ 3., 0.],
[ 3., 1.],
[ 3., 2.],
[ 3., 3.],
[ 3., 4.],
[ 3., 5.]],
[[ 4., -3.],
[ 4., -2.],
[ 4., -1.],
[ 4., 0.],
[ 4., 1.],
[ 4., 2.],
[ 4., 3.],
[ 4., 4.],
[ 4., 5.]]],
[[[-3., -3.],
[-3., -2.],
[-3., -1.],
[-3., 0.],
[-3., 1.],
[-3., 2.],
[-3., 3.],
[-3., 4.],
[-3., 5.]],
[[-2., -3.],
[-2., -2.],
[-2., -1.],
[-2., 0.],
[-2., 1.],
[-2., 2.],
[-2., 3.],
[-2., 4.],
[-2., 5.]],
[[-1., -3.],
[-1., -2.],
[-1., -1.],
[-1., 0.],
[-1., 1.],
[-1., 2.],
[-1., 3.],
[-1., 4.],
[-1., 5.]],
[[ 0., -3.],
[ 0., -2.],
[ 0., -1.],
[ 0., 0.],
[ 0., 1.],
[ 0., 2.],
[ 0., 3.],
[ 0., 4.],
[ 0., 5.]],
[[ 1., -3.],
[ 1., -2.],
[ 1., -1.],
[ 1., 0.],
[ 1., 1.],
[ 1., 2.],
[ 1., 3.],
[ 1., 4.],
[ 1., 5.]],
[[ 2., -3.],
[ 2., -2.],
[ 2., -1.],
[ 2., 0.],
[ 2., 1.],
[ 2., 2.],
[ 2., 3.],
[ 2., 4.],
[ 2., 5.]],
[[ 3., -3.],
[ 3., -2.],
[ 3., -1.],
[ 3., 0.],
[ 3., 1.],
[ 3., 2.],
[ 3., 3.],
[ 3., 4.],
[ 3., 5.]],
[[ 4., -3.],
[ 4., -2.],
[ 4., -1.],
[ 4., 0.],
[ 4., 1.],
[ 4., 2.],
[ 4., 3.],
[ 4., 4.],
[ 4., 5.]],
[[ 5., -3.],
[ 5., -2.],
[ 5., -1.],
[ 5., 0.],
[ 5., 1.],
[ 5., 2.],
[ 5., 3.],
[ 5., 4.],
[ 5., 5.]]],
[[[-2., -3.],
[-2., -2.],
[-2., -1.],
[-2., 0.],
[-2., 1.],
[-2., 2.],
[-2., 3.],
[-2., 4.],
[-2., 5.]],
[[-1., -3.],
[-1., -2.],
[-1., -1.],
[-1., 0.],
[-1., 1.],
[-1., 2.],
[-1., 3.],
[-1., 4.],
[-1., 5.]],
[[ 0., -3.],
[ 0., -2.],
[ 0., -1.],
[ 0., 0.],
[ 0., 1.],
[ 0., 2.],
[ 0., 3.],
[ 0., 4.],
[ 0., 5.]],
[[ 1., -3.],
[ 1., -2.],
[ 1., -1.],
[ 1., 0.],
[ 1., 1.],
[ 1., 2.],
[ 1., 3.],
[ 1., 4.],
[ 1., 5.]],
[[ 2., -3.],
[ 2., -2.],
[ 2., -1.],
[ 2., 0.],
[ 2., 1.],
[ 2., 2.],
[ 2., 3.],
[ 2., 4.],
[ 2., 5.]],
[[ 3., -3.],
[ 3., -2.],
[ 3., -1.],
[ 3., 0.],
[ 3., 1.],
[ 3., 2.],
[ 3., 3.],
[ 3., 4.],
[ 3., 5.]],
[[ 4., -3.],
[ 4., -2.],
[ 4., -1.],
[ 4., 0.],
[ 4., 1.],
[ 4., 2.],
[ 4., 3.],
[ 4., 4.],
[ 4., 5.]],
[[ 5., -3.],
[ 5., -2.],
[ 5., -1.],
[ 5., 0.],
[ 5., 1.],
[ 5., 2.],
[ 5., 3.],
[ 5., 4.],
[ 5., 5.]],
[[ 6., -3.],
[ 6., -2.],
[ 6., -1.],
[ 6., 0.],
[ 6., 1.],
[ 6., 2.],
[ 6., 3.],
[ 6., 4.],
[ 6., 5.]]]])
torch.Size([6, 9, 9, 2])
[머신러닝] 파이토치 텐서 (tensor)의 연산 — 무하지 (tistory.com)
[머신러닝] 파이토치 텐서 (tensor)의 연산
1. 텐서 기초 a. 텐서 선언하기 1차원 텐서 tensor = torch.FloatTensor([0., 1., 2., 3.]) 2차원 텐서 tensor = torch.FloatTensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.], [10., 11., 12.]]) print(tensor.dim()) print(tensor.size()) print(tensor
lwamuhaji.tistory.com
텐서의 덧셈에 대해선 위에서 잘 나와있음.
요약하자면, (6, 1, 1, 2)와 (1, 9, 9, 2)를 더하는 경우, 각 shape element의 max만큼 맞춰진다는 것(=(6, 9, 9, 2))
따라서 (6, 1, 1, 2)는 똑같은 (2)가 8개가 더 추가되고, 이렇게 만들어진 (9, 2)가 8개 추가된다.
(1, 9, 9, 2)는 (9, 9, 2)가 5개 추가된다.
다음으로는 bilinear_sampler을 실행한다. 이 함수는 아래와 같음(여기서 coords 파라미터는 위의 (6, 9, 9, 2)짜리 coords_lvl이고, img는 corr 파라미터, 즉 pyramid[i])
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:] # h2, w2
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
xgrid와 ygrid의 size는, split([1, 1], dim=-1) 실행후에 [32400, 9, 9, 1]이 된다. 마지막의 2를 하나씩 나눠가지기 때문에.
다음으로 정규화를 해주는데, 이 때 delta_lvl로 인해 (W-1)로 나눠서 정규화를 시도해도 (-1 ~ 1)의 값이 나오지 않을 지도 모른다. 이건 직접 코드를 돌려봐야 할 듯.
중요한 건 F.grid_sample임. 이는 일종의 보간법임.
pytorch 기본 문법 및 코드, 팁 snippets - gaussian37
in_data = torch.arange(4*4).view(1, 1, 4, 4).float()
print(in_data)
# tensor([[[[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.],
# [12., 13., 14., 15.]]]])
d = torch.linspace(-1, 1, 4)
meshx, meshy = torch.meshgrid((d, d))
grid = torch.stack((meshy, meshx), 2)
grid = grid.unsqueeze(0)
print(grid.shape)
# torch.Size([1, 4, 4, 2])
print(grid)
# tensor([[[[-1.0000, -1.0000],
# [-0.3333, -1.0000],
# [ 0.3333, -1.0000],
# [ 1.0000, -1.0000]],
# [[-1.0000, -0.3333],
# [-0.3333, -0.3333],
# [ 0.3333, -0.3333],
# [ 1.0000, -0.3333]],
# [[-1.0000, 0.3333],
# [-0.3333, 0.3333],
# [ 0.3333, 0.3333],
# [ 1.0000, 0.3333]],
# [[-1.0000, 1.0000],
# [-0.3333, 1.0000],
# [ 0.3333, 1.0000],
# [ 1.0000, 1.0000]]]])
output1 = torch.nn.functional.grid_sample(in_data, grid)
print(output1)
# tensor([[[[ 0.0000, 0.4167, 1.0833, 0.7500],
# [ 1.6667, 4.1667, 5.5000, 3.1667],
# [ 4.3333, 9.5000, 10.8333, 5.8333],
# [ 3.0000, 6.4167, 7.0833, 3.7500]]]])
output2 = torch.nn.functional.grid_sample(in_data, grid, align_corners=True)
print(output2)
# tensor([[[[ 0.0000, 1.0000, 2.0000, 3.0000],
# [ 4.0000, 5.0000, 6.0000, 7.0000],
# [ 8.0000, 9.0000, 10.0000, 11.0000],
# [12.0000, 13.0000, 14.0000, 15.0000]]]])
grid는 x y 순으로 살핌. 즉, [-1, -0.3]은 x축 맨 왼쪽, y축 두번째 것을 살피므로 원본 in_data의 [1, 0]인 4를 탐색.
그렇다면 grid가 (-1, 1)의 값이 아닌 그 미만/초과 값이 들어가면
tensor([[[[-1.5000, -1.5000],
[-0.5000, -1.5000],
[ 0.5000, -1.5000],
[ 1.5000, -1.5000]],
[[-1.5000, -0.5000],
[-0.5000, -0.5000],
[ 0.5000, -0.5000],
[ 1.5000, -0.5000]],
[[-1.5000, 0.5000],
[-0.5000, 0.5000],
[ 0.5000, 0.5000],
[ 1.5000, 0.5000]],
[[-1.5000, 1.5000],
[-0.5000, 1.5000],
[ 0.5000, 1.5000],
[ 1.5000, 1.5000]]]])
---------------------------------------
tensor([[[[ 0.0000, 0.1875, 0.5625, 0.1875],
[ 0.7500, 3.7500, 5.2500, 1.5000],
[ 2.2500, 9.7500, 11.2500, 3.0000],
[ 0.7500, 3.1875, 3.5625, 0.9375]]]])
[-1.5, -0.5]는 0.75가 나옴.
update.py의 class SepConvGRU를 보면 (5,5) 컨볼루션을 분해해 (1,5)와 (5,1)로 함.
이는 성능의 감소를 일으키지 않으면서 연산량이 감소된다. 그런데 다들 (5,5)를 대신 쓰는 이유는 아래와 같다.
(5,5) 커널 또는 (1,5) 및 (5,1) 커널의 조합을 사용하는 것 중에서 선택하는 것은 특정 작업, 입력 데이터의 특성 및 모델의 전체 아키텍처에 따라 달라집니다.
분리 가능한 컨볼루션(예: (1,5) 및 (5,1) 커널의 조합)을 사용하면 매개 변수와 계산 수를 줄일 수 있지만 모든 작업 또는 모델에 대해 항상 최선의 선택은 아닙니다. 다음은 (5,5) 커널을 대신 사용하는 몇 가지 이유입니다:
1. **다른 기능 추출**: (5,5) 커널과 (1,5) 및 (5,1) 커널의 조합은 정확히 동일한 기능을 추출하지 않습니다. (5,5) 커널은 (1,5) 및 (5,1) 커널의 조합이 수평 및 수직 패턴을 개별적으로 캡처하는 동안 단일 단계에서 2D 공간 패턴을 캡처할 수 있습니다. 작업 및 데이터에 따라 한 가지 접근 방식이 다른 방식보다 더 효과적일 수 있습니다.
2. **모델 단순성**: (5,5) 커널을 사용하면 모델을 더 간단하고 쉽게 이해할 수 있습니다. 분리 가능한 컨볼루션을 사용하면 모델을 더 효율적으로 만들 수 있지만 모델을 더 복잡하고 해석하기 어렵게 만들 수도 있습니다.
3. **구현 세부 사항**: 일부 딥 러닝 프레임워크 및 하드웨어 가속기는 표준 컨볼루션에 최적화되어 있으며 분리 가능한 컨볼루션을 효율적으로 지원하지 않을 수 있습니다.
4. **경험적 결과**: 궁극적으로 모델 아키텍처의 선택은 경험적 결과를 기반으로 하는 경우가 많습니다. 연구자가 (5,5) 커널이 있는 모델이 작업에서 더 나은 성능을 발휘한다는 사실을 알게 되면, 효율성이 떨어지더라도 해당 모델을 사용하기로 선택할 수 있습니다.
결국, (5,5) 커널과 (1,5) 및 (5,1) 커널의 조합 사이의 선택은 효율성과 작업 및 모델의 특정 요구 사항 간의 균형입니다. 그것은 시행착오와 당면한 과제의 구체적인 세부 사항에 기초한 경험적인 결정입니다.
즉, optical flow는 x와 y의 방향을 캡처하는 것이 목표이므로 이게 나을지도 모르지만, object detection은 정사각형의 커널이 더 나을지도 모름.
torch.meshgrid를 통해 만들어진 centroid_lvl의 모양새(corr.py 참조).
이렇게 (x, y) 좌표를 갖고 있음. 즉, [14260, 1, 1, 2]의 형태는 각각 14260(46 * 62 * batch) 개의 픽셀이 [1, 1, 2] 형태의 좌표를 갖고 있는 것.
그리고 이 [1, 1] 부분을 [9, 9]로 확장시킬 것임. 즉, [0, 0]의 좌표가 다음 프레임에 존재할 수 있는 위치좌표를 넣는 것.
논문에서는 하나의 픽셀은 한 프레임 당 최대 4픽셀을 초과해 움직일 수 없다고 봤다. 즉, 이전 프레임의 (0,0) 픽셀이 다음 프레임에서 있을 수 있는 위치는 (-4, -4) 부터 (4, 4)까지 총 9x9개다.
따라서 [14260,9, 9, 2]는 하나의 픽셀이 존재할 가능성이 있는 91개의 좌표를 들고 있는 형태라 할 수 있다.