pytorch中网络参数初始化
可以先定义两个函数:import torch.nn.init as initdef xavier(param):init.xavier_uniform(param)# init.kaiming_uniform_()# 可以选择其他的def weights_init(m):if isinstance(m, nn.Conv2d):xavier(m.weight.data)m.bias.data.zer
·
可以先定义两个函数:
import torch.nn.init as init
def xavier(param):
init.xavier_uniform(param)
# init.kaiming_uniform_() # 可以选择其他的
def weights_init(m):
if isinstance(m, nn.Conv2d):
xavier(m.weight.data)
m.bias.data.zero_()
初始化的时候可以直接调用这两个函数:
net.loc.apply(weights_init)
net.conf.apply(weights_init)
或者:
net.vgg[3:].apply(weights_init)
也可用其他的方法,看自己习惯
更多推荐


所有评论(0)