From 8e5600022ad374f98551c449a08abef899eb4cee Mon Sep 17 00:00:00 2001 From: Justin Deschenaux <33008801+jdeschena@users.noreply.github.com> Date: Tue, 12 Mar 2024 05:14:44 +0100 Subject: [PATCH] Implement RNN, GRU, LSTM (#268) * RNN base implementation * Address comments+format * nits in docs * add tests for prb * fix test * add a couple tests --------- Co-authored-by: Awni Hannun --- .circleci/config.yml | 3 + docs/src/python/nn/layers.rst | 3 + python/mlx/nn/layers/__init__.py | 1 + python/mlx/nn/layers/recurrent.py | 287 ++++++++++++++++++++++++++++++ python/tests/test_nn.py | 66 +++++++ python/tests/test_ops.py | 2 +- 6 files changed, 361 insertions(+), 1 deletion(-) create mode 100644 python/mlx/nn/layers/recurrent.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 26305ea2d..cd466fb72 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -275,6 +275,9 @@ workflows: context: pr-approval - mac_build_and_test: requires: [ hold ] + matrix: + parameters: + xcode_version: ["15.0.0", "15.2.0"] - linux_build_and_test: requires: [ hold ] nightly_build: diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index f6755e8fe..c0b59b6d4 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -21,9 +21,11 @@ Layers Embedding GELU GroupNorm + GRU InstanceNorm LayerNorm Linear + LSTM MaxPool1d MaxPool2d Mish @@ -32,6 +34,7 @@ Layers QuantizedLinear RMSNorm ReLU + RNN RoPE SELU Sequential diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 6d286220a..3b0856b30 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -62,6 +62,7 @@ from mlx.nn.layers.normalization import ( from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.quantized import QuantizedLinear +from mlx.nn.layers.recurrent import GRU, LSTM, RNN from mlx.nn.layers.transformer import ( MultiHeadAttention, Transformer, diff --git a/python/mlx/nn/layers/recurrent.py b/python/mlx/nn/layers/recurrent.py new file mode 100644 index 000000000..6f8a590fa --- /dev/null +++ b/python/mlx/nn/layers/recurrent.py @@ -0,0 +1,287 @@ +# Copyright © 2024 Apple Inc. + +import math +from typing import Callable, Optional + +import mlx.core as mx +from mlx.nn.layers.activations import tanh +from mlx.nn.layers.base import Module + + +class RNN(Module): + r"""An Elman recurrent layer. + + The input is a sequence of shape ``NLD`` or ``LD`` where: + + * ``N`` is the optional batch dimension + * ``L`` is the sequence length + * ``D`` is the input's feature dimension + + Concretely, for each element along the sequence length axis, this + layer applies the function: + + .. math:: + + h_{t + 1} = \text{tanh} (W_{ih}x_t + W_{hh}h_t + b) + + The hidden state :math:`h` has shape ``NH`` or ``H``, depending on + whether the input is batched or not. Returns the hidden state at each + time step, of shape ``NLH`` or ``LH``. + + Args: + input_size (int): Dimension of the input, ``D``. + hidden_size (int): Dimension of the hidden state, ``H``. + bias (bool, optional): Whether to use a bias. Default: ``True``. + nonlinearity (callable, optional): Non-linearity to use. If ``None``, + then func:`tanh` is used. Default: ``None``. + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + nonlinearity: Optional[Callable] = None, + ): + super().__init__() + + self.nonlinearity = nonlinearity or tanh + if not callable(self.nonlinearity): + raise ValueError( + f"Nonlinearity must be callable. Current value: {nonlinearity}." + ) + + scale = 1.0 / math.sqrt(hidden_size) + self.hidden_size = hidden_size + self.Wxh = mx.random.uniform( + low=-scale, high=scale, shape=(input_size, hidden_size) + ) + self.Whh = mx.random.uniform( + low=-scale, high=scale, shape=(hidden_size, hidden_size) + ) + self.bias = ( + mx.random.uniform(low=-scale, high=scale, shape=(hidden_size,)) + if bias + else None + ) + + def _extra_repr(self): + return ( + f"input_dims={self.Wxh.shape[0]}, " + f"hidden_size={self.hidden_size}, " + f"nonlinearity={self.nonlinearity}, bias={self.bias is not None}" + ) + + def __call__(self, x, hidden=None): + if self.bias is not None: + x = mx.addmm(self.bias, x, self.Wxh) + else: + x = x @ self.Wxh + + all_hidden = [] + for idx in range(x.shape[-2]): + if hidden is not None: + hidden = x[..., idx, :] + hidden @ self.Whh + else: + hidden = x[..., idx, :] + hidden = self.nonlinearity(hidden) + all_hidden.append(hidden) + + return mx.stack(all_hidden, axis=-2) + + +class GRU(Module): + r"""A gated recurrent unit (GRU) RNN layer. + + The input has shape ``NLD`` or ``LD`` where: + + * ``N`` is the optional batch dimension + * ``L`` is the sequence length + * ``D`` is the input's feature dimension + + Concretely, for each element of the sequence, this layer computes: + + .. math:: + + \begin{align*} + r_t &= \sigma (W_{xr}x_t + W_{hr}h_t + b_{r}) \\ + z_t &= \sigma (W_{xz}x_t + W_{hz}h_t + b_{z}) \\ + n_t &= \text{tanh}(W_{xn}x_t + b_{n} + r_t \odot (W_{hn}h_t + b_{hn})) \\ + h_{t + 1} &= (1 - z_t) \odot n_t + z_t \odot h_t + \end{align*} + + The hidden state :math:`h` has shape ``NH`` or ``H`` depending on + whether the input is batched or not. Returns the hidden state at each + time step of shape ``NLH`` or ``LH``. + + Args: + input_size (int): Dimension of the input, ``D``. + hidden_size (int): Dimension of the hidden state, ``H``. + bias (bool): Whether to use biases or not. Default: ``True``. + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + ): + super().__init__() + + self.hidden_size = hidden_size + scale = 1.0 / math.sqrt(hidden_size) + self.Wx = mx.random.uniform( + low=-scale, high=scale, shape=(input_size, 3 * hidden_size) + ) + self.Wh = mx.random.uniform( + low=-scale, high=scale, shape=(hidden_size, 3 * hidden_size) + ) + self.b = ( + mx.random.uniform(low=-scale, high=scale, shape=(3 * hidden_size,)) + if bias + else None + ) + self.bhn = ( + mx.random.uniform(low=-scale, high=scale, shape=(hidden_size,)) + if bias + else None + ) + + def _extra_repr(self): + return ( + f"input_dims={self.Wx.shape[0]}, " + f"hidden_size={self.hidden_size}, bias={self.b is not None}" + ) + + def __call__(self, x, hidden=None): + if self.b is not None: + x = mx.addmm(self.b, x, self.Wx) + else: + x = x @ self.Wx + + x_rz = x[..., : -self.hidden_size] + x_n = x[..., -self.hidden_size :] + + all_hidden = [] + + for idx in range(x.shape[-2]): + rz = x_rz[..., idx, :] + if hidden is not None: + h_proj = hidden @ self.Wh + h_proj_rz = h_proj[..., : -self.hidden_size] + h_proj_n = h_proj[..., -self.hidden_size :] + + if self.bhn is not None: + h_proj_n += self.bhn + + rz = rz + h_proj_rz + + rz = mx.sigmoid(rz) + + r, z = mx.split(rz, 2, axis=-1) + + n = x_n[..., idx, :] + + if hidden is not None: + n = n + r * h_proj_n + n = mx.tanh(n) + + hidden = (1 - z) * n + if hidden is not None: + hidden = hidden + z * hidden + all_hidden.append(hidden) + + return mx.stack(all_hidden, axis=-2) + + +class LSTM(Module): + r"""An LSTM recurrent layer. + + The input has shape ``NLD`` or ``LD`` where: + + * ``N`` is the optional batch dimension + * ``L`` is the sequence length + * ``D`` is the input's feature dimension + + Concretely, for each element of the sequence, this layer computes: + + .. math:: + \begin{align*} + i_t &= \sigma (W_{xi}x_t + W_{hi}h_t + b_{i}) \\ + f_t &= \sigma (W_{xf}x_t + W_{hf}h_t + b_{f}) \\ + g_t &= \text{tanh} (W_{xg}x_t + W_{hg}h_t + b_{g}) \\ + o_t &= \sigma (W_{xo}x_t + W_{ho}h_t + b_{o}) \\ + c_{t + 1} &= f_t \odot c_t + i_t \odot g_t \\ + h_{t + 1} &= o_t \text{tanh}(c_{t + 1}) + \end{align*} + + The hidden state :math:`h` and cell state :math:`c` have shape ``NH`` + or ``H``, depending on whether the input is batched or not. + + The layer returns two arrays, the hidden state and the cell state at + each time step, both of shape ``NLH`` or ``LH``. + + Args: + input_size (int): Dimension of the input, ``D``. + hidden_size (int): Dimension of the hidden state, ``H``. + bias (bool): Whether to use biases or not. Default: ``True``. + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + ): + super().__init__() + + self.hidden_size = hidden_size + scale = 1.0 / math.sqrt(hidden_size) + self.Wx = mx.random.uniform( + low=-scale, high=scale, shape=(input_size, 4 * hidden_size) + ) + self.Wh = mx.random.uniform( + low=-scale, high=scale, shape=(hidden_size, 4 * hidden_size) + ) + self.bias = ( + mx.random.uniform(low=-scale, high=scale, shape=(4 * hidden_size,)) + if bias + else None + ) + + def _extra_repr(self): + return ( + f"input_dims={self.Wx.shape[0]}, " + f"hidden_size={self.hidden_size}, bias={self.bias is not None}" + ) + + def __call__(self, x, hidden=None, cell=None): + if self.bias is not None: + x = mx.addmm(self.bias, x, self.Wx) + else: + x = x @ self.Wx + + all_hidden = [] + all_cell = [] + + for idx in range(x.shape[-2]): + ifgo = x[..., idx, :] + if hidden is not None: + ifgo = ifgo + hidden @ self.Wh + i, f, g, o = mx.split(ifgo, 4, axis=-1) + + i = mx.sigmoid(i) + f = mx.sigmoid(f) + g = mx.tanh(g) + o = mx.sigmoid(o) + + if cell is not None: + cell = f * cell + i * g + else: + cell = i * g + hidden = o * mx.tanh(cell) + + all_cell.append(cell) + all_hidden.append(hidden) + + return mx.stack(all_hidden, axis=-2), mx.stack(all_cell, axis=-2) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 678acfd5b..d704d4004 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1480,6 +1480,72 @@ class TestLayers(mlx_tests.MLXTestCase): "AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))", ) + def test_rnn(self): + layer = nn.RNN(input_size=5, hidden_size=12, bias=True) + inp = mx.random.normal((2, 25, 5)) + + h_out = layer(inp) + self.assertEqual(h_out.shape, (2, 25, 12)) + + layer = nn.RNN( + 5, + 12, + bias=False, + nonlinearity=lambda x: mx.maximum(0, x), + ) + + h_out = layer(inp) + self.assertEqual(h_out.shape, (2, 25, 12)) + + with self.assertRaises(ValueError): + nn.RNN(5, 12, nonlinearity="tanh") + + inp = mx.random.normal((44, 5)) + h_out = layer(inp) + self.assertEqual(h_out.shape, (44, 12)) + + h_out = layer(inp, hidden=h_out[-1, :]) + self.assertEqual(h_out.shape, (44, 12)) + + def test_gru(self): + layer = nn.GRU(5, 12, bias=True) + inp = mx.random.normal((2, 25, 5)) + + h_out = layer(inp) + self.assertEqual(h_out.shape, (2, 25, 12)) + + h_out = layer(inp, hidden=h_out[:, -1, :]) + self.assertEqual(h_out.shape, (2, 25, 12)) + + inp = mx.random.normal((44, 5)) + h_out = layer(inp) + self.assertEqual(h_out.shape, (44, 12)) + + h_out = layer(inp, h_out[-1, :]) + self.assertEqual(h_out.shape, (44, 12)) + + def test_lstm(self): + layer = nn.LSTM(5, 12) + inp = mx.random.normal((2, 25, 5)) + + h_out, c_out = layer(inp) + self.assertEqual(h_out.shape, (2, 25, 12)) + self.assertEqual(c_out.shape, (2, 25, 12)) + + h_out, c_out = layer(inp, hidden=h_out[:, -1, :], cell=c_out[:, -1, :]) + self.assertEqual(h_out.shape, (2, 25, 12)) + self.assertEqual(c_out.shape, (2, 25, 12)) + + inp = mx.random.normal((44, 5)) + h_out, c_out = layer(inp) + self.assertEqual(h_out.shape, (44, 12)) + self.assertEqual(c_out.shape, (44, 12)) + + inp = mx.random.normal((44, 5)) + h_out, c_out = layer(inp, hidden=h_out[-1, :], cell=c_out[-1, :]) + self.assertEqual(h_out.shape, (44, 12)) + self.assertEqual(c_out.shape, (44, 12)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index fe935ebc8..cbb49e8ae 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1416,7 +1416,7 @@ class TestOps(mlx_tests.MLXTestCase): # Sliced inputs y = mx.random.uniform(shape=(8, 4)) out = mx.softmax(y[:, 0:2], axis=-1) - self.assertAlmostEqual(out.sum().item(), 8.0) + self.assertAlmostEqual(out.sum().item(), 8.0, 5) def test_concatenate(self): a_npy = np.random.randn(32, 32, 32)