1. 批归一化 Batch Normalization
批归一化(Batch Normalization)因其可以加速神经网络训练、使网络训练更稳定,而且还有一定的正则化效果,所以得到了非常广泛的应用。但是,在推理阶段,BN层一般是可以完全融合到前面的卷积层的,而且丝毫不影响性能。
Batch Normalization 的思想非常简单,一句话概括就是,对一个神经元 (或者一个卷积核)的输出减去统计得到的均值除以标准差,然后乘以一个可学习的系数,再加上一个偏置,这个过程就完成了。
在训练过程中,主要执行如下四个步骤:
\[\begin{aligned} \mu & =\frac{1}{m} \sum_{i=1}^m x_i, \\ \sigma & =\frac{1}{m} \sum_{i=1}^m\left(x_i-\mu\right)^2, \\ \hat{x_i} & =\frac{x_i-\mu}{\sqrt{\sigma^2+\epsilon}}, \\ y_i & =\gamma * \hat{x_i}+\beta \end{aligned}\]其中 $\epsilon$ 为一个非常小的常数,例如 0.0001 ,主要是为了避免除零错误。而 $\gamma$ 和 $\beta$ 则是可学习参数,在训练过程中,和其他卷积核的参数一样,通过梯度下降来学习。在训练过程中,为保持稳定,一般使用滑动平均法更新均值和方差,滑动平均就是在更新当前值的时候,以一定比例保存之前的数值,以均值 $\mu$ 为例,以一定比例 $\theta$ (例如这里0.99) 保存之前的均值,当前只更新 $(1-\theta)$ 倍(也就是0.001倍)的本Batch 的均值,计算方法如下:
\[\mu_i=\theta \mu_{i-1}+(1-\theta) \mu_i\]标准差的滑动平均计算方法也一样。
2. 融合计算公式
网络完成训练后,在 inference 阶段,为了加速运算,通常将卷积层和BN层进行融合: (1) 巻积层可以抽象为如下公式:
\[x_i=w * x_{i-1}+b\](2) bn 层可以抽象为如下公式:
\[y=\frac{x_i-\mu}{\sqrt{\sigma+\epsilon}} \cdot \gamma+\beta\]其中 $\mu$ 和 $\sigma$ 是整个训练集的均值和方差(滑动),这是两个常量。 $\gamma$ 和 $\beta$ 是学习完成的参数,也是两个常量。 (3) 融合两层: 将 conv 层的公式带入到 BN 层的公式
\[\begin{aligned} y & =\frac{x_i-\mu}{\sqrt{\sigma+\epsilon}} \cdot \gamma+\beta \\ & =\frac{w * x_i+b-\mu}{\sqrt{\sigma+\epsilon}} \cdot \gamma+\beta \\ & =\frac{w * \gamma}{\sqrt{\sigma+\epsilon}} \cdot x+\left(\frac{b-\mu}{\sqrt{\sigma+\epsilon}} \cdot \gamma+\beta\right) \end{aligned}\]融合后相当于:
\[\begin{aligned} & w_{\text {new }}=\frac{w * \gamma}{\sqrt{\sigma+\epsilon}}, \\ & b_{\text {new }}=\left(\frac{b-\mu}{\sqrt{\sigma+\epsilon}} \cdot \gamma+\beta\right) \end{aligned}\]通过公式可以看出,我们可以将 Batch Normalization 层融合到卷积层中,相当于对卷积核进行一定的修改,没有增加卷积的计算量,同时整个 Batch Normalization 层的计算量都省去了。
3. 实际融合过程
import torch
import torchvision
import copy
from typing import Optional, Tuple, TypeVar
def fuse_conv_bn_weights(
conv_w: torch.Tensor,
conv_b: Optional[torch.Tensor],
bn_rm: torch.Tensor,
bn_rv: torch.Tensor,
bn_eps: float,
bn_w: Optional[torch.Tensor],
bn_b: Optional[torch.Tensor],
transpose: bool = False
) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
r"""Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters.
Args:
conv_w (torch.Tensor): Convolutional weight.
conv_b (Optional[torch.Tensor]): Convolutional bias.
bn_rm (torch.Tensor): BatchNorm running mean.
bn_rv (torch.Tensor): BatchNorm running variance.
bn_eps (float): BatchNorm epsilon.
bn_w (Optional[torch.Tensor]): BatchNorm weight.
bn_b (Optional[torch.Tensor]): BatchNorm bias.
transpose (bool, optional): If True, transpose the conv weight. Defaults to False.
Returns:
Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused convolutional weight and bias.
"""
conv_weight_dtype = conv_w.dtype
conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype
if conv_b is None:
conv_b = torch.zeros_like(bn_rm)
if bn_w is None:
bn_w = torch.ones_like(bn_rm)
if bn_b is None:
bn_b = torch.zeros_like(bn_rm)
bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
if transpose:
shape = [1, -1] + [1] * (len(conv_w.shape) - 2)
else:
shape = [-1, 1] + [1] * (len(conv_w.shape) - 2)
# import pdb; pdb.set_trace()
fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to(dtype=conv_weight_dtype)
fused_conv_b = ((conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b).to(dtype=conv_bias_dtype)
return (
torch.nn.Parameter(fused_conv_w, conv_w.requires_grad), torch.nn.Parameter(fused_conv_b, conv_b.requires_grad)
)
resnet18 = torchvision.models.resnet18(pretrained=True)
resnet18.eval()
conv_bn = torch.nn.Sequential(
resnet18.conv1,
resnet18.bn1
)
bn_var, bn_mean = resnet18.bn1.running_var, resnet18.bn1.running_mean
bn_eps = resnet18.bn1.eps
bn_weight, bn_bias = resnet18.bn1.weight, resnet18.bn1.bias
conv_weight, conv_bias = resnet18.conv1.weight, resnet18.conv1.bias
fused_w, fused_b = fuse_conv_bn_weights(conv_weight, conv_bias, bn_mean, bn_var, bn_eps, bn_weight, bn_bias)
fused = torch.nn.Conv2d(
resnet18.conv1.in_channels,
resnet18.conv1.out_channels,
kernel_size=resnet18.conv1.kernel_size,
stride=resnet18.conv1.stride,
padding=resnet18.conv1.padding,
bias=True
)
fused.weigh = fused_w
fused.bias = fused_b
torch.set_grad_enabled(False)
x = torch.randn(16, 3, 256, 256)
res1 = conv_bn.forward(x)
res2 = fused(x)
torch.allclose(res1, res2)
这里需要对计算过程中的Tensor Shape进行一个梳理:
Input: (16,3,256,256)
output: (16,64,128,128)
这里channel的变化是因为第一个conv算子的输入channel是3,输出channel是64,由于stride是(2,2) 因此HW减半
shape of conv weight:
conv weight (64,3,7,7)
conv bias (64)
shape of BN weight:
bn running_var: (64)
bn running_mean: (64)
bn weight (64)
bn bias: (64)