함께하는 데이터 분석

[Pytorch] 순환 신경망 모델 학습 본문

데이터분석 공부/ML | DL

[Pytorch] 순환 신경망 모델 학습

JEONGHEON 2022. 9. 20. 20:21

모델 구현

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from tqdm.notebook import tqdm

 

n_hidden = 35 # 순환 신경망의 노드 수
lr = 0.01
epochs = 1000

string = "hello pytorch. how long can a rnn cell remember? show me your limit!"
chars =  "abcdefghijklmnopqrstuvwxyz ?!.,:;01"
char_list = [i for i in chars]
n_letters = len(char_list)

예시에서 사용할 문장은 'hello pytorch. how long can a rnn cell remeber?'라는 문장이고

 

사용할 문자들은 알파벳 소문자와 특숨누자 몇 개로 한정

 

순환 신경망의 노드 수는 n_hidden이라는 변수에 지정

 

def string_to_onehot(string):
    
    start = np.zeros(shape=len(char_list), dtype=int)
    end = np.zeros(shape=len(char_list), dtype=int)
    start[-2] = 1
    end[-1] = 1
    
    for i in string:
        idx = char_list.index(i)
        zero = np.zeros(shape=n_letters, dtype=int)
        zero[idx]=1
        
        start = np.vstack([start,zero])
    
    output = np.vstack([start,end])
    
    return output

위는 문장이 들어왔을 때 이것을 연산 가능한 One-Hot 벡터로 바꾸는 함수를 만든 것

 

어떤 문장이 들어왔을 때 맨 앞에 시작 토큰과 맨 뒤 끝 토큰을 붙이고 One-Hot 벡터로 변환하여 전달하는 함수

 

def onehot_to_word(onehot_1):
    onehot = torch.Tensor.numpy(onehot_1)
    return char_list[onehot.argmax()]

이는 One-Hot 벡터를 다시 문자로 바꾸는 부분도 함수로 만들어 놓은 것

 

토치 텐서를 입력으로 받아서 이를 넘파이 배열로 변환하고 거기서 1인 지점을 인덱스로 잡아 char_list에서 뽑아내는 함수

 

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        self.i2h = nn.Linear(input_size, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.i2o = nn.Linear(hidden_size, output_size)
        self.act_fn = nn.Tanh()
        
    def forward(self, input, hidden):
        hidden = self.act_fn(self.i2h(input) + self.h2h(hidden))
        output = self.i2o(hidden)
        return output, hidden
    
    def init_hidden(self):
        return torch.zeros(1, self.hidden_size)
    
rnn = RNN(n_letters, n_hidden, n_letters)

이는 순환 신경망 클래스이고 해당 클래스는 One-Hot 벡터로 변환한 단어 하나를 입력값으로 받고 은닉층 하나를 통과시켜 결괏값을 내는 구조를 가지고 있음

 

입력값이 들어오면 이전 시간의 은닉층 값과의 조합으로 새로운 은닉층 값을 생성하고 은닉층에서 결괏값을 내는 부분의 연산을 한 번 더 통과해 결괏값이 나오게 됨

 

그리고 이전 시간의 은닉층 연산값이 없는 초기의 은닉층 값은 0으로 초기화해야 하기 때문에 init_hidden이라는 함수를 만들어 놓음

 

loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=lr)

이제 손실 함수와 최적화 함수를 정의

 

이번에는 MSE를 이용한 L2 손실함수를 사용할 것임

 

 

 

모델 학습

## training
one_hot = torch.from_numpy(string_to_onehot(string)).type_as(torch.FloatTensor())

for i in tqdm(range(epochs)):
    rnn.zero_grad()
    total_loss = 0
    hidden = rnn.init_hidden()

    for j in range(one_hot.size()[0]-1):
        input_ = one_hot[j:j+1,:]
        target = one_hot[j+1]
        
        output, hidden = rnn.forward(input_, hidden)
        loss = loss_func(output.view(-1),target.view(-1))
        total_loss += loss
        input_ = output

    total_loss.backward()
    optimizer.step()
    
    if i % 100 == 0:
        print(f'Epoch {i}\'s loss: {total_loss.item()}')

우선 우리가 학습하고자 했던 문장을 One-Hot 벡터로 변환한 넘파이 배열을 다시 토치 텐서 형태로 바꿔주고 이때 자료형은 연산에 기본적으로 사용되는 torch.FloatTensor로 지정

 

이렇게 하면 앞서 만든 함수대로 start_token + 문장 + end_token 이렇게 구성된 매트릭스가 생성되고, 학습할 때 시작 토큰이 들어오면 결괏값으로 p가 나오고, p가 들어오면 y, y가 들어오면 t가 나오면 됨

 

현재 One-Hot 벡터는 문장에 있는 단어 순서대로 배열되어 있기 때문에 j번째 인덱스에 해당하는 값이 입력으로 들어오면 j+1번째 인덱스에 해당하는 값이 target이 되면 됨

 

문장 전체를 학습하는 과정은 epochs에 지정한 1000만큼 반복하고 이 때 내부적으로 입력값과 목푯값의 차이를 계산하여 문장 전체에 대한 손실을 계산해야 함

 

그런데 문장에 대해 학습할 때 매번 손실을 초기화해야 하기 때문에 total_loss 변수는 0으로 초기화 했고, 또한 학습을 시작하려면 순환 신경망 은닉층의 초깃값을 지정해야 하기 때문에 rnn.init_hidden() 함수를 통해 0으로 초기화함

 

 

 

결과 확인

start = torch.zeros(1,len(char_list))
start[:,-2] = 1

with torch.no_grad():
    hidden = rnn.init_hidden()
    input_ = start
    output_string = ""
    for i in range(len(string)):
        output, hidden = rnn.forward(input_, hidden)
        output_string += onehot_to_word(output.data)
        input_ = output
        
print(output_string)

>>> hello pytoroh. gtce yyooonneeellronmellllobemlllbbbllblbblbblbblbblb

단순한 버전인 만큼 썩 만족스러운 결과는 아닌 결과가 만들어 짐


https://www.hanbit.co.kr/store/books/look.php?p_code=B7818450418 

 

파이토치 첫걸음

딥러닝 구현 복잡도가 증가함에 따라 ‘파이써닉’하고 사용이 편리한 파이토치가 주목받고 있다. 파이토치 코리아 운영진인 저자는 다년간 딥러닝을 공부하고 강의한 경험을 살려 딥러닝의

www.hanbit.co.kr