Search

GRU

대분류
인공지능/데이터
소분류
ML/DL 정리 노트
유형
딥 러닝
부유형
NLP Pytorch
최종 편집 일시
2024/10/27 15:22
생성 일시
2024/10/07 06:05
14 more properties

GRU(Gated Recurrent Unit)

LSTM에서 Cell이 빠진 모델

LSTM

GRU

순차 데이터를 처리하는 데 사용되는 모델
Reset Gate, Update Gate라는 두 개의 게이트를 사용하여 작동
LSTM(Long Short-Term Memory) 네트워크를 개선하기 위한 모델
Cell이 빠짐으로서 hidden layer에서 Cell역할을 같이 하고 있다.

GRU 구조

상태(State)

새로운 기억 상태 (gtg_t)
새로운 기억 상태는 리셋 게이트rtr_t를 이용해 이전 은닉 상태의 일부를 초기화하고, 이를 현재 입력과 결합해 새로운 정보를 만든다.
tanh\tanh함수를 사용해 비선형성을 추가한다.
gt=tanh(Whg(rtht1)+Wxgxt)g_t = \tanh(W_{hg} (r_t \odot h_{t-1}) + W_{xg} x_t)
여기서 \odot는 원소별 곱셈(Element-wise multiplication)을 의미
최종 은닉 상태 (hth_t)
최종 은닉 상태 hth_t는 업데이트 게이트 ztz_t를 이용해 이전 상태 ht1h_{t-1}와 새로운 기억 상태 gtg_t를 조합하여 결정된다.
업데이트 게이트는 새 정보와 이전 정보를 어떤 비율로 유지할지를 조정하는 역할을 한다.
ht=(1zt)gt+ztht1h_t = (1 - z_t) \odot g_t + z_t \odot h_{t-1}

게이트(Gate)

리셋 게이트 (r_t)
리셋 게이트는 이전의 은닉 상태 ht1h_{t-1}와 현재 입력 xtx_t을 얼마나 활용할지를 결정한다.
이 리셋 게이트는 주어진 시점의 입력과 이전 시점의 은닉 상태에 대해 시그모이드 함수(σ\sigma)를 적용하여 계산된다.
rt=σ(Wxrxt+Whrht1)r_t = \sigma(W_{xr} x_t + W_{hr} h_{t-1})
업데이트 게이트 (ztz_t)
업데이트 게이트는 새로 계산된 정보와 이전의 은닉 상태를 어떻게 조합할지를 결정한다.
이 게이트 역시 입력 xtx_t와 이전 은닉 상태 ht1h_{t-1}를 이용해서 계산되며, 시그모이드 함수를 사용해 ztz_t를 얻는다.
zt=σ(Wxzxt+Whzht1)z_t=σ(W_{xz}x_t+W_{hz}h_{t−1})
리셋 게이트는 이전 은닉 상태의 일부를 리셋해 특정 정보만 남기도록 하며,
업데이트 게이트는 이전 은닉 상태와 새로 계산된 상태 간의 균형을 맞춘다.

GRU 동작 과정

1. 업데이트 게이트 계산

업데이트 게이트는 현재 입력과 이전 숨겨진 상태를 기반으로 계산된다.
이 게이트는 이전 상태를 얼마나 유지할지를 결정한다.
zt=σ(Wzxt+Uzht1)z_t = \sigma(W_z \cdot x_t + U_z \cdot h_{t-1})
여기서 σ\sigma는 시그모이드 함수, WzW_zUzU_z는 학습 가능한 가중치 행렬, xtx_t는 현재 입력, ht1h_{t-1}는 이전 숨겨진 상태이다.

2. 리셋 게이트 계산

리셋 게이트는 현재 입력과 이전 숨겨진 상태를 기반으로 계산된다.
이 게이트는 이전 상태를 얼마나 잊을지를 결정한다.
rt=σ(Wrxt+Urht1)r_t = \sigma(W_r \cdot x_t + U_r \cdot h_{t-1})
여기서 WrW_rUrU_r는 학습 가능한 가중치 행렬이다.

3. 후보 숨겨진 상태 계산

리셋 게이트를 사용하여 이전 숨겨진 상태를 조정한 후, 후보 숨겨진 상태를 계산한다.
h~t=tanh(Wxt+U(rtht1))\tilde{h}_t = \tanh(W \cdot x_t + U \cdot (r_t \odot h_{t-1}))
여기서 \odot는 요소별 곱(element-wise multiplication), WWUU는 학습 가능한 가중치 행렬이다.

4. 최종 숨겨진 상태 계산

업데이트 게이트와 후보 숨겨진 상태를 결합하여 최종 숨겨진 상태를 업데이트한다.
ht=(1zt)ht1+zth~th_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t
이 식을 통해 GRU는 이전 상태를 유지할지, 새로운 정보를 반영할지를 결정하게 된다.