#! https://zhuanlan.zhihu.com/p/523008403
1011: 噪音条件分数网络 noise conditional score network (NCSE)
前言
目标: 理解score-based相关的操作,为理解DDPM的生成打下基础
模型效果: 实现对于某一数据集所约束的分布上的采样
总结:
NCSE is secretly a GMM
基于分数的生成模型,可以认为是,对数据自然规定的非参数密度的梯度(gradient of a nonparametric distribution),做了一个 参数化估计(神经网络).
这个构造从数学上来看非常优雅和巧妙,对数据分布进行了非常直接的拟合,
可以用来复习用降噪思想采样的数学框架 而且巧妙地把噪声和采样联系在了一起.
参考: Song and Ermon1
优点: 模型的构造性和原则性都很强 (a highly principled approach to model the data distribution),通过引入噪声分布,对黑盒部分进行了 进一步的限制,让模型更容易被理解
缺点: 慢
后续: 看看DDPM等模型具体做了啥优化.
扰动分布: perturbed distribution
作者考虑的perturbed distribution甚为优雅,对于方差为 \(\sigma\) 的噪音,在抽象测度上, 随机的高斯噪音等价于热力学核下的扩散,于是对于扰动后的分布 \(\bar{x}_\sigma\) 有
只要 \(\int p_{data}(x) dx = 1\) ,那么由于求边际分布的性质,可以导出 \(\int p(\bar{x}_\sigma) d \bar{x}_\sigma=1\)
损失函数: Fisher Divergence between guide function and guide estimator
从Data采样出corrupted sample的操作是比较好理解的,但是score函数的计算在论文各处并不是很清晰,常常用 \(\nabla_x \log p(x)\) 加以指代,从构造上讲,是一个把x引导到高分区域的一个有方向的函数. (个人认为翻译为score function不是很恰当, 因为score暗含了scalar标量的感觉,用flow estimator可能会好一些,或者guide function,下文中直接以guide function). 作者使用了L2损失函数如下, 形式上是一个Fisher散度
由于没从论文里看出guide function的具体计算方式,根据形式猜测也是用采样的方式进行的估算,因此直接调查了源码.
def dsm_score_estimation(scorenet, samples, sigma=0.01):
perturbed_samples = samples + torch.randn_like(samples) * sigma
target = - 1 / (sigma ** 2) * (perturbed_samples - samples)
scores = scorenet(perturbed_samples)
target = target.view(target.shape[0], -1)
scores = scores.view(scores.shape[0], -1)
loss = 1 / 2. * ((scores - target) ** 2).sum(dim=-1).mean(dim=0)
return loss
可以看出,实际计算的是
也就是说只考虑了在热力学核上求导得出的梯度,具体在符号推导上是怎么构成的等价性, 需要进一步做一些推导澄清. 从直觉上来看,这个梯度的估计,在 \(p_{data}(x)\) 较为 稀疏时是成立的,因为在任意位置 \(\bar{x}\) 都可以认为是单峰的高斯分布. 在高斯组分之间有一些重合时,不同高斯给出的梯度之间,也通过期望算子得到了聚合, 因此这个估计可能是exact的.通过这种直觉,应该能给出一个具体的求导方式.
尝试直接对扰动分布的对数求导 \(\log p(\bar{x}_{\sigma})\), 由于存在 logsumexp 的混合形式
注意到这里需要用 \(q(x|\bar{x})\) 对得到的梯度做reweight,具体 实现可能就是在采样里直接实现了?这个形式也是一个期望的形式,能否把期望套入到 损失函数的期望里,还需要具体求导确认.
视角
扰动分布的形式可以看成一个高斯混合模型,每一个数据点构成一个高斯组分, 这个分布的梯度是解析的,可以直接通过在数据集上采样估计. NCSN模型通过 用一套参数近似这个高斯混合模型的梯度,在采样时就不必依赖于原始数据, 直接使用模型提供的梯度,就能实现恢复出数据的效果.
离散情况
在离散的相空间里,扰动分布可以由随机的masking给出
降噪变成了反向填充mask的过程
\(P(x|\bar{x})\)应当可以分解为两步,选择填充的mask,再选择填充的token,任意选择 mask位置以后,应当按照原sequence进行填充.
在离散的unmasking开箱过程也构成一个特殊的reverse process. rev的起点是全部[mask][mask][mask]
的一个序列,
逐步把mask替换成word的时候就是不断地进行采样的过程, 类比NCSN,我们希望把这个采样过程分解成逐步进行的过程,[mask]
的数量可以用来区分解码的不同阶段.如果考虑一个简单的泊松masking过程,每次挑选一个位置替换成mask,那么一个序列进入自身的概率为
让我们用\(f(x_t,l)\) 表示在\(x_t\)的l处放置mask后的序列,那么
那么,对于给定了数据 \(P(x_0) = P_{data}(x_0)\)后的分布,它的逆分布由Bayes Theorem给出
考虑通过最大化cross entropy近似
因为相空间是离散的,所以没有啥梯度需要计算,直接数数就行.对于任意的序列\(x_0\),我们可以采样出一个逐渐完全噪声化的 序列\(x_0,x_1,x_2,x_3,\dots,x_T\),在每次加入mask时,可以考虑所有的组合\(f^{-1}(x_{t+1})=\{y:\exist i,s.t. f(y,i)=x_{t+1}\}\). 但是这样似乎会有一些不平衡.最简单的办法还是保留一个反向指针,要求模型预测这个被mask掉的词语.但是这样 似乎有些浪费,毕竟没法按照mask指定词语,
那么我们给模型输入的就是不同的mask后的序列 \(x_{t+1}\),并且每个位置给出unmask的词语. 相当于P(x_t|x_{t+1})的一个组分, 但是这样有点奇怪,可能因为没法估计配分函数.可以想见,一个常用词旁边的mask的可能性是非常多的,也就有很大的\(Z(x_{t+1})\), 理论上这个是要通过非参数过程来估计的.那么由于全\(mask\)附近密度较高,稀有序列附近密度较低,我们得到的分数应当做适当reweight?
综合以上情况,我们把混合mask数量的样本给入模型,观察采样效果
参考
- 1
Yang Song and Stefano Ermon. Generative modeling by estimating gradients of the data distribution. CoRR, 2019. URL: http://arxiv.org/abs/1907.05600, arXiv:1907.05600.
https://yang-song.github.io/blog/2021/score/