代码拉取完成,页面将自动刷新
2232
Lml/vhp
已合并
PR types
New features
PR changes
APIs
Describe
Add autograd functional API: vhp
This function computes the product between a vector v
and the
Hessian matrix of func
with respect to inputs
.
Parameters:
func (function): a Python function that takes a Tensor or a Tensor
list/tuple as inputs and returns a Tensor with a single element.
inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or
Tensor list/tuple of the function ``func``.
v (Tensor|list(Tensor)|tuple(Tensor)|None, optional): the vector used
to compute vector hessian product. ``v`` should have same shape
and dtype with ``inputs``. If ``v`` is None, it will be set as
Tensor|list(Tensor) with all elements 1. Defaults to "None".
create_graph (bool, optional): whether to create the gradient graphs
of the computing process. When it is True, higher order derivatives
are supported to compute; when it is False, the gradient graphs of
the computing process would be discarded. Defaults to ``False``.
allow_unused (bool, optional): whether to raise error or return None if
some Tensors of `inputs` are unreachable in the graph. Error would
be raised if allow_unused=False, and None would be returned as
their gradients if allow_unused=True. Default False.
Returns:
output (tuple): tuple with:
func_output (Tensor): output of ``func(inputs)``
vhp (list(Tensor)): result of the vector hessian product
with the same shape and dtype as the inputs.