Source code for flax.nnx.graph

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import contextlib
import dataclasses
import functools
import threading
import typing as tp

from flax import config
from flax.nnx import filterlib, reprlib, traversals, variablelib
from flax.nnx import statelib
from flax.nnx.proxy_caller import (
  ApplyCaller,
  CallableProxy,
  DelayedAccessor,
)
from flax.nnx.statelib import EmptyState, FlatState, State
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import Key, PathParts, is_key_like
import jax
import numpy as np
import treescope  # type: ignore[import-not-found,import-untyped]
import typing_extensions as tpe

A = tp.TypeVar('A')
B = tp.TypeVar('B')
C = tp.TypeVar('C')
F = tp.TypeVar('F', bound=tp.Callable)

HA = tp.TypeVar('HA', bound=tp.Hashable)
HB = tp.TypeVar('HB', bound=tp.Hashable)
KeyT = tp.TypeVar('KeyT', bound=Key)

Index = int
Names = tp.Sequence[int]
Node = tp.TypeVar('Node')
Leaf = tp.TypeVar('Leaf')
AuxData = tp.TypeVar('AuxData')


@jax.tree_util.register_static
@dataclasses.dataclass(frozen=True, slots=True)
class NoUpdate: ...


NO_UPDATE = NoUpdate()


@jax.tree_util.register_static
@dataclasses.dataclass(frozen=True, slots=True)
class Repeated: ...


REPEATED = Repeated()


@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True, slots=True, repr=False)
class MutableArrayOutput(reprlib.Representable):
  value: jax.Array | NoUpdate | Repeated

  def __nnx_repr__(self):
    yield reprlib.Object(type=type(self))
    yield reprlib.Attr('value', self.value)

  def __treescope_repr__(self, path, subtree_renderer):
    return treescope.repr_lib.render_object_constructor(
      object_type=type(self),
      attributes={
        'value': self.value,
      },
      path=path,
      subtree_renderer=subtree_renderer,
    )


LeafType = tp.Union[
  Variable,
  VariableState,
  jax.Array,
  np.ndarray,
  variablelib.MutableArray,
  MutableArrayOutput,
  NoUpdate,
]
GraphState = State[Key, LeafType]
GraphFlatState = FlatState[LeafType]


def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[LeafType]:
  return isinstance(x, LeafType) or variablelib.is_mutable_array(x) # type: ignore[misc, arg-type]


class IndexMap(dict[Index, tp.Any]):
  @staticmethod
  def from_refmap(refmap: RefMap) -> IndexMap:
    return IndexMap((index, value) for value, index in refmap.items())


if config.flax_use_flaxlib:
  import flaxlib  # type: ignore[import]

  globals()['IndexMap'] = flaxlib.IndexMap


# RefMap = dict
class RefMap(tp.MutableMapping[tp.Any, int], reprlib.MappingReprMixin):
  """A mapping that hashes keys by their identity."""

  def __init__(
    self,
    mapping: tp.Mapping[tp.Any, int]
    | tp.Iterable[tuple[tp.Any, int]]
    | None = None,
    /,
  ):
    self._mapping: dict[int, tuple[tp.Any, int]] = dict()
    if mapping is not None:
      self.update(mapping)

  @staticmethod
  def from_indexmap(indexmap: IndexMap) -> RefMap:
    refmap = RefMap()
    refmap.update((value, index) for index, value in indexmap.items())
    return refmap

  def get(self, key: tp.Any, default: int | None = None) -> int | None:  # type: ignore[override]
    return self._mapping.get(id(key), (None, default))[1]

  def __getitem__(self, key: tp.Any) -> int:
    return self._mapping[id(key)][1]

  def __setitem__(self, key: tp.Any, value: int):
    self._mapping[id(key)] = (key, value)

  def __delitem__(self, key: tp.Any):
    del self._mapping[id(key)]

  def __len__(self) -> int:
    return len(self._mapping)

  def __contains__(self, key: tp.Any) -> bool:
    return id(key) in self._mapping

  def __iter__(self) -> tp.Iterator[tp.Any]:
    for key, _ in self._mapping.values():
      yield key

  def items(self) -> tp.ItemsView[tp.Any, int]:
    return self._mapping.values()  # type: ignore


# save python version
PythonRefMap = RefMap

if config.flax_use_flaxlib:
  import flaxlib  # type: ignore[import]

  globals()['RefMap'] = flaxlib.RefMap


@dataclasses.dataclass(frozen=True, slots=True)
class NodeImplBase(tp.Generic[Node, Leaf, AuxData]):
  type: type[Node]
  flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]]

  def node_dict(self, node: Node) -> dict[Key, Leaf]:
    nodes, _ = self.flatten(node)
    return dict(nodes)


@dataclasses.dataclass(frozen=True, slots=True)
class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
  set_key: tp.Callable[[Node, Key, Leaf], None]
  pop_key: tp.Callable[[Node, Key], Leaf]
  create_empty: tp.Callable[[AuxData], Node]
  clear: tp.Callable[[Node], None]
  init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None]


@dataclasses.dataclass(frozen=True, slots=True)
class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
  unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node]


NodeImpl = tp.Union[
  GraphNodeImpl[Node, Leaf, AuxData], PytreeNodeImpl[Node, Leaf, AuxData]
]


GRAPH_REGISTRY: dict[type, NodeImpl[tp.Any, tp.Any, tp.Any]] = {}
PYTREE_REGISTRY: dict[type, PytreeNodeImpl[tp.Any, tp.Any, tp.Any]] = {}


def register_graph_node_type(
  type: type,
  flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]],
  set_key: tp.Callable[[Node, Key, Leaf], None],
  pop_key: tp.Callable[[Node, Key], Leaf],
  create_empty: tp.Callable[[AuxData], Node],
  clear: tp.Callable[[Node], None],
  init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None],
):
  if type in GRAPH_REGISTRY:
    raise ValueError(f'Node type {type} is already registered.')

  GRAPH_REGISTRY[type] = GraphNodeImpl(
    type=type,
    flatten=flatten,
    set_key=set_key,
    pop_key=pop_key,
    create_empty=create_empty,
    clear=clear,
    init=init,
  )


def register_pytree_node_type(
  type: type,
  flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]],
  unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node],
):
  if type in PYTREE_REGISTRY:
    raise ValueError(f'Node type {type} is already registered.')

  PYTREE_REGISTRY[type] = PytreeNodeImpl(
    type=type, flatten=flatten, unflatten=unflatten
  )


def is_node(x: tp.Any) -> bool:
  if isinstance(x, Variable):
    return False
  if type(x) in GRAPH_REGISTRY:
    return True
  return is_pytree_node(x)


def is_graph_node(x: tp.Any) -> bool:
  return type(x) in GRAPH_REGISTRY or variablelib.is_mutable_array(x)


def is_node_type(x: type[tp.Any]) -> bool:
  return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree


def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any] | None:
  if isinstance(x, Variable):
    return None

  node_type = type(x)

  if node_type in GRAPH_REGISTRY:
    return GRAPH_REGISTRY[node_type]
  elif node_type in PYTREE_REGISTRY:
    return PYTREE_REGISTRY[node_type]
  elif node_type in JAX_PYTREE_REGISTRY or issubclass(node_type, tuple):
    return PYTREE_NODE_IMPL  # type: ignore
  else:
    return None


def get_node_impl_for_type(
  x: type[Node],
) -> NodeImpl[Node, tp.Any, tp.Any] | None:
  if x is GenericPytree:
    return PYTREE_NODE_IMPL  # type: ignore
  elif x in PYTREE_REGISTRY:
    return PYTREE_REGISTRY[x]
  elif x in GRAPH_REGISTRY:
    return GRAPH_REGISTRY[x]
  else:
    return None


class HashableMapping(tp.Mapping[HA, HB], tp.Hashable):
  _mapping: dict[HA, HB] | tp.Mapping[HA, HB]

  def __init__(self, mapping: tp.Mapping[HA, HB], copy: bool = True):
    self._mapping = dict(mapping) if copy else mapping

  def __contains__(self, key: object) -> bool:
    return key in self._mapping

  def __getitem__(self, key: HA) -> HB:
    return self._mapping[key]

  def __iter__(self) -> tp.Iterator[HA]:
    return iter(self._mapping)

  def __len__(self) -> int:
    return len(self._mapping)

  def __hash__(self) -> int:
    return hash(tuple(sorted(self._mapping.items())))

  def __eq__(self, other: tp.Any) -> bool:
    return (
      isinstance(other, HashableMapping) and self._mapping == other._mapping
    )

  def __repr__(self) -> str:
    return repr(self._mapping)


@jax.tree_util.register_static
@dataclasses.dataclass(frozen=True, repr=False)
class NodeRef(tp.Generic[Node], reprlib.Representable):
  index: int

  def __nnx_repr__(self):
    yield reprlib.Object(type=type(self))
    yield reprlib.Attr('index', self.index)

  def __treescope_repr__(self, path, subtree_renderer):
    return treescope.repr_lib.render_object_constructor(
      object_type=type(self),
      attributes={'index': self.index},
      path=path,
      subtree_renderer=subtree_renderer,
    )


if config.flax_use_flaxlib:
  import flaxlib  # type: ignore[import]

  jax.tree_util.register_static(flaxlib.NodeRef)
  globals()['NodeRef'] = flaxlib.NodeRef


@dataclasses.dataclass(frozen=True, repr=False)
class VariableDef(reprlib.Representable, tp.Generic[Node]):
  type: type[Node]
  index: int
  outer_index: int | None
  metadata: HashableMapping[str, tp.Any]
  mutable_arraydef: MutableArrayDef | NodeRef | None

  def with_no_outer_index(self) -> VariableDef:
    return VariableDef(
      type=self.type,
      index=self.index,
      outer_index=None,
      metadata=self.metadata,
      mutable_arraydef=self.mutable_arraydef.with_no_outer_index()
      if isinstance(self.mutable_arraydef, MutableArrayDef)
      else self.mutable_arraydef,
    )

  def with_same_outer_index(self) -> VariableDef:
    return VariableDef(
      type=self.type,
      index=self.index,
      outer_index=self.index,
      metadata=self.metadata,
      mutable_arraydef=self.mutable_arraydef.with_same_outer_index()
      if isinstance(self.mutable_arraydef, MutableArrayDef)
      else self.mutable_arraydef,
    )

  def __nnx_repr__(self):
    yield reprlib.Object(type=type(self))
    yield reprlib.Attr('type', self.type.__name__)
    yield reprlib.Attr('index', self.index)
    yield reprlib.Attr('outer_index', self.outer_index)
    yield reprlib.Attr('metadata', reprlib.PrettyMapping(self.metadata))

  def __treescope_repr__(self, path, subtree_renderer):
    return treescope.repr_lib.render_object_constructor(
      object_type=type(self),
      attributes={
        'type': self.type,
        'index': self.index,
        'outer_index': self.outer_index,
        'metadata': self.metadata,
      },
      path=path,
      subtree_renderer=subtree_renderer,
    )


if config.flax_use_flaxlib:
  import flaxlib  # type: ignore[import]

  jax.tree_util.register_static(flaxlib.VariableDef)
  globals()['VariableDef'] = flaxlib.VariableDef


@dataclasses.dataclass(frozen=True, repr=False)
class MutableArrayDef(reprlib.Representable):
  index: int
  outer_index: int | None

  def with_no_outer_index(self):
    return MutableArrayDef(
      index=self.index,
      outer_index=None,
    )

  def with_same_outer_index(self):
    return MutableArrayDef(
      index=self.index,
      outer_index=self.index,
    )

  def __nnx_repr__(self):
    yield reprlib.Object(type=type(self))
    yield reprlib.Attr('index', self.index)
    yield reprlib.Attr('outer_index', self.outer_index)

  def __treescope_repr__(self, path, subtree_renderer):
    return treescope.repr_lib.render_object_constructor(
      object_type=type(self),
      attributes={
        'index': self.index,
        'outer_index': self.outer_index,
      },
      path=path,
      subtree_renderer=subtree_renderer,
    )


@jax.tree_util.register_static
@dataclasses.dataclass(frozen=True, repr=False, slots=True)
class NodeDef(tp.Generic[Node], reprlib.Representable):
  """A dataclass that denotes the tree structure of a
  :class:`Module`. A ``GraphDef`` can be generated by either
  calling :func:`split` or :func:`graphdef` on the :class:`Module`."""

  type: tp.Type[Node]
  index: int | None
  outer_index: int | None
  num_attributes: int
  metadata: tp.Any

  def with_no_outer_index(self) -> NodeDef[Node]:
    return NodeDef(
      type=self.type,
      index=self.index,
      outer_index=None,
      num_attributes=self.num_attributes,
      metadata=self.metadata,
    )

  def with_same_outer_index(self) -> NodeDef[Node]:
    return NodeDef(
      type=self.type,
      index=self.index,
      outer_index=self.index,
      num_attributes=self.num_attributes,
      metadata=self.metadata,
    )

  def __nnx_repr__(self):
    yield reprlib.Object(type=type(self))

    yield reprlib.Attr('type', self.type.__name__)
    yield reprlib.Attr('index', self.index)
    yield reprlib.Attr('outer_index', self.outer_index)
    yield reprlib.Attr('num_attributes', self.num_attributes)
    yield reprlib.Attr('metadata', self.metadata)

  def __treescope_repr__(self, path, subtree_renderer):
    return treescope.repr_lib.render_object_constructor(
      object_type=type(self),
      attributes={
        'type': self.type,
        'index': self.index,
        'outer_index': self.outer_index,
        'num_attributes': self.num_attributes,
        'metadata': self.metadata,
      },
      path=path,
      subtree_renderer=subtree_renderer,
    )


if config.flax_use_flaxlib:
  import flaxlib  # type: ignore[import]

  jax.tree_util.register_static(flaxlib.NodeDef)
  globals()['NodeDef'] = flaxlib.NodeDef

NodeDefType = tp.Union[
  NodeDef[Node],
  NodeRef[Node],
  VariableDef[Node],
  MutableArrayDef,
]


@dataclasses.dataclass(frozen=True, slots=True)
class ArrayAttr:
  pass


ARRAY_ATTR = ArrayAttr()

@dataclasses.dataclass(frozen=True, slots=True)
class MutableArrayAttr:
  pass


MUTABLE_ARRAY_ATTR = MutableArrayAttr()


@dataclasses.dataclass(frozen=True, slots=True)
class NodeAttr:
  pass


NODE_ATTR = NodeAttr()

AttrType = tp.Union[
  NodeAttr,
  ArrayAttr,
  MutableArrayAttr,
  'Static[tp.Any]',
]

# GraphDef = tp.Union[NodeDef[Node], NodeRef[Node], VariableDef[Node]]
[docs]@jax.tree_util.register_static @dataclasses.dataclass(frozen=True, slots=True) class GraphDef(tp.Generic[Node]): nodes: list[NodeDefType[tp.Any]] attributes: list[tuple[Key, AttrType]] num_leaves: int def __hash__(self) -> int: return hash((tuple(self.nodes), tuple(self.attributes))) def with_no_outer_index(self) -> GraphDef[Node]: return GraphDef( nodes=[ node.with_no_outer_index() if not isinstance(node, NodeRef) else node for node in self.nodes ], attributes=self.attributes, num_leaves=self.num_leaves, ) def with_same_outer_index(self) -> GraphDef[Node]: return GraphDef( nodes=[ node.with_same_outer_index() if not isinstance(node, NodeRef) else node for node in self.nodes ], attributes=self.attributes, num_leaves=self.num_leaves, ) # TODO(cgarciae): remove this method def apply( self, state: GraphState, *states: GraphState ) -> ApplyCaller[tuple[GraphDef[Node], GraphState]]: accessor = DelayedAccessor() def _apply( accessor: DelayedAccessor, *args, **kwargs ) -> tuple[tp.Any, tuple[GraphDef[Node], GraphState]]: module = merge(self, state, *states) fn = accessor(module) out = fn(*args, **kwargs) graphdef, flat_state = flatten(module) state_ = statelib.from_flat_state(flat_state) return out, (graphdef, state_) return CallableProxy(_apply, accessor) # type: ignore
PureState = tuple[GraphDef[Node], GraphState] @tp.overload def flatten( # type: ignore[invalid-annotation] node: Node, /, *, ref_index: RefMap | None = None, ref_outer_index: RefMap | None = None, ) -> tuple[GraphDef[Node], FlatState[VariableState[tp.Any]]]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] node: Node, /, *, with_paths: tp.Literal[True], return_variables: tp.Literal[True], ref_index: RefMap | None = None, ref_outer_index: RefMap | None = None, ) -> tuple[ GraphDef[Node], FlatState[Variable[tp.Any]], ]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] node: Node, /, *, with_paths: tp.Literal[False], return_variables: tp.Literal[True], ref_index: RefMap | None = None, ref_outer_index: RefMap | None = None, ) -> tuple[ GraphDef[Node], list[Variable[tp.Any]], ]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] node: Node, /, *, return_variables: tp.Literal[True], ref_index: RefMap | None = None, ref_outer_index: RefMap | None = None, ) -> tuple[ GraphDef[Node], FlatState[Variable[tp.Any]], ]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] node: Node, /, *, with_paths: bool, ref_index: RefMap | None = None, ref_outer_index: RefMap | None = None, ) -> tuple[ GraphDef[Node], FlatState[VariableState[tp.Any]] | list[tp.Any], ]: ... def flatten( # type: ignore[invalid-annotation] node: Node, /, *, with_paths: bool = True, return_variables: bool = False, ref_index: RefMap | None = None, ref_outer_index: RefMap | None = None, ) -> tuple[ GraphDef[Node], FlatState[VariableState[tp.Any]] | FlatState[Variable[tp.Any]] | list[tp.Any], ]: """Flattens a graph node into a (graphdef, state) pair. Args: x: A graph node. ref_index: A mapping from nodes to indexes, defaults to None. If not provided, a new empty dictionary is created. This argument can be used to flatten a sequence of graph nodes that share references. with_paths: A boolean that indicates whether to return a FlatState object that includes the paths to VariableState objects, or just a list of the Variable's inner values. """ if ref_index is None: ref_index = RefMap() leaves: list[LeafType] = [] path: list[Key] | None = [] if with_paths else None paths: list[PathParts] | None = [] if with_paths else None nodes: list[NodeDefType[tp.Any]] = [] attributes: list[tuple[Key, AttrType]] = [] node_impl = get_node_impl(node) if node_impl is None and not ( isinstance(node, Variable) or variablelib.is_mutable_array(node) ): raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') _graph_flatten( node, node_impl, path, ref_index, ref_outer_index, nodes, attributes, leaves, paths, return_variables, ) graphdef: GraphDef = GraphDef( nodes=nodes, attributes=attributes, num_leaves=len(leaves) ) if paths is not None: return graphdef, FlatState.from_sorted_keys_values(tuple(paths), leaves) # type: ignore[return-value] else: return graphdef, leaves def _graph_flatten( node: Node, node_impl: NodeImpl[Node, Leaf, AuxData] | None, path: list[Key] | None, ref_index: RefMap, ref_outer_index: RefMap | None, nodes: list[NodeDefType[tp.Any]], attributes: list[tuple[Key, AttrType]], leaves: list[LeafType], paths: list[PathParts] | None, return_variables: bool, ) -> None: is_pytree_node_ = type(node_impl) is PytreeNodeImpl index: int | None if not is_pytree_node_ and node in ref_index: nodes.append(NodeRef(index := ref_index[node])) return is_graph_node_ = type(node_impl) is GraphNodeImpl is_variable = isinstance(node, Variable) is_mutable_array = variablelib.is_mutable_array(node) # only cache graph nodes, we don't add mutable arrays here # as they are added in the make_mutable_arraydef function if is_graph_node_ or is_variable: index = len(ref_index) ref_index[node] = index else: index = None def make_mutable_arraydef(value: variablelib.MutableArray): if value in ref_index: index = ref_index[value] return NodeRef(index), REPEATED else: index = len(ref_index) ref_index[value] = index output_value: NoUpdate | MutableArrayOutput | variablelib.MutableArray if ref_outer_index is not None: if value in ref_outer_index: outer_index = ref_outer_index[value] output_value = NO_UPDATE mutable_arraydef = MutableArrayDef(index=index, outer_index=outer_index) else: output_value = MutableArrayOutput(value[...]) mutable_arraydef = MutableArrayDef(index=index, outer_index=None) else: output_value = value mutable_arraydef = MutableArrayDef(index=index, outer_index=None) return mutable_arraydef, output_value if is_variable: assert isinstance(node, Variable) assert index is not None inner_value = node.raw_value if variablelib.is_mutable_array(inner_value): mutable_arraydef, inner_value = make_mutable_arraydef(inner_value) else: mutable_arraydef = None if return_variables: leaf = node leaf.raw_value = inner_value elif path is None: leaf = inner_value else: leaf = node.to_state() # type: ignore[assignment] leaf.raw_value = inner_value variabledef = VariableDef( type=type(node), index=index, outer_index=ref_outer_index.get(node, None) if ref_outer_index else None, metadata=HashableMapping(node._var_metadata), mutable_arraydef=mutable_arraydef, ) if type(inner_value) is not Repeated: assert not isinstance(leaf, Repeated) leaves.append(leaf) if path is not None: assert paths is not None paths.append(tuple(path)) nodes.append(variabledef) return elif is_mutable_array: mutable_arraydef, leaf = make_mutable_arraydef(node) # type: ignore[arg-type] if not isinstance(leaf, Repeated): leaves.append(leaf) if path is not None: assert paths is not None paths.append(tuple(path)) nodes.append(mutable_arraydef) return if node_impl is None: raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') values, metadata = node_impl.flatten(node) num_attributes = len(values) nodedef = NodeDef( node_impl.type, index, ref_outer_index[node] if is_graph_node_ and ref_outer_index and node in ref_outer_index else None, num_attributes, metadata, ) nodes.append(nodedef) for key, value in values: value_node_impl = get_node_impl(value) if path is not None: path.append(key) if value_node_impl is not None or isinstance(value, Variable): attributes.append((key, NODE_ATTR)) _graph_flatten( value, value_node_impl, path, ref_index, ref_outer_index, nodes, attributes, leaves, paths, return_variables, ) elif variablelib.is_mutable_array(value): attributes.append((key, MUTABLE_ARRAY_ATTR)) mutable_arraydef, leaf = make_mutable_arraydef(value) if not isinstance(leaf, Repeated): leaves.append(leaf) if paths is not None: paths.append(tuple(path)) # type: ignore nodes.append(mutable_arraydef) elif isinstance(value, (jax.Array, np.ndarray)): attributes.append((key, ARRAY_ATTR)) if paths is not None: paths.append(tuple(path)) # type: ignore leaves.append(value) else: attributes.append((key, Static(value))) if path is not None: path.pop() return @dataclasses.dataclass(slots=True) class FingerprintContext: next_index: int # TODO(cgarciae): the actual fingerprint object is not being used, # only the traversal process is still relevant def fingerprint( node, /, *, ref_index: RefMap | None = None, new_ref_index: RefMap | None = None, ) -> list[tp.Hashable]: """ """ if ref_index is None: ref_index = RefMap() if new_ref_index is None: new_ref_index = RefMap() node_impl = get_node_impl(node) if node_impl is None: raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') ctx = FingerprintContext(len(ref_index) + len(new_ref_index)) fp: list[tp.Hashable] = [] _graph_fingerprint(ctx, fp.append, node, node_impl, ref_index, new_ref_index) return fp def _graph_fingerprint( ctx: FingerprintContext, append_fn: tp.Callable[[tp.Any], None], node, node_impl: NodeImpl[Node, Leaf, AuxData], ref_index: RefMap, new_ref_index: RefMap, ): is_pytree_node_ = type(node_impl) is PytreeNodeImpl is_graph_node_ = type(node_impl) is GraphNodeImpl append_fn(type(node)) if is_graph_node_: append_fn(id(node)) if node in ref_index: append_fn(ref_index[node]) return elif node in new_ref_index: append_fn(new_ref_index[node]) return index = new_ref_index[node] = ctx.next_index ctx.next_index += 1 else: index = -1 values, metadata = node_impl.flatten(node) append_fn(index) append_fn(metadata) for key, value in values: value_node_impl = get_node_impl(value) append_fn(key) if value_node_impl is not None: _graph_fingerprint( ctx, append_fn, value, value_node_impl, ref_index, new_ref_index, ) elif isinstance(value, Variable): append_fn(id(value)) append_fn(type(value)) if value in ref_index: append_fn(ref_index[value]) elif value in new_ref_index: append_fn(new_ref_index[value]) else: variable_index = new_ref_index[value] = ctx.next_index ctx.next_index += 1 append_fn(variable_index) for key_value in value._var_metadata.items(): append_fn(key_value) elif not isinstance(value, (jax.Array, np.ndarray)): append_fn(value) def check_fingerprint( node, fp: list[tp.Hashable], /, *, ref_index: RefMap | None = None, new_ref_index: RefMap | None = None, ) -> bool: """ """ if ref_index is None: ref_index = RefMap() if new_ref_index is None: new_ref_index = RefMap() node_impl = get_node_impl(node) if node_impl is None: raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') ctx = FingerprintContext(len(ref_index) + len(new_ref_index)) fp_matches = _check_graph_fingerprint( ctx, iter(fp), node, node_impl, ref_index, new_ref_index ) return fp_matches def _check_graph_fingerprint( ctx: FingerprintContext, fp_iterator: tp.Iterator[tp.Hashable], node, node_impl: NodeImpl[Node, Leaf, AuxData], ref_index: RefMap, new_ref_index: RefMap, ) -> bool: is_pytree_node_ = type(node_impl) is PytreeNodeImpl is_graph_node_ = type(node_impl) is GraphNodeImpl if type(node) != next(fp_iterator): return False if is_graph_node_: # append_fn(id(node)) if id(node) != next(fp_iterator): return False if node in ref_index: # append_fn(ref_index[node]) return ref_index[node] == next(fp_iterator) elif node in new_ref_index: # append_fn(new_ref_index[node]) return new_ref_index[node] == next(fp_iterator) index = new_ref_index[node] = ctx.next_index ctx.next_index += 1 else: index = -1 values, metadata = node_impl.flatten(node) # append_fn(index) if index != next(fp_iterator): return False # append_fn(metadata) if metadata != next(fp_iterator): return False for key, value in values: value_node_impl = get_node_impl(value) # append_fn(key) if key != next(fp_iterator): return False if value_node_impl is not None: if not _check_graph_fingerprint( ctx, fp_iterator, value, value_node_impl, ref_index, new_ref_index, ): return False elif isinstance(value, Variable): # append_fn(id(value)) if id(value) != next(fp_iterator): return False # append_fn(type(value)) if type(value) != next(fp_iterator): return False if value in ref_index: # append_fn(ref_index[value]) if ref_index[value] != next(fp_iterator): return False elif value in new_ref_index: # append_fn(new_ref_index[value]) if new_ref_index[value] != next(fp_iterator): return False else: variable_index = new_ref_index[value] = ctx.next_index ctx.next_index += 1 # append_fn(variable_index) if variable_index != next(fp_iterator): return False for key_value in value._var_metadata.items(): # append_fn(key_value) if key_value != next(fp_iterator): return False else: if isinstance(value, (jax.Array, np.ndarray)): raise ValueError(f'Arrays leaves are not supported: {value}') # append_fn(value) if value != next(fp_iterator): return False return True def _get_sorted_leaves( xs: tp.Mapping[tp.Any, tp.Any], ) -> list[tp.Any]: if not isinstance(xs, tp.Mapping): # type: ignore raise TypeError(f'expected Mapping; got {type(xs).__qualname__}') leaves: list[tp.Any] = [] def _flatten(xs): if not isinstance(xs, tp.Mapping): leaves.append(xs) else: for _, value in sorted(xs.items()): _flatten(value) _flatten(xs) return leaves def unflatten( # type: ignore[invalid-annotation] graphdef: GraphDef[Node], state: State[Key, tp.Any] | FlatState[tp.Any] | list[tp.Any], /, *, index_ref: IndexMap | None = None, outer_index_outer_ref: IndexMap | None = None, ) -> Node: """Unflattens a graphdef into a node with the given state. Args: graphdef: A GraphDef instance. state: A State instance. index_ref: A mapping from indexes to nodes references found during the graph traversal, defaults to None. If not provided, a new empty dictionary is created. This argument can be used to unflatten a sequence of (graphdef, state) pairs that share the same index space. index_ref_cache: A mapping from indexes to existing nodes that can be reused. When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the object in an empty state and then filled by the unflatten process, as a result existing graph nodes are mutated to have the new content/topology specified by the graphdef. """ if isinstance(state, (State, dict)): leaves = _get_sorted_leaves(state) elif isinstance(state, FlatState): leaves = state.leaves elif isinstance(state, list): # type: ignore leaves = state else: raise ValueError(f'Unsupported state type: {type(state)}') if index_ref is None: index_ref = IndexMap() if len(leaves) != graphdef.num_leaves: raise ValueError( f'Incorrect number of leaves, expected {graphdef.num_leaves} leaves, but got {len(leaves)}.' ) if isinstance(nodedef := graphdef.nodes[0], NodeRef): node = index_ref[nodedef.index] else: node_iter = iter(graphdef.nodes) attribute_iter = iter(graphdef.attributes) leaves_iter = iter(leaves) nodedef = next(node_iter) assert not isinstance(nodedef, NodeRef) if isinstance(nodedef, MutableArrayDef): node_impl = None else: node_impl = get_node_impl_for_type(nodedef.type) node = _graph_unflatten( nodedef, node_impl, node_iter, attribute_iter, leaves_iter, index_ref, outer_index_outer_ref, ) try: next(leaves_iter) except StopIteration: pass else: raise ValueError('Incorrect number of leaves in state.') return node def _graph_unflatten( nodedef: NodeDefType[Node], node_impl: NodeImpl[Node, Leaf, AuxData] | None, node_iter: tp.Iterator[NodeDefType[Node]], attribute_iter: tp.Iterator[tuple[Key, AttrType]], leaves_iter: tp.Iterator[tp.Any], index_ref: IndexMap, outer_index_outer_ref: IndexMap | None, ) -> Node: """Recursive helper for graph_unflatten. Args: nodedef: A GraphDef instance or an index to a node in the cache. state: A mapping from attribute names to variables or subgraphs. index_to_ref: A mapping from indexes to nodes that have been traversed. If a node is already in the cache, it won't be traversed again. index_ref_cache: A mapping from indexes to existing nodes that can be reused. When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the object in an empty state and then filled by the unflatten process, as a result existing graph nodes are mutated to have the new content/topology specified by the nodedef. """ def get_mutable_array(mutable_arraydef: MutableArrayDef, leaf): assert type(mutable_arraydef) is MutableArrayDef if ( outer_index_outer_ref is not None and mutable_arraydef.outer_index is not None and mutable_arraydef.outer_index in outer_index_outer_ref ): # if mutable array exists, update it mutable_array = outer_index_outer_ref[mutable_arraydef.outer_index] if not variablelib.is_mutable_array(mutable_array): raise RuntimeError( f'Expected a MutableArray type but got {mutable_array}.' ) if type(leaf) is not NoUpdate: raise RuntimeError( f'Expected a no update for MutableArray but got {leaf}.' ) elif type(leaf) in (NoUpdate, Repeated): raise ValueError( 'Expected a MutableArrayOutput type but got ' f"'{leaf.value}.'" ) elif type(leaf) is MutableArrayOutput: mutable_array = variablelib.mutable_array(leaf.value) elif variablelib.is_mutable_array(leaf): mutable_array = leaf elif isinstance(leaf, jax.Array): # here we allow merging frozen arrays and will not create a new mutable array mutable_array = leaf else: raise ValueError(f'Found unexpected type for MutableArray, got {leaf}') index_ref[mutable_arraydef.index] = mutable_array return mutable_array if type(nodedef) is NodeRef: return index_ref[nodedef.index] if type(nodedef) is VariableDef: variabledef = tp.cast(VariableDef[Variable], nodedef) # its a unseen variable, create a new one if variabledef.mutable_arraydef is not None: if type(variabledef.mutable_arraydef) is NodeRef: value = index_ref[variabledef.mutable_arraydef.index] else: value = next(leaves_iter) assert type(variabledef.mutable_arraydef) is MutableArrayDef if isinstance(value, Variable | VariableState): inner_value = value.raw_value mutable_array = get_mutable_array( variabledef.mutable_arraydef, inner_value ) value.raw_value = mutable_array else: # if value is an array or mutable array, we need call get_mutable_array # to register it in the index_ref value = get_mutable_array(variabledef.mutable_arraydef, value) else: value = next(leaves_iter) # when idxmap is present, check if the Varable exists there # and update existing variables if it does if ( outer_index_outer_ref is not None and variabledef.outer_index is not None and variabledef.outer_index in outer_index_outer_ref ): # if variable exists, update it variable = outer_index_outer_ref[variabledef.outer_index] if not isinstance(variable, Variable): raise ValueError(f'Expected a Variable type but got {type(variable)}.') elif isinstance(value, Variable): raise ValueError( f'Cannot unflatten flat_state containing Variables when using `outer_index_outer_ref`. ' f'Got {value!r}' ) elif isinstance(value, VariableState): variable.update_from_state(value) else: variable.raw_value = value else: # variabledef.index not in index_ref_cache # variable reference does not exist outside, create a new one if isinstance(value, Variable): variable = value elif isinstance(value, VariableState): variable = value.to_variable() else: variable = variabledef.type.from_metadata( value, dict(variabledef.metadata) ) index_ref[variabledef.index] = variable return variable # type: ignore[return-value] if type(nodedef) is MutableArrayDef: leaf = next(leaves_iter) mutable_array = get_mutable_array(nodedef, leaf) return mutable_array # type: ignore[return-value] assert type(nodedef) is NodeDef if node_impl is None: raise RuntimeError(f'Unsupported type: {nodedef.type}, this is a bug.') if nodedef.index is not None and nodedef.index in index_ref: raise RuntimeError(f'GraphDef index {nodedef.index} already used.') def _get_children() -> list[tuple[Key, tp.Any]]: children: list[tuple[Key, LeafType | Node]] = [] # type: ignore[invalid-annotation] assert type(nodedef) is NodeDef for _ in range(nodedef.num_attributes): key, value = next(attribute_iter) if type(value) is Static: children.append((key, value.value)) # type: ignore[attribute-error] elif type(value) is MutableArrayAttr: mutable_arraydef = next(node_iter) assert ( type(mutable_arraydef) is MutableArrayDef or type(mutable_arraydef) is NodeRef ) if type(mutable_arraydef) is NodeRef: mutable_array = index_ref[mutable_arraydef.index] else: assert type(mutable_arraydef) is MutableArrayDef leaf = next(leaves_iter) mutable_array = get_mutable_array(mutable_arraydef, leaf) children.append((key, mutable_array)) elif type(value) is ArrayAttr: array = next(leaves_iter) children.append((key, array)) elif type(value) is NodeRef: children.append((key, index_ref[value.index])) # type: ignore[attribute-error] elif type(value) is NodeAttr: # if the key is a subgraph we create an empty node subgraphdef = next(node_iter) if type(subgraphdef) is NodeDef: value_node_impl = get_node_impl_for_type(subgraphdef.type) # type: ignore[attribute-error] else: value_node_impl = None subnode = _graph_unflatten( subgraphdef, value_node_impl, node_iter, attribute_iter, leaves_iter, index_ref, outer_index_outer_ref, ) children.append((key, subnode)) else: raise RuntimeError(f'Unknown static field: {key!r}') return children if isinstance(node_impl, GraphNodeImpl): # we create an empty node first and add it to the index # this avoids infinite recursion when there is a reference cycle assert type(nodedef) is NodeDef if ( outer_index_outer_ref is not None and nodedef.outer_index is not None and nodedef.outer_index in outer_index_outer_ref ): node = outer_index_outer_ref[nodedef.outer_index] if type(node) != nodedef.type: raise ValueError( f'Expected a node of type {nodedef.type} for index ' f'{nodedef.index}, but got a node of type {type(node)}.' ) node_impl.clear(node) else: node = node_impl.create_empty(nodedef.metadata) assert nodedef.index is not None index_ref[nodedef.index] = node node_impl.init(node, _get_children()) else: # if the node type does not support the creation of an empty object it means # that it cannot reference itself, so we can create its children first node = node_impl.unflatten(_get_children(), nodedef.metadata) return node def graph_pop( node: tp.Any, filters: tuple[filterlib.Filter, ...], ) -> tuple[GraphState, ...]: id_to_index: dict[int, Index] = {} path_parts: PathParts = () predicates = tuple(filterlib.to_predicate(filter) for filter in filters) flat_states: tuple[dict[PathParts, LeafType], ...] = tuple( {} for _ in predicates ) _graph_pop(node, id_to_index, path_parts, flat_states, predicates) return tuple( statelib.from_flat_state(flat_state) for flat_state in flat_states ) def _graph_pop( node: tp.Any, id_to_index: dict[int, Index], path_parts: PathParts, flat_states: tuple[dict[PathParts, LeafType], ...], predicates: tuple[filterlib.Predicate, ...], ) -> None: if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') if id(node) in id_to_index: return id_to_index[id(node)] = len(id_to_index) node_impl = get_node_impl(node) if node_impl is None: raise TypeError(f'Unknown node type: {type(node)}') node_dict = node_impl.node_dict(node) for name, value in node_dict.items(): if is_node(value): _graph_pop( node=value, id_to_index=id_to_index, path_parts=(*path_parts, name), flat_states=flat_states, predicates=predicates, ) continue elif not is_node_leaf(value): continue elif id(value) in id_to_index: continue node_path = (*path_parts, name) node_impl = get_node_impl(node) if node_impl is None: raise TypeError(f'Unknown node type: {type(node)}') for state, predicate in zip(flat_states, predicates): if predicate(node_path, value): if isinstance(node_impl, PytreeNodeImpl): raise ValueError( f'Cannot pop key {name!r} from node of type {type(node).__name__}' ) id_to_index[id(value)] = len(id_to_index) node_impl.pop_key(node, name) if isinstance(value, Variable): value = value.to_state() state[node_path] = value # type: ignore[index] # mypy is wrong here? break else: # NOTE: should we raise an error here? pass def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]): def _update_variable(node: Variable, value): if isinstance(value, VariableState): # updated from VariableState node.update_from_state(value) else: # updated from raw value if isinstance(value, State) and not value: # NOTE: this is a special case when trying to update a Variable from state # created when flattening into a NodeRef, which creates an empty State. This # can happen when using standalone Variables with `grad` pass else: node.raw_value = value if isinstance(node, Variable): _update_variable(node, state) return if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}') node_impl = get_node_impl(node) if node_impl is None: raise TypeError(f'Unknown node type: {type(node)}') node_dict = node_impl.node_dict(node) for key, value in state.items(): # case 1: new state is being added if key not in node_dict: if isinstance(node_impl, PytreeNodeImpl): raise ValueError( f'Cannot set key {key!r} on immutable node of ' f'type {type(node).__name__}' ) if isinstance(value, Variable): value = value.copy() node_impl.set_key(node, key, value) continue current_value = node_dict[key] # case 2: subgraph is being updated if is_node(current_value): if is_node_leaf(value): raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}') _graph_update_dynamic(current_value, value) else: if isinstance(current_value, jax.Array | np.ndarray): if isinstance(node_impl, PytreeNodeImpl): raise ValueError( f'Cannot set key {key!r} on immutable node of ' f'type {type(node).__name__}' ) node_impl.set_key(node, key, value) continue elif not isinstance(current_value, Variable): # case 3: state leaf is being updated raise ValueError( f'Trying to update a non-Variable attribute {key!r} with a Variable: ' f'{value!r}' ) _update_variable(current_value, value) # -------------------------------------------------------- # UpdateContext # -------------------------------------------------------- class StaticCache(tp.NamedTuple): graphdef: GraphDef[tp.Any] final_graphdef: GraphDef[tp.Any] paths: tuple[PathParts, ...] variables: list[Variable[tp.Any]] new_ref_index: RefMap new_index_ref: IndexMap @staticmethod def create( graphdef: GraphDef[tp.Any], paths: tuple[PathParts, ...], variables: list[Variable[tp.Any]], new_ref_index: RefMap, ): new_index_ref = IndexMap.from_refmap(new_ref_index) final_graphdef: GraphDef[tp.Any] final_graphdef = graphdef.with_same_outer_index() return StaticCache( graphdef=graphdef, final_graphdef=final_graphdef, paths=paths, variables=variables, new_ref_index=new_ref_index, new_index_ref=new_index_ref, ) @dataclasses.dataclass class GraphContext(threading.local): update_context_stacks: dict[tp.Hashable, list[UpdateContext]] = ( dataclasses.field(default_factory=dict) ) ref_index_stack: list[SplitContext] = dataclasses.field(default_factory=list) index_ref_stack: list[MergeContext] = dataclasses.field(default_factory=list) tmp_static_cache: tp.MutableMapping[tp.Any, StaticCache] | None = None caching: bool = False GRAPH_CONTEXT = GraphContext() @contextlib.contextmanager def static_cache(static_cache: tp.MutableMapping[tp.Any, StaticCache]): if GRAPH_CONTEXT.caching: yield return GRAPH_CONTEXT.tmp_static_cache = static_cache try: yield finally: if GRAPH_CONTEXT.tmp_static_cache is not None: raise ValueError( 'GRAPH_CONTEXT.tmp_static_cache should be None, no context consumed it.' ) def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args): """Create a partial from a NNX transformed function alog with some cached input arguments and reduces the python overhead by caching the traversal of NNX graph nodes. This is useful for speed up function that are called repeatedly with the same subset of inputs e.g. a ``train_step`` with a ``model`` and ``optimizer``:: >>> from flax import nnx >>> import jax.numpy as jnp >>> import optax ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> optimizer = nnx.Optimizer(model, optax.adamw(1e-3)) ... >>> @nnx.jit ... def train_step(model, optimizer, x, y): ... def loss_fn(model): ... return jnp.mean((model(x) - y) ** 2) ... ... loss, grads = nnx.value_and_grad(loss_fn)(model) ... optimizer.update(grads) ... return loss ... >>> cached_train_step = nnx.cached_partial(train_step, model, optimizer) ... >>> for step in range(total_steps:=2): ... x, y = jnp.ones((10, 2)), jnp.ones((10, 3)) ... # loss = train_step(model, optimizer, x, y) ... loss = cached_train_step(x, y) ... print(f'Step {step}: loss={loss:.3f}') Step 0: loss=2.669 Step 1: loss=2.660 Note that ``cached_partial`` will clone all cached graph nodes to gurantee the validity of the cache, and these clones will contain references to the same Variable objects which guarantees that state is propagated correctly back to the original graph nodes. Because of the previous, the final structure of all graph nodes must be the same after each call to the cached function, otherswise an error will be raised. Temporary mutations are allowed (e.g. the use of ``Module.sow``) as long as they are cleaned up before the function returns (e.g. via ``nnx.pop``). Args: f: A function to cache. *cached_args: A subset of the input arguments containing the graph nodes to cache. Returns: A partial function expecting the remaining arguments to the original function. """ cache: tp.MutableMapping[tp.Any, StaticCache] = PythonRefMap() # type: ignore original_ref_index: RefMap = RefMap() index_ref: IndexMap = IndexMap() cached_ref_index: RefMap = RefMap() def create_static_cache(x): # TODO(cgarciae): support Array attribute updates for graph nodes if is_graph_node(x) or isinstance(x, Variable): graphdef, flat_state = flatten( x, with_paths=True, return_variables=True, ref_index=original_ref_index ) paths = flat_state.paths variables = flat_state.leaves # clone but keep the same variable references node_cache = unflatten(graphdef, flat_state, index_ref=index_ref) cached_new_ref_index = RefMap() _fp = fingerprint( node_cache, ref_index=cached_ref_index, new_ref_index=cached_new_ref_index, ) cached_ref_index.update(cached_new_ref_index) cache[node_cache] = StaticCache.create( graphdef, paths, variables, cached_new_ref_index ) return node_cache return x cached_args = jax.tree.map(create_static_cache, cached_args) @functools.wraps(f) def cache_args_wrapper(*args, **kwargs): with static_cache(cache): return f(*cached_args, *args, **kwargs) return cache_args_wrapper if tp.TYPE_CHECKING: cached_partial = functools.partial else: cached_partial = _cached_partial @dataclasses.dataclass class SplitContext: ctxtag: tp.Hashable | None ref_index: RefMap is_inner: bool | None @tp.overload def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... # type: ignore[invalid-annotation] @tp.overload def split( # type: ignore[invalid-annotation] self, graph_node: A, first: filterlib.Filter, / ) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( self, graph_node: A, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ... # type: ignore[not-supported-yet] def split( self, node: A, *filters: filterlib.Filter ) -> tuple[GraphDef[A], tpe.Unpack[tuple[GraphState, ...]]]: # type: ignore[not-supported-yet] ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) inner_ref_outer_index = ( ctx.inner_ref_outer_index if ctx and ctx.inner_ref_outer_index else None ) graphdef, flat_state = flatten( node, ref_index=self.ref_index, ref_outer_index=inner_ref_outer_index ) flat_states = _split_state(flat_state, filters) states = _to_nested_state(graphdef, flat_states) return graphdef, *states @tp.overload def flatten( # type: ignore[invalid-annotation] self, graph_node: A, /, *, with_paths: tp.Literal[False], ) -> tuple[GraphDef[A], list[tp.Any]]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] self, graph_node: A, /, ) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] self, graph_node: A, first: filterlib.Filter, /, ) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] self, graph_node: A, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[ GraphDef[A], FlatState[VariableState[tp.Any]], tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]], ]: ... def flatten( # type: ignore[invalid-annotation] self, node: A, *filters: filterlib.Filter, with_paths: bool = True, ) -> tuple[ GraphDef[A], FlatState[VariableState[tp.Any]] | list[tp.Any], tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]], ]: if not with_paths and filters: raise ValueError('Cannot use filters with with_paths=False') ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) static_cache = ( ctx.static_cache if ctx is not None and self.is_inner is False else None ) ref_outer_index = ( ctx.inner_ref_outer_index if ctx and ctx.inner_ref_outer_index else None ) flat_state: ( FlatState[VariableState[tp.Any]] | FlatState[Variable[tp.Any]] | list[tp.Any] ) leaves: list[tp.Any] if node in self.ref_index: # node is already in the ref_index, call flatten which will return a NodeRef graphdef, flat_state = flatten( node, ref_index=self.ref_index, ref_outer_index=ref_outer_index, with_paths=with_paths, ) if with_paths: assert isinstance(flat_state, FlatState) paths = flat_state.paths leaves = flat_state.leaves else: assert isinstance(flat_state, list) paths = None leaves = flat_state elif static_cache is not None and node in static_cache: node_static_cache = static_cache[node] graphdef = node_static_cache.graphdef # add the new references to the ref_index self.ref_index.update(node_static_cache.new_ref_index) if with_paths: paths = node_static_cache.paths leaves = [ variable.to_state() for variable in node_static_cache.variables ] else: paths = None leaves = [ variable.raw_value for variable in node_static_cache.variables ] else: graphdef, flat_state = flatten( node, ref_index=self.ref_index, ref_outer_index=ref_outer_index, with_paths=with_paths, ) if with_paths: assert isinstance(flat_state, FlatState) paths = flat_state.paths leaves = flat_state.leaves else: assert isinstance(flat_state, list) paths = None leaves = flat_state if with_paths: assert paths is not None flat_state = FlatState.from_sorted_keys_values(paths, leaves) flat_states = _split_state(flat_state, filters) return graphdef, *flat_states # type: ignore[bad-return-type] else: return graphdef, leaves @contextlib.contextmanager def split_context(ctxtag: tp.Hashable | None = None): ctx = current_update_context(ctxtag) if ctxtag is not None else None is_inner = ctx.outer_ref_outer_index is not None if ctx is not None else None GRAPH_CONTEXT.ref_index_stack.append(SplitContext(ctxtag, RefMap(), is_inner)) try: yield GRAPH_CONTEXT.ref_index_stack[-1] finally: flatten_ctx = GRAPH_CONTEXT.ref_index_stack.pop() if ctxtag is not None: ctx = current_update_context(ctxtag) ctx.flatten_end(flatten_ctx.ref_index) del flatten_ctx.ref_index del flatten_ctx.ctxtag @dataclasses.dataclass class MergeContext: ctxtag: tp.Hashable | None index_ref: IndexMap is_inner: bool | None def merge( # type: ignore[invalid-annotation] self, graphdef: GraphDef[A], state: GraphState | VariableState, /, *states: GraphState | VariableState, ) -> A: ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) outer_index_outer_ref = ( ctx.outer_index_outer_ref if ctx and ctx.outer_index_outer_ref else None ) _state = _merge_to_flat_state((state, *states)) node = unflatten( graphdef, _state, index_ref=self.index_ref, outer_index_outer_ref=outer_index_outer_ref, ) return node def unflatten( # type: ignore[invalid-annotation] self, graphdef: GraphDef[A], flat_state: GraphFlatState | list[tp.Any], /, *flat_states: GraphFlatState, ) -> A: ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) static_cache = ( ctx.static_cache if ctx is not None and self.is_inner is False else None ) state: FlatState[tp.Any] | list[tp.Any] if type(flat_state) is list: if flat_states: raise ValueError( 'Cannot use multiple flat_states when flat_state is a list, ' f'got flat_state: {flat_state!r}, flat_states: {flat_states!r}' ) state = flat_state else: state = FlatState.merge(flat_state, *flat_states) if type(graphdef.nodes[0]) is NodeRef: node = unflatten( graphdef, state, index_ref=self.index_ref, ) elif static_cache is not None: assert isinstance(graphdef.nodes[0], NodeDef) assert ctx is not None if (outer_index := graphdef.nodes[0].outer_index) is not None: outer_index_outer_ref = ctx.outer_index_outer_ref assert outer_index_outer_ref is not None node = outer_index_outer_ref[outer_index] if node in static_cache: static_cache_node = static_cache[node] if static_cache_node.final_graphdef != graphdef: raise ValueError( 'The graph structure of a node added to cached_partial was mutated inside the transformation, ' f'this is not allowed.\nNode: {node}\nOuput graphdef: {graphdef}\nExpected graphdef: {static_cache_node.final_graphdef}' ) if type(state) is list: leaves = state elif type(state) is FlatState: leaves = state.leaves else: raise ValueError(f'Unsupported state type: {type(state)}') if len(leaves) != len(static_cache_node.variables): raise ValueError( f'Incorrect number of leaves: expected {len(static_cache_node.variables)} ' f'leaves in the state, got {len(leaves)}' ) for variable, leaf in zip(static_cache_node.variables, leaves): if type(leaf) is VariableState: variable.update_from_state(leaf) else: variable.raw_value = leaf self.index_ref.update(static_cache_node.new_index_ref) else: # uncached node, create it node = unflatten( graphdef, state, index_ref=self.index_ref, outer_index_outer_ref=outer_index_outer_ref, ) else: # graphdef.outer_index is None # its a new node, create it node = unflatten( graphdef, state, index_ref=self.index_ref, ) else: outer_index_outer_ref = ( ctx.outer_index_outer_ref if ctx and ctx.outer_index_outer_ref else None ) node = unflatten( graphdef, state, index_ref=self.index_ref, outer_index_outer_ref=outer_index_outer_ref, ) return node @tp.overload @contextlib.contextmanager def merge_context() -> tp.Generator[MergeContext, None, None]: ... # type: ignore[bad-return-type] @tp.overload @contextlib.contextmanager def merge_context( ctxtag: tp.Hashable | None, inner: bool | None ) -> tp.Generator[MergeContext, None, None]: ... # type: ignore[bad-return-type] @contextlib.contextmanager def merge_context(ctxtag: tp.Hashable | None = None, inner: bool | None = None): GRAPH_CONTEXT.index_ref_stack.append(MergeContext(ctxtag, IndexMap(), inner)) try: yield GRAPH_CONTEXT.index_ref_stack[-1] finally: unflatten_ctx = GRAPH_CONTEXT.index_ref_stack.pop() index_ref = unflatten_ctx.index_ref if ctxtag is not None: if inner is None: raise ValueError('inner_merge must be specified when using ctxtag') ctx = current_update_context(ctxtag) ctx.unflatten_end(index_ref, inner) del unflatten_ctx.index_ref del unflatten_ctx.ctxtag
[docs]@jax.tree_util.register_static @dataclasses.dataclass class UpdateContext: """A context manager for handling complex state updates.""" tag: tp.Hashable outer_ref_outer_index: RefMap | None outer_index_inner_ref: IndexMap | None # reverse caches outer_index_outer_ref: IndexMap | None inner_ref_outer_index: RefMap | None static_cache: tp.MutableMapping[tp.Any, StaticCache] | None # define hash and eq to make this an opaque object def __hash__(self): return 0 def __eq__(self, other): return isinstance(other, UpdateContext) def flatten_end(self, ref_index: RefMap): if self.outer_ref_outer_index is None: # outer split (1), store the references self.outer_ref_outer_index = ref_index self.outer_index_outer_ref = IndexMap.from_refmap( self.outer_ref_outer_index ) else: # inner split (3), clear index_ref self.outer_index_inner_ref = None self.inner_ref_outer_index = None def unflatten_end(self, index_ref: IndexMap, inner_merge: bool): if inner_merge: # inner merge (2) self.outer_index_inner_ref = index_ref self.inner_ref_outer_index = RefMap.from_indexmap(index_ref)
@dataclasses.dataclass class UpdateContextManager: tag: tp.Hashable def __enter__(self): if GRAPH_CONTEXT.tmp_static_cache is not None: # take current static cache static_cache = GRAPH_CONTEXT.tmp_static_cache GRAPH_CONTEXT.tmp_static_cache = None else: static_cache = None ctx = UpdateContext( tag=self.tag, outer_ref_outer_index=None, outer_index_inner_ref=None, outer_index_outer_ref=None, inner_ref_outer_index=None, static_cache=static_cache, ) if self.tag not in GRAPH_CONTEXT.update_context_stacks: GRAPH_CONTEXT.update_context_stacks[self.tag] = [ctx] else: GRAPH_CONTEXT.update_context_stacks[self.tag].append(ctx) return ctx def __exit__(self, *args): if self.tag not in GRAPH_CONTEXT.update_context_stacks: raise RuntimeError( f'No update context found for tag {self.tag!r}, this is a bug.' ) stack = GRAPH_CONTEXT.update_context_stacks[self.tag] ctx = stack.pop() # clear references del ctx.outer_ref_outer_index del ctx.outer_index_inner_ref del ctx.outer_index_outer_ref del ctx.inner_ref_outer_index if not stack: del GRAPH_CONTEXT.update_context_stacks[self.tag] def __call__(self, f: F) -> F: @functools.wraps(f) def update_context_manager_wrapper(*args, **kwargs): with self: return f(*args, **kwargs) return update_context_manager_wrapper # type: ignore
[docs]def update_context(tag: tp.Hashable): """Creates an :class:`UpdateContext` context manager which can be used to handle more complex state updates beyond what ``nnx.update`` can handle, including updates to static properties and graph structure. UpdateContext exposes a ``split`` and ``merge`` API with the same signature as ``nnx.split`` / ``nnx.merge`` but performs some bookkeeping to have the necessary information in order to perfectly update the input objects based on the changes made inside the transform. The UpdateContext must call split and merge a total of 4 times, the first and last calls happen outside the transform and the second and third calls happen inside the transform as shown in the diagram below:: idxmap (2) merge ─────────────────────────────► split (3) ▲ │ │ inside │ │. . . . . . . . . . . . . . . . . . │ index_mapping │ outside │ │ ▼ (1) split──────────────────────────────► merge (4) refmap The first call to split ``(1)`` creates a ``refmap`` which keeps track of the outer references, and the first call to merge ``(2)`` creates an ``idxmap`` which keeps track of the inner references. The second call to split ``(3)`` combines the refmap and idxmap to produce the ``index_mapping`` which indicates how the outer references map to the inner references. Finally, the last call to merge ``(4)`` uses the index_mapping and the refmap to reconstruct the output of the transform while reusing/updating the inner references. To avoid memory leaks, the idxmap is cleared after ``(3)`` and the refmap is cleared after ``(4)``, and both are cleared after the context manager exits. Here is a simple example showing the use of ``update_context``:: >>> from flax import nnx ... >>> m1 = nnx.Dict({}) >>> with nnx.update_context('example'): ... with nnx.split_context('example') as ctx: ... graphdef, state = ctx.split(m1) ... @jax.jit ... def f(graphdef, state): ... with nnx.merge_context('example', inner=True) as ctx: ... m2 = ctx.merge(graphdef, state) ... m2.a = 1 ... m2.ref = m2 # create a reference cycle ... with nnx.split_context('example') as ctx: ... return ctx.split(m2) ... graphdef_out, state_out = f(graphdef, state) ... with nnx.merge_context('example', inner=False) as ctx: ... m3 = ctx.merge(graphdef_out, state_out) ... >>> assert m1 is m3 >>> assert m1.a == 1 >>> assert m1.ref is m1 Note that ``update_context`` takes in a ``tag`` argument which is used primarily as a safety mechanism reduce the risk of accidentally using the wrong UpdateContext when using :func:`current_update_context` to access the current active context. ``update_context`` can also be used as a decorator that creates/activates an UpdateContext context for the duration of the function:: >>> from flax import nnx ... >>> m1 = nnx.Dict({}) >>> @jax.jit ... def f(graphdef, state): ... with nnx.merge_context('example', inner=True) as ctx: ... m2 = ctx.merge(graphdef, state) ... m2.a = 1 # insert static attribute ... m2.ref = m2 # create a reference cycle ... with nnx.split_context('example') as ctx: ... return ctx.split(m2) ... >>> @nnx.update_context('example') ... def g(m1): ... with nnx.split_context('example') as ctx: ... graphdef, state = ctx.split(m1) ... graphdef_out, state_out = f(graphdef, state) ... with nnx.merge_context('example', inner=False) as ctx: ... return ctx.merge(graphdef_out, state_out) ... >>> m3 = g(m1) >>> assert m1 is m3 >>> assert m1.a == 1 >>> assert m1.ref is m1 The context can be accessed using :func:`current_update_context`. Args: tag: A string tag to identify the context. """ return UpdateContextManager(tag=tag)
[docs]def current_update_context(tag: tp.Hashable) -> UpdateContext: """Returns the current active :class:`UpdateContext` for the given tag.""" if tag not in GRAPH_CONTEXT.update_context_stacks: raise ValueError(f'No update context found for tag {tag!r}.') return GRAPH_CONTEXT.update_context_stacks[tag][-1]
# -------------------------------------------------------- # Functional API # -------------------------------------------------------- def _split_state( state: FlatState[tp.Any], filters: tuple[filterlib.Filter, ...], ) -> tuple[FlatState[tp.Any], tpe.Unpack[tuple[FlatState[tp.Any], ...]]]: if not filters: return (state,) # type: ignore[bad-return-type] states = state.split(*filters) if not isinstance(states, tuple): return (states,) # type: ignore[bad-return-type] assert len(states) > 0 return states # type: ignore[return-value] @tp.overload def split( # type: ignore[invalid-annotation] graph_node: A, / ) -> tuple[GraphDef[A], GraphState | VariableState]: ... @tp.overload def split( # type: ignore[invalid-annotation] graph_node: A, first: filterlib.Filter, / ) -> tuple[GraphDef[A], GraphState | VariableState]: ... @tp.overload def split( # type: ignore[invalid-annotation] graph_node: A, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[ GraphDef[A], GraphState | VariableState, tpe.Unpack[tuple[GraphState | VariableState, ...]], ]: ...
[docs]def split( # type: ignore[invalid-annotation] node: A, *filters: filterlib.Filter ) -> tuple[ GraphDef[A], GraphState | VariableState, tpe.Unpack[tuple[GraphState | VariableState, ...]], ]: """Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef contains all the static information needed to reconstruct a ``Module`` graph, it is analogous to JAX’s ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to switch seamlessly between stateful and stateless representations of the graph. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... >>> node = Foo(nnx.Rngs(0)) >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat) ... >>> jax.tree.map(jnp.shape, params) State({ 'batch_norm': { 'bias': VariableState( type=Param, value=(2,) ), 'scale': VariableState( type=Param, value=(2,) ) }, 'linear': { 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) } }) >>> jax.tree.map(jnp.shape, batch_stats) State({ 'batch_norm': { 'mean': VariableState( type=BatchStat, value=(2,) ), 'var': VariableState( type=BatchStat, value=(2,) ) } }) :func:`split` and :func:`merge` are primarily used to interact directly with JAX transformations, see `Functional API <https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#the-functional-api>`__ for more information. Arguments: node: graph node to split. *filters: some optional filters to group the state into mutually exclusive substates. Returns: ``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no filters are passed, a single ``State`` is returned. """ graphdef, flat_state = flatten(node) flat_states = _split_state(flat_state, filters) states = _to_nested_state(graphdef, flat_states) return graphdef, *states # type: ignore[return-value]
def _to_nested_state( graphdef: GraphDef[A], flat_states: tp.Iterable[tp.Any] ) -> tuple[tp.Any, ...]: if type(graphdef.nodes[0]) in (VariableDef, MutableArrayDef): states = tuple( flat_state[0][1] if flat_state else EmptyState() for flat_state in flat_states ) else: states = tuple( statelib.from_flat_state(flat_state) for flat_state in flat_states ) return states def _merge_to_flat_state(states: tp.Iterable[tp.Any]): flat_state: list[tuple[PathParts, tp.Any]] = [] for state in states: if isinstance(state, dict | State): flat_state.extend(traversals.flatten_to_sequence(state)) elif isinstance(state, FlatState): flat_state.extend(state) else: flat_state.append(((), state)) flat_state.sort() return [value for _, value in flat_state]
[docs]def merge( # type: ignore[invalid-annotation] graphdef: GraphDef[A], state: tp.Any, /, *states: tp.Any, ) -> A: """The inverse of :func:`flax.nnx.split`. ``nnx.merge`` takes a :class:`flax.nnx.GraphDef` and one or more :class:`flax.nnx.State`'s and creates a new node with the same structure as the original node. Recall: :func:`flax.nnx.split` is used to represent a :class:`flax.nnx.Module` by: 1) a static ``nnx.GraphDef`` that captures its Pythonic static information; and 2) one or more :class:`flax.nnx.Variable` ``nnx.State``'(s) that capture its ``jax.Array``'s in the form of JAX pytrees. ``nnx.merge`` is used in conjunction with ``nnx.split`` to switch seamlessly between stateful and stateless representations of the graph. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... >>> node = Foo(nnx.Rngs(0)) >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat) ... >>> new_node = nnx.merge(graphdef, params, batch_stats) >>> assert isinstance(new_node, Foo) >>> assert isinstance(new_node.batch_norm, nnx.BatchNorm) >>> assert isinstance(new_node.linear, nnx.Linear) ``nnx.split`` and ``nnx.merge`` are primarily used to interact directly with JAX transformations (refer to `Functional API <https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-functional-api>`__ for more information. Args: graphdef: A :class:`flax.nnx.GraphDef` object. state: A :class:`flax.nnx.State` object. *states: Additional :class:`flax.nnx.State` objects. Returns: The merged :class:`flax.nnx.Module`. """ if isinstance(state, list): if len(states) != 0: raise ValueError( f'Only one state can be passed as a list.' ) _state = state else: _state = _merge_to_flat_state((state, *states)) node = unflatten(graphdef, _state) return node
[docs]def update(node, state: tp.Any, /, *states: tp.Any) -> None: """Update the given graph node with a new state(s) in-place. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 3)) >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> def loss_fn(model, x, y): ... return jnp.mean((y - model(x))**2) >>> prev_loss = loss_fn(model, x, y) >>> grads = nnx.grad(loss_fn)(model, x, y) >>> new_state = jax.tree.map(lambda p, g: p - 0.1*g, nnx.state(model), grads) >>> nnx.update(model, new_state) >>> assert loss_fn(model, x, y) < prev_loss Args: node: A graph node to update. state: A :class:`State` object. *states: Additional :class:`State` objects. """ if states: if isinstance(node, Variable): non_empty_states = [ _state for _state in (state, *states) if not isinstance(_state, tp.Mapping) or _state ] if len(non_empty_states) != 1: all_states = (state, *states) raise ValueError( f'Expected exactly one non-empty state, got: {all_states!r}' ) state = non_empty_states[0] else: state = statelib.merge_state(state, *states) _graph_update_dynamic(node, state)
def _variables_generator(node) -> tp.Iterable[tuple[PathParts, Variable]]: for path, value in iter_graph(node): if isinstance(value, Variable): yield path, value @tp.overload def variables(node, /) -> State[Key, Variable]: ... @tp.overload def variables(node, first: filterlib.Filter, /) -> State[Key, Variable]: ... @tp.overload def variables( node, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[State[Key, Variable], ...]: ...
[docs]def variables( node, *filters: filterlib.Filter, ) -> tp.Union[State[Key, Variable], tuple[State[Key, Variable], ...]]: """Similar to :func:`state` but returns the current :class:`Variable` objects instead of new :class:`VariableState` instances. Example:: >>> from flax import nnx ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> params = nnx.variables(model, nnx.Param) ... >>> assert params['kernel'] is model.kernel >>> assert params['bias'] is model.bias Args: node: A graph node object. *filters: One or more :class:`Variable` objects to filter by. Returns: One or more :class:`State` mappings containing the :class:`Variable` objects. """ num_filters = len(filters) if num_filters == 0: filters = (..., ...) else: filters = (*filters, ...) variables_iterable = _variables_generator(node) flat_states = variablelib.split_flat_state( variables_iterable, (*filters, ...) ) states = tuple( statelib.from_flat_state(flat_state) for flat_state in flat_states ) if num_filters < 2: return states[0] return states
@tp.overload def state(node, /) -> GraphState: ... @tp.overload def state(node, first: filterlib.Filter, /) -> GraphState: ... @tp.overload def state( node, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[GraphState, ...]: ...
[docs]def state( node, *filters: filterlib.Filter, ) -> tp.Union[GraphState, tuple[GraphState, ...]]: """Similar to :func:`split` but only returns the :class:`State`'s indicated by the filters. Example usage:: >>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batch_norm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> # get the learnable parameters from the batch norm and linear layer >>> params = nnx.state(model, nnx.Param) >>> # get the batch statistics from the batch norm layer >>> batch_stats = nnx.state(model, nnx.BatchStat) >>> # get them separately >>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat) >>> # get them together >>> state = nnx.state(model) Args: node: A graph node object. *filters: One or more :class:`Variable` objects to filter by. Returns: One or more :class:`State` mappings. """ _, flat_state = flatten(node) state = flat_state.to_nested_state() states: GraphState | tuple[GraphState, ...] if len(filters) == 0: states = state # type: ignore[assignment] elif len(filters) == 1: states = statelib.filter_state(state, filters[0]) else: states = statelib.filter_state(state, filters[0], filters[1], *filters[2:]) return states
[docs]def graphdef(node: tp.Any, /) -> GraphDef[tp.Any]: """Get the :class:`GraphDef` of the given graph node. Example usage:: >>> from flax import nnx >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> graphdef, _ = nnx.split(model) >>> assert graphdef == nnx.graphdef(model) Args: node: A graph node object. Returns: The :class:`GraphDef` of the :class:`Module` object. """ graphdef, _ = flatten(node) return graphdef
@tp.overload def pop( node, filter: filterlib.Filter, /, ) -> GraphState: ... @tp.overload def pop( node, filter: filterlib.Filter, filter2: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[GraphState, ...]: ...
[docs]def pop( node, *filters: filterlib.Filter ) -> tp.Union[GraphState, tuple[GraphState, ...]]: """Pop one or more :class:`Variable` types from the graph node. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'i', x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> assert not hasattr(model, 'i') >>> y = model(x) >>> assert hasattr(model, 'i') >>> intermediates = nnx.pop(model, nnx.Intermediate) >>> assert intermediates['i'].value[0].shape == (1, 3) >>> assert not hasattr(model, 'i') Args: node: A graph node object. *filters: One or more :class:`Variable` objects to filter by. Returns: The popped :class:`State` containing the :class:`Variable` objects that were filtered for. """ if len(filters) == 0: raise ValueError('Expected at least one filter') id_to_index: dict[int, Index] = {} path_parts: PathParts = () predicates = tuple(filterlib.to_predicate(filter) for filter in filters) flat_states: tuple[dict[PathParts, LeafType], ...] = tuple( {} for _ in predicates ) _graph_pop( node=node, id_to_index=id_to_index, path_parts=path_parts, flat_states=flat_states, predicates=predicates, ) states = tuple( statelib.from_flat_state(flat_state) for flat_state in flat_states ) if len(states) == 1: return states[0] else: return states
[docs]def clone(node: Node) -> Node: """Create a deep copy of the given graph node. Example usage:: >>> from flax import nnx >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> cloned_model = nnx.clone(model) >>> model.bias.value += 1 >>> assert (model.bias.value != cloned_model.bias.value).all() Args: node: A graph node object. Returns: A deep copy of the :class:`Module` object. """ graphdef, state = split(node) return merge(graphdef, state)
def _mutable_like(path, x): return ( isinstance(x, Variable) and x.mutable ) or variablelib.is_mutable_array(x) def freeze(tree: A, /, only: filterlib.Filter = _mutable_like) -> A: """Converts a pytree of mutable arrays to regular arrays. Example:: >>> from flax import nnx >>> import jax >>> import jax.numpy as jnp ... >>> tree = [nnx.mutable_array(jnp.array(1.0)), jnp.array(2.0)] >>> assert nnx.is_mutable_array(tree[0]) ... >>> frozen_tree = nnx.freeze(tree) >>> assert isinstance(frozen_tree[0], jax.Array) If the tree contains duplicate mutable arrays, a ValueError is raised:: >>> shared_array = nnx.mutable_array(jnp.array(1.0)) >>> tree = [shared_array, shared_array] >>> try: ... nnx.freeze(tree) ... except ValueError as e: ... print(e) Found duplicate MutableArray found at path [1] and [0] ... ``only`` is a `Filter <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__ that can be used to specify which mutable arrays to freeze:: >>> tree = [nnx.mutable_array(jnp.array(1.0)), nnx.mutable_array(jnp.array(2.0))] >>> frozen_tree = nnx.freeze(tree, only=lambda path, x: path[0] == 0) ... >>> assert isinstance(frozen_tree[0], jax.Array) >>> assert isinstance(frozen_tree[1], nnx.MutableArray) Args: tree: A pytree potentially containing mutable arrays. only: A Filter to specify which mutable arrays to freeze. Returns: A pytree with the frozen arrays. """ freeze_filter = filterlib.to_predicate(only) mutable_arrays: dict[int, str] = {} def check_mutable_array(path, x): m_array_id = id(x) if m_array_id in mutable_arrays: current_path_str = jax.tree_util.keystr(path) previous_path_str = mutable_arrays[m_array_id] raise ValueError( f'Found duplicate MutableArray found at path {current_path_str} ' f'and {previous_path_str} at object {x}.' ) mutable_arrays[m_array_id] = jax.tree_util.keystr(path) def _freeze_fn(jax_path, x): path = tuple(_key_path_to_key(part) for part in jax_path) if freeze_filter(path, x): if isinstance(x, Variable): check_mutable_array(jax_path, x.raw_value) return x.from_metadata(x[...], x.get_metadata().copy()) elif variablelib.is_mutable_array(x): check_mutable_array(jax_path, x) return x[...] return x tree = jax.tree.map_with_path( _freeze_fn, tree, is_leaf=lambda x: isinstance(x, Variable) ) return tree def _array_like(path, x): return ( isinstance(x, Variable) and isinstance(x.raw_value, jax.Array) ) or isinstance(x, jax.Array) def mutable(tree: A, /, only: filterlib.Filter = _array_like) -> A: """Converts a pytree of arrays to mutable arrays. Example:: >>> from flax import nnx >>> import jax >>> import jax.numpy as jnp ... >>> tree = [jnp.array(1.0), nnx.mutable_array(jnp.array(2.0))] >>> mutable_tree = nnx.mutable(tree) >>> assert nnx.is_mutable_array(mutable_tree[0]) >>> assert nnx.is_mutable_array(mutable_tree[1]) If the tree contains duplicate arrays a ValueError is raised:: >>> shared_array = jnp.array(1.0) >>> tree = [shared_array, shared_array] >>> try: ... nnx.mutable(tree) ... except ValueError as e: ... print(e) Found duplicate Array found at path [1] and [0] ... ``only`` is a `Filter <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__ that can be used to specify which arrays to convert to mutable arrays. >>> tree = [jnp.array(1.0), jnp.array(2.0)] >>> mutable_tree = nnx.mutable(tree, only=lambda path, x: path[0] == 0) ... >>> assert isinstance(mutable_tree[0], nnx.MutableArray) >>> assert isinstance(mutable_tree[1], jax.Array) Args: tree: A pytree potentially containing arrays. only: A Filter to specify which arrays to convert to mutable arrays. Returns: A pytree with the mutable arrays. """ mutable_filter = filterlib.to_predicate(only) arrays: dict[int, str] = {} def check_array(path, x): m_array_id = id(x) if m_array_id in arrays: current_path_str = jax.tree_util.keystr(path) previous_path_str = arrays[m_array_id] raise ValueError( f'Found duplicate Array found at path {current_path_str} ' f'and {previous_path_str} at object {x}.' ) arrays[m_array_id] = jax.tree_util.keystr(path) def _mutable_fn(jax_path, x): path = tuple(_key_path_to_key(part) for part in jax_path) if mutable_filter(path, x): if isinstance(x, Variable) and isinstance(x.raw_value, jax.Array): check_array(jax_path, x.raw_value) mutable_array = variablelib.mutable_array(x.raw_value) return x.from_metadata(mutable_array, x.get_metadata().copy()) elif isinstance(x, jax.Array): check_array(jax_path, x) return variablelib.mutable_array(x) return x return jax.tree.map_with_path( _mutable_fn, tree, is_leaf=lambda x: isinstance(x, Variable) ) def pure(tree: A) -> A: """Returns a new tree with all ``Variable`` and ``VariableState`` objects replaced with inner values. This can be used to remove Variable metadata when its is not needed for tasks like serialization or exporting. Example:: >>> from flax import nnx >>> import jax >>> import jax.numpy as jnp ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> graphdef, state = nnx.split(model) >>> jax.tree.map(jnp.shape, state) State({ 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) }) >>> pure_state = nnx.pure(state) >>> jax.tree.map(jnp.shape, pure_state) State({ 'bias': (3,), 'kernel': (2, 3) }) Args: tree: A pytree potentially containing ``Variable`` and ``VariableState`` objects. Returns: A new pytree with all ``Variable`` and ``VariableState`` objects replaced with their inner values. """ def _pure_fn(x): if isinstance(x, Variable | VariableState): return x.raw_value return x return jax.tree.map( _pure_fn, tree, is_leaf=lambda x: isinstance(x, Variable | VariableState), )
[docs]def call( graphdef_state: tuple[GraphDef[A], GraphState], / ) -> ApplyCaller[tuple[GraphDef[A], GraphState]]: """Calls a method underlying graph node defined by a (GraphDef, State) pair. ``call`` takes a ``(GraphDef, State)`` pair and creates a proxy object that can be used to call methods on the underlying graph node. When a method is called, the output is returned along with a new (GraphDef, State) pair that represents the updated state of the graph node. ``call`` is equivalent to :func:`merge` > ``method`` > :func:`split`` but is more convenient to use in pure JAX functions. Example:: >>> from flax import nnx >>> import jax >>> import jax.numpy as jnp ... >>> class StatefulLinear(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... self.count = Variable(jnp.array(0, dtype=jnp.uint32)) ... ... def increment(self): ... self.count += 1 ... ... def __call__(self, x): ... self.increment() ... return x @ self.w + self.b ... >>> linear = StatefulLinear(3, 2, nnx.Rngs(0)) >>> linear_state = nnx.split(linear) ... >>> @jax.jit ... def forward(x, linear_state): ... y, linear_state = nnx.call(linear_state)(x) ... return y, linear_state ... >>> x = jnp.ones((1, 3)) >>> y, linear_state = forward(x, linear_state) >>> y, linear_state = forward(x, linear_state) ... >>> linear = nnx.merge(*linear_state) >>> linear.count.value Array(2, dtype=uint32) The proxy object returned by ``call`` supports indexing and attribute access to access nested methods. In the example below, the ``increment`` method indexing is used to call the ``increment`` method of the ``StatefulLinear`` module at the ``b`` key of a ``nodes`` dictionary. >>> class StatefulLinear(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32)) ... ... def increment(self): ... self.count += 1 ... ... def __call__(self, x): ... self.increment() ... return x @ self.w + self.b ... >>> rngs = nnx.Rngs(0) >>> nodes = dict( ... a=StatefulLinear(3, 2, rngs), ... b=StatefulLinear(2, 1, rngs), ... ) ... >>> node_state = nnx.split(nodes) >>> # use attribute access >>> _, node_state = nnx.call(node_state)['b'].increment() ... >>> nodes = nnx.merge(*node_state) >>> nodes['a'].count.value Array(0, dtype=uint32) >>> nodes['b'].count.value Array(1, dtype=uint32) """ def pure_caller(accessor: DelayedAccessor, *args, **kwargs): node = merge(*graphdef_state) method = accessor(node) out = method(*args, **kwargs) return out, split(node) return CallableProxy(pure_caller) # type: ignore
[docs]def iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]: """Iterates over all nested nodes and leaves of the given graph node, including the current node. ``iter_graph`` creates a generator that yields path and value pairs, where the path is a tuple of strings or integers representing the path to the value from the root. Repeated nodes are visited only once. Leaves include static values. Example:: >>> from flax import nnx >>> import jax.numpy as jnp ... >>> class Linear(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.din, self.dout = din, dout ... self.w = nnx.Param(jax.random.uniform(rngs.next(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... >>> module = Linear(3, 4, rngs=nnx.Rngs(0)) >>> graph = [module, module] ... >>> for path, value in nnx.iter_graph(graph): ... print(path, type(value).__name__) ... (0, '_object__state') ObjectState (0, 'b') Param (0, 'din') int (0, 'dout') int (0, 'w') Param (0,) Linear () list """ visited: set[int] = set() path_parts: PathParts = () yield from _iter_graph(node, visited, path_parts)
def _iter_graph( node: tp.Any, visited: set[int], path_parts: PathParts ) -> tp.Iterator[tuple[PathParts, tp.Any]]: if is_node(node): if id(node) in visited: return visited.add(id(node)) node_impl = get_node_impl(node) if node_impl is None and not ( isinstance(node, Variable) or variablelib.is_mutable_array(node) ): raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') assert node_impl is not None node_dict = node_impl.node_dict(node) for key, value in node_dict.items(): yield from _iter_graph(value, visited, (*path_parts, key)) yield path_parts, node @jax.tree_util.register_static @dataclasses.dataclass(frozen=True, slots=True) class Static(tp.Generic[A]): """An empty pytree node that treats its inner value as static. ``value`` must define ``__eq__`` and ``__hash__``. """ value: A # --------------------------------------------------------- # Pytree # --------------------------------------------------------- class GenericPytree: ... from jax._src.tree_util import _registry as JAX_PYTREE_REGISTRY def is_pytree_node(x: tp.Any) -> bool: if isinstance(x, Variable): return False elif type(x) in JAX_PYTREE_REGISTRY: return True elif isinstance(x, tuple): return True else: return False def _key_path_to_key(key: tp.Any) -> Key: if isinstance(key, jax.tree_util.SequenceKey): return key.idx elif isinstance( key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey) ): if not is_key_like(key.key): # type: ignore[not-supported-yet] raise ValueError( f'Invalid key: {key.key}. May be due to its type not being hashable or comparable.' ) return key.key elif isinstance(key, jax.tree_util.GetAttrKey): return key.name else: return str(key) class IndexesPytreeDef(tp.NamedTuple): key_index: HashableMapping[Key, int] treedef: jax.tree_util.PyTreeDef def _flatten_pytree(pytree: tp.Any): leaves, treedef = jax.tree_util.tree_flatten_with_path( pytree, is_leaf=lambda x: x is not pytree ) nodes = [(_key_path_to_key(path[0]), value) for path, value in leaves] key_index = HashableMapping( {key: i for i, (key, _) in enumerate(nodes)}, copy=False ) nodes.sort() # sort by key return nodes, IndexesPytreeDef(key_index, treedef) def _unflatten_pytree( nodes: tuple[tuple[Key, tp.Any], ...], metadata: IndexesPytreeDef ): # sort to original order sorted_nodes = sorted(nodes, key=lambda x: metadata.key_index[x[0]]) pytree = metadata.treedef.unflatten(value for _, value in sorted_nodes) return pytree PYTREE_NODE_IMPL = PytreeNodeImpl( type=GenericPytree, flatten=_flatten_pytree, unflatten=_unflatten_pytree, # type: ignore ) # common pytrees # list register_pytree_node_type( list, flatten=lambda x: (list(enumerate(x)), None), unflatten=lambda nodes, _: [value for _, value in nodes], # type: ignore ) # tuple register_pytree_node_type( tuple, flatten=lambda x: (list(enumerate(x)), None), unflatten=lambda nodes, _: tuple(value for _, value in nodes), # type: ignore ) # dict register_pytree_node_type( dict, flatten=lambda x: (sorted(x.items()), None), unflatten=lambda nodes, _: {key: value for key, value in nodes}, # type: ignore ) # None register_pytree_node_type( type(None), flatten=lambda x: ([], None), unflatten=lambda _, __: None, # type: ignore )