标题:GANS for Sequences of Discrete Elements with the Gumbel-softmax Distribution
作者:Matt J. Kusner, José Miguel Hernández-Lobato
发表:2016
GAN在生成文本等离散元素序列时有着局限性,因为从如多项式这样的离散分布中取样是无法对参数求导的。对此,论文提出了使用Gumbel-softmax分布来对多项式分布作连续的拟合。
Gumbel-softmax分布
用one-hot编码的离散数据,可以从以softmax函数输出作为概率的多项式分布中取样得来。对 \(d\) 维one-hot编码向量 \(\mathbf{y}\),可以用连续的 \(d\) 维向量 \(\mathbf{h}\),通过softmax函数得到表示多项式分布 \(\mathbf{y}\) 概率的向量 \(\mathbf{p}\),其中 \(p_i = p(y_i = 1),i = 1,...,d.\),即: \[\mathbf{p} = softmax(\mathbf{h}) \tag{1}\]
这里softmax(·)使用softmax函数返回一个 \(d\) 维向量: \[ [\operatorname{softmax}(\mathbf{h})]_{i}=\frac{\exp \left(\mathbf{h}_{i}\right)}{\sum_{j=1}^{K} \exp \left(\mathbf{h}_{j}\right)}, \quad \text { for } \quad i=1, \ldots, d \tag{2} \] 可以证明从 \((1)\) 中概率表示的多项式中采样 \(\mathbf{y}\) 等价于从以下公式中取样 \(\mathbf{y}\): \[ \mathbf{y}=\text { one_hot }(\underset{i}{\arg \max }(h_{i}+g_{i})) \tag{3}, \]
其中 \(g_i\) 独立服从 \(\mu = 0, \beta = 1\) 的Gumbel分布。
从 \((3)\) 中取得的样本对 \(\mathbf{h}\) 的梯度为0,因为 \(\text { one_hot }(\underset{i}{\arg \max(·)) }\) 不可微。因此论文提出通过soft-max变换,用一个可微的函数来替换,具体的,使用: \[ \mathbf{y}=softmax(1/\tau(\mathbf{h}+\mathbf{g})) \tag{4}\]
其中 \(\tau\) 是一个inverse temperature参数,当 \(\tau \to 0\) 时,由 \((4)\) 产生的样本和 \((3)\) 有着相同的分布;当 \(\tau \to \infty\) 时,样本是从均匀分布中得来;当 \(\tau\) 取正值时,由 \((4)\) 产生的样本是与 \(\mathbf{h}\) 有关的光滑可导的分布。
\((4)\) 表示的概率分布就叫做Gumbel-softmax分布,包含参数 \(\tau\) 和 \(\mathbf{h}\)。用于生成离散数据的GAN模型可以使用该分布进行训练,初始设定较大的 \(\tau\) 值,在训练过程中逐步减小到0。
LSTM生成模型
使用Gumbel-softmax分布
生成器 \(G\) 和判别器 \(D\) 都是LSTM网络,分别有参数 \(\Theta\) 和 \(\Phi\)。训练目标是通过取样输入 \(\mathbf{x}\) 和生成的数据 \(\mathbf{z}\) 来学习 \(G\) 和 \(D\),最小化 \(G\) 和 \(D\) 的损失函数来更新 \(\Theta\) 和 \(\Phi\)。但是从softmax分布中取样的数据 \(\mathbf{z}\) 是不可导的,也就无法更新参数。在这里使用上文介绍的Gumbel-softmax分布进行替换,就可以使用梯度方法更新参数。