# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
from __future__ import annotations
import dataclasses
import functools
from functools import partial
import typing as tp
from typing import Any
from flax import config
import jax
import treescope # type: ignore[import-untyped]
from flax import errors
from flax.nnx import filterlib, reprlib, tracers, visualization
from flax.typing import Missing, PathParts, SizeBytes
import jax.tree_util as jtu
import jax.numpy as jnp
from jax._src.core import mutable_array, MutableArray
from jax._src.state.types import AbstractRef
A = tp.TypeVar('A')
B = tp.TypeVar('B')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
V = tp.TypeVar('V', bound='Variable[Any]')
GetValueHook = tp.Callable[['Variable[A]', A], A]
SetValueHook = tp.Callable[['Variable[A]', A], A]
CreateValueHook = tp.Callable[['Variable[A]', A], A]
AxisName = str
AxisIndex = int
AddAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None]
RemoveAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None]
MUTABLE_DEFAULT = True if config.flax_mutable_array else None
def is_mutable_array(x) -> tp.TypeGuard[MutableArray]:
return isinstance(x, jax.Array | AbstractRef | MutableArray) and isinstance(
jax.typeof(x), AbstractRef | MutableArray
)
[docs]class Variable(tp.Generic[A], reprlib.Representable):
"""The base class for all ``Variable`` types. Create custom ``Variable``
types by subclassing this class. Numerous NNX graph functions can filter
for specific ``Variable`` types, for example, :func:`split`, :func:`state`,
:func:`pop`, and :func:`State.filter`.
Example usage::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> class CustomVariable(nnx.Variable):
... pass
>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.linear = nnx.Linear(2, 3, rngs=rngs)
... self.custom_variable = CustomVariable(jnp.ones((1, 3)))
... def __call__(self, x):
... return self.linear(x) + self.custom_variable
>>> model = Model(rngs=nnx.Rngs(0))
>>> linear_variables = nnx.state(model, nnx.Param)
>>> jax.tree.map(jnp.shape, linear_variables)
State({
'linear': {
'bias': VariableState(
type=Param,
value=(3,)
),
'kernel': VariableState(
type=Param,
value=(2, 3)
)
}
})
>>> custom_variable = nnx.state(model, CustomVariable)
>>> jax.tree.map(jnp.shape, custom_variable)
State({
'custom_variable': VariableState(
type=CustomVariable,
value=(1, 3)
)
})
>>> variables = nnx.state(model)
>>> jax.tree.map(jnp.shape, variables)
State({
'custom_variable': VariableState(
type=CustomVariable,
value=(1, 3)
),
'linear': {
'bias': VariableState(
type=Param,
value=(3,)
),
'kernel': VariableState(
type=Param,
value=(2, 3)
)
}
})
"""
__slots__ = ('raw_value', '_trace_state', '_var_metadata')
raw_value: A
_trace_state: tracers.TraceState
_var_metadata: dict[str, tp.Any]
def __init__(
self,
value: tp.Union[A, VariableMetadata[A]],
*,
mutable: bool | None = MUTABLE_DEFAULT,
**metadata: tp.Any,
):
var_t = type(self)
object.__setattr__(self, '_trace_state', tracers.TraceState())
if isinstance(value, VariableMetadata):
metadata.update(value.metadata)
value = tp.cast(A, value.raw_value)
if mutable is None:
_value = value
elif mutable:
if is_mutable_array(value):
_value = tp.cast(A, value)
else:
_value = mutable_array(jnp.asarray(value))
else:
_value = tp.cast(A, jnp.asarray(value))
object.__setattr__(self, 'raw_value', _value)
if hasattr(var_t, 'on_get_value') and 'on_get_value' not in metadata:
metadata['on_get_value'] = var_t.on_get_value
if hasattr(var_t, 'on_set_value') and 'on_set_value' not in metadata:
metadata['on_set_value'] = var_t.on_set_value
if hasattr(var_t, 'on_create_value') and 'on_create_value' not in metadata:
metadata['on_create_value'] = var_t.on_create_value
if hasattr(var_t, 'on_add_axis') and 'on_add_axis' not in metadata:
metadata['on_add_axis'] = var_t.on_add_axis
if hasattr(var_t, 'on_remove_axis') and 'on_remove_axis' not in metadata:
metadata['on_remove_axis'] = var_t.on_remove_axis
object.__setattr__(self, '_var_metadata', metadata)
# run create_value hooks
object.__setattr__(self, 'raw_value', self.create_value(self.raw_value))
def __getattr__(self, name: str) -> tp.Any:
if name in object.__getattribute__(self, '_var_metadata'):
return self._var_metadata[name]
return getattr(self.raw_value, name)
def __setattr__(self, name: str, value: tp.Any):
if not self._trace_state.is_valid() and (
name != 'value' or not self.mutable
):
raise errors.TraceContextError(
f'Cannot mutate {type(self).__name__} from a different trace level'
)
if (
name == 'value'
or name == 'raw_value'
or name == '_var_metadata'
or name == '_trace_state'
):
object.__setattr__(self, name, value)
else:
self._var_metadata[name] = value
def __delattr__(self, name: str):
if not self._trace_state.is_valid():
raise errors.TraceContextError(
f'Cannot mutate {type(self).__name__} from a different trace level'
)
if (
name == 'value'
or name == 'raw_value'
or name == '_var_metadata'
or name == '_trace_state'
):
object.__delattr__(self, name)
else:
del self._var_metadata[name]
@classmethod
def state(cls, value: A, **metadata) -> VariableState[A]:
return cls(value, **metadata).to_state()
@property
def mutable(self) -> bool | None:
if is_mutable_array(self.raw_value):
return True
elif isinstance(self.raw_value, jax.Array):
return False
else:
return None
def get_metadata(self):
return self._var_metadata
def copy_from(self, other: Variable[A]) -> None:
if type(self) is not type(other):
raise ValueError(
f'Cannot copy from incompatible container, '
f'expected {type(self).__name__}, got {type(other).__name__}'
)
if self is other:
return
self.raw_value = other.raw_value
self._var_metadata.clear()
self._var_metadata.update(other.get_metadata())
def update_from_state(self, variable_state: VariableState[A]):
object.__setattr__(self, 'raw_value', variable_state.raw_value)
object.__setattr__(
self, '_var_metadata', variable_state._var_metadata.copy()
)
@property
def value(self) -> A:
value = self.raw_value
if is_mutable_array(value):
value = value[...]
if 'on_get_value' in self._var_metadata:
value = self._var_metadata['on_get_value'](self, value)
return value
@value.setter
def value(self, value: A):
if isinstance(value, Variable):
raise ValueError(
'Cannot set value to a Variable, ' 'use `copy_from` method instead'
)
if 'on_set_value' in self._var_metadata:
value = self._var_metadata['on_set_value'](self, value)
if config.flax_mutable_array:
self.raw_value[...] = value # type: ignore
else:
object.__setattr__(self, 'raw_value', value)
def create_value(self, value: A):
if 'on_create_value' in self._var_metadata:
value = self._var_metadata['on_create_value'](self, value)
return value
def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_add_axis' in self._var_metadata:
self._var_metadata['on_add_axis'](self, axis_index, axis_name)
def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_remove_axis' in self._var_metadata:
self._var_metadata['on_remove_axis'](self, axis_index, axis_name)
@tp.overload
def replace(self, value: B, **kwargs) -> Variable[B]: ...
@tp.overload
def replace(self, **kwargs) -> Variable[A]: ...
def replace(self, value: tp.Any = Missing, **kwargs) -> Variable[tp.Any]:
if value is not Missing:
kwargs['raw_value'] = value
# rename `value` to `raw_value`
if 'value' in kwargs:
kwargs['raw_value'] = kwargs.pop('value')
# return `value` if it is a Variable
if 'raw_value' in kwargs and isinstance(
value := kwargs['raw_value'], Variable
):
# remove value from kwargs
kwargs.pop('raw_value')
if type(self) is not type(value):
raise ValueError(
'Cannot replace value from incompatible container, '
f'expected {type(self).__name__}, got {type(value).__name__}'
)
# if kwargs aren't empty, recursively call replace
# else return variable value
if kwargs:
return value.replace(**kwargs)
else:
return value
# get and update attributes
# return new instance with updated attributes
obj = object.__new__(type(self))
object.__setattr__(obj, '_trace_state', self._trace_state)
object.__setattr__(obj, 'raw_value', kwargs.pop('raw_value'))
object.__setattr__(obj, '_var_metadata', self.get_metadata() | kwargs)
return obj
@classmethod
def from_metadata(cls, value: A, attributes: dict[str, tp.Any]):
obj = object.__new__(cls)
object.__setattr__(obj, '_trace_state', tracers.TraceState())
object.__setattr__(obj, 'raw_value', value)
object.__setattr__(obj, '_var_metadata', attributes)
return obj
def copy(self: Variable[A]) -> Variable[A]:
obj = object.__new__(type(self))
object.__setattr__(obj, '_trace_state', self._trace_state)
object.__setattr__(obj, 'raw_value', self.raw_value)
object.__setattr__(obj, '_var_metadata', self.get_metadata().copy())
return obj
def to_state(self: Variable[A]) -> VariableState[A]:
return VariableState(type(self), self.raw_value, **self._var_metadata)
def __nnx_repr__(self):
stats = SizeBytes.from_any(self.raw_value)
if stats:
comment = f' # {stats}'
else:
comment = ''
yield reprlib.Object(type=type(self).__name__, comment=comment)
yield reprlib.Attr('value', self.raw_value)
for name, value in self._var_metadata.items():
yield reprlib.Attr(name, repr(value))
def __treescope_repr__(self, path, subtree_renderer):
size_bytes = SizeBytes.from_any(self.value)
if size_bytes:
stats_repr = f' # {size_bytes}'
first_line_annotation = treescope.rendering_parts.comment_color(
treescope.rendering_parts.text(f'{stats_repr}')
)
else:
first_line_annotation = None
children = {'value': self.raw_value, **self._var_metadata}
return visualization.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
first_line_annotation=first_line_annotation,
)
# hooks API
if tp.TYPE_CHECKING:
def on_get_value(self, value: A) -> A: ...
def on_set_value(self, value: A) -> A: ...
def on_create_value(self, value: A) -> A: ...
def on_add_axis(
self: V, axis_index: AxisIndex, axis_name: AxisName | None
) -> V: ...
def on_remove_axis(
self: V, axis_index: AxisIndex, axis_name: AxisName | None
) -> V: ...
def __jax_array__(self):
return self.value
# pickle support
def __getstate__(self):
return {
'raw_value': self.raw_value,
'_trace_state': self._trace_state,
'_var_metadata': self._var_metadata,
}
def __setstate__(self, state):
object.__setattr__(self, 'raw_value', state['raw_value'])
object.__setattr__(self, '_trace_state', state['_trace_state'])
object.__setattr__(self, '_var_metadata', state['_var_metadata'])
# --------------------------------------------
# proxy methods
# --------------------------------------------
def __getitem__(self, key) -> jax.Array:
return self.value[key] # type: ignore
def __setitem__(self, key, value) -> None:
if config.flax_mutable_array:
self.raw_value[key] = value # type: ignore
else:
if not is_mutable_array(self.raw_value):
if not self._trace_state.is_valid():
raise errors.TraceContextError(
f'Cannot mutate {type(self).__name__} from a different trace level'
)
if isinstance(self.raw_value, jax.Array):
self.raw_value = self.raw_value.at[key].set(value) # type: ignore
else:
self.raw_value[key] = value # type: ignore
def __call__(self, *args, **kwargs) -> tp.Any:
return self.value(*args, **kwargs) # type: ignore
def __len__(self) -> int:
return len(self.value) # type: ignore
def __iter__(self) -> tp.Iterator:
return iter(self.value) # type: ignore
def __contains__(self, item) -> bool:
return item in self.value # type: ignore
def __add__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__add__(other) # type: ignore
def __sub__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__sub__(other) # type: ignore
def __mul__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__mul__(other) # type: ignore
def __matmul__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__matmul__(other) # type: ignore
def __truediv__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__truediv__(other) # type: ignore
def __floordiv__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__floordiv__(other) # type: ignore
def __mod__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__mod__(other) # type: ignore
def __divmod__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__divmod__(other) # type: ignore
def __pow__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__pow__(other) # type: ignore
def __lshift__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__lshift__(other) # type: ignore
def __rshift__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rshift__(other) # type: ignore
def __and__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__and__(other) # type: ignore
def __xor__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__xor__(other) # type: ignore
def __or__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__or__(other) # type: ignore
def __radd__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__radd__(other) # type: ignore
def __rsub__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rsub__(other) # type: ignore
def __rmul__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rmul__(other) # type: ignore
def __rmatmul__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rmatmul__(other) # type: ignore
def __rtruediv__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rtruediv__(other) # type: ignore
def __rfloordiv__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rfloordiv__(other) # type: ignore
def __rmod__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rmod__(other) # type: ignore
def __rdivmod__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rdivmod__(other) # type: ignore
def __rpow__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rpow__(other) # type: ignore
def __rlshift__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rlshift__(other) # type: ignore
def __rrshift__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rrshift__(other) # type: ignore
def __rand__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rand__(other) # type: ignore
def __rxor__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__rxor__(other) # type: ignore
def __ror__(self, other) -> A:
if isinstance(other, Variable):
other = other.value
return self.value.__ror__(other) # type: ignore
def __iadd__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__iadd__'):
value.__iadd__(other)
else:
self.value = value.__add__(other)
return self
def __isub__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__isub__'):
value.__isub__(other)
else:
self.value = value.__sub__(other)
return self
def __imul__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__imul__'):
value.__imul__(other)
else:
self.value = value.__mul__(other)
return self
def __imatmul__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__imatmul__'):
value.__imatmul__(other)
else:
self.value = value.__matmul__(other)
return self
def __itruediv__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__itruediv__'):
value.__itruediv__(other)
else:
self.value = value.__truediv__(other)
return self
def __ifloordiv__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__ifloordiv__'):
value.__ifloordiv__(other)
else:
self.value = value.__floordiv__(other)
return self
def __imod__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__imod__'):
value.__imod__(other)
else:
self.value = value.__mod__(other)
return self
def __ipow__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__ipow__'):
value.__ipow__(other)
else:
self.value = value.__pow__(other)
return self
def __ilshift__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__ilshift__'):
value.__ilshift__(other)
else:
self.value = value.__lshift__(other)
return self
def __irshift__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__irshift__'):
value.__irshift__(other)
else:
self.value = value.__rshift__(other)
return self
def __iand__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__iand__'):
value.__iand__(other)
else:
self.value = value.__and__(other)
return self
def __ixor__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__ixor__'):
value.__ixor__(other)
else:
self.value = value.__xor__(other)
return self
def __ior__(self: V, other) -> V:
if isinstance(other, Variable):
other = other.value
value = self.value
if hasattr(value, '__ior__'):
value.__ior__(other)
else:
self.value = value.__or__(other)
return self
def __neg__(self) -> A:
return self.value.__neg__() # type: ignore
def __pos__(self) -> A:
return self.value.__pos__() # type: ignore
def __abs__(self) -> A:
return self.value.__abs__() # type: ignore
def __invert__(self) -> A:
return self.value.__invert__() # type: ignore
def __complex__(self) -> A:
return self.value.__complex__() # type: ignore
def __int__(self) -> A:
return self.value.__int__() # type: ignore
def __float__(self) -> A:
return self.value.__float__() # type: ignore
def __index__(self) -> A:
return self.value.__index__() # type: ignore
def __round__(self, ndigits: int) -> A:
return self.value.__round__(ndigits) # type: ignore
def __trunc__(self) -> A:
return self.value.__trunc__() # type: ignore
def __floor__(self) -> A:
return self.value.__floor__() # type: ignore
def __ceil__(self) -> A:
return self.value.__ceil__() # type: ignore
# --------------------------------------------
def __init_subclass__(cls) -> None:
super().__init_subclass__()
jax.tree_util.register_pytree_with_keys(
cls,
flatten_with_keys=_variable_flatten_with_keys,
unflatten_func=partial(_variable_unflatten, cls), # type: ignore
flatten_func=_variable_flatten,
)
def _variable_flatten_with_keys(x: Variable[tp.Any]):
metadata = tuple(x.get_metadata().items())
node = (jtu.GetAttrKey('value'), x.raw_value)
return (node,), metadata
def _variable_flatten(x: Variable[tp.Any]):
metadata = tuple(x.get_metadata().items())
return (x.raw_value,), metadata
def _variable_unflatten(
cls: type[Variable[tp.Any]],
static: tuple[tuple[str, tp.Any], ...],
children: tuple[tp.Any],
):
return cls.from_metadata(value=children[0], attributes=dict(static))
jax.tree_util.register_pytree_with_keys(
Variable,
flatten_with_keys=_variable_flatten_with_keys,
unflatten_func=partial(_variable_unflatten, Variable), # type: ignore
flatten_func=_variable_flatten,
)
[docs]class Param(Variable[A]):
"""The canonical learnable parameter. All learnable parameters
in NNX layer modules will have the ``Param`` :class:`Variable`
type::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
'bias': VariableState(
type=Param,
value=(3,)
),
'kernel': VariableState(
type=Param,
value=(2, 3)
)
})
"""
pass
[docs]class BatchStat(Variable[A]):
"""The mean and variance batch statistics stored in
the :class:`BatchNorm` layer. Note, these are not the
learnable scale and bias parameters, but rather the
running average statistics that are typically used
during post-training inference::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.BatchNorm(3, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
'bias': VariableState(
type=Param,
value=(3,)
),
'mean': VariableState(
type=BatchStat,
value=(3,)
),
'scale': VariableState(
type=Param,
value=(3,)
),
'var': VariableState(
type=BatchStat,
value=(3,)
)
})
"""
pass
[docs]class Cache(Variable[A]):
"""Autoregressive cache in :class:`MultiHeadAttention`::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.MultiHeadAttention(
... num_heads=2,
... in_features=3,
... qkv_features=6,
... out_features=6,
... decode=True,
... rngs=nnx.Rngs(0),
... )
>>> layer.init_cache((1, 3))
>>> jax.tree.map(jnp.shape, nnx.state(layer, nnx.Cache))
State({
'cache_index': VariableState(
type=Cache,
value=()
),
'cached_key': VariableState(
type=Cache,
value=(1, 2, 3)
),
'cached_value': VariableState(
type=Cache,
value=(1, 2, 3)
)
})
"""
pass
class Perturbation(Intermediate[A]):
""":class:`Variable` type that is typically used for
:func:`Module.perturb`::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.linear1 = nnx.Linear(2, 3, rngs=rngs)
... self.linear2 = nnx.Linear(3, 4, rngs=rngs)
... def __call__(self, x):
... x = self.linear1(x)
... x = self.perturb('i', x)
... x = self.linear2(x)
... return x
>>> model = Model(rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 2))
>>> y = model(x)
>>> jax.tree.map(jnp.shape, nnx.state(model, nnx.Perturbation))
State({
'i': VariableState(
type=Perturbation,
value=(1, 3)
)
})
"""
pass
[docs]class VariableState(tp.Generic[A], reprlib.Representable):
__slots__ = ('type', 'value', '_var_metadata')
type: type[Variable[A]]
value: A
_var_metadata: dict[str, tp.Any]
def __init__(
self,
type: type[Variable[A]], # type: ignore [valid-type]
value: A,
**metadata,
):
object.__setattr__(self, 'type', type)
object.__setattr__(self, 'value', value)
object.__setattr__(self, '_var_metadata', metadata)
@property
def raw_value(self) -> A:
return object.__getattribute__(self, 'value')
@raw_value.setter
def raw_value(self, value: A) -> None:
object.__setattr__(self, 'value', value)
def __getattribute__(self, name: str) -> None:
if name == 'value':
value = object.__getattribute__(self, 'value')
if is_mutable_array(value):
value = value[...]
return value
return object.__getattribute__(self, name)
def __getattr__(self, name: str) -> None:
var_metadata = object.__getattribute__(self, '_var_metadata')
if name not in var_metadata:
raise AttributeError(f"'VariableState' object has no attribute '{name}'")
return var_metadata[name]
def __setattr__(self, name: str, value: Any) -> None:
if name in ('type', 'value', '_var_metadata', 'raw_value'):
object.__setattr__(self, name, value)
else:
self._var_metadata[name] = value
def __delattr__(self, name: str) -> None:
if name in ('type', 'value', '_var_metadata', 'raw_value'):
object.__delattr__(self, name)
else:
del self._var_metadata[name]
def __getitem__(self, key: Any) -> jax.Array:
return self.raw_value[key] # type: ignore
def __setitem__(self, key: Any, value: Any) -> None:
self.raw_value[key] = value # type: ignore
def __nnx_repr__(self):
stats = SizeBytes.from_any(self.raw_value)
if stats:
comment = f' # {stats}'
else:
comment = ''
yield reprlib.Object(type=type(self), comment=comment)
yield reprlib.Attr('type', self.type)
yield reprlib.Attr('value', self.raw_value)
for name, value in self._var_metadata.items():
yield reprlib.Attr(name, value)
def __treescope_repr__(self, path, subtree_renderer):
size_bytes = SizeBytes.from_any(self.raw_value)
if size_bytes:
stats_repr = f' # {size_bytes}'
first_line_annotation = treescope.rendering_parts.comment_color(
treescope.rendering_parts.text(f'{stats_repr}')
)
else:
first_line_annotation = None
children = {'type': self.type, 'value': self.value, **self._var_metadata}
return visualization.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
first_line_annotation=first_line_annotation,
)
def replace(self, value: B) -> VariableState[B]:
return VariableState(self.type, value, **self.get_metadata())
def to_variable(self) -> Variable[A]:
# we use object.__new__ to avoid calling __init__ and bypass the
# __init__ logic which should not be called twice
variable = object.__new__(self.type)
object.__setattr__(variable, '_trace_state', tracers.TraceState())
object.__setattr__(variable, 'raw_value', self.raw_value)
object.__setattr__(variable, '_var_metadata', self.get_metadata().copy())
return variable
def copy(self: VariableState[A]) -> VariableState[A]:
return jax.tree.map(lambda x: x, self)
def get_metadata(self) -> dict[str, tp.Any]:
return self._var_metadata
def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_add_axis' in self._var_metadata:
self._var_metadata['on_add_axis'](self, axis_index, axis_name)
def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_remove_axis' in self._var_metadata:
self._var_metadata['on_remove_axis'](self, axis_index, axis_name)
GraphVariableState = VariableState[VariableState[tp.Any]]
def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool):
metadata = tuple(x.get_metadata().items())
if with_keys:
node = (jtu.GetAttrKey('value'), x.raw_value)
else:
node = x.raw_value
return (node,), (x.type, metadata)
def _variable_state_unflatten(
static: tuple[type[Variable[A]], tuple[tuple[str, tp.Any], ...]],
children: tuple[A],
) -> VariableState[A]:
return VariableState(
type=static[0],
value=children[0],
**dict(static[1]),
)
jtu.register_pytree_with_keys(
VariableState,
partial(_variable_state_flatten, with_keys=True), # type: ignore
_variable_state_unflatten, # type: ignore
flatten_func=partial(_variable_state_flatten, with_keys=False), # type: ignore
)
def split_flat_state(
flat_state: tp.Iterable[tuple[PathParts, Variable | VariableState]],
filters: tuple[filterlib.Filter, ...],
) -> tuple[list[tuple[PathParts, Variable | VariableState]], ...]:
predicates = filterlib.filters_to_predicates(filters)
# we have n + 1 states, where n is the number of predicates
# the last state is for values that don't match any predicate
flat_states: tuple[list[tuple[PathParts, Variable | VariableState]], ...] = (
tuple([] for _ in predicates)
)
for path, value in flat_state:
for i, predicate in enumerate(predicates):
if predicate(path, value):
flat_states[i].append((path, value))
break
else:
raise ValueError(
'Non-exhaustive filters, got a non-empty remainder: '
f'{path} -> {value}.'
'\nUse `...` to match all remaining elements.'
)
return flat_states
###################################################
### Variable type/class <-> string name mapping ###
###################################################
# Assumption: the mapping is 1-1 and unique.
VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {}
[docs]def variable_type_from_name(
name: str,
/,
*,
base: type[Variable[tp.Any]] = Variable,
allow_register: bool = False,
) -> tp.Type[Variable[tp.Any]]:
"""Given a Linen-style collection name, get or create its NNX Variable class."""
if name not in VariableTypeCache:
if not allow_register:
raise ValueError(
f'Name {name} is not registered in the registry. '
'To register a new name, use register_variable_name() '
'or set allow_register=True.'
)
VariableTypeCache[name] = type(name, (base,), {})
return VariableTypeCache[name]
[docs]def variable_name_from_type(
typ: tp.Type[Variable[tp.Any]], /, *, allow_register: bool = False
) -> str:
"""Given an NNX Variable type, get its Linen-style collection name.
Should output the exact inversed result of `variable_type_from_name()`."""
for name, t in VariableTypeCache.items():
if typ == t:
return name
if not allow_register:
raise ValueError(
f'Type {typ} is not registered in the registry. '
'To register a new type, use register_variable_name() '
'or set allow_register=True.'
)
name = typ.__name__
if name in VariableTypeCache:
raise ValueError(
'Name {name} is already registered in the registry as {VariableTypeCache[name]}. '
'It cannot be linked with this type {typ}.'
)
register_variable_name(name, typ)
return name
class _Missing:
pass
_MISSING = _Missing()
@tp.overload
def register_variable_name(
name: str,
typ: type[Variable[tp.Any]],
*,
overwrite: bool = False,
) -> type[Variable[tp.Any]]:
...
@tp.overload
def register_variable_name(
name: str,
*,
overwrite: bool = False,
) -> tp.Callable[[type[Variable[tp.Any]]], type[Variable[tp.Any]]]:
...
def register_variable_name(
name: str,
typ: type[Variable[A]] | _Missing = _MISSING,
*,
overwrite=False,
) -> type[Variable[A]] | tp.Callable[[type[Variable[A]]], type[Variable[A]]]:
"""Register a pair of Linen collection name and its NNX type."""
if typ is _MISSING:
return partial(register_variable_name, name, overwrite=overwrite)
typ = tp.cast(type[Variable[A]], typ)
if not overwrite and name in VariableTypeCache:
raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. '
'To overwrite, call set_variable_name() with `overwrite=True`.')
VariableTypeCache[name] = typ
return typ
# add known variable type names
register_variable_name('params', Param)
register_variable_name('batch_stats', BatchStat)
register_variable_name('cache', Cache)
register_variable_name('intermediates', Intermediate)
register_variable_name('perturbations', Perturbation)