From 386bf306fad762a178baeacabc1a860d68548d8d Mon Sep 17 00:00:00 2001 From: Aditya Singh Date: Thu, 21 May 2026 17:05:22 -0700 Subject: [PATCH] Document LSTMCell carry as the (c, h) tuple it actually is The LSTMCell.__call__ docstring described 'carry' as 'the hidden state of the LSTM cell', leaving users to infer that it is actually a tuple of cell state and hidden state both of shape (*batch, features), typically created via LSTMCell.initialize_carry. Spell that contract out in both the Linen and NNX twins, and mirror the same docstring for OptimizedLSTMCell. Fixes #4124 --- flax/linen/recurrent.py | 8 +++++--- flax/nnx/nn/recurrent.py | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index b55b39ba7..086417e2c 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -136,7 +136,8 @@ def __call__(self, carry, inputs): r"""A long short-term memory (LSTM) cell. Args: - carry: the hidden state of the LSTM cell, + carry: a tuple ``(c, h)`` of the cell state ``c`` and the hidden + state ``h``, both of shape ``(*batch, features)``. Typically initialized using ``LSTMCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. @@ -285,8 +286,9 @@ def __call__( r"""An optimized long short-term memory (LSTM) cell. Args: - carry: the hidden state of the LSTM cell, initialized using - ``LSTMCell.initialize_carry``. + carry: a tuple ``(c, h)`` of the cell state ``c`` and the hidden + state ``h``, both of shape ``(*batch, features)``. Typically + initialized using ``OptimizedLSTMCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. diff --git a/flax/nnx/nn/recurrent.py b/flax/nnx/nn/recurrent.py index 9df3f065e..1bb1cf837 100644 --- a/flax/nnx/nn/recurrent.py +++ b/flax/nnx/nn/recurrent.py @@ -201,7 +201,8 @@ def __call__( r"""A long short-term memory (LSTM) cell. Args: - carry: the hidden state of the LSTM cell, + carry: a tuple ``(c, h)`` of the cell state ``c`` and the hidden + state ``h``, both of shape ``(*batch, features)``. Typically initialized using ``LSTMCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. @@ -382,8 +383,9 @@ def __call__( r"""An optimized long short-term memory (LSTM) cell. Args: - carry: the hidden state of the LSTM cell, initialized using - ``LSTMCell.initialize_carry``. + carry: a tuple ``(c, h)`` of the cell state ``c`` and the hidden + state ``h``, both of shape ``(*batch, features)``. Typically + initialized using ``OptimizedLSTMCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.