Skip to content

Implement native hessian interface and build_primitive_frule #1014

@yebai

Description

@yebai

@Technici4n proposed below in #990 (comment)

an frule for build_derived_rrule is introduced such that forward-mode primitives don't get inlined. We thus need to make sure that all calls to build_derived_rrule get seen by the forward pass. This means that the inner function must contain prepare_pullback_cache, and currently every Hessian-vector product thus builds the derived rule, which is very slow. My idea to fix this is to introduce build_primitive_frule, such that we will be able to instantiate a callable struct once during the outer preparation, which will be able to efficiently cache the inner preparation. This should make taking derivatives of functions which call build_primitive_rule reasonably efficient, and should also generalized to higher orders. In the future we could even provide a dedicated Hessian API for extra speed!

See, also

"""
PrimitiveRRule{Sig, Tmaybeinline}
Callable wrapper used for primitive rrules. Both variants route through `tuple_splat`
helpers to avoid vararg lowering to `_apply_iterate`. When `Tmaybeinline=true`, inlining
is permitted. When `Tmaybeinline=false`, a noinline boundary is forced around the
`rrule!!` call (useful for higher-order AD). Construct via `build_primitive_rrule`.
"""
struct PrimitiveRRule{Sig,Tmaybeinline} end
@inline function _primitive_rule_sig(sig::Type{<:Tuple})
if sig isa DataType && isconcretetype(sig)
return sig
end
nparams = length(Base.unwrap_unionall(sig).parameters)
return Tuple{Vararg{Any,nparams}}
end
@inline function (rule::PrimitiveRRule{Sig,true})(args...) where {Sig}
return tuple_splat(rrule!!, args)
end
@inline function (rule::PrimitiveRRule{Sig,false})(args...) where {Sig}
return tuple_splat_noinline(rrule!!, args)
end
"""
PrimitiveFRule{Sig, Tmaybeinline}
Callable wrapper used for primitive frules. Both variants route through `tuple_splat`
helpers to avoid vararg lowering to `_apply_iterate`. When `Tmaybeinline=true`, inlining
is permitted. When `Tmaybeinline=false`, a noinline boundary is forced around the
`frule!!` call (useful for higher-order AD). Construct via `build_primitive_frule`.
"""
struct PrimitiveFRule{Sig,Tmaybeinline} end
@inline function (rule::PrimitiveFRule{Sig,true})(args...) where {Sig}
return tuple_splat(frule!!, args)
end
@inline function (rule::PrimitiveFRule{Sig,false})(args...) where {Sig}
return tuple_splat_noinline(frule!!, args)
end
"""
build_primitive_rrule(sig::Type{<:Tuple}; maybeinline_primitive=true)
Construct an rrule for signature `sig`. For this function to be called in `build_rrule`, you
must also ensure that a method of `_is_primitive(context_type, ReverseMode, sig)` exists,
preferably by using the [@is_primitive](@ref) macro.
The callable returned by this must obey the rrule interface, but there are no restrictions
on the type of callable itself. For example, you might return a callable `struct`. By
default, this function returns a wrapper that may inline the primitive rule; set
`maybeinline_primitive=false` to force a noinline boundary (useful for higher-order AD).
# Extended Help
The purpose of this function is to permit computation at rule construction time, which can
be re-used at runtime. For example, you might wish to derive some information from `sig`
which you use at runtime (e.g. the fdata type of one of the arguments). While constant
propagation will often optimise this kind of computation away, it will sometimes fail to do
so in hard-to-predict circumstances. Consequently, if you need certain computations not to
happen at runtime in order to guarantee good performance, you might wish to e.g. emit a
callable `struct` with type parameters which are the result of this computation. In this
context, the motivation for using this function is the same as that of using staged
programming (e.g. via `@generated` functions) more generally.
"""
function build_primitive_rrule(sig::Type{<:Tuple}; maybeinline_primitive::Bool=true)
PrimitiveRRule{_primitive_rule_sig(sig),maybeinline_primitive}()
end
function build_primitive_rrule(sig::Type{<:Tuple}, maybeinline_primitive::Bool)
PrimitiveRRule{_primitive_rule_sig(sig),maybeinline_primitive}()
end
"""
build_primitive_frule(sig::Type{<:Tuple}; maybeinline_primitive=true)
Construct an frule for signature `sig`. For this function to be called in `build_frule`, you
must also ensure that a method of `_is_primitive(context_type, ForwardMode, sig)` exists,
preferably by using the [@is_primitive](@ref) macro.
By default, this function returns a wrapper that may inline the primitive rule; set
`maybeinline_primitive=false` to force a noinline boundary (useful for higher-order AD,
e.g., forward-over-reverse for Hessians).
See [`build_primitive_rrule`](@ref) for extended discussion of staged rule construction.
"""
function build_primitive_frule(sig::Type{<:Tuple}; maybeinline_primitive::Bool=true)
PrimitiveFRule{_primitive_rule_sig(sig),maybeinline_primitive}()
end
function build_primitive_frule(sig::Type{<:Tuple}, maybeinline_primitive::Bool)
PrimitiveFRule{_primitive_rule_sig(sig),maybeinline_primitive}()
end

from #962 (comment)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions