如果因变量是标量的元组,自变量是tensor矩阵,如批量十次推理,得到某个词的logit是10个数,10次推理中的某个hidden_states构成的矩阵为10行d列,且每个logit只与该次推理过程的hidden_states有关,即如图所示关系:

import torch
from torch import autograd
import random
random.seed(3)
x = torch.rand(3, 4)
x.requires_grad_()
print(x)
y1 = sum(x[0, :])
y2 = sum(x[1, :] * x[1, :])
grad = torch.autograd.grad((y1, y2), x)
print(grad)

另:如果因变量是标量的元组,自变量是tensor矩阵([1, d]),则结果是每个标量对自变量求导的和。

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐