버츄얼유튜버

CLTR 분석

두원공대88학번뚜뚜 2023. 2. 22. 13:59

misc.py에 이런 게 있음.

import torch
from typing import Optional, List
from torch import Tensor


def _max_by_axis(the_list):
    # type: (List[List[int]]) -> List[int]
    maxes = the_list[0]
    for sublist in the_list[1:]:
        for index, item in enumerate(sublist):
            maxes[index] = max(maxes[index], item)

    return maxes

class NestedTensor(object):
    def __init__(self, tensors, mask: Optional[Tensor]):
        self.tensors = tensors
        self.mask = mask

    def to(self, device):
        # type: (Device) -> NestedTensor # noqa
        cast_tensor = self.tensors.to(device)
        mask = self.mask
        if mask is not None:
            assert mask is not None
            cast_mask = mask.to(device)
        else:
            cast_mask = None
        return NestedTensor(cast_tensor, cast_mask)

    def decompose(self):
        return self.tensors, self.mask

    def __repr__(self):
        return str(self.tensors)

# TODO make it support different-sized images
# [1, 3, 4, 4]
# tensor_list = torch.arange(0, 192, 1)
tensor_list = torch.randn(4, 3, 4, 4)
# tensor_list = tensor_list.view(4, 3, 4, 4)
print(tensor_list)

max_size = _max_by_axis([list(img.shape) for img in tensor_list])   # 각 [3, 256, 256]를 list화 시켜, 각 채널 중 가장 큰 값을 담은 [256, 256] 반환
print(max_size)
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size # batch_shape = [width x height, 3, max(여기선 256), max(여기선 256)]
print(batch_shape)
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype    # torch.float32
device = tensor_list[0].device  # [3, 256, 256].device()
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)   # tensor = torch.zeros([width x height, 3, 512, 512]) and is for pad_img
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)   # mask = [1, 256, 256]--> all is True and is mask

for img, pad_img, m in zip(tensor_list, tensor, mask):
    print(pad_img)
    print(m)
    print('----------------------------')
    pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)  # pad_img = [3, 256, 256]
    print(pad_img)
    m[: img.shape[1], :img.shape[2]] = False    # m = [256, 256] and all is False
    print(m)
    print(tensor)
    print('---------------- end --------------')

output = NestedTensor(tensor, mask)
print('--------output--------------')
print(output)

a, b = output.decompose()
print(a)
print(b)

일단은 달라지는 건 없음.

그냥 NestedTensor라는, 인풋 tensor(이미지)를 작성자 지정 타입으로 만들어주는 것인듯.

작성자 지정 타입이기에, len()같은 일반적인 것은 활용 불가.

decompose 처럼, mask와 tensor을 분리하는 함수만 활용 가능

 

굳이 이렇게 바꿔주는 이유가 분명 있을 테니, 향후 더 코드를 분석해가며 알아갈 예정.