有两种方式:
1.

Net.load_state_dict(torch.load(model_path),strict=False)

使用strict参数,如果为True,表明预训练模型的层和自己定义的网络结构层严格对应相等(比如层名和维度),这里选择为False,则不完全对等,会自动舍去多余的层和其参数。
2.

pretrained_dict=torch.load(model_path)
model_dict=Net.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#不必要的键去除掉
model_dict.update(pretrained_dict)#覆盖现有的字典里的条目
Net.load_state_dict(model_dict)
Net.load_state_dict(torch.load(model_path))
Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐