pytorch学习系列(5):模型部分参数使用(迁移学习)
有两种方式:1.Net.load_state_dict(torch.load(model_path),strict=False)使用strict参数,如果为True,表明预训练模型的层和自己定义的网络结构层严格对应相等(比如层名和维度),这里选择为False,则不完全对等,会自动舍去多余的层和其参数。2.pretrained_dict=torch.load(model_path)mo...
·
有两种方式:
1.
Net.load_state_dict(torch.load(model_path),strict=False)
使用strict参数,如果为True,表明预训练模型的层和自己定义的网络结构层严格对应相等(比如层名和维度),这里选择为False,则不完全对等,会自动舍去多余的层和其参数。
2.
pretrained_dict=torch.load(model_path)
model_dict=Net.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#不必要的键去除掉
model_dict.update(pretrained_dict)#覆盖现有的字典里的条目
Net.load_state_dict(model_dict)
Net.load_state_dict(torch.load(model_path))
更多推荐




所有评论(0)