Como faço para usar autograd para uma função separada e independente de backpropagate em PyTorch?

0

Pergunta

Eu tenho duas variáveis, x e theta. Eu estou tentando minimizar a perda com relação a theta só, mas como parte da minha perda de função que eu preciso, a derivada de uma função diferente (fcom respeito à x. Este derivado em si não é relevante para a minimização de resíduos, apenas a sua saída. No entanto, aquando da aplicação do presente em PyTorch estou recebendo um erro de tempo de execução.

Um exemplo mínimo é como segue:

# minimal example of two different autograds
import torch

from torch.autograd.functional import jacobian
def f(theta, x):
    return torch.sum(theta * x ** 2)

def df(theta, x):
    J = jacobian(lambda x: f(theta, x), x)
    return J

# example evaluations of the autograd gradient
x = torch.tensor([1., 2.])
theta = torch.tensor([1., 1.], requires_grad = True)

# derivative should be 2*theta*x (same as an analytical)
with torch.no_grad():
    print(df(theta, x))
    print(2*theta*x)

tensor([2., 4.])

tensor([2., 4.])

# define some arbitrary loss as a fn of theta
loss = torch.sum(df(theta, x)**2)
loss.backward()

dá o seguinte erro

RuntimeError: elemento 0 do tensores não requer pós-graduação e não ter um grad_fn

Se eu fornecer uma analítico derivado (2*theta*x), ele funciona muito bem:

loss = torch.sum((2*theta*x)**2)
loss.backward()

Existe uma maneira de fazer isso no PyTorch? Ou eu estou limitado de alguma forma?

Deixe-me saber se alguém precisa de mais detalhes.

PS

Eu estou imaginando que a solução é algo semelhante à forma como o JAX não autograd, que é o que eu estou mais familiarizado. O que eu quero dizer é que aqui em JAX eu acredito que você só iria fazer:

from jax import grad
df = grad(lambda x: f(theta, x))

e, em seguida, df seria apenas uma função que pode ser chamado em qualquer ponto. Mas é PyTorch o mesmo? Ou existe algum conflito dentro .backward() que faz com que este erro?

autograd gradient python pytorch
2021-11-12 11:56:03
1

Melhor resposta

0

PyTorch do jacobian não criar um cálculo gráfico, a menos que você explicitamente pedir para ele

J = jacobian(lambda x: f(theta, x), x, create_graph=True)

.. com create_graph argumento.

A documentação é muito clara sobre isso

create_graph (bool, opcional) – Se for Verdade, a Razão será computada em uma diferenciáveis forma

2021-11-12 14:55:57

Sim, eu tenho olhado para isso, mas eu realmente não entendo. Eu acho que isso significa que eu não queira usar o create_graph argumento, porque eu não quero ser incluída na minha .backward() chamada. Neste caso, por que não dá-me um erro? Eu não entendo a mensagem de erro.
Danny Williams

O que você está tentando fazer é, basicamente, a diferenciação através de operações de uma razão. Nesse caso, você precisa tê-los criado como um gráfico. Sem create_graph, razão não criar um gráfico de suspensão de loss (você pode verificar loss.grad_fn é vazio, daí o erro)
ayandas

Em outros idiomas

Esta página está em outros idiomas

Русский
..................................................................................................................
Italiano
..................................................................................................................
Polski
..................................................................................................................
Română
..................................................................................................................
한국어
..................................................................................................................
हिन्दी
..................................................................................................................
Français
..................................................................................................................
Türk
..................................................................................................................
Česk
..................................................................................................................
ไทย
..................................................................................................................
中文
..................................................................................................................
Español
..................................................................................................................
Slovenský
..................................................................................................................