-
Notifications
You must be signed in to change notification settings - Fork 31
Description
@Technici4n proposed below in #990 (comment)
an frule for build_derived_rrule is introduced such that forward-mode primitives don't get inlined. We thus need to make sure that all calls to build_derived_rrule get seen by the forward pass. This means that the inner function must contain prepare_pullback_cache, and currently every Hessian-vector product thus builds the derived rule, which is very slow. My idea to fix this is to introduce build_primitive_frule, such that we will be able to instantiate a callable struct once during the outer preparation, which will be able to efficiently cache the inner preparation. This should make taking derivatives of functions which call build_primitive_rule reasonably efficient, and should also generalized to higher orders. In the future we could even provide a dedicated Hessian API for extra speed!
See, also
Lines 86 to 180 in d040047
| """ | |
| PrimitiveRRule{Sig, Tmaybeinline} | |
| Callable wrapper used for primitive rrules. Both variants route through `tuple_splat` | |
| helpers to avoid vararg lowering to `_apply_iterate`. When `Tmaybeinline=true`, inlining | |
| is permitted. When `Tmaybeinline=false`, a noinline boundary is forced around the | |
| `rrule!!` call (useful for higher-order AD). Construct via `build_primitive_rrule`. | |
| """ | |
| struct PrimitiveRRule{Sig,Tmaybeinline} end | |
| @inline function _primitive_rule_sig(sig::Type{<:Tuple}) | |
| if sig isa DataType && isconcretetype(sig) | |
| return sig | |
| end | |
| nparams = length(Base.unwrap_unionall(sig).parameters) | |
| return Tuple{Vararg{Any,nparams}} | |
| end | |
| @inline function (rule::PrimitiveRRule{Sig,true})(args...) where {Sig} | |
| return tuple_splat(rrule!!, args) | |
| end | |
| @inline function (rule::PrimitiveRRule{Sig,false})(args...) where {Sig} | |
| return tuple_splat_noinline(rrule!!, args) | |
| end | |
| """ | |
| PrimitiveFRule{Sig, Tmaybeinline} | |
| Callable wrapper used for primitive frules. Both variants route through `tuple_splat` | |
| helpers to avoid vararg lowering to `_apply_iterate`. When `Tmaybeinline=true`, inlining | |
| is permitted. When `Tmaybeinline=false`, a noinline boundary is forced around the | |
| `frule!!` call (useful for higher-order AD). Construct via `build_primitive_frule`. | |
| """ | |
| struct PrimitiveFRule{Sig,Tmaybeinline} end | |
| @inline function (rule::PrimitiveFRule{Sig,true})(args...) where {Sig} | |
| return tuple_splat(frule!!, args) | |
| end | |
| @inline function (rule::PrimitiveFRule{Sig,false})(args...) where {Sig} | |
| return tuple_splat_noinline(frule!!, args) | |
| end | |
| """ | |
| build_primitive_rrule(sig::Type{<:Tuple}; maybeinline_primitive=true) | |
| Construct an rrule for signature `sig`. For this function to be called in `build_rrule`, you | |
| must also ensure that a method of `_is_primitive(context_type, ReverseMode, sig)` exists, | |
| preferably by using the [@is_primitive](@ref) macro. | |
| The callable returned by this must obey the rrule interface, but there are no restrictions | |
| on the type of callable itself. For example, you might return a callable `struct`. By | |
| default, this function returns a wrapper that may inline the primitive rule; set | |
| `maybeinline_primitive=false` to force a noinline boundary (useful for higher-order AD). | |
| # Extended Help | |
| The purpose of this function is to permit computation at rule construction time, which can | |
| be re-used at runtime. For example, you might wish to derive some information from `sig` | |
| which you use at runtime (e.g. the fdata type of one of the arguments). While constant | |
| propagation will often optimise this kind of computation away, it will sometimes fail to do | |
| so in hard-to-predict circumstances. Consequently, if you need certain computations not to | |
| happen at runtime in order to guarantee good performance, you might wish to e.g. emit a | |
| callable `struct` with type parameters which are the result of this computation. In this | |
| context, the motivation for using this function is the same as that of using staged | |
| programming (e.g. via `@generated` functions) more generally. | |
| """ | |
| function build_primitive_rrule(sig::Type{<:Tuple}; maybeinline_primitive::Bool=true) | |
| PrimitiveRRule{_primitive_rule_sig(sig),maybeinline_primitive}() | |
| end | |
| function build_primitive_rrule(sig::Type{<:Tuple}, maybeinline_primitive::Bool) | |
| PrimitiveRRule{_primitive_rule_sig(sig),maybeinline_primitive}() | |
| end | |
| """ | |
| build_primitive_frule(sig::Type{<:Tuple}; maybeinline_primitive=true) | |
| Construct an frule for signature `sig`. For this function to be called in `build_frule`, you | |
| must also ensure that a method of `_is_primitive(context_type, ForwardMode, sig)` exists, | |
| preferably by using the [@is_primitive](@ref) macro. | |
| By default, this function returns a wrapper that may inline the primitive rule; set | |
| `maybeinline_primitive=false` to force a noinline boundary (useful for higher-order AD, | |
| e.g., forward-over-reverse for Hessians). | |
| See [`build_primitive_rrule`](@ref) for extended discussion of staged rule construction. | |
| """ | |
| function build_primitive_frule(sig::Type{<:Tuple}; maybeinline_primitive::Bool=true) | |
| PrimitiveFRule{_primitive_rule_sig(sig),maybeinline_primitive}() | |
| end | |
| function build_primitive_frule(sig::Type{<:Tuple}, maybeinline_primitive::Bool) | |
| PrimitiveFRule{_primitive_rule_sig(sig),maybeinline_primitive}() | |
| end |
from #962 (comment)