Source code for flax.nnx.dataclasses

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import dataclasses
import typing as tp
import typing_extensions as tpe

from flax import struct
from flax.nnx import object as objectlib

A = tp.TypeVar('A')
T = tp.TypeVar('T', bound=type[objectlib.Object])


class StaticTag:
  ...


Static = tp.Annotated[A, StaticTag]  # type: ignore[invalid-typevar]


def _is_static(annotation: type, cls_attr: tp.Any) -> bool:
  return (
    annotation != tp.ClassVar
    and getattr(annotation, '__metadata__', None) == (StaticTag,)
  ) or (
    isinstance(cls_attr, dataclasses.Field)
    and (
      cls_attr.metadata.get('static', False)
      or not cls_attr.metadata.get('pytree_node', True)
    )
  )


# def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
#           hash=None, compare=True, metadata=None, kw_only=MISSING):
MISSING = dataclasses.MISSING


@tp.overload # type: ignore[misc]
def field(  # type: ignore[misc]
  *,
  default: tp.Any = MISSING,
  default_factory: tp.Callable[[], A] | tp.Any = MISSING,
  init: bool = True,
  repr: bool = True,
  hash: bool | None = None,
  compare: bool = True,
  metadata: tp.Mapping[str, tp.Any] | None = None,
  kw_only: bool = False,
  static: bool = False,
) -> tp.Any: ...


[docs]def field( *, static: bool = False, **kwargs, ): metadata = kwargs.pop('metadata', None) metadata = dict(metadata) if metadata else {} if 'static' in metadata and metadata['static'] != static: raise ValueError( f'Inconsistent static metadata, field specified {static=} ' f'but also got {metadata["static"]=}' ) else: metadata['static'] = static kwargs['metadata'] = metadata return dataclasses.field(**kwargs) # type: ignore[wrong-arg-type]
@tp.overload # type: ignore[misc] def static( # type: ignore[misc] *, default: tp.Any = MISSING, default_factory: tp.Callable[[], A] | tp.Any = MISSING, init: bool = True, repr: bool = True, hash: bool | None = None, compare: bool = True, metadata: tp.Mapping[str, tp.Any] | None = None, kw_only: bool = False, ) -> tp.Any: ... def static(**kwargs): return field( static=True, **kwargs, ) @tp.overload def dataclass( cls: T, /, *, init: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, ) -> T: ... @tp.overload def dataclass( *, init: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, ) -> tp.Callable[[T], T]: ...
[docs]@tpe.dataclass_transform( # type: ignore[not-supported-yet] field_specifiers=(field, static, dataclasses.field, struct.field), ) def dataclass( cls: T | None = None, /, *, init: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, ) -> T | tp.Callable[[T], T]: """Makes an nnx.Object type as a dataclass and defines its pytree node attributes using type hints. ``nnx.dataclass`` can be used to create pytree dataclass types using type hints instead of the ``__data__`` attribute. By default, all fields are considered to be nodes, to mark a field as static annotate it with ``nnx.Static[T]``. Example:: from flax import nnx import jax @nnx.dataclass class Foo(nnx.Object): a: int b: jax.Array c: nnx.Static[int] tree = Foo(a=1, b=jax.numpy.array(1), c=1) assert len(jax.tree.leaves(tree)) == 2 # a and b ``dataclass`` will raise a ``ValueError`` if the class does not derive from ``nnx.Object``, if the parent Object has ``pytree`` set to anything other than ``'strict'``, or if the class has a ``__data__`` attribute. ``nnx.dataclass`` doesn't accept ``repr`` and defines it as ``False`` to avoid overwriting the default ``__repr__`` method from ``Object``. """ def _dataclass(cls: T): if not issubclass(cls, objectlib.Object): raise ValueError( 'dataclass can only be used with a class derived from nnx.Object' ) if cls._object__nodes in ('auto', 'all'): raise ValueError( "dataclass cannot be used with a class that has __data__ set to 'auto' or 'all', " f'got {cls._object__nodes}' ) # here we redefine _object__nodes using the type hints hints = cls.__annotations__ if cls._object__nodes is None: all_nodes = set() else: all_nodes = set(cls._object__nodes) for name, typ in hints.items(): class_attr = getattr(cls, name, None) if not _is_static(typ, class_attr): all_nodes.add(name) cls._object__nodes = frozenset(all_nodes) cls = dataclasses.dataclass( # type: ignore cls, init=init, repr=False, eq=eq, order=order, unsafe_hash=unsafe_hash, match_args=match_args, kw_only=kw_only, slots=slots, ) return cls if cls is None: return _dataclass else: return _dataclass(cls)