미연시리뷰

getattr in HNCT

두원공대88학번뚜뚜 2024. 3. 2. 19:40
# This is a simple wrapper function for ConcatDataset
class MyConcatDataset(ConcatDataset):
    def __init__(self, datasets):
        super(MyConcatDataset, self).__init__(datasets)
        self.train = datasets[0].train

    def set_scale(self, idx_scale):
        for d in self.datasets:
            if hasattr(d, 'set_scale'): d.set_scale(idx_scale)

class Data:
    def __init__(self, args):
        self.loader_train = None

        if not args.test_only:
            datasets = []
            data_train = args.data_train
            for d in args.data_train:
                module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
                m = import_module('data.' + module_name.lower())
                datasets.append(getattr(m, module_name)(args, name=d))

            self.loader_train = dataloader.DataLoader(
                MyConcatDataset(datasets),
                batch_size=args.batch_size,
                shuffle=True,
                pin_memory=not args.cpu,
                num_workers=args.n_threads,
            )

        self.loader_test = []
        for d in args.data_test:
            if d in ['Set5', 'Set14', 'B100', 'Urban100','Manga109']:
                m = import_module('data.benchmark')
                testset = getattr(m, 'Benchmark')(args, train=False, name=d)
            else:
                module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
                m = import_module('data.' + module_name.lower())
                testset = getattr(m, module_name)(args, train=False, name=d)

            self.loader_test.append(
                dataloader.DataLoader(
                    testset,
                    batch_size=1,
                    shuffle=False,
                    pin_memory=not args.cpu,
                    num_workers=args.n_threads,
                )
            )

 

데이터셋 디렉토리를 찾기 위해 보다 보니 dataloader을 짜는 함수는 src/data/__init__.py에 있었다.

그런데 args.dir_data를 처리하는 코드는 없음. 어째서일까? 해서 보니 getattr에서 처리한다고 한다.

class Benchmark(srdata.SRData):
    def __init__(self, args, name='', train=True, benchmark=True):
        super(Benchmark, self).__init__(
            args, name=name, train=train, benchmark=True
        )

    def _set_filesystem(self, dir_data):
        self.apath = os.path.join(dir_data, 'benchmark', self.name)
        self.dir_hr = os.path.join(self.apath, 'HR')
        if self.input_large:
            self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
        else:
            self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
        self.ext = ('', '.png')

getattr은 m과 module_name을 사용하며, 이후로 parameter로 args를 받는데....

여기서 m = 패키지 이름, module_name = 패키지 내의 내장클래스임.

즉, 패키지 m = data.benchmark 내에 있는 BenchMark 클래스를 가져오라는 뜻.

 

[Python] 내장함수 getattr()를 활용해 코드 간소화 시키기 : 네이버 블로그 (naver.com)

 

[Python] 내장함수 getattr()를 활용해 코드 간소화 시키기

getattr()의 기능 getattr(object, 'name') 이라는 함수는 object라는 오브젝트 내부의 name이라...

blog.naver.com

 

위의 블로그의 예시:

# 직접 짠 my_models.py를 임포트

import my_models as M

# my_model에 구현된 모델을 주어진 이름에 맞춰 반환
def build_neural_network(model_name):
  if model_name == 'googlenet':
    model = M.googlenet(args)
  elif model_name == 'vgg':
    model = M.vgg(args)
  elif model_name == 'resnet':
    model = M.resnet(args)
  ..
  ..
  return model

########################################################################

import my_models as M
def build_neural_network(model_name):
  return getattr(M, model_name)(args)

 

위아래의 두 코드는 같음.