graph

Contents

graph#

flax.nnx.dataclass(cls=None, /, *, init=True, eq=True, order=False, unsafe_hash=False, match_args=True, kw_only=False, slots=False)[source]#

Makes an nnx.Object type as a dataclass and defines its pytree node attributes using type hints.

nnx.dataclass can be used to create pytree dataclass types using type hints instead of the __data__ attribute. By default, all fields are considered to be nodes, to mark a field as static annotate it with nnx.Static[T].

Example:

from flax import nnx
import jax

@nnx.dataclass
class Foo(nnx.Object):
  a: int
  b: jax.Array
  c: nnx.Static[int]
    tree = Foo(a=1, b=jax.numpy.array(1), c=1)

assert len(jax.tree.leaves(tree)) == 2 # a and b

dataclass will raise a ValueError if the class does not derive from nnx.Object, if the parent Object has pytree set to anything other than 'strict', or if the class has a __data__ attribute.

nnx.dataclass doesn’t accept repr and defines it as False to avoid overwriting the default __repr__ method from Object.

flax.nnx.field(*, static=False, **kwargs)[source]#