前言
目标:
背景与动机:
结论:
备注:
关键词:
展望方向:
CHANGELOG
梯度对齐模型考虑的是用L2损失拟合对数概率在相空间上的梯度,形式上类似一个fisher divergence. 在文献中经常出现用DSM把ESM替换掉的情况,容易造成
读者断片.
\[\begin{split}\begin{align}
L_{ESM}(m) &= -\int p(x) ||s_m(x) - \nabla_x \log p(x)||^2.dx \\
&= - E_{p(x)}[||s_m(x) - \nabla_x \log p(x)||^2] \\
L_{DSM}(m) &= -E_{p(x,y)}[||s_m(x) - \nabla_x \log p(x|y)||^2] \end{align}
\end{split}\]
至少从Denoising Auto Encoder开始,发现这个explicit score matching(ESM)
目标可以改写成一个Denoisng score matching (DSM)形式. 让我们考虑任意可微的条件概率分解 \(p(x,y) = p(x|y)p(y)\) ,可以对ESM进行重写
\[\begin{split}\begin{align}
L_{ESM}(m) &= - E_{p(x)}[||s_m(x) - \nabla_x \log p(x)||^2] \\
&= - E_{p(x)}[||s_m(x) - \nabla_x \log \sum_y p(x,y)||^2] \\
&= - E_{p(x)}[||s_m(x) - \sum_{y}{p(x,y)\over \sum_y p(x,y)} \nabla_x \log p(x,y)||^2] \\
&= - E_{p(x)}[||s_m(x) - \sum_{y}{p(y|x)} \nabla_x (\log p(x|y)+\log p(y))||^2] \\
&= - E_{p(x)}[||s_m(x) - \sum_{y}{p(y|x)} \nabla_x (\log p(x|y)+0)||^2] \\
&= - E_{p(x)}[||s_m(x)||^2] - E_{p(x)}[||\nabla_x \log p(x)||^2]
\\&+ 2E_{p(x)}[s_m(x)^T \sum_{y}{p(y|x)} \nabla_x \log p(x|y)] \\
&= - E_{p(x,y)}[||s_m(x)||^2] - E_{p(x)}[||\nabla_x \log p(x)||^2]
\\&+ 2E_{p(x)} [\sum_{y}{p(y|x)}\cdot s_m(x)^T \nabla_x \log p(x|y)] \\
L_{ESM}(m)&= - E_{p(x,y)}[||s_m(x)||^2] - E_{p(x)}[||\nabla_x \log p(x)||^2]
\\&+ 2E_{p(x)p(y|x)} [ s_m(x)^T \nabla_x \log p(x|y)]
\end{align}\end{split}\]
对比一下DSM目标,做差后只有第二项不为零,且不依赖于参数 \(m\), 所以两者给m提供的梯度信息是一致的. 这个差值可以
可以变化成不同的 y 上 \(\log p(x|y)\) 梯度的方差,也就是说,如果不同的y都能
导出同一个x,但是条件概率的梯度却很不同时,这两个优化目标的不取决于 \(s_m(x)\) 的差值就越大.
\[\begin{split}\begin{align}
L_{DSM}(m) &= -E_{p(x,y)}[||s_m(x) - \nabla_x \log p(x|y)||^2] \\
&=-E_{p(x,y)}[||s_m(x)||^2] - E_{p(x,y)}[||\nabla_x \log p(x|y)||^2]\\&+ 2 E_{p(x,y)}[s_m(x)^T \nabla_x \log p(x|y)]
\\
L_{ESM}(m) - L_{DSM}(m) &= E_{p(x,y)}[||\nabla_x \log p(x|y)||^2] - E_{p(x)}[||\nabla_x \log p(x)||^2] \\
&= E_{p(x,y)}[||\nabla_x \log p(x|y)||^2] - E_{p(x,y)}[||\nabla_x \log p(x)||^2] \\
&= E_{p(x,y)}[||\nabla_x \log p(x|y)||^2 - ||\nabla_x \log \sum_y p(x|y) p(y)||^2] \\
&= E_{p(x,y)}[||\nabla_x \log p(x|y)||^2 - ||\sum_y p(y|x)\nabla_x \log p(x|y)||^2] \\
&= E_{p(x)}[ E_{p(y|x)} [||\nabla_x \log p(x|y)||^2 - ||E_{p(y|x)}[\nabla_x \log p(x|y)]||^2] \\
&= E_{p(x)}[ Var_{p(y|x)} [\nabla_x \log p(x|y)]] \\
&\geq 0 \\
L_{ESM}(m) &\geq L_{DSM}(m)
\end{align}\end{split}\]
从这个角度讲,尽管DSM和ESM提供的梯度均值是相同的,但是可能有一些方差上的差别
让我们观察一下梯度的估计值的方差
\[\begin{split}\begin{align}
{\partial \over \partial m} L_{ESM} &= -{\partial \over \partial m} E_{p(x)}[||s_m(x)-\nabla_x \log p(x)||^2]\\
&= E_{p(y)}[ E_{p(x|y)}[-{\partial \over \partial m}||s_m(x)-\nabla_x \log p(x)||^2]]\\
&= 2E_{p(y)}[ E_{p(x|y)}[(\nabla_x \log p(x) - s_m(x) )^T {\partial \over \partial m} s_m(x)]]\\
{\partial \over \partial m} L_{DSM} &= -{\partial \over \partial m} E_{p(x)}[||s_m(x)-\nabla_x \log p(x|y)||^2]\\
&= E_{p(y)}[ E_{p(x|y)}[-{\partial \over \partial m}||s_m(x)-\nabla_x \log p(x|y)||^2]]\\
&= 2E_{p(y)}[ E_{p(x|y)}[(\nabla_x \log p(x|y) - s_m(x) )^T {\partial \over \partial m} s_m(x)]] \\
D(m) &= {\partial \over \partial m} (L_{ESM} - L_{DSM})\\
&= 2E_{p(x)}[ E_{p(y|x)}[ (\nabla_x\log p(x) - \nabla_x \log p(x|y) )^T {\partial \over \partial m} s_m(x)] \\
&= 2E_{p(x)}[ E_{p(y|x)}[ (\nabla_x\log p(y) - \nabla_x \log p(y|x) )^T {\partial \over \partial m} s_m(x)] \\
&= 2E_{p(x)}[ E_{p(y|x)}[ - \nabla_x \log p(y|x)^T {\partial \over \partial m} s_m(x)]] \\
&= 2E_{p(x)}[ E_{p(y|x)}[ - \nabla_x \log p(y|x)]^T {\partial \over \partial m} s_m(x)]\\
&= 2E_{p(x)}[ - \int_y \nabla_x^T p(y|x) {\partial \over \partial m} s_m(x)]\\
&= 2E_{p(x)}[ - \nabla_x^T \int_y p(y|x) {\partial \over \partial m} s_m(x)]\\
&= 2E_{p(x)}[ - \vec {0}\ \cdot {\partial \over \partial m} s_m(x)]
\end{align}\end{split}\]
我们可以看到,梯度差的期望为0,是由条件概率归一化所保证的,考察梯度差的方差,得到了一个协方差的形式. 可以看出这个方差非负的,只要条件概率梯度和变分梯度的相关性越高,那么使用DSM引入的方差就越高.如果条件概率和变分梯度处处正交,那么方差就越小
\[\begin{split}\begin{align}
Dv(m) &= 4E_{p(x)}[Var_{p(y|x)}[ - \nabla_x \log p(y|x)^T {\partial \over \partial m} s_m(x)]] \\
&= 4E_{p(x)}[E_{p(y|x)}[ ||- \nabla_x \log p(y|x)^T {\partial \over \partial m} s_m(x) ||^2]]\\
&= 4E_{p(x)}[E_{p(y|x)}[ \left( {\partial \over \partial m} s_m(x)^T ( \nabla_x \log p(y|x) - E_{p(y|x)}[ \nabla_x \log p(y|x)] \right)^2]]\\
&= \text{(extra below)}\\
&= 4E_{p(x)}[E_{p(y|x)}[ {\partial \over \partial m} s_m(x) ^T \nabla_x \log p(y|x) \cdot \nabla_x \log p(y|x)^T {\partial \over \partial m} s_m(x) ]]\\
&= 4E_{p(x)}[ {\partial \over \partial m} s_m(x) ^T E_{p(y|x)}[ \nabla_x \log p(y|x) \cdot \nabla_x \log p(y|x)^T ] {\partial \over \partial m} s_m(x) ]\\
&= 4E_{p(x)}[ tr({\partial \over \partial m} s_m(x) {\partial \over \partial m} s_m(x) ^T E_{p(y|x)}[ \nabla_x \log p(y|x) \cdot \nabla_x \log p(y|x)^T ] ) ]\\
&= 4E_{p(x)}[
E_{p(y|x)}[
\sum_{i,j}
{\partial \over \partial m} s_m(x)_i
{\partial \over \partial m} s_m(x)_j
\nabla_{x_i} \log p(y|x) \cdot \nabla_{x_j} \log p(y|x) ] ]\\
&= 4E_{p(x)}[
\sum_{i,j}
{\partial \over \partial m} s_m(x)_i
{\partial \over \partial m} s_m(x)_j
E_{p(y|x)}[
\nabla_{x_i} \log p(y|x) \cdot \nabla_{x_j} \log p(y|x) ] ]
\end{align}
\end{split}\]
定义梯度x在 \(p(y|x)\) 上的方差为 V(x),我们可以看到DSM的方差不小于ESM,这是因为 \(\nabla_x \log p(y|x)\) 不可能处处为零,并且导向使得 y 更有可能的 x. 这也意味着有可能找到比DSM的方差更低的estimator
\[\begin{split}\begin{align}
V({\partial \over \partial m} L_{DSM})
&= E_{p(x)}[ Var_{p(y|x)}[{\partial \over \partial m}||s_m(x)-\nabla_x \log p(x|y)||^2]]\\
&= 4 E_{p(x)}[ Var_{p(y|x)}[{\nabla_x (\log p(x) + \log p(y|x))^T {\partial \over \partial m}s_m(x) }]]\\
&= \left( \begin{aligned}
&4 E_{p(x)}[ Var_{p(y|x)}[\nabla_x^T \log p(x)
{\partial \over \partial m}s_m(x) ]
\\
&+ Var_{p(y|x)}[ \nabla_x^T \log p(y|x) {\partial \over \partial m}s_m(x) ]
\\
&+2 Cov_{p(y|x)}[\nabla_x^T \log p(x)
{\partial \over \partial m}s_m(x) , \nabla_x^T \log p(y|x) {\partial \over \partial m}s_m(x)]]
\end{aligned} \right)
\\
&= \left( \begin{aligned} &
V({\partial\over \partial m} L_{ESM}) + 4E_{p(x)}[E_{p(y|x)}[(\nabla_x^T \log p(y|x) {\partial \over \partial m}s_m(x))^2]]\\
&+8E_{p(x)}[E_{p(y|x)}[\nabla_x^T \log p(x)
{\partial \over \partial m}s_m(x) \nabla_x^T \log p(y|x) {\partial \over \partial m}s_m(x)]] \\
\end{aligned} \right) \\
&= \left( \begin{aligned} &
V({\partial\over \partial m} L_{ESM}) + 4E_{p(x)}[E_{p(y|x)}[(\nabla_x^T \log p(y|x) {\partial \over \partial m}s_m(x))^2]]\\
&+8E_{p(x)}[\nabla_x^T \log p(x)
{\partial \over \partial m}s_m(x) E_{p(y|x)}[\nabla_x^T \log p(y|x)] {\partial \over \partial m}s_m(x)] \\
\end{aligned} \right) \\
&= \left( V({\partial\over \partial m} L_{ESM}) + 4E_{p(x)}[E_{p(y|x)}[(\nabla_x^T \log p(y|x) {\partial \over \partial m}s_m(x))^2]] +0\right) \\
V({\partial \over \partial m} L_{DSM}) &= V({\partial\over \partial m} L_{ESM}) + 4E_{p(x)}[E_{p(y|x)}[(\nabla_x^T \log p(y|x) {\partial \over \partial m}s_m(x))^2]]\\
V({\partial \over \partial m} L_{DSM}) &= 0 + Dv(m)\\
V({\partial \over \partial m} L_{DSM}) &\geq V({\partial\over \partial m} L_{ESM}) =0\\
Dv(m) &= 4E_{p(x,y)}[ ||\nabla_x \log p(y|x)^T {\partial \over \partial m} s_m(x) ||^2]\\
\end{align}\end{split}\]