Source code for flax.nnx.spmd

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import typing as tp

import flax.core.spmd as core_spmd
from flax.nnx import variablelib
from flax.typing import (
  Array,
  ArrayPytree,  # pylint: disable=invalid-name
  PartitionSpecPytree,  # pylint: disable=invalid-name
  Sharding,
)
import jax
from jax.interpreters import pxla
from jax.sharding import PartitionSpec

A = tp.TypeVar('A')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
PARTITION_NAME = 'partition_name'

class HasSharding(tp.Protocol):
  sharding: tuple[str | None, ...] | None


def _has_sharding(x: tp.Any) -> tp.TypeGuard[HasSharding]:
  return hasattr(x, 'sharding') and x.sharding is not None

def add_axis(tree: A, index: int, transform_metadata: tp.Mapping) -> A:
  axis_name, other_meta = _get_partition_name_and_metadata(transform_metadata)

  def insert_field(fields, index, value):
    iterable = list(fields)
    while len(iterable) < index:
      iterable.append(None)
    iterable.insert(index, value)
    return tuple(iterable)

  def _add_axis(x: tp.Any):
    if isinstance(x, variablelib.VariableState):
      if _has_sharding(x) and x.sharding is not None:
        x.sharding = insert_field(x.sharding, index, axis_name)

      for k, v in other_meta.items():
        if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple):
          setattr(x, k, insert_field(t, index, v))

      assert isinstance(x, variablelib.VariableState)
      x.add_axis(index, axis_name)
    return x

  return jax.tree.map(
    _add_axis, tree, is_leaf=lambda x: isinstance(x, variablelib.VariableState)
  )


def remove_axis(tree: A, index: int, transform_metadata: tp.Mapping[tp.Any, tp.Any]) -> A:
  axis_name, other_meta = _get_partition_name_and_metadata(transform_metadata)

  def remove_field(fields, index, value):
    iterable = list(fields)
    assert iterable.pop(index) == value
    return tuple(iterable)

  def _remove_axis(x: tp.Any):
    if isinstance(x, variablelib.VariableState):
      if hasattr(x, 'sharding') and x.sharding is not None:
        x.sharding = remove_field(x.sharding, index, axis_name)

      for k, v in other_meta.items():
        if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple):
          setattr(x, k, remove_field(t, index, v))

      x.remove_axis(index, axis_name)
    return x

  return jax.tree.map(
    _remove_axis,
    tree,
    is_leaf=lambda x: isinstance(x, variablelib.VariableState),
  )


def _get_partition_name_and_metadata(
  transform_metadata: tp.Mapping[tp.Any, tp.Any]
) -> tuple[str, tp.Mapping[tp.Any, tp.Any]]:
  if PARTITION_NAME not in transform_metadata:
    raise ValueError(
      'Trying to transform a Partitioned variable but "partition_name" '
      f'is not specified in transform_metadata: {transform_metadata}'
    )
  other_meta = dict(transform_metadata)  # shallow copy
  other_meta.pop(PARTITION_NAME)
  return transform_metadata[PARTITION_NAME], other_meta


[docs]def get_partition_spec(tree: A) -> A: """Extracts a PartitionSpec tree from a PyTree containing ``Variable`` values.""" def _maybe_replicate(x): if hasattr(x, 'shape'): return PartitionSpec() else: return None def f(x): if isinstance(x, (variablelib.VariableState, variablelib.Variable)): if hasattr(x, 'sharding') and x.sharding: if core_spmd.get_logical_axis_rules() or hasattr(x, 'sharding_rules'): context_rules = core_spmd.get_logical_axis_rules() local_rules = getattr(x, 'sharding_rules', ()) rules = core_spmd.composite_rules(context_rules, local_rules) return x.replace( PartitionSpec(*core_spmd.from_sharding_rules(x.sharding, rules)) ) return x.replace(PartitionSpec(*x.sharding)) else: return x.replace(_maybe_replicate(x.raw_value)) return _maybe_replicate(x) return jax.tree.map( f, tree, is_leaf=lambda x: isinstance( x, (variablelib.VariableState, variablelib.Variable) ), )
[docs]def get_named_sharding(tree: A, mesh: jax.sharding.Mesh) -> A: spec = get_partition_spec(tree) sharding = jax.tree.map( lambda p: jax.sharding.NamedSharding(mesh, p), spec ) return sharding
# Dynamic Axis Mapping Rngs # ------------------------------------------------------------------------------ def _global_mesh_defined() -> bool: """Checks if global mesh resource environment is defined.""" env = pxla.thread_resources.env return env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison def _with_sharding_constraint( x: Array, axis_resources: tp.Optional[jax.sharding.PartitionSpec], mesh: tp.Optional[jax.sharding.Mesh] = None, ): # if jax.devices()[0].platform == "cpu" or ( if not _global_mesh_defined() and mesh is None: return x else: if mesh is not None and axis_resources is not None: sharding = jax.sharding.NamedSharding(mesh, axis_resources) return jax.lax.with_sharding_constraint(x, sharding) return jax.lax.with_sharding_constraint(x, axis_resources) def _is_spec(x): return x is None or ( isinstance(x, tuple) and all(isinstance(e, str) or e is None for e in x) )
[docs]def with_sharding_constraint( x: ArrayPytree, axis_resources: PartitionSpecPytree, mesh: tp.Optional[jax.sharding.Mesh] = None, ): # If no axis binding is set, this is a no-op. if axis_resources is None: return x # Translate logical names to mesh assignments. return jax.tree.map( functools.partial(_with_sharding_constraint, mesh=mesh), x, axis_resources, is_leaf=_is_spec, )
[docs]def with_partitioning( initializer: F, sharding: Sharding, mesh: tp.Optional[jax.sharding.Mesh] = None, **metadata: tp.Any, ) -> F: return variablelib.with_metadata( initializer, sharding=sharding, mesh=mesh, **metadata, )