Source code for flax.nnx.nn.lora

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import typing as tp

from flax.nnx import rnglib, variablelib
from flax.nnx.module import Module
from flax.nnx.nn import initializers
from flax.nnx.nn.linear import Linear
from flax.nnx.nn.dtypes import promote_dtype
from flax.typing import Dtype, Initializer
import jax
import jax.numpy as jnp

Array = jax.Array
Axis = int
Size = int
A = tp.TypeVar('A')

default_a_initializer = initializers.he_uniform()
default_b_initializer = initializers.zeros


class LoRAParam(variablelib.Param[A]):
  pass


[docs]class LoRA(Module): """A standalone LoRA layer. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.LoRA(3, 2, 4, rngs=nnx.Rngs(0)) >>> layer.lora_a.value.shape (3, 2) >>> layer.lora_b.value.shape (2, 4) >>> # Wrap around existing layer >>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0)) >>> wrapper = nnx.LoRA(3, 2, 4, base_module=linear, rngs=nnx.Rngs(1)) >>> assert wrapper.base_module == linear >>> wrapper.lora_a.value.shape (3, 2) >>> layer.lora_b.value.shape (2, 4) >>> y = layer(jnp.ones((16, 3))) >>> y.shape (16, 4) Args: in_features: the number of input features. lora_rank: the rank of the LoRA dimension. out_features: the number of output features. base_module: a base module to call and substitute, if possible. dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. a_initializer: initializer function for the fan-in matrices. Default to `he_uniform`. b_initializer: initializer function for the fan-out matrices. Default to `zero initializer`. lora_param_type: the type of the LoRA params. """ __data__ = ('lora_a', 'lora_b', 'base_module') def __init__( self, in_features: int, lora_rank: int, out_features: int, *, base_module: tp.Optional[Module] = None, dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, a_initializer: Initializer = default_a_initializer, b_initializer: Initializer = default_b_initializer, lora_param_type: tp.Type[variablelib.Variable] = LoRAParam, rngs: rnglib.Rngs, ): self.in_features = in_features self.out_features = out_features self.dtype = dtype self.param_dtype = param_dtype self.lora_param_type = lora_param_type self.base_module = base_module self.lora_a = lora_param_type( a_initializer(rngs.params(), (in_features, lora_rank), param_dtype) ) self.lora_b = lora_param_type( b_initializer(rngs.params(), (lora_rank, out_features), param_dtype) )
[docs] def __call__(self, x: jax.Array): x, lora_a, lora_b = promote_dtype( (x, self.lora_a[...], self.lora_b[...]), dtype=self.dtype ) out = x @ lora_a @ lora_b if self.base_module is not None: if not callable(self.base_module): raise ValueError('`self.base_module` must be callable.') out += self.base_module(x) return out
[docs]class LoRALinear(Linear): """An `nnx.Linear` layer in which the output will be LoRAified. The model state structure will be compatible with that of Linear. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0)) >>> lora_linear = nnx.LoRALinear(3, 4, lora_rank=2, rngs=nnx.Rngs(0)) >>> linear.kernel.value.shape (3, 4) >>> lora_linear.kernel.value.shape (3, 4) >>> lora_linear.lora.lora_a.value.shape (3, 2) >>> jnp.allclose(linear.kernel.value, lora_linear.kernel.value) Array(True, dtype=bool) >>> y = lora_linear(jnp.ones((16, 3))) >>> y.shape (16, 4) Args: in_features: the number of input features. out_features: the number of output features. lora_rank: the rank of the LoRA dimension. base_module: a base module to call and substitute, if possible. dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. a_initializer: initializer function for the fan-in matrices. Default to `he_uniform`. b_initializer: initializer function for the fan-out matrices. Default to `zero initializer`. lora_param_type: the type of the LoRA params. """ __data__ = ('lora',) # type: ignore[assignment] def __init__( self, in_features: int, out_features: int, *, lora_rank: int, lora_dtype: tp.Optional[Dtype] = None, lora_param_dtype: Dtype = jnp.float32, a_initializer: Initializer = default_a_initializer, b_initializer: Initializer = default_b_initializer, lora_param_type: tp.Type[variablelib.Variable] = LoRAParam, rngs: rnglib.Rngs, **kwargs, ): super().__init__(in_features, out_features, rngs=rngs, **kwargs) self.lora = LoRA( in_features, lora_rank, out_features, dtype=lora_dtype, param_dtype=lora_param_dtype, a_initializer=a_initializer, b_initializer=b_initializer, lora_param_type=lora_param_type, rngs=rngs, )
[docs] def __call__(self, x: jax.Array): y = super().__call__(x) y += self.lora(x) return y