Encoder-decoderモデルとTeacher Forcing,それを拡張したScheduled Sampling,Professor Forcingについて簡単に書きました.

概要

Encoder-decoderモデルは,ソース系列をEncoderと呼ばれるLSTMを用いて固定長のベクトルに変換(Encode)し,Decoderと呼ばれる別のLSTMを用いてターゲット系列に近くなるように系列を生成するモデルです.もちろん,LSTMでなくてGRUでもいいです.機械翻訳のほか,文書要約や対話生成にも使われます.

Encoder-decoderモデルの概略図

Encoder

ソース系列$X = (x_1, x_2, \dots, x_S)$が与えられたとき,ソースの辞書サイズ$V_S$に基づく埋め込み重み$U_S \in \mathbb{R}^{E_S \times V_S}$により,埋め込み表現へと変換され,各タイムスタンプ毎に順にセル($\text{cell} \in \mathbb{R}^{h}$)と隠れ状態($\text{hidden} \in \mathbb{R}^{h}$)が計算される.

$$ \text{cell}_{t+1}, \text{hidden}_{t+1} = \text{LSTM}(U_Sx_t, \text{cell}_{t+1}, \text{hidden}_{t+1}) $$

ここで,$x_t$は$V_S$次元のone-hotベクトル,$h$は隠れ状態の次元,$E_S$はソース側の埋め込み次元を表す.なお,セルと隠れ状態は一般に0で初期化されることが多い. LSTMについて前回の記事を参照ください.

この段階でソース系列は固定長ベクトルであるセルと隠れ状態に変換(Encode)されました.

Decoder

Encoder側で出力されたセルと隠れ状態からDecoderを用いて系列を生成していきます. 初期状態は”BOS”記号や”null”記号を開始記号として下記のような数式で一単語目$w_1$を生成します.

$$ \text{cell}_{1}, \text{hidden}_{1} = \text{LSTM}(U_T w_{\text{null}}, \text{cell}_{0}, \text{hidden}_{0})\\ w_1 = \text{argmax}({W_{T} \text{hidden}_{1}}) $$

ここで,$U_T$はターゲットの辞書サイズ$V_T$に基づく埋め込み重み,$W_T$は$W_T \in \mathbb{R}^{V_T \times h}$である.

この調子で

$$ \text{cell}_{2}, \text{hidden}_{2} = \text{LSTM}(U_T w_1, \text{cell}_{1}, \text{hidden}_{1})\\ w_2 = \text{argmax}({W_{T} \text{hidden}_{2}}) $$

と,$(w_1, w_2, \dots)$を計算して,ターゲット系列$Y = (y_1, y_2, \dots, y_T)$との交差エントロピーを使って学習すると思われます.

前の出力をそのまま次の入力として使って学習を行う例

しかし,この方法だと連鎖的に誤差が大きくなっていき,学習が不安定になったり,収束が遅かったりしてしまうという問題があります.

Teacher Forcing

この問題に対して,Teacher Forcingという方法を取ることが多いですが,この方法にも評価時に問題が生じます(後述). Teacher Forcingとは,訓練時には下図のようにDecoder側の入力にはターゲット系列$Y$をそのまま使うというものです.

Encoder-decoderモデルにおけるTeacher Forcingの概略図

こうすることによって学習が安定し,収束が早くなるというメリットがありますが, 逆に,評価時はDecoderの入力が自動生成されたものが使われるため,学習時と分布が異なってしまうというデメリットもあります.

Teacher Forcingの拡張

Scheduled Sampling

Teacher Forcingの拡張として,Scheduled Sampling [Samy Bengio+ 2015]がある.

Scheduled Sampling

Schedule Samplingは上図のように,ターゲット$y_t$を入力とするか,生成された$w_t$を入力とするか確率的にサンプルするというもの.

Professor Forcing

Professor Forcing[Lamb+ 2016]とは,Free Running(Decoder側で自動生成されたものを入力とする)とTeacher Forcingで出力されるセルと隠れ状態の差を小さくするようにGANを用いたと言うもの.

  • Generator:Free RunningとTeacher Forcingの出力の区別がつかないようにする.
  • Discreminator:Free RunningとTeacher Forcingを正しく分類できるようにする.

Professor Forcing