#! https://zhuanlan.zhihu.com/p/523008403

1011: 噪音条件分数网络 noise conditional score network (NCSE)

CATSMILE-1011

前言

  • 目标: 理解score-based相关的操作,为理解DDPM的生成打下基础

  • 模型效果: 实现对于某一数据集所约束的分布上的采样

  • 总结:

    • NCSE is secretly a GMM

    • 基于分数的生成模型,可以认为是,对数据自然规定的非参数密度的梯度(gradient of a nonparametric distribution),做了一个 参数化估计(神经网络).

    • 这个构造从数学上来看非常优雅和巧妙,对数据分布进行了非常直接的拟合,

    • 可以用来复习用降噪思想采样的数学框架 而且巧妙地把噪声和采样联系在了一起.

  • 参考: Song and Ermon1

  • 优点: 模型的构造性和原则性都很强 (a highly principled approach to model the data distribution),通过引入噪声分布,对黑盒部分进行了 进一步的限制,让模型更容易被理解

  • 缺点: 慢

  • 后续: 看看DDPM等模型具体做了啥优化.

扰动分布: perturbed distribution

作者考虑的perturbed distribution甚为优雅,对于方差为 \(\sigma\) 的噪音,在抽象测度上, 随机的高斯噪音等价于热力学核下的扩散,于是对于扰动后的分布 \(\bar{x}_\sigma\)

\[\begin{split} \begin{align} x &\sim p_{data}(\cdot) \\ \bar{x}_{\sigma} &\sim x + N(0,\sigma^2 I) \\ p(\bar{x}_{\sigma}) &= \sum_{x} p(\bar{x}_{\sigma}| x) p_{data}(x) \\ &= \sum_{x} N(\bar{x} - x|0,\sigma^2I) p_{data}(x) \\ &= \int N(\bar{x} - x|0,\sigma^2I) p_{data}(x) dx \end{align} \end{split}\]

只要 \(\int p_{data}(x) dx = 1\) ,那么由于求边际分布的性质,可以导出 \(\int p(\bar{x}_\sigma) d \bar{x}_\sigma=1\)

损失函数: Fisher Divergence between guide function and guide estimator

从Data采样出corrupted sample的操作是比较好理解的,但是score函数的计算在论文各处并不是很清晰,常常用 \(\nabla_x \log p(x)\) 加以指代,从构造上讲,是一个把x引导到高分区域的一个有方向的函数. (个人认为翻译为score function不是很恰当, 因为score暗含了scalar标量的感觉,用flow estimator可能会好一些,或者guide function,下文中直接以guide function). 作者使用了L2损失函数如下, 形式上是一个Fisher散度

\[ E_{p(x)}[||\nabla_x \log p(x) - s_\theta(x) ||^2] = \int p(x) ||\nabla_x \log p(x) - s_\theta(x) ||^2 dx \]

由于没从论文里看出guide function的具体计算方式,根据形式猜测也是用采样的方式进行的估算,因此直接调查了源码.

def dsm_score_estimation(scorenet, samples, sigma=0.01):
    perturbed_samples = samples + torch.randn_like(samples) * sigma
    target = - 1 / (sigma ** 2) * (perturbed_samples - samples)
    scores = scorenet(perturbed_samples)
    target = target.view(target.shape[0], -1)
    scores = scores.view(scores.shape[0], -1)
    loss = 1 / 2. * ((scores - target) ** 2).sum(dim=-1).mean(dim=0)

    return loss

可以看出,实际计算的是

\[\begin{split} \begin{align} loss = {1 \over 2} E_{p(\bar{x})} [||s_\theta(\bar{x}) - {1\over \sigma^2}(x - \bar{x} ) ||^2] \\ \nabla_{\bar{x}} \log p(\bar{x}) \approx {1\over \sigma^2}(x - \bar{x} ) \end{align} \end{split}\]

也就是说只考虑了在热力学核上求导得出的梯度,具体在符号推导上是怎么构成的等价性, 需要进一步做一些推导澄清. 从直觉上来看,这个梯度的估计,在 \(p_{data}(x)\) 较为 稀疏时是成立的,因为在任意位置 \(\bar{x}\) 都可以认为是单峰的高斯分布. 在高斯组分之间有一些重合时,不同高斯给出的梯度之间,也通过期望算子得到了聚合, 因此这个估计可能是exact的.通过这种直觉,应该能给出一个具体的求导方式.

尝试直接对扰动分布的对数求导 \(\log p(\bar{x}_{\sigma})\), 由于存在 logsumexp 的混合形式

\[\begin{split} \begin{align} x &\sim p_{data}(\cdot) \\ \bar{x}_{\sigma} & \sim x + N(0,\sigma^2 I) \\ p(\bar{x}|x) &= e^ { g(\bar{x}|x) }\\ \nabla_{\bar{x}} \log p(\bar{x}_{\sigma}) &= \nabla_{\bar{x}} \log(\sum_{x} p(\bar{x}_{\sigma}| x) p_{data}(x) ) \\ &= {\sum_{x}\nabla_{\bar{x}} [p(\bar{x}|x) p_{data}(x)] \over {\sum_{x} p(\bar{x}|x) p_{data}(x)} }\\ &= {\sum_{x} p_{data}(x)\nabla_{\bar{x}} p(\bar{x}|x) \over {\sum_{x} p(\bar{x}|x) p_{data}(x)} }\\ &= \sum_{x}{ p_{data}(x) \over {\sum_{x} p(\bar{x}|x) p_{data}(x)} } \nabla_{\bar{x}} e^{g(\bar{x}|x)}\\ &= \sum_{x}{ p_{data}(x) \over {\sum_{x} p(\bar{x}|x) p_{data}(x)} } e^{g(\bar{x}|x)} \nabla_{\bar{x}} g(\bar{x}|x)\\ &= \sum_{x}{ p(\bar{x}|x) p_{data}(x) \over {\sum_{x} p(\bar{x}|x) p_{data}(x)} } \nabla_{\bar{x}} g(\bar{x}|x)\\ &= \sum_{x} q(x|\bar{x} ) \nabla_{\bar{x}} g(\bar{x}|x) \\ &= \sum_{x} q(x|\bar{x} ) \nabla_{\bar{x}} (-||\bar{x}-x||^2/\sigma^2 )\\ &= 2 \sum_{x} q(x|\bar{x} ) (x-\bar{x})/\sigma^2 \\ \end{align} \end{split}\]

注意到这里需要用 \(q(x|\bar{x})\) 对得到的梯度做reweight,具体 实现可能就是在采样里直接实现了?这个形式也是一个期望的形式,能否把期望套入到 损失函数的期望里,还需要具体求导确认.

视角

扰动分布的形式可以看成一个高斯混合模型,每一个数据点构成一个高斯组分, 这个分布的梯度是解析的,可以直接通过在数据集上采样估计. NCSN模型通过 用一套参数近似这个高斯混合模型的梯度,在采样时就不必依赖于原始数据, 直接使用模型提供的梯度,就能实现恢复出数据的效果.

离散情况

在离散的相空间里,扰动分布可以由随机的masking给出

\[\begin{split} \bar{x} \sim mask(x) \\ P(\bar{x}|x) = { 1\over n_{um}(x)}\delta(x,\bar{x}) \end{split}\]

降噪变成了反向填充mask的过程

\[ s(\bar{x}) = P(x|\bar{x}) ={ P(x) P(\bar{x}|x)\over \sum_x P(x) P(\bar{x}|x)} \]

\(P(x|\bar{x})\)应当可以分解为两步,选择填充的mask,再选择填充的token,任意选择 mask位置以后,应当按照原sequence进行填充.

在离散的unmasking开箱过程也构成一个特殊的reverse process. rev的起点是全部[mask][mask][mask]的一个序列, 逐步把mask替换成word的时候就是不断地进行采样的过程, 类比NCSN,我们希望把这个采样过程分解成逐步进行的过程,[mask]的数量可以用来区分解码的不同阶段.如果考虑一个简单的泊松masking过程,每次挑选一个位置替换成mask,那么一个序列进入自身的概率为

\[\begin{split} P(x_{t+1}= x_{t}|x_{t}) = {m(x_t)\over L(x_t)}\\ P(x_{t+1} \neq x_{t}|x_{t}) = 1 - {m(x_t)\over L(x_t)} \end{split}\]

让我们用\(f(x_t,l)\) 表示在\(x_t\)的l处放置mask后的序列,那么

\[ P(x_{t+1} | x_t)= {1 \over L(x)}\sum_i \delta(x_{t+1}, f(x_t,l)) \]

那么,对于给定了数据 \(P(x_0) = P_{data}(x_0)\)后的分布,它的逆分布由Bayes Theorem给出

\[\begin{split} P(x_{t} | x_{t+1}) = {P(x_{t+1}|x_t) P(x_t) \over \sum_{x_{t}}P(x_{t+1}|x_t) P(x_t) } \\ = {P(x_{t+1}|x_t) P(x_t) \over Z(x_{t+1}) } \end{split}\]

考虑通过最大化cross entropy近似

\[ loss = - \sum_{x_t} P(x_t | x_{t+1}) \log s(x_{t+1},\theta) \]

因为相空间是离散的,所以没有啥梯度需要计算,直接数数就行.对于任意的序列\(x_0\),我们可以采样出一个逐渐完全噪声化的 序列\(x_0,x_1,x_2,x_3,\dots,x_T\),在每次加入mask时,可以考虑所有的组合\(f^{-1}(x_{t+1})=\{y:\exist i,s.t. f(y,i)=x_{t+1}\}\). 但是这样似乎会有一些不平衡.最简单的办法还是保留一个反向指针,要求模型预测这个被mask掉的词语.但是这样 似乎有些浪费,毕竟没法按照mask指定词语,

那么我们给模型输入的就是不同的mask后的序列 \(x_{t+1}\),并且每个位置给出unmask的词语. 相当于P(x_t|x_{t+1})的一个组分, 但是这样有点奇怪,可能因为没法估计配分函数.可以想见,一个常用词旁边的mask的可能性是非常多的,也就有很大的\(Z(x_{t+1})\), 理论上这个是要通过非参数过程来估计的.那么由于全\(mask\)附近密度较高,稀有序列附近密度较低,我们得到的分数应当做适当reweight?

综合以上情况,我们把混合mask数量的样本给入模型,观察采样效果

参考

1

Yang Song and Stefano Ermon. Generative modeling by estimating gradients of the data distribution. CoRR, 2019. URL: http://arxiv.org/abs/1907.05600, arXiv:1907.05600.

https://yang-song.github.io/blog/2021/score/