summary

Contents

summary#

flax.nnx.tabulate(obj, *input_args, depth=None, method='__call__', row_filter=<function filter_rng_streams>, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), console_kwargs=mappingproxy({}), **input_kwargs)[source]#

Creates a summary of the graph object represented as a table.

The table summarizes the object’s state and metadata. The table is structured as follows:

  • The first column represents the path of the object in the graph.

  • The second column represents the type of the object.

  • The third column represents the input arguments passed to the object’s method.

  • The fourth column represents the output of the object’s method.

  • The following columns provide information about the object’s state, grouped by Variable types.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.bn = nnx.BatchNorm(dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.2, rngs=rngs)
...
...   def __call__(self, x):
...     return nnx.relu(self.dropout(self.bn(self.linear(x))))
...
>>> class Foo(nnx.Module):
...   def __init__(self, rngs: nnx.Rngs):
...     self.block1 = Block(32, 128, rngs=rngs)
...     self.block2 = Block(128, 10, rngs=rngs)
...
...   def __call__(self, x):
...     return self.block2(self.block1(x))
...
>>> foo = Foo(nnx.Rngs(0))
>>> # print(nnx.tabulate(foo, jnp.ones((1, 32))))

                                                      Foo Summary
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
┃ path           ┃ type      ┃ inputs         ┃ outputs        ┃ BatchStat          ┃ Param                   ┃ RngState ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
│                │ Foo       │ float32[1,32]  │ float32[1,10]  │ 276 (1.1 KB)       │ 5,790 (23.2 KB)         │ 2 (12 B) │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│ block1         │ Block     │ float32[1,32]  │ float32[1,128] │ 256 (1.0 KB)       │ 4,480 (17.9 KB)         │ 2 (12 B) │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│ block1/linear  │ Linear    │ float32[1,32]  │ float32[1,128] │                    │ bias: float32[128]      │          │
│                │           │                │                │                    │ kernel: float32[32,128] │          │
│                │           │                │                │                    │                         │          │
│                │           │                │                │                    │ 4,224 (16.9 KB)         │          │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│ block1/bn      │ BatchNorm │ float32[1,128] │ float32[1,128] │ mean: float32[128] │ bias: float32[128]      │          │
│                │           │                │                │ var: float32[128]  │ scale: float32[128]     │          │
│                │           │                │                │                    │                         │          │
│                │           │                │                │ 256 (1.0 KB)       │ 256 (1.0 KB)            │          │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│ block1/dropout │ Dropout   │ float32[1,128] │ float32[1,128] │                    │                         │ 2 (12 B) │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│ block2         │ Block     │ float32[1,128] │ float32[1,10]  │ 20 (80 B)          │ 1,310 (5.2 KB)          │          │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│ block2/linear  │ Linear    │ float32[1,128] │ float32[1,10]  │                    │ bias: float32[10]       │          │
│                │           │                │                │                    │ kernel: float32[128,10] │          │
│                │           │                │                │                    │                         │          │
│                │           │                │                │                    │ 1,290 (5.2 KB)          │          │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│ block2/bn      │ BatchNorm │ float32[1,10]  │ float32[1,10]  │ mean: float32[10]  │ bias: float32[10]       │          │
│                │           │                │                │ var: float32[10]   │ scale: float32[10]      │          │
│                │           │                │                │                    │                         │          │
│                │           │                │                │ 20 (80 B)          │ 20 (80 B)               │          │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│ block2/dropout │ Dropout   │ float32[1,10]  │ float32[1,10]  │                    │                         │          │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│                │           │                │          Total │ 276 (1.1 KB)       │ 5,790 (23.2 KB)         │ 2 (12 B) │
└────────────────┴───────────┴────────────────┴────────────────┴────────────────────┴─────────────────────────┴──────────┘

                                            Total Parameters: 6,068 (24.3 KB)

Note that block2/dropout is not shown in the table because it shares the same RngState with block1/dropout.

Parameters:
  • obj – A object to summarize. It can a pytree or a graph objects such as nnx.Module or nnx.Optimizer.

  • *input_args – Positional arguments passed to the object’s method.

  • **input_kwargs – Keyword arguments passed to the object’s method.

  • depth – The depth of the table.

  • method – The method to call on the object. Default is '__call__'.

  • row_filter – A callable that filters the rows to be displayed in the table. By default, it filters out rows with type nnx.RngStream.

  • table_kwargs – An optional dictionary with additional keyword arguments that are passed to rich.table.Table constructor.

  • column_kwargs – An optional dictionary with additional keyword arguments that are passed to rich.table.Table.add_column when adding columns to the table.

  • console_kwargs – An optional dictionary with additional keyword arguments that are passed to rich.console.Console when rendering the table. Default arguments are 'force_terminal': True, and 'force_jupyter' is set to True if the code is running in a Jupyter notebook, otherwise it is set to False.

Returns:

A string summarizing the object.