C++

Conv_BN融合

Posted by songlin on March 3, 2022

1. 批归一化 Batch Normalization

批归一化(Batch Normalization)因其可以加速神经网络训练、使网络训练更稳定,而且还有一定的正则化效果,所以得到了非常广泛的应用。但是,在推理阶段,BN层一般是可以完全融合到前面的卷积层的,而且丝毫不影响性能。

Batch Normalization 的思想非常简单,一句话概括就是,对一个神经元 (或者一个卷积核)的输出减去统计得到的均值除以标准差,然后乘以一个可学习的系数,再加上一个偏置,这个过程就完成了。

在训练过程中,主要执行如下四个步骤:

\[\begin{aligned} \mu & =\frac{1}{m} \sum_{i=1}^m x_i, \\ \sigma & =\frac{1}{m} \sum_{i=1}^m\left(x_i-\mu\right)^2, \\ \hat{x_i} & =\frac{x_i-\mu}{\sqrt{\sigma^2+\epsilon}}, \\ y_i & =\gamma * \hat{x_i}+\beta \end{aligned}\]

其中 $\epsilon$ 为一个非常小的常数,例如 0.0001 ,主要是为了避免除零错误。而 $\gamma$ 和 $\beta$ 则是可学习参数,在训练过程中,和其他卷积核的参数一样,通过梯度下降来学习。在训练过程中,为保持稳定,一般使用滑动平均法更新均值和方差,滑动平均就是在更新当前值的时候,以一定比例保存之前的数值,以均值 $\mu$ 为例,以一定比例 $\theta$ (例如这里0.99) 保存之前的均值,当前只更新 $(1-\theta)$ 倍(也就是0.001倍)的本Batch 的均值,计算方法如下:

\[\mu_i=\theta \mu_{i-1}+(1-\theta) \mu_i\]

标准差的滑动平均计算方法也一样。

2. 融合计算公式

网络完成训练后,在 inference 阶段,为了加速运算,通常将卷积层和BN层进行融合: (1) 巻积层可以抽象为如下公式:

\[x_i=w * x_{i-1}+b\]

(2) bn 层可以抽象为如下公式:

\[y=\frac{x_i-\mu}{\sqrt{\sigma+\epsilon}} \cdot \gamma+\beta\]

其中 $\mu$ 和 $\sigma$ 是整个训练集的均值和方差(滑动),这是两个常量。 $\gamma$ 和 $\beta$ 是学习完成的参数,也是两个常量。 (3) 融合两层: 将 conv 层的公式带入到 BN 层的公式

\[\begin{aligned} y & =\frac{x_i-\mu}{\sqrt{\sigma+\epsilon}} \cdot \gamma+\beta \\ & =\frac{w * x_i+b-\mu}{\sqrt{\sigma+\epsilon}} \cdot \gamma+\beta \\ & =\frac{w * \gamma}{\sqrt{\sigma+\epsilon}} \cdot x+\left(\frac{b-\mu}{\sqrt{\sigma+\epsilon}} \cdot \gamma+\beta\right) \end{aligned}\]

融合后相当于:

\[\begin{aligned} & w_{\text {new }}=\frac{w * \gamma}{\sqrt{\sigma+\epsilon}}, \\ & b_{\text {new }}=\left(\frac{b-\mu}{\sqrt{\sigma+\epsilon}} \cdot \gamma+\beta\right) \end{aligned}\]

通过公式可以看出,我们可以将 Batch Normalization 层融合到卷积层中,相当于对卷积核进行一定的修改,没有增加卷积的计算量,同时整个 Batch Normalization 层的计算量都省去了。

3. 实际融合过程

import torch
import torchvision
import copy
from typing import Optional, Tuple, TypeVar

def fuse_conv_bn_weights(
    conv_w: torch.Tensor,
    conv_b: Optional[torch.Tensor],
    bn_rm: torch.Tensor,
    bn_rv: torch.Tensor,
    bn_eps: float,
    bn_w: Optional[torch.Tensor],
    bn_b: Optional[torch.Tensor],
    transpose: bool = False
) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
    r"""Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters.

    Args:
        conv_w (torch.Tensor): Convolutional weight.
        conv_b (Optional[torch.Tensor]): Convolutional bias.
        bn_rm (torch.Tensor): BatchNorm running mean.
        bn_rv (torch.Tensor): BatchNorm running variance.
        bn_eps (float): BatchNorm epsilon.
        bn_w (Optional[torch.Tensor]): BatchNorm weight.
        bn_b (Optional[torch.Tensor]): BatchNorm bias.
        transpose (bool, optional): If True, transpose the conv weight. Defaults to False.

    Returns:
        Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused convolutional weight and bias.
    """
    conv_weight_dtype = conv_w.dtype
    conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype
    if conv_b is None:
        conv_b = torch.zeros_like(bn_rm)
    if bn_w is None:
        bn_w = torch.ones_like(bn_rm)
    if bn_b is None:
        bn_b = torch.zeros_like(bn_rm)
    bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)

    if transpose:
        shape = [1, -1] + [1] * (len(conv_w.shape) - 2)
    else:
        shape = [-1, 1] + [1] * (len(conv_w.shape) - 2)
    # import pdb; pdb.set_trace()
    fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to(dtype=conv_weight_dtype)
    fused_conv_b = ((conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b).to(dtype=conv_bias_dtype)

    return (
        torch.nn.Parameter(fused_conv_w, conv_w.requires_grad), torch.nn.Parameter(fused_conv_b, conv_b.requires_grad)
    )

resnet18 = torchvision.models.resnet18(pretrained=True)
resnet18.eval()
conv_bn = torch.nn.Sequential(
    resnet18.conv1, 
    resnet18.bn1 
)                                                                                                                     
bn_var, bn_mean = resnet18.bn1.running_var, resnet18.bn1.running_mean
bn_eps = resnet18.bn1.eps
bn_weight, bn_bias = resnet18.bn1.weight, resnet18.bn1.bias
conv_weight, conv_bias = resnet18.conv1.weight, resnet18.conv1.bias
fused_w, fused_b = fuse_conv_bn_weights(conv_weight, conv_bias, bn_mean, bn_var, bn_eps, bn_weight, bn_bias)
fused = torch.nn.Conv2d(
             resnet18.conv1.in_channels,
             resnet18.conv1.out_channels,
             kernel_size=resnet18.conv1.kernel_size,
             stride=resnet18.conv1.stride,
             padding=resnet18.conv1.padding,
             bias=True
         )
fused.weigh = fused_w
fused.bias = fused_b
torch.set_grad_enabled(False) 
x = torch.randn(16, 3, 256, 256)
res1 = conv_bn.forward(x)
res2 = fused(x)
torch.allclose(res1, res2)

这里需要对计算过程中的Tensor Shape进行一个梳理:

Input: (16,3,256,256)

output: (16,64,128,128)

这里channel的变化是因为第一个conv算子的输入channel是3,输出channel是64,由于stride是(2,2) 因此HW减半

shape of conv weight:

conv weight (64,3,7,7)

conv bias (64)

shape of BN weight:

bn running_var: (64)

bn running_mean: (64)

bn weight (64)

bn bias: (64)