helpers# class flax.nnx.Dict(self, *args, **kwargs)[source]# class flax.nnx.Sequential(self, *fns)[source]# class flax.nnx.TrainState(graphdef: 'graph.GraphDef[M]', params: 'State', opt_state: 'optax.OptState', step: 'jax.Array', tx: 'optax.GradientTransformation')[source]# replace(**updates)# Returns a new object replacing the specified fields with new values.