chex
Chex: Testing made fun, in JAX!
Description
Chex
<!-- disableFinding(SNIPPET_INVALID_LANGUAGE) -->
Chex is a library of utilities for helping to write reliable JAX code.
This includes utils to help:
- Instrument your code (e.g. assertions, warnings)
- Debug (e.g. transforming
pmapsinvmapswithin a context manager). - Test JAX code across many
variants(e.g. jitted vs non-jitted).
Installation
You can install the latest released version of Chex from PyPI via:
pip install chex
or you can install the latest development version from GitHub:
pip install git+https://github.com/deepmind/chex.git
Modules Overview
Dataclass (dataclass.py)
Dataclasses are a popular construct introduced by Python 3.7 to allow to easily specify typed data structures with minimal boilerplate code. They are not, however, compatible with JAX and dm-tree out of the box.
In Chex we provide a JAX-friendly dataclass implementation reusing python dataclasses.
Chex implementation of dataclass registers dataclasses as internal PyTree
nodes to ensure
compatibility with JAX data structures.
In addition, we provide a class wrapper that exposes dataclasses as
collections.Mapping descendants which allows to process them
(e.g. (un-)flatten) in dm-tree methods as usual Python dictionaries.
See @mappable_dataclass
docstring for more details.
Example:
@chex.dataclass
class Parameters:
x: chex.ArrayDevice
y: chex.ArrayDevice
parameters = Parameters(
x=jnp.ones((2, 2)),
y=jnp.ones((1, 2)),
)
# Dataclasses can be treated as JAX pytrees
jax.tree_util.tree_map(lambda x: 2.0 * x, parameters)
# and as mappings by dm-tree
tree.flatten(parameters)
NOTE: Unlike standard Python 3.7 dataclasses, Chex
dataclasses cannot be constructed using positional arguments. They support
construction arguments provided in the same format as the Python dict
constructor. Dataclasses can be converted to tuples with the from_tuple and
to_tuple methods if necessary.
parameters = Parameters(
jnp.ones((2, 2)),
jnp.ones((1, 2)),
)
# ValueError: Mappable dataclass constructor doesn't support positional args.
Assertions (asserts.py)
One limitation of PyType annotations for JAX is that they do not support the
specification of DeviceArray ranks, shapes or dtypes. Chex includes a number
of functions that allow flexible and concise specification of these properties.
E.g. suppose you want to ensure that all tensors t1, t2, t3 have the same
shape, and that tensors t4, t5 have rank 2 and (3 or 4), respectively.
chex.assert_equal_shape([t1, t2, t3])
chex.assert_rank([t4, t5], [2, {3, 4}])
More examples:
from chex import assert_shape, assert_rank, ...
assert_shape(x, (2, 3)) # x has shape (2, 3)
assert_shape([x, y], [(), (2,3)]) # x is scalar and y has shape (2, 3)
assert_rank(x, 0) # x is scalar
assert_rank([x, y], [0, 2]) # x is scalar and y is a rank-2 array
assert_rank([x, y], {0, 2}) # x and y are scalar OR rank-2 arrays
assert_type(x, int) # x has type `int` (x can be an array)
assert_type([x, y], [int, float]) # x has type `int` and y has type `float`
assert_equal_shape([x, y, z]) # x, y, and z have equal shapes
assert_trees_all_close(tree_x, tree_y) # values and structure of trees match
assert_tree_all_finite(tree_x) # all tree_x leaves are finite
assert_devices_available(2, 'gpu') # 2 GPUs available
assert_tpu_available() # at least 1 TPU available
assert_numerical_grads(f, (x, y), j) # f^{(j)}(x, y) matches numerical grads
See asserts.py
documentation to
find all supported assertions.
If you cannot find a specific assertion, please consider making a pull request or openning an issue on the bug tracker.
Optional Arguments
All chex assertions support the following optional kwargs for manipulating the emitted exception messages:
custom_message: A string to include into the emitted exception messages.include_default_message: Whether to include the default Chex message into the emitted exception messages.exception_type: An exception type to use.AssertionErrorby default.
For example, the following code:
dataset = load_dataset()
params = init_params()
for i in range(num_steps):
params = update_params(params, dataset.sample())
chex.assert_tree_all_finite(params,
custom_message=f'Failed at iteration {i}.',
exception_type=ValueError)
will raise a ValueError that includes a step number when params get polluted
with NaNs or Nones.
Static and Value (aka Runtime) Assertions
Chex divides all assertions into 2 classes: static and value assertions.
-
static assertions use anything except concrete values of tensors. Examples:
assert_shape,assert_trees_all_equal_dtypes,assert_max_traces. -
value assertions require access to tensor values, which are not available during JAX tracing (see HowJAX primitives work), thus such assertion need special treatment in a jitted code.
To enable value assertions in a jitted function, it can be decorated with
chex.chexify() wrapper. Example:
@chex.chexify
@jax.jit
def logp1_abs_safe(x: chex.Array) -> chex.Array:
chex.assert_tree_all_finite(x)
return jnp.log(jnp.abs(x) + 1)
logp1_abs_safe(jnp.ones(2)) # OK
logp1_abs_safe(jnp.array([jnp.nan, 3])) # FAILS (in async mode)
# The error will be raised either at the next line OR at the next
# `logp1_abs_safe` call. See the docs for more detain on async mode.
logp1_abs_safe.wait_checks() # Wait for the (async) computation to complete.
See
this docstring
for more detail on chex.chexify().
JAX Tracing Assertions
JAX re-traces JIT'ted function every time the structure of passed arguments
changes. Often this behavior is inadvertent and leads to a significant
performance drop which is hard to debug. @chex.assert_max_traces
decorator asserts that the function is not re-traced more than n times during
program execution.
Global trace counter can be cleared by calling
chex.clear_trace_counter(). This function be used to isolate unittests relying
on @chex.assert_max_traces.
Examples:
@jax.jit
@chex.assert_max_traces(n=1)
def fn_sum_jitted(x, y):
return x + y
fn_sum_jitted(jnp.zeros(3), jnp.zeros(3)) # tracing for the 1st time - OK
fn_sum_jitted(jnp.zeros([6, 7]), jnp.zeros([6, 7])) # AssertionError!
Can be used with jax.pmap() as well:
def fn_sub(x, y):
return x - y
fn_sub_pmapped = jax.pmap(chex.assert_max_traces(fn_sub, n=10))
See HowJAX primitives work section for more information about tracing.
Warnings (warnigns.py)
In addition to hard assertions Chex also offers utilities to add common warnings, such as specific types of deprecation warnings.
Test variants (variants.py)
JAX relies extensively on code transformation and compilation, meaning that it can be hard to ensure that code is properly tested. For instance, just testing a python function using JAX code will not cover the actual code path that is executed when jitted, and that path will also differ whether the code is jitted for CPU, GPU, or TPU. This has been a source of obscure and hard to catch bugs where XLA changes would lead to undesirable behaviours that however only manifest in one specific code transformation.
Variants make it easy to ensure that unit tests cover different ‘variations’ of a function, by providing a simple decorator that can be used to repeat any test under all (or a subset) of the relevant code transformations.
E.g. suppose you want to test the output of a function fn with or without jit.
You can use chex.variants to run the test with both the jitted and non-jitted
version of the function by simply decorating a test method with
@chex.variants, and then using self.variant(fn) in place of fn in the body
of the test.
def fn(x, y):
return x + y
...
class ExampleTest(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)
def test(self):
var_fn = self.variant(fn)
self.assertEqual(fn(1, 2), 3)
self.assertEqual(var_fn(1, 2), fn(1, 2))
If you define the function in the test method, you may also use self.variant
as a decorator in the function definition. For example:
class ExampleTest(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)
def test(self):
@self.variant
def var_fn(x, y):
return x + y
self.assertEqual(var_fn(1, 2), 3)
Example of parameterized test:
from absl.testing import parameterized
# Could also be:
# `class ExampleParameterizedTest(chex.TestCase, parameterized.Test