Eu tenho duas variáveis, x
e theta
. Eu estou tentando minimizar a perda com relação a theta
só, mas como parte da minha perda de função que eu preciso, a derivada de uma função diferente (f
com respeito à x
. Este derivado em si não é relevante para a minimização de resíduos, apenas a sua saída. No entanto, aquando da aplicação do presente em PyTorch estou recebendo um erro de tempo de execução.
Um exemplo mínimo é como segue:
# minimal example of two different autograds
import torch
from torch.autograd.functional import jacobian
def f(theta, x):
return torch.sum(theta * x ** 2)
def df(theta, x):
J = jacobian(lambda x: f(theta, x), x)
return J
# example evaluations of the autograd gradient
x = torch.tensor([1., 2.])
theta = torch.tensor([1., 1.], requires_grad = True)
# derivative should be 2*theta*x (same as an analytical)
with torch.no_grad():
print(df(theta, x))
print(2*theta*x)
tensor([2., 4.])
tensor([2., 4.])
# define some arbitrary loss as a fn of theta
loss = torch.sum(df(theta, x)**2)
loss.backward()
dá o seguinte erro
RuntimeError: elemento 0 do tensores não requer pós-graduação e não ter um grad_fn
Se eu fornecer uma analítico derivado (2*theta*x
), ele funciona muito bem:
loss = torch.sum((2*theta*x)**2)
loss.backward()
Existe uma maneira de fazer isso no PyTorch? Ou eu estou limitado de alguma forma?
Deixe-me saber se alguém precisa de mais detalhes.
PS
Eu estou imaginando que a solução é algo semelhante à forma como o JAX não autograd, que é o que eu estou mais familiarizado. O que eu quero dizer é que aqui em JAX eu acredito que você só iria fazer:
from jax import grad
df = grad(lambda x: f(theta, x))
e, em seguida, df
seria apenas uma função que pode ser chamado em qualquer ponto. Mas é PyTorch o mesmo? Ou existe algum conflito dentro .backward()
que faz com que este erro?
create_graph
argumento, porque eu não quero ser incluída na minha.backward()
chamada. Neste caso, por que não dá-me um erro? Eu não entendo a mensagem de erro.