본문 바로가기

컴퓨터공학

[Pytorch] RNN 예제

반응형

해당 포스트는 아래 글의 번역 & 요약글입니다. 공부하다가 이해에 도움이 많이 되서 번역해보았습니다. https://www.kaggle.com/code/andradaolteanu/pytorch-rnns-and-lstms-explained-acc-0-99/notebook

 

🔥PyTorch RNNs and LSTMs Explained (Acc 0.99)

Explore and run machine learning code with Kaggle Notebooks | Using data from Digit Recognizer

www.kaggle.com

 

 

RNN with 1 Layer

RNN은 텍스트, 비디오, 음성과 같은 순차적 데이터를 다루는 것에 유용합니다. 젤 간단한 1 layer rnn을 한번 살펴보도록 할게요. 

RNN은 간단하게 다음과 같이 작동합니다.

 

1. 이전의 정보를 반영하여 미래의 값들을 예측합니다. 

2. 중요한 3개 요소를 기억해두세요! Input, Output, Hidden 

3. 순차적으로 input forward를 거치고, 해당 값들을 저장해둡니다.

4. hidden state에 해당 정보가 저장된다고 생각하시면 쉽습니다.

5. 위의 그림에서처럼 (U, V, W)의 파라미터가 있습니다. 

 

RNN with 1 Layer and 1 Neuron

RNN의 뉴런의 개수를 맘대로 설정할 수 있는데요, 가장 간단한 예로 1개의 뉴런을 가지는 네트워크를 봅시다. 2개의 timestep 0과 1이 있습니다. 

원본 링크에 하나하나 구현한 것이 있으니 참고하시면 좋을 것 같습니다. 저는 파이토치에서 제공하는 nn.RNN()을 사용할 거여서 이건 생략하겠습니다. 

 

Vanilla RNN for MNIST Classification

이번에는 Image classification 구현을 해보도록 하겠다. 왜 image classification을 RNN으로 다룰려고 하냐? 에 대해서는 논리적인 이유를 찾기 힘듭니다. CNN 처럼 이미지에 필터를 씌워서 feature을 추출하는것도 아니구요. 어려우니까 그냥 컴퓨터가 해당 input number들을 학습하고 패턴을 이해하는 수학적 방법이라고 이해하고 넘어갑시다. 

 

# The Neural Network
class VanillaRNN_MNIST(nn.Module):
    def __init__(self, batch_size, input_size, hidden_size, output_size):
        super(VanillaRNN_MNIST, self).__init__()
        self.batch_size, self.input_size, self.hidden_size, self.output_size = batch_size, input_size, hidden_size, output_size
        
        # RNN Layer
        self.rnn = nn.RNN(input_size, hidden_size)
        # Fully Connected Layer
        self.layer = nn.Linear(hidden_size, self.output_size)
    
    def forward(self, images, prints=False):        
        images = images.permute(1, 0, 2)
        
        # Initialize hidden state with zeros
        hidden_state = torch.zeros(1, self.batch_size, self.hidden_size)
        
        hidden_outputs, hidden_state = self.rnn(images, hidden_state)
        out = self.layer(hidden_state)
        out = out.view(-1, self.output_size)
        
        return out

 

반응형