贝叶斯神经网络变分近似
date
Dec 28, 2021
slug
BayesianNN _Variational_Approximation
status
Published
tags
Neural Network
summary
贝叶斯神经网络中参数的变分近似
type
Post
本文脉络:首先回顾一下之前对贝叶斯神经网络的介绍,然后简述如何在贝叶斯神经网络进行反向传播。
回顾
之前讲到。普通的神经网络具有如下缺陷:
- 容易过拟合
- 无法表达不确定性
所以为了解决普通神经网络的以上缺陷,学者将贝叶斯引入到了神经网络中来,通过神经网络中权重的不确定性来使神经网络得以表达不确定性。
具体的做法就是,贝叶斯神经网络中的权重不再是一个固定的值,而是一个完整的分布。

那我们应该如何求解贝叶斯神经网络中的参数呢?反向传播对于贝叶斯神经网络还适用吗?
我们先来回顾一下普通神经网络中参数的求解。
我们将一个神经网络视为这样一个概率模型:
在给定一个输入 的情况下,神经网络会使用一组参数 为每个可能的输出分配一个概率 。
- 对于分类模型: 为多分类分布,对应的是交叉熵损失
- 对于回归模型:为高斯分布,对应的是均方差损失
模型中的参数通过极大似然估计(MLE)求解:
如果引入正则化后,则对应的极大后验估计:
且在极大后验估计中,如果先验 为高斯分布,那么对应的是正则化,如果先验 为拉普拉斯分布,那么对应的是正则化。
具体参数通过梯度下降来求解。
贝叶斯神经网络中的变分近似
贝叶斯神经网络
贝叶斯神经网络中的参数均为一个个随机变量,我们想要的是这个随机变量的完整的分布。根据贝叶斯公式:
现在想算出后验,我们做如下尝试:
- ☹️计算后验解析解。由全概率公式可知,分母 为在 的取值空间上进行积分,我们知道神经网络的单个权重的取值空间为实数集,而这些权重一起构成的空间将相当复杂,所以想要通过积分计算 解析解,基本是不可能的。
- ☹️通过采样计算 。通过蒙特卡洛采样来计算分母 的积分。然而着要求从一个相当高的维度分布中采样。理论上可行,但实践中太过于消耗时间了。
- 😄变分推断。采用一些简单的分布去近似后验分布
变分引入
普通神经网络中,每个权重都是一个固定值,为贝叶斯神经网络中,权重为一个分布。现在我们用高斯分布去近似每个权重的分布。

对于权重 的分布, 我们用高斯分布 来近似,形式化表达如下:
将新引进来的所有参数表示为 :
此时,我们就定义以下的变分分布,该变分分布包括了神经网络中所有权重的分布:
此时我们的目的就很清晰了,选择最佳的 ,使得我们构造出来的变分分布 与 目标后验分布最接近。

那么就有一个问题,应该如何衡量这两个分布之间的差异呢?在这里,选择采用KL散度度量量分布之间的差异:
变分学习
这个时候我们的问题已经很清晰了,找到一组合适的参数\theta, 使得变分分布 与后验分布 的 散度最小。形式化表达如下:
散度即可以看作是我们普通神经网络中的损失函数,为了简单,我们将其用 表示:
精确的计算这个损失函数的最小化是相当困难的,实际上,我们采用梯度下降来求解。
但是这就有哟个问题,我们该如何对一个函数的期望求梯度呢?
论文中提出这样一个命题:
Proposition 1. Let 𝜖 be a random variable having a probability density given by 𝑞(𝜖) and let 𝑤 = 𝑡(𝜃,𝜖) where 𝑡(𝜃,𝜖) is a deterministic function. Suppose further that the marginal probability density of 𝑤, 𝑞(𝑤∣𝜃), is such that 𝑞(𝜖)𝑑𝜖= 𝑞(𝑤∣𝜃)𝑑𝑤. Then for a function 𝑓 with derivatives in 𝑤:
简单来说,就是在满足一定的条件下,期望的导数可以被表示为导数的期望。恰好,我们的目标函数 是满足这样条件的。所以我们就可以直接对 求梯度,然后再对梯度求期望。其中,对梯度求期望可以通过多次采样求平均来实现。
所以现在我们的目标函数转化为 :
我们可以通过求 的多次采样的均值来估计
至此,优化问题就转化为了
我们通过梯度下降来求解 :
- 采样得到 ,
采样的过程不可导,所以这里采用了一个重参数化的技巧,另
此时,对 采样就转化为了对 采样,同时参数转化为了
- 计算损失函数值
- 计算似然
- 计算损失值:
- 计算梯度:
- 更新参数
小结
贝叶斯神经网络中的参数W为随机变量,我们想要求得这个随机变量的完整分布,然后在前向传播的时候,通过采样得到的权重值具有不确定度,进而使得整个神经网络可以表达不确定度。
参数W完整的分布称为后验分布,由贝叶斯公式:
可知,想要计算后验的解析解是异常困难的,所以这里我们通过一些简单的分布来近似这个后验分布,这个过程称之为变分近似,这些个简单的分布称为变分分布。
我们通过\text{KL} 散度来度量变分分布于后验分布的差异,所以有了如下的优化问题:
在解这个优化问题的过程中,使用了一个命题和一个重参数化的技巧:
- 命题:在满足一定的条件下,期望的导数可以表示为导数的期望
- 重参数技巧:将对 的采用转化为了对 的采样:
最后,优化问题转化为了:
最后
最后我们看一下该如何在Pytorch中编写这样一个贝叶斯神经网络全连接层。
首先来看一下Pytorch中的普通神经网络全连接层:
class Linear(nn.Module): def __init__(self, in_features, out_features): super(Linear, self).__init__() self.in_features = in_features self.out_features = out_features # 定义两个参数 self.w = nn.Parameter(torch.empty((out_features, in_features))) self.b = nn.Parameter(torch.empty((out_features))) self.w.data.normal_(mean = 0, std = 0.01) self.b.data.fill_(0.0) def forward(self, input): return torch.matmul(input, self.w.T) + self.b
那如果要实现贝叶斯神经网络全连接层,就需要这么写:
class BayesLinear(nn.Module): def __init__(self, in_feature, out_feature, prior_var=1.0): super(BayesLinear,self).__init__() self.in_feature = in_feature self.out_feature = out_feature # 权重的分布 self.w_mu = nn.Parameter(torch.zeros(self.out_feature, self.in_feature)) self.w_rho = nn.Parameter(torch.zeros(self.out_feature, self.in_feature)) # 偏置的分布 self.b_mu = nn.Parameter(torch.zeros(self.out_feature)) self.b_rho = nn.Parameter(torch.zeros(self.out_feature)) # 权重 和偏置 self.w = None self.b = None # 初始化先验分布 self.prior = torch.distributions.Normal(0, prior_var) def forward(self, input): # 采样得到 权重和偏执 # 这里 epsion shape 和mu.shape 是一样的,是一对一的 w_epsilon = torch.distributions.Normal(0, 1).sample(self.w_mu.shape) self.w = self.w_mu + torch.log(1+torch.exp(self.w_rho)) * w_epsilon b_epsilon = Normal(0,1).sample(self.b_mu.shape) self.b = self.b_mu + torch.log(1+torch.exp(self.b_rho)) * b_epsilon # 计算先验概率 w_log_prior = self.prior.log_prob(self.w) b_log_prior = self.prior.log_prob(self.b) self.log_prior = torch.sum(w_log_prior) + torch.sum(b_log_prior) # 计算变分分布 self.w_post = torch.distributions.Normal(self.w_mu.data, torch.log(1+torch.exp(self.w_rho))) self.b_post = torch.distributions.Normal(self.b_mu.data, torch.log(1+torch.exp(self.b_rho))) # 计算变分后验概率 self.log_post = self.w_post.log_prob(self.w).sum() + self.b_post.log_prob(self.b).sum() return torch.matmul(input, self.w.T) + self.b
可见,贝叶斯神经网络全连接层中的参数数量为普通神经网络参数数量的二倍。
参考
视频:贝叶斯神经网络