동방프로젝트

[Optical Flow] RAFT 이해를 위한 여러 함수 정리글

두원공대88학번뚜뚜 2023. 3. 24. 20:47

torch.meshgrid 정리 - raft.py의 

coords0, coords1 = self.initialize_flow(image1)

를 실행할 경우 행해지는 함수들

width, height = 128, 64
coords = torch.meshgrid(torch.arange(width//8), torch.arange(height//8))    # 16, 8
print(coords)
print('------------------------------')
print(coords[::-1])
print('------------------------------')
coords = torch.stack(coords[::-1], dim=0).float()
# print(coords)
print('------------------------------')
# print(coords[None].repeat(1, 1, 1, 1))

위를 진행하면, 우선 torch.meshgrid는 다음과 같은 결과를 냄.

(tensor([[ 0,  0,  0,  0,  0,  0,  0,  0],
        [ 1,  1,  1,  1,  1,  1,  1,  1],
        [ 2,  2,  2,  2,  2,  2,  2,  2],
        [ 3,  3,  3,  3,  3,  3,  3,  3],
        [ 4,  4,  4,  4,  4,  4,  4,  4],
        [ 5,  5,  5,  5,  5,  5,  5,  5],
        [ 6,  6,  6,  6,  6,  6,  6,  6],
        [ 7,  7,  7,  7,  7,  7,  7,  7],
        [ 8,  8,  8,  8,  8,  8,  8,  8],
        [ 9,  9,  9,  9,  9,  9,  9,  9],
        [10, 10, 10, 10, 10, 10, 10, 10],
        [11, 11, 11, 11, 11, 11, 11, 11],
        [12, 12, 12, 12, 12, 12, 12, 12],
        [13, 13, 13, 13, 13, 13, 13, 13],
        [14, 14, 14, 14, 14, 14, 14, 14],
        [15, 15, 15, 15, 15, 15, 15, 15]]), tensor([[0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7]]))

즉, (16, 8)짜리 행이 같은 숫자인 배열, 열이 같은 숫자인 배열 두 개를 만들어냄. shape는 (16, 8)

 

torch.stack([::-1], dim=0)은 아래의 결과를 도출.

tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.]],

        [[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
         [ 2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.],
         [ 3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.],
         [ 4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.],
         [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.],
         [ 6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.],
         [ 7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.],
         [ 8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.],
         [ 9.,  9.,  9.,  9.,  9.,  9.,  9.,  9.],
         [10., 10., 10., 10., 10., 10., 10., 10.],
         [11., 11., 11., 11., 11., 11., 11., 11.],
         [12., 12., 12., 12., 12., 12., 12., 12.],
         [13., 13., 13., 13., 13., 13., 13., 13.],
         [14., 14., 14., 14., 14., 14., 14., 14.],
         [15., 15., 15., 15., 15., 15., 15., 15.]]])

-1, 즉 역순(reverse)으로 해당 텐서를 stack함.

마지막으로 coords[None].repeat(batch, 1, 1, 1)을 실행하면

tensor([[[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.]],

         [[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
          [ 2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.],
          [ 3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.],
          [ 4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.],
          [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.],
          [ 6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.],
          [ 7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.],
          [ 8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.],
          [ 9.,  9.,  9.,  9.,  9.,  9.,  9.,  9.],
          [10., 10., 10., 10., 10., 10., 10., 10.],
          [11., 11., 11., 11., 11., 11., 11., 11.],
          [12., 12., 12., 12., 12., 12., 12., 12.],
          [13., 13., 13., 13., 13., 13., 13., 13.],
          [14., 14., 14., 14., 14., 14., 14., 14.],
          [15., 15., 15., 15., 15., 15., 15., 15.]]]])

그냥 배치만큼 스택하고 끝임. 이거로 flow map을 초기화하는듯.

[None] == np.newaxis

 

그럼 여기서 궁금증이 생긴다.

1) coords[None]의 의미 >> 실험결과, coords.repeat(...)와 똑같음.

이 때 주의할 점 :

flow_loss += i_weight * (valid[:, None] * i_loss).mean()

이 경우, 두 번째 차원에 추가됨. 즉,

n[:, None] == n.unsqueeze(dim=1)

 

2) repeat(a, b, c, d) >> 각 차원을 얼마나 반복하는지. 즉 (1,1,1,3)이면 [0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7]이 됨.

 

아무튼 이렇게 똑같은 두 개의 배열을 만들어냄(각각 coords0, coords1).

2개를 만드는 이유는, 각각 x방향, y방향이라.

 

그러면 다음으론, 아래가 실행되는데, 우선 coor_fn(coords1)을 살펴보겠다. 이건 아무래도 논문의 3.3.Upsampling으로 보임.

for itr in range(iters):
    coords1 = coords1.detach()
    corr = corr_fn(coords1) # index correlation volume, def __call__ in CorrBlock class

    flow = coords1 - coords0
    with autocast(enabled=self.args.mixed_precision):
        net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)

    # F(t+1) = F(t) + \Delta(t)
    coords1 = coords1 + delta_flow

    # upsample predictions
    if up_mask is None:
        flow_up = upflow8(coords1 - coords0)
    else:
        flow_up = self.upsample_flow(coords1 - coords0, up_mask)
            
    flow_predictions.append(flow_up)
    
if test_mode:
    return coords1 - coords0, flow_up
            
return flow_predictions

 

coor_fn을 실행하면 아래와 같은 것을 수행한다.

    def __call__(self, coords):
        r = self.radius # 4
        # batch, 2, w//8, h//8 --> batch, w//8, h//8, 2
        coords = coords.permute(0, 2, 3, 1)
        batch, h1, w1, _ = coords.shape

        out_pyramid = []
        for i in range(self.num_levels):
            corr = self.corr_pyramid[i] # get i-th pyramid got from __init__, correlation map
            # [-4 ~ 4], size = 9. so, [-4, -3, -2, .... 3, 4]
            # maybe it would has concerned with RAFT 3.3.Upsampling(use 9 neighbors to make original size predict map)
            dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
            dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
            delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)

            centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
            delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
            coords_lvl = centroid_lvl + delta_lvl

            corr = bilinear_sampler(corr, coords_lvl)
            corr = corr.view(batch, h1, w1, -1)
            out_pyramid.append(corr)

        out = torch.cat(out_pyramid, dim=-1)
        return out.permute(0, 3, 1, 2).contiguous().float()

delta는 다음과 같다. axis = -1이기에, 아래와 같은 형태가 됨. shape는 (9, 9, 2)

tensor([[[-4., -4.],
         [-4., -3.],
         [-4., -2.],
         [-4., -1.],
         [-4.,  0.],
         [-4.,  1.],
         [-4.,  2.],
         [-4.,  3.],
         [-4.,  4.]],

        [[-3., -4.],
         [-3., -3.],
         [-3., -2.],
         [-3., -1.],
         [-3.,  0.],
         [-3.,  1.],
         [-3.,  2.],
         [-3.,  3.],
         [-3.,  4.]],

        [[-2., -4.],
         [-2., -3.],
         [-2., -2.],
         [-2., -1.],
         [-2.,  0.],
         [-2.,  1.],
         [-2.,  2.],
         [-2.,  3.],
         [-2.,  4.]],

        [[-1., -4.],
         [-1., -3.],
         [-1., -2.],
         [-1., -1.],
         [-1.,  0.],
         [-1.,  1.],
         [-1.,  2.],
         [-1.,  3.],
         [-1.,  4.]],

        [[ 0., -4.],
         [ 0., -3.],
         [ 0., -2.],
         [ 0., -1.],
         [ 0.,  0.],
         [ 0.,  1.],
         [ 0.,  2.],
         [ 0.,  3.],
         [ 0.,  4.]],

        [[ 1., -4.],
         [ 1., -3.],
         [ 1., -2.],
         [ 1., -1.],
         [ 1.,  0.],
         [ 1.,  1.],
         [ 1.,  2.],
         [ 1.,  3.],
         [ 1.,  4.]],

        [[ 2., -4.],
         [ 2., -3.],
         [ 2., -2.],
         [ 2., -1.],
         [ 2.,  0.],
         [ 2.,  1.],
         [ 2.,  2.],
         [ 2.,  3.],
         [ 2.,  4.]],

        [[ 3., -4.],
         [ 3., -3.],
         [ 3., -2.],
         [ 3., -1.],
         [ 3.,  0.],
         [ 3.,  1.],
         [ 3.,  2.],
         [ 3.,  3.],
         [ 3.,  4.]],

        [[ 4., -4.],
         [ 4., -3.],
         [ 4., -2.],
         [ 4., -1.],
         [ 4.,  0.],
         [ 4.,  1.],
         [ 4.,  2.],
         [ 4.,  3.],
         [ 4.,  4.]]])

centroid_lvl는 기존의 (batch, w//8, h//8, 2) 배열을 (batch * w//8 * h//8, 1, 1, 2)로 바꿔준다.

다음으로 2**i로 나누는데, 이는 각 피라미드에 적합하게 하기 위함으로 보인다. 

corr = pyramid[i]이므로, 이 pyramid에 적합한 사이즈로 만들기 위해, 각 요소를 2**i로 나눠줌.

 

이를 바탕으로 coords_lvl을 구하면, (1920/8 * 1080/8 * batch, 9, 2)의 shape를 가진 tensor가 나옴.

정확히 어떤 값이 나오는지 알아보기 위해, 아래와 같이 코드를 짜봄.

import torch
import torch.nn.functional as F

# utils.py - coords_grid & raft.py - initialize_flow
H, W = 16, 24
coords = torch.meshgrid(torch.arange(H//8), torch.arange(W//8))
coords = torch.stack(coords[::-1], dim=0).float()
coords = coords[None].repeat(1, 1, 1, 1) # batch = 1이라 가정
print(coords)
print('---------------------------')

# corr_fn - __call__
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
centroid_lvl = coords.reshape(1*h1*w1, 1, 1, 2)

print(centroid_lvl.shape)
print(centroid_lvl)
print('---------------------------')

r = 4
dx = torch.linspace(-r, r, 2*r+1)
dy = torch.linspace(-r, r, 2*r+1)

delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
print(delta_lvl.shape)
print(delta_lvl)
print('---------------------------')

coords_lvl = centroid_lvl + delta_lvl
print(coords_lvl)
print(coords_lvl.shape)

# bilinear_sampler
print('---------------------------')
xgrid, ygrid = coords_lvl.split([1, 1], dim=-1)   # xgrid/ygrid : (bhw x 9 x 9 x 1) cause each (bhw x 9 x 9) is x/y coor

xgrid = 2 * xgrid / (W-1) - 1
ygrid = 2 * ygrid / (H-1) - 1
grid = torch.cat([xgrid, ygrid], dim = -1)
# print(xgrid)
# print(ygrid)
# img = F.grid_sample(img, grid, align_corners = True)

그 결과는 아래와 같다.

torch.Size([6, 1, 1, 2])
tensor([[[[0., 0.]]],
        [[[1., 0.]]],
        [[[2., 0.]]],
        [[[0., 1.]]],
        [[[1., 1.]]],
        [[[2., 1.]]]])
----------------------------------------------------------
torch.Size([1, 9, 9, 2])
tensor([[[[-4., -4.],
          [-4., -3.],
          [-4., -2.],
          [-4., -1.],
          [-4.,  0.],
          [-4.,  1.],
          [-4.,  2.],
          [-4.,  3.],
          [-4.,  4.]],

         [[-3., -4.],
          [-3., -3.],
          [-3., -2.],
          [-3., -1.],
          [-3.,  0.],
          [-3.,  1.],
          [-3.,  2.],
          [-3.,  3.],
          [-3.,  4.]],

         [[-2., -4.],
          [-2., -3.],
          [-2., -2.],
          [-2., -1.],
          [-2.,  0.],
          [-2.,  1.],
          [-2.,  2.],
          [-2.,  3.],
          [-2.,  4.]],

         [[-1., -4.],
          [-1., -3.],
          [-1., -2.],
          [-1., -1.],
          [-1.,  0.],
          [-1.,  1.],
          [-1.,  2.],
          [-1.,  3.],
          [-1.,  4.]],

         [[ 0., -4.],
          [ 0., -3.],
          [ 0., -2.],
          [ 0., -1.],
          [ 0.,  0.],
          [ 0.,  1.],
          [ 0.,  2.],
          [ 0.,  3.],
          [ 0.,  4.]],

         [[ 1., -4.],
          [ 1., -3.],
          [ 1., -2.],
          [ 1., -1.],
          [ 1.,  0.],
          [ 1.,  1.],
          [ 1.,  2.],
          [ 1.,  3.],
          [ 1.,  4.]],

         [[ 2., -4.],
          [ 2., -3.],
          [ 2., -2.],
          [ 2., -1.],
          [ 2.,  0.],
          [ 2.,  1.],
          [ 2.,  2.],
          [ 2.,  3.],
          [ 2.,  4.]],

         [[ 3., -4.],
          [ 3., -3.],
          [ 3., -2.],
          [ 3., -1.],
          [ 3.,  0.],
          [ 3.,  1.],
          [ 3.,  2.],
          [ 3.,  3.],
          [ 3.,  4.]],

         [[ 4., -4.],
          [ 4., -3.],
          [ 4., -2.],
          [ 4., -1.],
          [ 4.,  0.],
          [ 4.,  1.],
          [ 4.,  2.],
          [ 4.,  3.],
          [ 4.,  4.]]]])
----------------------------------------------------------
tensor([[[[-4., -4.],
          [-4., -3.],
          [-4., -2.],
          [-4., -1.],
          [-4.,  0.],
          [-4.,  1.],
          [-4.,  2.],
          [-4.,  3.],
          [-4.,  4.]],

         [[-3., -4.],
          [-3., -3.],
          [-3., -2.],
          [-3., -1.],
          [-3.,  0.],
          [-3.,  1.],
          [-3.,  2.],
          [-3.,  3.],
          [-3.,  4.]],

         [[-2., -4.],
          [-2., -3.],
          [-2., -2.],
          [-2., -1.],
          [-2.,  0.],
          [-2.,  1.],
          [-2.,  2.],
          [-2.,  3.],
          [-2.,  4.]],

         [[-1., -4.],
          [-1., -3.],
          [-1., -2.],
          [-1., -1.],
          [-1.,  0.],
          [-1.,  1.],
          [-1.,  2.],
          [-1.,  3.],
          [-1.,  4.]],

         [[ 0., -4.],
          [ 0., -3.],
          [ 0., -2.],
          [ 0., -1.],
          [ 0.,  0.],
          [ 0.,  1.],
          [ 0.,  2.],
          [ 0.,  3.],
          [ 0.,  4.]],

         [[ 1., -4.],
          [ 1., -3.],
          [ 1., -2.],
          [ 1., -1.],
          [ 1.,  0.],
          [ 1.,  1.],
          [ 1.,  2.],
          [ 1.,  3.],
          [ 1.,  4.]],

         [[ 2., -4.],
          [ 2., -3.],
          [ 2., -2.],
          [ 2., -1.],
          [ 2.,  0.],
          [ 2.,  1.],
          [ 2.,  2.],
          [ 2.,  3.],
          [ 2.,  4.]],

         [[ 3., -4.],
          [ 3., -3.],
          [ 3., -2.],
          [ 3., -1.],
          [ 3.,  0.],
          [ 3.,  1.],
          [ 3.,  2.],
          [ 3.,  3.],
          [ 3.,  4.]],

         [[ 4., -4.],
          [ 4., -3.],
          [ 4., -2.],
          [ 4., -1.],
          [ 4.,  0.],
          [ 4.,  1.],
          [ 4.,  2.],
          [ 4.,  3.],
          [ 4.,  4.]]],


        [[[-3., -4.],
          [-3., -3.],
          [-3., -2.],
          [-3., -1.],
          [-3.,  0.],
          [-3.,  1.],
          [-3.,  2.],
          [-3.,  3.],
          [-3.,  4.]],

         [[-2., -4.],
          [-2., -3.],
          [-2., -2.],
          [-2., -1.],
          [-2.,  0.],
          [-2.,  1.],
          [-2.,  2.],
          [-2.,  3.],
          [-2.,  4.]],

         [[-1., -4.],
          [-1., -3.],
          [-1., -2.],
          [-1., -1.],
          [-1.,  0.],
          [-1.,  1.],
          [-1.,  2.],
          [-1.,  3.],
          [-1.,  4.]],

         [[ 0., -4.],
          [ 0., -3.],
          [ 0., -2.],
          [ 0., -1.],
          [ 0.,  0.],
          [ 0.,  1.],
          [ 0.,  2.],
          [ 0.,  3.],
          [ 0.,  4.]],

         [[ 1., -4.],
          [ 1., -3.],
          [ 1., -2.],
          [ 1., -1.],
          [ 1.,  0.],
          [ 1.,  1.],
          [ 1.,  2.],
          [ 1.,  3.],
          [ 1.,  4.]],

         [[ 2., -4.],
          [ 2., -3.],
          [ 2., -2.],
          [ 2., -1.],
          [ 2.,  0.],
          [ 2.,  1.],
          [ 2.,  2.],
          [ 2.,  3.],
          [ 2.,  4.]],

         [[ 3., -4.],
          [ 3., -3.],
          [ 3., -2.],
          [ 3., -1.],
          [ 3.,  0.],
          [ 3.,  1.],
          [ 3.,  2.],
          [ 3.,  3.],
          [ 3.,  4.]],

         [[ 4., -4.],
          [ 4., -3.],
          [ 4., -2.],
          [ 4., -1.],
          [ 4.,  0.],
          [ 4.,  1.],
          [ 4.,  2.],
          [ 4.,  3.],
          [ 4.,  4.]],

         [[ 5., -4.],
          [ 5., -3.],
          [ 5., -2.],
          [ 5., -1.],
          [ 5.,  0.],
          [ 5.,  1.],
          [ 5.,  2.],
          [ 5.,  3.],
          [ 5.,  4.]]],


        [[[-2., -4.],
          [-2., -3.],
          [-2., -2.],
          [-2., -1.],
          [-2.,  0.],
          [-2.,  1.],
          [-2.,  2.],
          [-2.,  3.],
          [-2.,  4.]],

         [[-1., -4.],
          [-1., -3.],
          [-1., -2.],
          [-1., -1.],
          [-1.,  0.],
          [-1.,  1.],
          [-1.,  2.],
          [-1.,  3.],
          [-1.,  4.]],

         [[ 0., -4.],
          [ 0., -3.],
          [ 0., -2.],
          [ 0., -1.],
          [ 0.,  0.],
          [ 0.,  1.],
          [ 0.,  2.],
          [ 0.,  3.],
          [ 0.,  4.]],

         [[ 1., -4.],
          [ 1., -3.],
          [ 1., -2.],
          [ 1., -1.],
          [ 1.,  0.],
          [ 1.,  1.],
          [ 1.,  2.],
          [ 1.,  3.],
          [ 1.,  4.]],

         [[ 2., -4.],
          [ 2., -3.],
          [ 2., -2.],
          [ 2., -1.],
          [ 2.,  0.],
          [ 2.,  1.],
          [ 2.,  2.],
          [ 2.,  3.],
          [ 2.,  4.]],

         [[ 3., -4.],
          [ 3., -3.],
          [ 3., -2.],
          [ 3., -1.],
          [ 3.,  0.],
          [ 3.,  1.],
          [ 3.,  2.],
          [ 3.,  3.],
          [ 3.,  4.]],

         [[ 4., -4.],
          [ 4., -3.],
          [ 4., -2.],
          [ 4., -1.],
          [ 4.,  0.],
          [ 4.,  1.],
          [ 4.,  2.],
          [ 4.,  3.],
          [ 4.,  4.]],

         [[ 5., -4.],
          [ 5., -3.],
          [ 5., -2.],
          [ 5., -1.],
          [ 5.,  0.],
          [ 5.,  1.],
          [ 5.,  2.],
          [ 5.,  3.],
          [ 5.,  4.]],

         [[ 6., -4.],
          [ 6., -3.],
          [ 6., -2.],
          [ 6., -1.],
          [ 6.,  0.],
          [ 6.,  1.],
          [ 6.,  2.],
          [ 6.,  3.],
          [ 6.,  4.]]],


        [[[-4., -3.],
          [-4., -2.],
          [-4., -1.],
          [-4.,  0.],
          [-4.,  1.],
          [-4.,  2.],
          [-4.,  3.],
          [-4.,  4.],
          [-4.,  5.]],

         [[-3., -3.],
          [-3., -2.],
          [-3., -1.],
          [-3.,  0.],
          [-3.,  1.],
          [-3.,  2.],
          [-3.,  3.],
          [-3.,  4.],
          [-3.,  5.]],

         [[-2., -3.],
          [-2., -2.],
          [-2., -1.],
          [-2.,  0.],
          [-2.,  1.],
          [-2.,  2.],
          [-2.,  3.],
          [-2.,  4.],
          [-2.,  5.]],

         [[-1., -3.],
          [-1., -2.],
          [-1., -1.],
          [-1.,  0.],
          [-1.,  1.],
          [-1.,  2.],
          [-1.,  3.],
          [-1.,  4.],
          [-1.,  5.]],

         [[ 0., -3.],
          [ 0., -2.],
          [ 0., -1.],
          [ 0.,  0.],
          [ 0.,  1.],
          [ 0.,  2.],
          [ 0.,  3.],
          [ 0.,  4.],
          [ 0.,  5.]],

         [[ 1., -3.],
          [ 1., -2.],
          [ 1., -1.],
          [ 1.,  0.],
          [ 1.,  1.],
          [ 1.,  2.],
          [ 1.,  3.],
          [ 1.,  4.],
          [ 1.,  5.]],

         [[ 2., -3.],
          [ 2., -2.],
          [ 2., -1.],
          [ 2.,  0.],
          [ 2.,  1.],
          [ 2.,  2.],
          [ 2.,  3.],
          [ 2.,  4.],
          [ 2.,  5.]],

         [[ 3., -3.],
          [ 3., -2.],
          [ 3., -1.],
          [ 3.,  0.],
          [ 3.,  1.],
          [ 3.,  2.],
          [ 3.,  3.],
          [ 3.,  4.],
          [ 3.,  5.]],

         [[ 4., -3.],
          [ 4., -2.],
          [ 4., -1.],
          [ 4.,  0.],
          [ 4.,  1.],
          [ 4.,  2.],
          [ 4.,  3.],
          [ 4.,  4.],
          [ 4.,  5.]]],


        [[[-3., -3.],
          [-3., -2.],
          [-3., -1.],
          [-3.,  0.],
          [-3.,  1.],
          [-3.,  2.],
          [-3.,  3.],
          [-3.,  4.],
          [-3.,  5.]],

         [[-2., -3.],
          [-2., -2.],
          [-2., -1.],
          [-2.,  0.],
          [-2.,  1.],
          [-2.,  2.],
          [-2.,  3.],
          [-2.,  4.],
          [-2.,  5.]],

         [[-1., -3.],
          [-1., -2.],
          [-1., -1.],
          [-1.,  0.],
          [-1.,  1.],
          [-1.,  2.],
          [-1.,  3.],
          [-1.,  4.],
          [-1.,  5.]],

         [[ 0., -3.],
          [ 0., -2.],
          [ 0., -1.],
          [ 0.,  0.],
          [ 0.,  1.],
          [ 0.,  2.],
          [ 0.,  3.],
          [ 0.,  4.],
          [ 0.,  5.]],

         [[ 1., -3.],
          [ 1., -2.],
          [ 1., -1.],
          [ 1.,  0.],
          [ 1.,  1.],
          [ 1.,  2.],
          [ 1.,  3.],
          [ 1.,  4.],
          [ 1.,  5.]],

         [[ 2., -3.],
          [ 2., -2.],
          [ 2., -1.],
          [ 2.,  0.],
          [ 2.,  1.],
          [ 2.,  2.],
          [ 2.,  3.],
          [ 2.,  4.],
          [ 2.,  5.]],

         [[ 3., -3.],
          [ 3., -2.],
          [ 3., -1.],
          [ 3.,  0.],
          [ 3.,  1.],
          [ 3.,  2.],
          [ 3.,  3.],
          [ 3.,  4.],
          [ 3.,  5.]],

         [[ 4., -3.],
          [ 4., -2.],
          [ 4., -1.],
          [ 4.,  0.],
          [ 4.,  1.],
          [ 4.,  2.],
          [ 4.,  3.],
          [ 4.,  4.],
          [ 4.,  5.]],

         [[ 5., -3.],
          [ 5., -2.],
          [ 5., -1.],
          [ 5.,  0.],
          [ 5.,  1.],
          [ 5.,  2.],
          [ 5.,  3.],
          [ 5.,  4.],
          [ 5.,  5.]]],


        [[[-2., -3.],
          [-2., -2.],
          [-2., -1.],
          [-2.,  0.],
          [-2.,  1.],
          [-2.,  2.],
          [-2.,  3.],
          [-2.,  4.],
          [-2.,  5.]],

         [[-1., -3.],
          [-1., -2.],
          [-1., -1.],
          [-1.,  0.],
          [-1.,  1.],
          [-1.,  2.],
          [-1.,  3.],
          [-1.,  4.],
          [-1.,  5.]],

         [[ 0., -3.],
          [ 0., -2.],
          [ 0., -1.],
          [ 0.,  0.],
          [ 0.,  1.],
          [ 0.,  2.],
          [ 0.,  3.],
          [ 0.,  4.],
          [ 0.,  5.]],

         [[ 1., -3.],
          [ 1., -2.],
          [ 1., -1.],
          [ 1.,  0.],
          [ 1.,  1.],
          [ 1.,  2.],
          [ 1.,  3.],
          [ 1.,  4.],
          [ 1.,  5.]],

         [[ 2., -3.],
          [ 2., -2.],
          [ 2., -1.],
          [ 2.,  0.],
          [ 2.,  1.],
          [ 2.,  2.],
          [ 2.,  3.],
          [ 2.,  4.],
          [ 2.,  5.]],

         [[ 3., -3.],
          [ 3., -2.],
          [ 3., -1.],
          [ 3.,  0.],
          [ 3.,  1.],
          [ 3.,  2.],
          [ 3.,  3.],
          [ 3.,  4.],
          [ 3.,  5.]],

         [[ 4., -3.],
          [ 4., -2.],
          [ 4., -1.],
          [ 4.,  0.],
          [ 4.,  1.],
          [ 4.,  2.],
          [ 4.,  3.],
          [ 4.,  4.],
          [ 4.,  5.]],

         [[ 5., -3.],
          [ 5., -2.],
          [ 5., -1.],
          [ 5.,  0.],
          [ 5.,  1.],
          [ 5.,  2.],
          [ 5.,  3.],
          [ 5.,  4.],
          [ 5.,  5.]],

         [[ 6., -3.],
          [ 6., -2.],
          [ 6., -1.],
          [ 6.,  0.],
          [ 6.,  1.],
          [ 6.,  2.],
          [ 6.,  3.],
          [ 6.,  4.],
          [ 6.,  5.]]]])
torch.Size([6, 9, 9, 2])

[머신러닝] 파이토치 텐서 (tensor)의 연산 — 무하지 (tistory.com)

 

[머신러닝] 파이토치 텐서 (tensor)의 연산

1. 텐서 기초 a. 텐서 선언하기 1차원 텐서 tensor = torch.FloatTensor([0., 1., 2., 3.]) 2차원 텐서 tensor = torch.FloatTensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.], [10., 11., 12.]]) print(tensor.dim()) print(tensor.size()) print(tensor

lwamuhaji.tistory.com

텐서의 덧셈에 대해선 위에서 잘 나와있음.

요약하자면, (6, 1, 1, 2)와 (1, 9, 9, 2)를 더하는 경우, 각 shape element의 max만큼 맞춰진다는 것(=(6, 9, 9, 2))

따라서 (6, 1, 1, 2)는 똑같은 (2)가 8개가 더 추가되고, 이렇게 만들어진 (9, 2)가 8개 추가된다.

(1, 9, 9, 2)는 (9, 9, 2)가 5개 추가된다.

 

다음으로는 bilinear_sampler을 실행한다. 이 함수는 아래와 같음(여기서 coords 파라미터는 위의 (6, 9, 9, 2)짜리 coords_lvl이고, img는 corr 파라미터, 즉 pyramid[i])

def bilinear_sampler(img, coords, mode='bilinear', mask=False):
    """ Wrapper for grid_sample, uses pixel coordinates """
    H, W = img.shape[-2:]   # h2, w2
    xgrid, ygrid = coords.split([1,1], dim=-1)
    xgrid = 2*xgrid/(W-1) - 1
    ygrid = 2*ygrid/(H-1) - 1

    grid = torch.cat([xgrid, ygrid], dim=-1)
    img = F.grid_sample(img, grid, align_corners=True)

    if mask:
        mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
        return img, mask.float()

    return img

 

xgrid와 ygrid의 size는, split([1, 1], dim=-1) 실행후에 [32400, 9, 9, 1]이 된다. 마지막의 2를 하나씩 나눠가지기 때문에.

다음으로 정규화를 해주는데, 이 때 delta_lvl로 인해 (W-1)로 나눠서 정규화를 시도해도 (-1 ~ 1)의 값이 나오지 않을 지도 모른다. 이건 직접 코드를 돌려봐야 할 듯.

 

중요한 건 F.grid_sample임. 이는 일종의 보간법임. 

pytorch 기본 문법 및 코드, 팁 snippets - gaussian37

in_data = torch.arange(4*4).view(1, 1, 4, 4).float()
print(in_data)
# tensor([[[[ 0.,  1.,  2.,  3.],
#           [ 4.,  5.,  6.,  7.],
#           [ 8.,  9., 10., 11.],
#           [12., 13., 14., 15.]]]])

d = torch.linspace(-1, 1, 4)
meshx, meshy = torch.meshgrid((d, d))
grid = torch.stack((meshy, meshx), 2)
grid = grid.unsqueeze(0)
print(grid.shape)
# torch.Size([1, 4, 4, 2])
print(grid)
# tensor([[[[-1.0000, -1.0000],
#           [-0.3333, -1.0000],
#           [ 0.3333, -1.0000],
#           [ 1.0000, -1.0000]],

#          [[-1.0000, -0.3333],
#           [-0.3333, -0.3333],
#           [ 0.3333, -0.3333],
#           [ 1.0000, -0.3333]],

#          [[-1.0000,  0.3333],
#           [-0.3333,  0.3333],
#           [ 0.3333,  0.3333],
#           [ 1.0000,  0.3333]],

#          [[-1.0000,  1.0000],
#           [-0.3333,  1.0000],
#           [ 0.3333,  1.0000],
#           [ 1.0000,  1.0000]]]])

output1 = torch.nn.functional.grid_sample(in_data, grid)
print(output1)
# tensor([[[[ 0.0000,  0.4167,  1.0833,  0.7500],
#           [ 1.6667,  4.1667,  5.5000,  3.1667],
#           [ 4.3333,  9.5000, 10.8333,  5.8333],
#           [ 3.0000,  6.4167,  7.0833,  3.7500]]]])

output2 = torch.nn.functional.grid_sample(in_data, grid, align_corners=True)
print(output2)
# tensor([[[[ 0.0000,  1.0000,  2.0000,  3.0000],
#           [ 4.0000,  5.0000,  6.0000,  7.0000],
#           [ 8.0000,  9.0000, 10.0000, 11.0000],
#           [12.0000, 13.0000, 14.0000, 15.0000]]]])

grid는 x y 순으로 살핌. 즉, [-1, -0.3]은 x축 맨 왼쪽, y축 두번째 것을 살피므로 원본 in_data의 [1, 0]인 4를 탐색.

그렇다면 grid가 (-1, 1)의 값이 아닌 그 미만/초과 값이 들어가면

tensor([[[[-1.5000, -1.5000],
          [-0.5000, -1.5000],
          [ 0.5000, -1.5000],
          [ 1.5000, -1.5000]],

         [[-1.5000, -0.5000],
          [-0.5000, -0.5000],
          [ 0.5000, -0.5000],
          [ 1.5000, -0.5000]],

         [[-1.5000,  0.5000],
          [-0.5000,  0.5000],
          [ 0.5000,  0.5000],
          [ 1.5000,  0.5000]],

         [[-1.5000,  1.5000],
          [-0.5000,  1.5000],
          [ 0.5000,  1.5000],
          [ 1.5000,  1.5000]]]])
---------------------------------------
tensor([[[[ 0.0000,  0.1875,  0.5625,  0.1875],
          [ 0.7500,  3.7500,  5.2500,  1.5000],
          [ 2.2500,  9.7500, 11.2500,  3.0000],
          [ 0.7500,  3.1875,  3.5625,  0.9375]]]])

[-1.5, -0.5]는 0.75가 나옴. 

 

update.py의 class SepConvGRU를 보면 (5,5) 컨볼루션을 분해해 (1,5)와 (5,1)로 함.

이는 성능의 감소를 일으키지 않으면서 연산량이 감소된다. 그런데 다들 (5,5)를 대신 쓰는 이유는 아래와 같다.

(5,5) 커널 또는 (1,5) 및 (5,1) 커널의 조합을 사용하는 것 중에서 선택하는 것은 특정 작업, 입력 데이터의 특성 및 모델의 전체 아키텍처에 따라 달라집니다.

분리 가능한 컨볼루션(예: (1,5) 및 (5,1) 커널의 조합)을 사용하면 매개 변수와 계산 수를 줄일 수 있지만 모든 작업 또는 모델에 대해 항상 최선의 선택은 아닙니다. 다음은 (5,5) 커널을 대신 사용하는 몇 가지 이유입니다:

1. **다른 기능 추출**: (5,5) 커널과 (1,5) 및 (5,1) 커널의 조합은 정확히 동일한 기능을 추출하지 않습니다. (5,5) 커널은 (1,5) 및 (5,1) 커널의 조합이 수평 및 수직 패턴을 개별적으로 캡처하는 동안 단일 단계에서 2D 공간 패턴을 캡처할 수 있습니다. 작업 및 데이터에 따라 한 가지 접근 방식이 다른 방식보다 더 효과적일 수 있습니다.

2. **모델 단순성**: (5,5) 커널을 사용하면 모델을 더 간단하고 쉽게 이해할 수 있습니다. 분리 가능한 컨볼루션을 사용하면 모델을 더 효율적으로 만들 수 있지만 모델을 더 복잡하고 해석하기 어렵게 만들 수도 있습니다.

3. **구현 세부 사항**: 일부 딥 러닝 프레임워크 및 하드웨어 가속기는 표준 컨볼루션에 최적화되어 있으며 분리 가능한 컨볼루션을 효율적으로 지원하지 않을 수 있습니다.

4. **경험적 결과**: 궁극적으로 모델 아키텍처의 선택은 경험적 결과를 기반으로 하는 경우가 많습니다. 연구자가 (5,5) 커널이 있는 모델이 작업에서 더 나은 성능을 발휘한다는 사실을 알게 되면, 효율성이 떨어지더라도 해당 모델을 사용하기로 선택할 수 있습니다.

결국, (5,5) 커널과 (1,5) 및 (5,1) 커널의 조합 사이의 선택은 효율성과 작업 및 모델의 특정 요구 사항 간의 균형입니다. 그것은 시행착오와 당면한 과제의 구체적인 세부 사항에 기초한 경험적인 결정입니다.

즉, optical flow는 x와 y의 방향을 캡처하는 것이 목표이므로 이게 나을지도 모르지만, object detection은 정사각형의 커널이 더 나을지도 모름.

 

 


torch.meshgrid를 통해 만들어진 centroid_lvl의 모양새(corr.py 참조).

이렇게 (x, y) 좌표를 갖고 있음. 즉, [14260, 1, 1, 2]의 형태는 각각 14260(46 * 62 * batch) 개의 픽셀이 [1, 1, 2] 형태의 좌표를 갖고 있는 것.

 

그리고 이 [1, 1] 부분을 [9, 9]로 확장시킬 것임. 즉, [0, 0]의 좌표가 다음 프레임에 존재할 수 있는 위치좌표를 넣는 것.

논문에서는 하나의 픽셀은 한 프레임 당 최대 4픽셀을 초과해 움직일 수 없다고 봤다. 즉, 이전 프레임의 (0,0) 픽셀이 다음 프레임에서 있을 수 있는 위치는 (-4, -4) 부터 (4, 4)까지 총 9x9개다.

 

따라서 [14260,9, 9, 2]는 하나의 픽셀이 존재할 가능성이 있는 91개의 좌표를 들고 있는 형태라 할 수 있다.