state is passed to functions. Functions should not have side effects in jax.
vmap is a vectorizing transform, i.e. it makes it such that your function, written for a single input, is executed in parallel on a batch of inputs in parallel
map takes batch_size as an argument if you can’t process the entire batch at once (executes vmap sequentially)
debug.print for printing runtime values, normal prints for printing trace-calls, and chex or equinox error thingy for other stuff
There’s also host callbacks https://docs.jax.dev/en/latest/external-callbacks.html
The key is effectively a stand-in for NumPy’s hidden state object, but we pass it explicitly to jax.random() functions. Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated.
print(random.normal(key))
print(random.normal(key))
-0.028304616
-0.028304616Note: if you print the repr it shows a rounded precision…
for i in range(3):
key, subkey = random.split(key) # The old key is consumed by split() -- we must never use it again.
val = random.normal(subkey) # The subkey is consumed by normal(). -- we must never use it again.
print(f"draw {i}: {val}")
jax.random.split()is a deterministic function that converts onekeyinto several independent (in the pseudorandomness sense) keys. We keep one of the outputs as thenew_key, and can safely use the unique extra key (calledsubkey) as input into a random function, and then discard it forever. If you wanted to get another sample from the normal distribution, you would splitkeyagain, and so on: the crucial point is that you never use the same key twice.
Calling the outputs key, subkey is a matter of convention (where subkey is used for immediate consumption, key for future splitting). They are just two new random keys:
>>> random.split(key)
Array((2,), dtype=key<fry>) overlaying:
[[1832780943 270669613]
[ 64467757 2916123636]]
>>> random.split(key)
Array((2,), dtype=key<fry>) overlaying:
[[1832780943 270669613]
[ 64467757 2916123636]]NOTE: No sequential equivalence: Using n subkeys to generate n samples sequentially produces a different result than sampling normal with a single key and shape (n,).
JIT related stuff:
We can’t JIT everything, as JAX passes abstract shape and dtype dependent objects into your JITed functions s.t. we have to rarely recompile. Value dependent program flow errors.
Easy workaround (and most of the time sufficient solution): Jit the expensive parts, e.g. while some_condition: do_expensive_jitted_thing(x);and_then_some_logging(), instead of trying to cram the loop into jit here too.
Tho JAX has functions that allow for branching / dynamic predicates with JIT jax.lax.{cond,select,switch},jnp.{where,piecewise,select} or for loops: jax.lax.{scan,fori_loop,while_loop}.
cond is for boolean condition
select is tensorized like where
switch to choose out of branches (functions) via index
jnp.logical_{and,or} do not short-circuit necessarily (both branches evaluated, always) but can be used with jit
| Construct | jit | grad |
|---|---|---|
| if | ❌ | ✓ |
| for | ✓* | ✓ |
| while | ✓* | ✓ |
| lax.cond | ✓ | ✓ |
| lax.while_loop | ✓ | fwd |
| lax.fori_loop | ✓ | fwd |
| lax.scan | ✓ | ✓ |
* = argument-value-independent loop condition - unrolls the loop ( |
But there’s no need (or use) in trying to jit every python loop.
You can also JIT for specific input values that change (static_argnames=["myarg"], but your function will be recompiled for each new value, so this only makes sense for infrequently changing vals.
Comparison of different loop types:
| Primitive | Memory (forward) | Memory (backward) | Differentiable? | Early stop? |
|---|---|---|---|---|
| fori_loop | O(1) | N/A | Fwd only* | No |
| while_loop | O(1) | N/A | Fwd only | Yes |
| scan | O(n) | O(n) | Both | No |
| Python loop | O(n) traced ops | O(n) traced ops | Both | Yes |
* With static bounds (start, stop known at compile time), fori_loop unrolls and becomes fully differentiable but uses O(n) memory |
# Python loop: Very slow, creates new JAX operations each iteration
def python_loop(x, n=1000):
carry = 0.0
for i in range(n):
carry = carry * x + 1 # Each iteration creates new JAX ops
return carry
# scan: Fast, single XLA compilation, but O(n) memory for gradients
def scan_version(x, n=1000):
def body(carry, _):
return carry * x + 1, None
return jax.lax.scan(body, 0.0, jnp.arange(n))[0]
# fori_loop: Fast, single XLA compilation, O(1) memory, but no gradients
def fori_version(x, n=1000):
def body(i, carry):
return carry * x + 1
return jax.lax.fori_loop(0, n, body, 0.0)Checkpoint / Rematerialize
By default intermediates are stored for reverse mode autodiff. Every operations output is kept in memory until the bwd pass completes (so for a function y=x+2;z=y+2, both y and z are stored).
If you decorate with@checkpoint(or@rematwhich makes more sense?) jax rematerializes instead of storing intermediates.
So if you run OOM / have large loops/scans checkpoint the bodies of those loops, in cases where JIT doesn’t do it for you:When differentiated functions are staged out to XLA for compilation — for example by applying
jax.jit()to a function which contains ajax.grad()call — XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result,jax.checkpoint()often isn’t needed for differentiated functions under ajax.jit(). XLA will optimize things for you.One exception is when using staged-out control flow, like
jax.lax.scan(). Automatic compiler optimizations across multiple control flow primitives (for example, across a forward-passscanand the corresponding backward-passscan), typically aren’t as thorough. As a result, it’s often a good idea to usejax.checkpoint()on the body function passed tojax.lax.scan().
Resources
some faqs: https://chatgpt.com/c/68c46b6c-29c0-832f-afdf-4de0a044ffc7
https://github.com/luchris429/purejaxrl/tree/main (uses flax, very nice reference)
https://github.com/ponseko/jymkit/tree/main (uses equinox, considerably less stars, frameworky, but looks clean)
Transclude of visualization#^db74f2