state is passed to functions. Functions should not have side effects in jax.
vmap
is a vectorizing transform, i.e. it makes it such that your function, written for a single input, is executed in parallel on a batch of inputs in parallel
map
takes batch_size
as an argument if you can’t process the entire batch at once (executes vmap
sequentially)
debug.print
for printing runtime values, normal prints for printing trace-calls, and chex or equinox error thingy for other stuff
There’s also host callbacks https://docs.jax.dev/en/latest/external-callbacks.html
We can’t JIT everything, as JAX passes abstract shape and dtype dependent objects into your JITed functions s.t. we have to rarely recompile. Value dependent program flow errors.
Easy workaround (and most of the time sufficient solution): Jit the expensive parts, e.g. while some_condition: do_expensive_jitted_thing(x)
, instead of trying to cram the loop into jit here too.
Tho JAX has functions that allow for branching / dynamic predicates with JIT jax.lax.{cond,select,switch},jnp.{where,piecewise,select}
or for loops: jax.lax.{scan,fori_loop,while_loop}
.
cond
is for boolean condition
select
is tensorized like where
switch
to choose out of branches (functions) via index
jnp.logical_{and,or}
do not short-circuit necessarily (both branches evaluated, always) but can be used with jit
Construct | jit | grad |
---|---|---|
if | ❌ | ✓ |
for | ✓* | ✓ |
while | ✓* | ✓ |
lax.cond | ✓ | ✓ |
lax.while_loop | ✓ | fwd |
lax.fori_loop | ✓ | fwd |
lax.scan | ✓ | ✓ |
* = argument-value-independent loop condition - unrolls the loop ( |
But there’s no need (or use) in trying to jit every python loop.
You can also JIT for specific input values that change (static_argnames=["myarg"]
, but your function will be recompiled for each new value, so this only makes sense for infrequently changing vals.
Comparison of different loop types:
Primitive | Memory (forward) | Memory (backward) | Differentiable? | Early stop? |
---|---|---|---|---|
fori_loop | O(1) | N/A | Fwd only* | No |
while_loop | O(1) | N/A | Fwd only | Yes |
scan | O(n) | O(n) | Both | No |
Python loop | O(n) traced ops | O(n) traced ops | Both | Yes |
* With static bounds (start, stop known at compile time), fori_loop unrolls and becomes fully differentiable but uses O(n) memory |
# Python loop: Very slow, creates new JAX operations each iteration
def python_loop(x, n=1000):
carry = 0.0
for i in range(n):
carry = carry * x + 1 # Each iteration creates new JAX ops
return carry
# scan: Fast, single XLA compilation, but O(n) memory for gradients
def scan_version(x, n=1000):
def body(carry, _):
return carry * x + 1, None
return jax.lax.scan(body, 0.0, jnp.arange(n))[0]
# fori_loop: Fast, single XLA compilation, O(1) memory, but no gradients
def fori_version(x, n=1000):
def body(i, carry):
return carry * x + 1
return jax.lax.fori_loop(0, n, body, 0.0)
Checkpoint / Rematerialize
By default intermediates are stored for reverse mode autodiff. Every operations output is kept in memory until the bwd pass completes (so for a function y=x+2;z=y+2, both y and z are stored).
If you decorate with@checkpoint
(or@remat
which makes more sense?) jax rematerializes instead of storing intermediates.
So if you run OOM / have large loops/scans checkpoint the bodies of those loops, in cases where JIT doesn’t do it for you:When differentiated functions are staged out to XLA for compilation — for example by applying
jax.jit()
to a function which contains ajax.grad()
call — XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result,jax.checkpoint()
often isn’t needed for differentiated functions under ajax.jit()
. XLA will optimize things for you.One exception is when using staged-out control flow, like
jax.lax.scan()
. Automatic compiler optimizations across multiple control flow primitives (for example, across a forward-passscan
and the corresponding backward-passscan
), typically aren’t as thorough. As a result, it’s often a good idea to usejax.checkpoint()
on the body function passed tojax.lax.scan()
.