기존에 학습해두었던 모델을 다른 학습 모델의 부분으로 초기화하려고 한다. 보통 파이토치에서 모델을 구현할 때, 다음과 같은 구성을 가진다.
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.features = ~
def forward(self, x):
x = self.features(x)
return x
여기서 내가 직접 학습한 pretrained 모델을 사용하기 위해서 함수를 하나 추가해주려고 한다.
1) pretrained로 사용할 모델은 다음과 같이 이루어져 있다.
A - B - C
2) 학습할 모델은 다음과 같이 이루어져 있다.
A - B - D
나는 1번의 A, B만을 부분적으로 가져와서 모델을 초기화 하려고 한다. initmodel이라는 함수를 추가해준다.
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Conv2d(~)
self.initmodel()
def initmodel(self):
pretrained_dict = torch.load(path)['model_state_dict']
model_dict = self.state_dict()
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
self.load_state_dict(pretrained_dict, strict=False)
def forward(self, x):
x = self.features(x)
return x
코드는 위와 같다. 하나하나 살펴보자. 4단계로 이루어진다
1) pretrained_dict에 학습된 모델 weight를 torch.load를 통해서 불러온다.
2) model_dict에 현재 모델의 state_dict를 넣어주고
3) pretrained_dict와 model_dict는 다른 모델에서 파생되었기 때문에 key값이 다르다. 이걸 key값이 맞는 A, B 부분에만 적용해주기 위해서 pretrained_dict를 model_dict와 key값이 같을 때만을 고려해서 바꾸어준다. 여기서 k[7:]은 pretrain 모델에서 key값에 module.cnn2 이런식으로 module.(7글자) 가 추가가 되어있어서 이걸 빼줘서 model_dict와 상응하게 하였다.
4) 만들어진 pretrained_dict를 strict=False로 예외상황을 제거해주고 load한다.
+ load_state_dict 전과 후에 하나의 key를 지정해서 비교해주면 제대로 업데이트가 되었는지 확인할 수 있다.
ex) print(self.state_dict()['features_1.bias'])
'컴퓨터공학' 카테고리의 다른 글
트랜스포머로 시계열 데이터 예측하기 (0) | 2022.12.21 |
---|---|
리눅스에 CMake 설치하기 (0) | 2022.12.20 |
Perceiver IO: Optical Flow (0) | 2022.07.10 |
Perceiver IO: Image Classification (0) | 2022.07.10 |
Perceiver IO: Text Classification (0) | 2022.07.10 |