graph#
- flax.nnx.dataclass(cls=None, /, *, init=True, eq=True, order=False, unsafe_hash=False, match_args=True, kw_only=False, slots=False)[source]#
Makes an nnx.Object type as a dataclass and defines its pytree node attributes using type hints.
nnx.dataclass
can be used to create pytree dataclass types using type hints instead of the__data__
attribute. By default, all fields are considered to be nodes, to mark a field as static annotate it withnnx.Static[T]
.Example:
from flax import nnx import jax @nnx.dataclass class Foo(nnx.Object): a: int b: jax.Array c: nnx.Static[int] tree = Foo(a=1, b=jax.numpy.array(1), c=1) assert len(jax.tree.leaves(tree)) == 2 # a and b
dataclass
will raise aValueError
if the class does not derive fromnnx.Object
, if the parent Object haspytree
set to anything other than'strict'
, or if the class has a__data__
attribute.nnx.dataclass
doesn’t acceptrepr
and defines it asFalse
to avoid overwriting the default__repr__
method fromObject
.