#! https://zhuanlan.zhihu.com/p/535306685

1025: 降噪梯度对齐(Denoisng Score Matching)

CATSMILE-1025

前言

  • 目标:

  • 背景与动机:

    • 梯度对齐(Score matching)是一种特殊的基于Fisher Divergence概率模型.

    • DDPM需要SM的前置知识

  • 结论:

    • DSM梯度估值的方差严格不小于ESM的梯度估值方差

    • Kingma2021提出的learning the corruption noise或许可以减小这个gap.

  • 备注:

  • 关键词:

  • 展望方向:

  • CHANGELOG

梯度对齐模型考虑的是用L2损失拟合对数概率在相空间上的梯度,形式上类似一个fisher divergence. 在文献中经常出现用DSM把ESM替换掉的情况,容易造成 读者断片.

\[\begin{split}\begin{align} L_{ESM}(m) &= -\int p(x) ||s_m(x) - \nabla_x \log p(x)||^2.dx \\ &= - E_{p(x)}[||s_m(x) - \nabla_x \log p(x)||^2] \\ L_{DSM}(m) &= -E_{p(x,y)}[||s_m(x) - \nabla_x \log p(x|y)||^2] \end{align} \end{split}\]

至少从Denoising Auto Encoder开始,发现这个explicit score matching(ESM) 目标可以改写成一个Denoisng score matching (DSM)形式. 让我们考虑任意可微的条件概率分解 \(p(x,y) = p(x|y)p(y)\) ,可以对ESM进行重写

\[\begin{split}\begin{align} L_{ESM}(m) &= - E_{p(x)}[||s_m(x) - \nabla_x \log p(x)||^2] \\ &= - E_{p(x)}[||s_m(x) - \nabla_x \log \sum_y p(x,y)||^2] \\ &= - E_{p(x)}[||s_m(x) - \sum_{y}{p(x,y)\over \sum_y p(x,y)} \nabla_x \log p(x,y)||^2] \\ &= - E_{p(x)}[||s_m(x) - \sum_{y}{p(y|x)} \nabla_x (\log p(x|y)+\log p(y))||^2] \\ &= - E_{p(x)}[||s_m(x) - \sum_{y}{p(y|x)} \nabla_x (\log p(x|y)+0)||^2] \\ &= - E_{p(x)}[||s_m(x)||^2] - E_{p(x)}[||\nabla_x \log p(x)||^2] \\&+ 2E_{p(x)}[s_m(x)^T \sum_{y}{p(y|x)} \nabla_x \log p(x|y)] \\ &= - E_{p(x,y)}[||s_m(x)||^2] - E_{p(x)}[||\nabla_x \log p(x)||^2] \\&+ 2E_{p(x)} [\sum_{y}{p(y|x)}\cdot s_m(x)^T \nabla_x \log p(x|y)] \\ L_{ESM}(m)&= - E_{p(x,y)}[||s_m(x)||^2] - E_{p(x)}[||\nabla_x \log p(x)||^2] \\&+ 2E_{p(x)p(y|x)} [ s_m(x)^T \nabla_x \log p(x|y)] \end{align}\end{split}\]

对比一下DSM目标,做差后只有第二项不为零,且不依赖于参数 \(m\), 所以两者给m提供的梯度信息是一致的. 这个差值可以 可以变化成不同的 y 上 \(\log p(x|y)\) 梯度的方差,也就是说,如果不同的y都能 导出同一个x,但是条件概率的梯度却很不同时,这两个优化目标的不取决于 \(s_m(x)\) 的差值就越大.

\[\begin{split}\begin{align} L_{DSM}(m) &= -E_{p(x,y)}[||s_m(x) - \nabla_x \log p(x|y)||^2] \\ &=-E_{p(x,y)}[||s_m(x)||^2] - E_{p(x,y)}[||\nabla_x \log p(x|y)||^2]\\&+ 2 E_{p(x,y)}[s_m(x)^T \nabla_x \log p(x|y)] \\ L_{ESM}(m) - L_{DSM}(m) &= E_{p(x,y)}[||\nabla_x \log p(x|y)||^2] - E_{p(x)}[||\nabla_x \log p(x)||^2] \\ &= E_{p(x,y)}[||\nabla_x \log p(x|y)||^2] - E_{p(x,y)}[||\nabla_x \log p(x)||^2] \\ &= E_{p(x,y)}[||\nabla_x \log p(x|y)||^2 - ||\nabla_x \log \sum_y p(x|y) p(y)||^2] \\ &= E_{p(x,y)}[||\nabla_x \log p(x|y)||^2 - ||\sum_y p(y|x)\nabla_x \log p(x|y)||^2] \\ &= E_{p(x)}[ E_{p(y|x)} [||\nabla_x \log p(x|y)||^2 - ||E_{p(y|x)}[\nabla_x \log p(x|y)]||^2] \\ &= E_{p(x)}[ Var_{p(y|x)} [\nabla_x \log p(x|y)]] \\ &\geq 0 \\ L_{ESM}(m) &\geq L_{DSM}(m) \end{align}\end{split}\]

从这个角度讲,尽管DSM和ESM提供的梯度均值是相同的,但是可能有一些方差上的差别 让我们观察一下梯度的估计值的方差

\[\begin{split}\begin{align} {\partial \over \partial m} L_{ESM} &= -{\partial \over \partial m} E_{p(x)}[||s_m(x)-\nabla_x \log p(x)||^2]\\ &= E_{p(y)}[ E_{p(x|y)}[-{\partial \over \partial m}||s_m(x)-\nabla_x \log p(x)||^2]]\\ &= 2E_{p(y)}[ E_{p(x|y)}[(\nabla_x \log p(x) - s_m(x) )^T {\partial \over \partial m} s_m(x)]]\\ {\partial \over \partial m} L_{DSM} &= -{\partial \over \partial m} E_{p(x)}[||s_m(x)-\nabla_x \log p(x|y)||^2]\\ &= E_{p(y)}[ E_{p(x|y)}[-{\partial \over \partial m}||s_m(x)-\nabla_x \log p(x|y)||^2]]\\ &= 2E_{p(y)}[ E_{p(x|y)}[(\nabla_x \log p(x|y) - s_m(x) )^T {\partial \over \partial m} s_m(x)]] \\ D(m) &= {\partial \over \partial m} (L_{ESM} - L_{DSM})\\ &= 2E_{p(x)}[ E_{p(y|x)}[ (\nabla_x\log p(x) - \nabla_x \log p(x|y) )^T {\partial \over \partial m} s_m(x)] \\ &= 2E_{p(x)}[ E_{p(y|x)}[ (\nabla_x\log p(y) - \nabla_x \log p(y|x) )^T {\partial \over \partial m} s_m(x)] \\ &= 2E_{p(x)}[ E_{p(y|x)}[ - \nabla_x \log p(y|x)^T {\partial \over \partial m} s_m(x)]] \\ &= 2E_{p(x)}[ E_{p(y|x)}[ - \nabla_x \log p(y|x)]^T {\partial \over \partial m} s_m(x)]\\ &= 2E_{p(x)}[ - \int_y \nabla_x^T p(y|x) {\partial \over \partial m} s_m(x)]\\ &= 2E_{p(x)}[ - \nabla_x^T \int_y p(y|x) {\partial \over \partial m} s_m(x)]\\ &= 2E_{p(x)}[ - \vec {0}\ \cdot {\partial \over \partial m} s_m(x)] \end{align}\end{split}\]

我们可以看到,梯度差的期望为0,是由条件概率归一化所保证的,考察梯度差的方差,得到了一个协方差的形式. 可以看出这个方差非负的,只要条件概率梯度和变分梯度的相关性越高,那么使用DSM引入的方差就越高.如果条件概率和变分梯度处处正交,那么方差就越小

\[\begin{split}\begin{align} Dv(m) &= 4E_{p(x)}[Var_{p(y|x)}[ - \nabla_x \log p(y|x)^T {\partial \over \partial m} s_m(x)]] \\ &= 4E_{p(x)}[E_{p(y|x)}[ ||- \nabla_x \log p(y|x)^T {\partial \over \partial m} s_m(x) ||^2]]\\ &= 4E_{p(x)}[E_{p(y|x)}[ \left( {\partial \over \partial m} s_m(x)^T ( \nabla_x \log p(y|x) - E_{p(y|x)}[ \nabla_x \log p(y|x)] \right)^2]]\\ &= \text{(extra below)}\\ &= 4E_{p(x)}[E_{p(y|x)}[ {\partial \over \partial m} s_m(x) ^T \nabla_x \log p(y|x) \cdot \nabla_x \log p(y|x)^T {\partial \over \partial m} s_m(x) ]]\\ &= 4E_{p(x)}[ {\partial \over \partial m} s_m(x) ^T E_{p(y|x)}[ \nabla_x \log p(y|x) \cdot \nabla_x \log p(y|x)^T ] {\partial \over \partial m} s_m(x) ]\\ &= 4E_{p(x)}[ tr({\partial \over \partial m} s_m(x) {\partial \over \partial m} s_m(x) ^T E_{p(y|x)}[ \nabla_x \log p(y|x) \cdot \nabla_x \log p(y|x)^T ] ) ]\\ &= 4E_{p(x)}[ E_{p(y|x)}[ \sum_{i,j} {\partial \over \partial m} s_m(x)_i {\partial \over \partial m} s_m(x)_j \nabla_{x_i} \log p(y|x) \cdot \nabla_{x_j} \log p(y|x) ] ]\\ &= 4E_{p(x)}[ \sum_{i,j} {\partial \over \partial m} s_m(x)_i {\partial \over \partial m} s_m(x)_j E_{p(y|x)}[ \nabla_{x_i} \log p(y|x) \cdot \nabla_{x_j} \log p(y|x) ] ] \end{align} \end{split}\]

定义梯度x在 \(p(y|x)\) 上的方差为 V(x),我们可以看到DSM的方差不小于ESM,这是因为 \(\nabla_x \log p(y|x)\) 不可能处处为零,并且导向使得 y 更有可能的 x. 这也意味着有可能找到比DSM的方差更低的estimator

\[\begin{split}\begin{align} V({\partial \over \partial m} L_{DSM}) &= E_{p(x)}[ Var_{p(y|x)}[{\partial \over \partial m}||s_m(x)-\nabla_x \log p(x|y)||^2]]\\ &= 4 E_{p(x)}[ Var_{p(y|x)}[{\nabla_x (\log p(x) + \log p(y|x))^T {\partial \over \partial m}s_m(x) }]]\\ &= \left( \begin{aligned} &4 E_{p(x)}[ Var_{p(y|x)}[\nabla_x^T \log p(x) {\partial \over \partial m}s_m(x) ] \\ &+ Var_{p(y|x)}[ \nabla_x^T \log p(y|x) {\partial \over \partial m}s_m(x) ] \\ &+2 Cov_{p(y|x)}[\nabla_x^T \log p(x) {\partial \over \partial m}s_m(x) , \nabla_x^T \log p(y|x) {\partial \over \partial m}s_m(x)]] \end{aligned} \right) \\ &= \left( \begin{aligned} & V({\partial\over \partial m} L_{ESM}) + 4E_{p(x)}[E_{p(y|x)}[(\nabla_x^T \log p(y|x) {\partial \over \partial m}s_m(x))^2]]\\ &+8E_{p(x)}[E_{p(y|x)}[\nabla_x^T \log p(x) {\partial \over \partial m}s_m(x) \nabla_x^T \log p(y|x) {\partial \over \partial m}s_m(x)]] \\ \end{aligned} \right) \\ &= \left( \begin{aligned} & V({\partial\over \partial m} L_{ESM}) + 4E_{p(x)}[E_{p(y|x)}[(\nabla_x^T \log p(y|x) {\partial \over \partial m}s_m(x))^2]]\\ &+8E_{p(x)}[\nabla_x^T \log p(x) {\partial \over \partial m}s_m(x) E_{p(y|x)}[\nabla_x^T \log p(y|x)] {\partial \over \partial m}s_m(x)] \\ \end{aligned} \right) \\ &= \left( V({\partial\over \partial m} L_{ESM}) + 4E_{p(x)}[E_{p(y|x)}[(\nabla_x^T \log p(y|x) {\partial \over \partial m}s_m(x))^2]] +0\right) \\ V({\partial \over \partial m} L_{DSM}) &= V({\partial\over \partial m} L_{ESM}) + 4E_{p(x)}[E_{p(y|x)}[(\nabla_x^T \log p(y|x) {\partial \over \partial m}s_m(x))^2]]\\ V({\partial \over \partial m} L_{DSM}) &= 0 + Dv(m)\\ V({\partial \over \partial m} L_{DSM}) &\geq V({\partial\over \partial m} L_{ESM}) =0\\ Dv(m) &= 4E_{p(x,y)}[ ||\nabla_x \log p(y|x)^T {\partial \over \partial m} s_m(x) ||^2]\\ \end{align}\end{split}\]

参考