官方文档: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html

class torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True,
	  track_running_stats=True)

输入: (N,C)(N,C,L)
输出: (N,C)(N,C,L) (与输入相同的形状)

参数:

  • num_features: 输入 (N,C,L) 中的特征数C 或 输入 (N,L) 中的L
  • eps: 用于稳定分母的极小值ε, 默认为1E-5
  • momentum: 动量, 用于累计移动平均值 (下面细说) , 默认为0.1
  • affine: 仿射, 标识该模块是否具有可学习的仿射参数, 默认为True
  • track_running_stats:
    True时此模块跟踪运行的均值和方差.
    False时此模块不跟踪此类统计信息, 而是初始化统计信息缓冲区.
    Nonerunning_meanrunning_var皆为None, 使得这个模块无论是训练模式还是评估模式总是使用批处理统计.
    默认为 True

Pytorch 中 BatchNorm 中参数更新规则为:

mean = momentum * new + (1 - momentum) * mean

值得注意的是, 在 Tensorflow 中, batch_normalization 的更新规则为:

mean = momentum * mean + (1 - momentum) * new

或用 decay (衰减率) 表示亦是如此:

mean = decay * mean + (1 - decay) * new

这意味着 PyTorch 中的 momentum 等于 Tensorflow 中 (1 - momentum)

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐