Technical background
In Jax's JIT on-the-fly compilation, the Shape change of each Tensor is tracked. If there are some Tensor with dynamic Shape (Shape size is related to the input data) during the computation, then it is not possible to use Jax's JIT for compilation optimization. The most common ones areThis operation, because this operation returns the Index number that meets the judgment condition, and the length of the output Index corresponding to different inputs is generally inconsistent, so it is not possible to compile this operation in Jax's JIT. Of course, it is necessary to specify that
There are two uses of this operation, one is
(condition, 1, 0)
Masks the input directly. another use is to(condition)
This usage returns an Index sequence, which is the application scenario we need to discuss.
normal mode
We consider a simpler Toy Model for testing purposes:
In scenarios where on-the-fly compilation is not used, Jax's code can be written like this:
import os
['XLA_PYTHON_CLIENT_PREALLOCATE']='false'
import numpy as np
(0)
import jax
from jax import numpy as jnp
def func(r, q, cutoff=0.2):
dis = (r[:, None] - r[None])
maski, maskj = (dis<=cutoff)
qi = q[maski]
qj = q[maskj]
return (qi*qj)
N = 100
r = ((N), jnp.float32)
q = ((N), jnp.float32)
print (func(r, q))
# 1035.7422
So let's memorize this output first, because the random seeds used are consistent, and we can compare it directly with the JIT output in a moment.
JIT mode
There are three common ways to use Jax's JIT mode, one is to add a function header with adecorator, one is to use the function reference in the(function)
to invoke it, and the last one is in conjunction with thepartial
(math.) partial functionto use, are not very complicated. So here's a demonstration of on-the-fly compilation usage in Jax in the form of a decorator first:
import os
['XLA_PYTHON_CLIENT_PREALLOCATE']='false'
import numpy as np
(0)
import jax
from jax import numpy as jnp
@
def func(r, q, cutoff=0.2):
dis = (r[:, None] - r[None])
maski, maskj = (dis<=cutoff)
qi = q[maski]
qj = q[maskj]
return (qi*qj)
N = 100
r = ((N), jnp.float32)
q = ((N), jnp.float32)
print (func(r, q))
As stated earlier, becauseThe corresponding output is a dynamic Shape, then an error is reported during the compilation phase. The error message is as follows:
Traceback (most recent call last):
File "/home/dechin/projects/gitee/dechin/tests/jax_mask.py", line 21, in <module>
print (func(r, q))
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/", line 622, in cache_miss
execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/", line 236, in _xla_call_impl_lazy
return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/linear_util.py", line 303, in memoized_fun
ans = call(fun, *args)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/", line 359, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/", line 314, in wrapper
return func(*args, **kwargs)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/", line 445, in lower_xla_callable
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/", line 314, in wrapper
return func(*args, **kwargs)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2077, in trace_to_jaxpr_final2
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers_)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = (*args, **dict(, **kwargs))
File "/home/dechin/projects/gitee/dechin/tests/jax_mask.py", line 12, in func
maski, maskj = (dis<=cutoff)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1077, in where
return nonzero(condition, size=size, fill_value=fill_value)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1332, in nonzero
size = core.concrete_or_error(, size,
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/", line 1278, in concrete_or_error
raise ConcretizationTypeError(val, context)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
The size argument of must be statically specified to use within JAX transformations.
The error occurred while tracing the function func at /home/dechin/projects/gitee/dechin/tests/jax_mask.py:9 for jit. This concrete value was not available in Python because it depends on the value of the argument 'r'.
See /en/latest/#
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/dechin/projects/gitee/dechin/tests/jax_mask.py", line 21, in <module>
print (func(r, q))
File "/home/dechin/projects/gitee/dechin/tests/jax_mask.py", line 12, in func
maski, maskj = (dis<=cutoff)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1077, in where
return nonzero(condition, size=size, fill_value=fill_value)
File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1332, in nonzero
size = core.concrete_or_error(, size,
jax._src.: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
The size argument of must be statically specified to use within JAX transformations.
The error occurred while tracing the function func at /home/dechin/projects/gitee/dechin/tests/jax_mask.py:9 for jit. This concrete value was not available in Python because it depends on the value of the argument 'r'.
See /en/latest/#
To avoid this error, you either have to leave the function uncompiled (sacrificing performance), or write your own CUDA arithmetic (increasing the workload), or use the NonZero fixed-length output method we used here (preconditioning).
Use of NonZero
There is also one thing to keep in mind when using Jax's NonZero function: although NonZero can do fixed-length output, this fixed length is itself a function namedsize
This means that the output Shape of NonZero is also dependent on the input parameters. That is to say, the output Shape of NonZero is also dependent on the input parameters. jax was developed with this in mind, so it provides a function to set static parameters at compile time:static_argnames
, e.g., in our case, thesize
The passing parameter for this name is set to a static parameter so that Jax's on-the-fly compilation can be used:
import os
['XLA_PYTHON_CLIENT_PREALLOCATE']='false'
import numpy as np
(0)
import jax
from jax import numpy as jnp
from functools import partial
@partial(, static_argnames='size')
def func(r, q, cutoff=0.2, size=5000):
if [0] != [0]+1:
raise ValueError("The [0] should be equal to [0]+1")
dis = (r[:, None] - r[None])
maski, maskj = ((dis<=cutoff, 1, 0), size=size, fill_value=-1)
qi = q[maski]
qj = q[maskj]
return (qi*qj)
N = 100
r = ((N), jnp.float32)
q = ((N), jnp.float32)
pader = ([0.], jnp.float32)
q = (q, pader)
print (func(r, q))
# 1035.7422
As you can see, the function is successfully compiled with Jax's JIT, and the output is consistent with the previous uncompiled results. Of course, there is also a small trick used here, that is, the NonZero function output results, less than the length of the output result will be automatically Pad to a given length, where the value of the Pad is given using the value of thefill_value
. Since NonZero outputs indexes as well, this allows us to set these indexes of Pad's-1
, and then in the build parameter\(q\)at the end of theappend
A 0, which ensures that the output of the calculation is directly correct.
Summary outline
In the process of using Jax, sometimes we encounter functions whose output is a dynamic Shape, in which case it is difficult to take advantage of Jax's on-the-fly compilation feature, which does not maximize performance. This is also a feature of using Tensor data structure to calculate, there are good and bad. This article introduces another Jax function NonZero, which allows us to compile functions with dynamic Shape output.
copyright statement
This article was first linked to:/dechinphy/p/
Author ID: DechinPhy
More original articles:/dechinphy/
Buy the blogger coffee:/dechinphy/gallery/image/