버츄얼유튜버
CLTR 분석
두원공대88학번뚜뚜
2023. 2. 22. 13:59
misc.py에 이런 게 있음.
import torch
from typing import Optional, List
from torch import Tensor
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
# TODO make it support different-sized images
# [1, 3, 4, 4]
# tensor_list = torch.arange(0, 192, 1)
tensor_list = torch.randn(4, 3, 4, 4)
# tensor_list = tensor_list.view(4, 3, 4, 4)
print(tensor_list)
max_size = _max_by_axis([list(img.shape) for img in tensor_list]) # 각 [3, 256, 256]를 list화 시켜, 각 채널 중 가장 큰 값을 담은 [256, 256] 반환
print(max_size)
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size # batch_shape = [width x height, 3, max(여기선 256), max(여기선 256)]
print(batch_shape)
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype # torch.float32
device = tensor_list[0].device # [3, 256, 256].device()
tensor = torch.zeros(batch_shape, dtype=dtype, device=device) # tensor = torch.zeros([width x height, 3, 512, 512]) and is for pad_img
mask = torch.ones((b, h, w), dtype=torch.bool, device=device) # mask = [1, 256, 256]--> all is True and is mask
for img, pad_img, m in zip(tensor_list, tensor, mask):
print(pad_img)
print(m)
print('----------------------------')
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) # pad_img = [3, 256, 256]
print(pad_img)
m[: img.shape[1], :img.shape[2]] = False # m = [256, 256] and all is False
print(m)
print(tensor)
print('---------------- end --------------')
output = NestedTensor(tensor, mask)
print('--------output--------------')
print(output)
a, b = output.decompose()
print(a)
print(b)
일단은 달라지는 건 없음.
그냥 NestedTensor라는, 인풋 tensor(이미지)를 작성자 지정 타입으로 만들어주는 것인듯.
작성자 지정 타입이기에, len()같은 일반적인 것은 활용 불가.
decompose 처럼, mask와 tensor을 분리하는 함수만 활용 가능
굳이 이렇게 바꿔주는 이유가 분명 있을 테니, 향후 더 코드를 분석해가며 알아갈 예정.