state is passed to functions. Functions should not have side effects in jax.
vmap is a vectorizing transform, i.e. it makes it such that your function, written for a single input, is executed in parallel on a batch of inputs in parallel
map takes batch_size as an argument if you can’t process the entire batch at once (executes vmap sequentially)
debug.print for printing runtime values, normal prints for printing trace-calls, and chex or equinox error thingy for other stuff
There’s also host callbacks https://docs.jax.dev/en/latest/external-callbacks.html

We can’t JIT everything, as JAX passes abstract shape and dtype dependent objects into your JITed functions s.t. we have to rarely recompile. Value dependent program flow errors.
Easy workaround (and most of the time sufficient solution): Jit the expensive parts, e.g. while some_condition: do_expensive_jitted_thing(x), instead of trying to cram the loop into jit here too.
Tho JAX has functions that allow for branching / dynamic predicates with JIT jax.lax.{cond,select,switch},jnp.{where,piecewise,select} or for loops: jax.lax.{scan,fori_loop,while_loop}.
cond is for boolean condition
select is tensorized like where
switch to choose out of branches (functions) via index
jnp.logical_{and,or} do not short-circuit necessarily (both branches evaluated, always) but can be used with jit

Constructjitgrad
if
for✓*
while✓*
lax.cond
lax.while_loopfwd
lax.fori_loopfwd
lax.scan
* = argument-value-independent loop condition - unrolls the loop (

But there’s no need (or use) in trying to jit every python loop.
You can also JIT for specific input values that change (static_argnames=["myarg"], but your function will be recompiled for each new value, so this only makes sense for infrequently changing vals.

Comparison of different loop types:

PrimitiveMemory (forward)Memory (backward)Differentiable?Early stop?
fori_loopO(1)N/AFwd only*No
while_loopO(1)N/AFwd onlyYes
scanO(n)O(n)BothNo
Python loopO(n) traced opsO(n) traced opsBothYes
* With static bounds (start, stop known at compile time), fori_loop unrolls and becomes fully differentiable but uses O(n) memory
# Python loop: Very slow, creates new JAX operations each iteration
def python_loop(x, n=1000):
  carry = 0.0
  for i in range(n):
	  carry = carry * x + 1  # Each iteration creates new JAX ops
  return carry
# scan: Fast, single XLA compilation, but O(n) memory for gradients
def scan_version(x, n=1000):
  def body(carry, _):
	  return carry * x + 1, None
  return jax.lax.scan(body, 0.0, jnp.arange(n))[0]
# fori_loop: Fast, single XLA compilation, O(1) memory, but no gradients
def fori_version(x, n=1000):
  def body(i, carry):
	  return carry * x + 1
  return jax.lax.fori_loop(0, n, body, 0.0)

Checkpoint / Rematerialize

By default intermediates are stored for reverse mode autodiff. Every operations output is kept in memory until the bwd pass completes (so for a function y=x+2;z=y+2, both y and z are stored).
If you decorate with @checkpoint (or @remat which makes more sense?) jax rematerializes instead of storing intermediates.
So if you run OOM / have large loops/scans checkpoint the bodies of those loops, in cases where JIT doesn’t do it for you:

When differentiated functions are staged out to XLA for compilation — for example by applying jax.jit() to a function which contains a jax.grad() call — XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, jax.checkpoint() often isn’t needed for differentiated functions under a jax.jit(). XLA will optimize things for you.

One exception is when using staged-out control flow, like jax.lax.scan(). Automatic compiler optimizations across multiple control flow primitives (for example, across a forward-pass scan and the corresponding backward-pass scan), typically aren’t as thorough. As a result, it’s often a good idea to use jax.checkpoint() on the body function passed to jax.lax.scan().