본문 바로가기

컴퓨터공학

파이토치 부분 pretrained 모델 쉽게 만들기

반응형

기존에 학습해두었던 모델을 다른 학습 모델의 부분으로 초기화하려고 한다. 보통 파이토치에서 모델을 구현할 때, 다음과 같은 구성을 가진다. 

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = ~ 

    def forward(self, x):
        x = self.features(x)
        return x

 

여기서 내가 직접 학습한 pretrained 모델을 사용하기 위해서 함수를 하나 추가해주려고 한다. 

1) pretrained로 사용할 모델은 다음과 같이 이루어져 있다. 

A - B - C

2) 학습할 모델은 다음과 같이 이루어져 있다.

A - B - D

 

 

나는 1번의 A, B만을 부분적으로 가져와서 모델을 초기화 하려고 한다. initmodel이라는 함수를 추가해준다.

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Conv2d(~)
	self.initmodel()
        
    def initmodel(self):
        pretrained_dict  = torch.load(path)['model_state_dict']
        model_dict = self.state_dict() 
        pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
        self.load_state_dict(pretrained_dict, strict=False)

    def forward(self, x):
        x = self.features(x)
        return x

코드는 위와 같다. 하나하나 살펴보자. 4단계로 이루어진다

 

1) pretrained_dict에 학습된 모델 weight를 torch.load를 통해서 불러온다.

2) model_dict에 현재 모델의 state_dict를 넣어주고

3) pretrained_dict와 model_dict는 다른 모델에서 파생되었기 때문에 key값이 다르다. 이걸 key값이 맞는 A, B 부분에만 적용해주기 위해서 pretrained_dict를 model_dict와 key값이 같을 때만을 고려해서 바꾸어준다. 여기서 k[7:]은 pretrain 모델에서 key값에 module.cnn2 이런식으로 module.(7글자) 가 추가가 되어있어서 이걸 빼줘서 model_dict와 상응하게 하였다.

4) 만들어진 pretrained_dict를 strict=False로 예외상황을 제거해주고 load한다.

 

+ load_state_dict 전과 후에 하나의 key를 지정해서 비교해주면 제대로 업데이트가 되었는지 확인할 수 있다.

ex) print(self.state_dict()['features_1.bias'])

반응형