可以先定义两个函数:

import torch.nn.init as init

def xavier(param):
    init.xavier_uniform(param)
    # init.kaiming_uniform_()    # 可以选择其他的


def weights_init(m):
    if isinstance(m, nn.Conv2d):
        xavier(m.weight.data)
        m.bias.data.zero_()

初始化的时候可以直接调用这两个函数:

net.loc.apply(weights_init)
net.conf.apply(weights_init)

或者:

net.vgg[3:].apply(weights_init)

也可用其他的方法,看自己习惯

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐