Index _ | A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P | R | S | T | U | V | W | X | Z _ __call__() (flax.nnx.BatchNorm method) (flax.nnx.bridge.NNXMeta method) (flax.nnx.bridge.ToLinen method) (flax.nnx.bridge.ToNNX method) (flax.nnx.Conv method) (flax.nnx.ConvTranspose method) (flax.nnx.Einsum method) (flax.nnx.Embed method) (flax.nnx.GroupNorm method) (flax.nnx.LayerNorm method) (flax.nnx.Linear method) (flax.nnx.LinearGeneral method) (flax.nnx.LoRA method) (flax.nnx.LoRALinear method) (flax.nnx.MultiHeadAttention method) (flax.nnx.nn.recurrent.Bidirectional method) (flax.nnx.nn.recurrent.GRUCell method) (flax.nnx.nn.recurrent.LSTMCell method) (flax.nnx.nn.recurrent.OptimizedLSTMCell method) (flax.nnx.nn.recurrent.RNN method) (flax.nnx.nn.recurrent.SimpleCell method) (flax.nnx.RMSNorm method) __init__() (flax.nnx.metrics.Average method) (flax.nnx.metrics.Metric method) (flax.nnx.metrics.MultiMetric method) (flax.nnx.metrics.Welford method) (flax.nnx.optimizer.Optimizer method) (flax.nnx.Rngs method) A Accuracy (class in flax.nnx.metrics) add_axis() (flax.nnx.bridge.NNXMeta method) All (class in flax.nnx) Any (class in flax.nnx) apply_gradients() (flax.training.train_state.TrainState method) attend() (flax.nnx.Embed method) Average (class in flax.nnx.metrics) B BatchNorm (class in flax.nnx) BatchStat (class in flax.nnx) Bidirectional (class in flax.nnx.nn.recurrent) C Cache (class in flax.nnx) cached_partial() (in module flax.nnx) call() (in module flax.nnx) canonicalize_dtype() (in module flax.nnx.nn.dtypes) celu() (in module flax.nnx) clone() (in module flax.nnx) combine_masks() (in module flax.nnx) compute() (flax.nnx.metrics.Average method) (flax.nnx.metrics.Metric method) (flax.nnx.metrics.MultiMetric method) (flax.nnx.metrics.Welford method) cond() (in module flax.nnx) constant() (in module flax.nnx.initializers) Conv (class in flax.nnx) ConvTranspose (class in flax.nnx) copy() (flax.core.frozen_dict.FrozenDict method) (in module flax.core.frozen_dict) create() (flax.training.train_state.TrainState class method) current_update_context() (in module flax.nnx) custom_vjp() (in module flax.nnx) D dataclass() (in module flax.nnx) (in module flax.struct) delta_orthogonal() (in module flax.nnx.initializers) Dict (class in flax.nnx) display() (in module flax.nnx) dot_product_attention() (in module flax.nnx) Dropout (class in flax.nnx) E Einsum (class in flax.nnx) elu() (in module flax.nnx) Embed (class in flax.nnx) eval() (flax.nnx.Module method) eval_shape() (in module flax.nnx) Everything (class in flax.nnx) F field() (in module flax.nnx) Filter flax.nnx module, [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15], [16], [17] flax.nnx.bridge module flax.nnx.initializers module flax.nnx.metrics module flax.nnx.nn.dtypes module flax.nnx.nn.recurrent module flax.nnx.optimizer module flax.struct module flip_sequences() (in module flax.nnx.nn.recurrent) Folding in fori_loop() (in module flax.nnx) freeze() (in module flax.core.frozen_dict) FrozenDict (class in flax.core.frozen_dict) G gelu() (in module flax.nnx) get_named_sharding() (in module flax.nnx) get_partition_spec() (flax.nnx.bridge.NNXMeta method) (in module flax.nnx) glorot_normal() (in module flax.nnx.initializers) glorot_uniform() (in module flax.nnx.initializers) glu() (in module flax.nnx) grad() (in module flax.nnx) graph() (in module flax.nnx) GraphDef (class in flax.nnx) graphdef() (in module flax.nnx) GroupNorm (class in flax.nnx) GRUCell (class in flax.nnx.nn.recurrent) H hard_sigmoid() (in module flax.nnx) hard_silu() (in module flax.nnx) hard_swish() (in module flax.nnx) hard_tanh() (in module flax.nnx) he_normal() (in module flax.nnx.initializers) he_uniform() (in module flax.nnx.initializers) I init_cache() (flax.nnx.MultiHeadAttention method) initialize_carry() (flax.nnx.nn.recurrent.GRUCell method) (flax.nnx.nn.recurrent.LSTMCell method) (flax.nnx.nn.recurrent.OptimizedLSTMCell method) (flax.nnx.nn.recurrent.SimpleCell method) Intermediate (class in flax.nnx) iter_children() (flax.nnx.Module method) iter_graph() (in module flax.nnx) iter_modules() (flax.nnx.Module method) J jit() (in module flax.nnx) K kaiming_normal() (in module flax.nnx.initializers) kaiming_uniform() (in module flax.nnx.initializers) L LayerNorm (class in flax.nnx) lazy_init() (flax.nnx.bridge.ToNNX method) leaky_relu() (in module flax.nnx) lecun_normal() (in module flax.nnx.initializers) lecun_uniform() (in module flax.nnx.initializers) Linear (class in flax.nnx) LinearGeneral (class in flax.nnx) log_sigmoid() (in module flax.nnx) log_softmax() (in module flax.nnx) logsumexp() (in module flax.nnx) LoRA (class in flax.nnx) LoRALinear (class in flax.nnx) LSTMCell (class in flax.nnx.nn.recurrent) M make_attention_mask() (in module flax.nnx) make_causal_mask() (in module flax.nnx) Merge merge() (in module flax.nnx) Metric (class in flax.nnx.metrics) Module module (class in flax.nnx) flax.nnx, [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15], [16], [17] flax.nnx.bridge flax.nnx.initializers flax.nnx.metrics flax.nnx.nn.dtypes flax.nnx.nn.recurrent flax.nnx.optimizer flax.struct MultiHeadAttention (class in flax.nnx) MultiMetric (class in flax.nnx.metrics) N NNXMeta (class in flax.nnx.bridge) normal() (in module flax.nnx.initializers) Not (class in flax.nnx) Nothing (class in flax.nnx) O OfType (class in flax.nnx) one_hot() (in module flax.nnx) ones() (in module flax.nnx.initializers) ones_init() (in module flax.nnx.initializers) OptimizedLSTMCell (class in flax.nnx.nn.recurrent) Optimizer (class in flax.nnx.optimizer) orthogonal() (in module flax.nnx.initializers) P Param (class in flax.nnx) Params / parameters PathContains (class in flax.nnx) perturb() (flax.nnx.Module method) pop() (flax.core.frozen_dict.FrozenDict method) (in module flax.core.frozen_dict) (in module flax.nnx) pretty_repr() (flax.core.frozen_dict.FrozenDict method) (in module flax.core.frozen_dict) PRNG states promote_dtype() (in module flax.nnx.nn.dtypes) PyTreeNode (class in flax.struct) R relu() (in module flax.nnx) remat() (in module flax.nnx) remove_axis() (flax.nnx.bridge.NNXMeta method) replace() (flax.nnx.bridge.NNXMeta method) (flax.nnx.TrainState method) replace_boxed() (flax.nnx.bridge.NNXMeta method) reseed() (in module flax.nnx) reset() (flax.nnx.metrics.Average method) (flax.nnx.metrics.Metric method) (flax.nnx.metrics.MultiMetric method) (flax.nnx.metrics.Welford method) RMSNorm (class in flax.nnx) Rngs (class in flax.nnx) RngStream (class in flax.nnx) RNN (class in flax.nnx.nn.recurrent) S scan() (in module flax.nnx) selu() (in module flax.nnx) Sequential (class in flax.nnx) set_attributes() (flax.nnx.Module method) shard_map() (in module flax.nnx) sigmoid() (in module flax.nnx) silu() (in module flax.nnx) SimpleCell (class in flax.nnx.nn.recurrent) soft_sign() (in module flax.nnx) softmax() (in module flax.nnx) softplus() (in module flax.nnx) sow() (flax.nnx.Module method) Split and merge split() (in module flax.nnx) standardize() (in module flax.nnx) State (class in flax.nnx) state() (in module flax.nnx) swish() (in module flax.nnx) switch() (in module flax.nnx) T tabulate() (in module flax.nnx) tanh() (in module flax.nnx) to_linen() (in module flax.nnx.bridge) to_nnx_variable() (flax.nnx.bridge.NNXMeta method) to_predicate() (in module flax.nnx.filterlib) ToLinen (class in flax.nnx.bridge) ToNNX (class in flax.nnx.bridge) train() (flax.nnx.Module method) TrainState (class in flax.nnx) (class in flax.training.train_state) Transformation truncated_normal() (in module flax.nnx.initializers) U unbox() (flax.nnx.bridge.NNXMeta method) unfreeze() (flax.core.frozen_dict.FrozenDict method) (in module flax.core.frozen_dict) uniform() (in module flax.nnx.initializers) update() (flax.nnx.metrics.Accuracy method) (flax.nnx.metrics.Average method) (flax.nnx.metrics.Metric method) (flax.nnx.metrics.MultiMetric method) (flax.nnx.metrics.Welford method) (flax.nnx.optimizer.Optimizer method) (in module flax.nnx) update_context() (in module flax.nnx) UpdateContext (class in flax.nnx) V value_and_grad() (in module flax.nnx) Variable (class in flax.nnx) Variable state variable_name_from_type() (in module flax.nnx) variable_type_from_name() (in module flax.nnx) VariableMetadata (class in flax.nnx) variables() (in module flax.nnx) VariableState (class in flax.nnx) variance_scaling() (in module flax.nnx.initializers) vmap() (in module flax.nnx) W Welford (class in flax.nnx.metrics) while_loop() (in module flax.nnx) with_metadata() (in module flax.nnx) with_partitioning() (in module flax.nnx) with_sharding_constraint() (in module flax.nnx) WithTag (class in flax.nnx) X xavier_normal() (in module flax.nnx.initializers) xavier_uniform() (in module flax.nnx.initializers) Z zeros() (in module flax.nnx.initializers) zeros_init() (in module flax.nnx.initializers)