mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Implement RNN, GRU, LSTM (#268)
* RNN base implementation * Address comments+format * nits in docs * add tests for prb * fix test * add a couple tests --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
0e95b64942
commit
8e5600022a
@ -275,6 +275,9 @@ workflows:
|
||||
context: pr-approval
|
||||
- mac_build_and_test:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
|
@ -21,9 +21,11 @@ Layers
|
||||
Embedding
|
||||
GELU
|
||||
GroupNorm
|
||||
GRU
|
||||
InstanceNorm
|
||||
LayerNorm
|
||||
Linear
|
||||
LSTM
|
||||
MaxPool1d
|
||||
MaxPool2d
|
||||
Mish
|
||||
@ -32,6 +34,7 @@ Layers
|
||||
QuantizedLinear
|
||||
RMSNorm
|
||||
ReLU
|
||||
RNN
|
||||
RoPE
|
||||
SELU
|
||||
Sequential
|
||||
|
@ -62,6 +62,7 @@ from mlx.nn.layers.normalization import (
|
||||
from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d
|
||||
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
||||
from mlx.nn.layers.quantized import QuantizedLinear
|
||||
from mlx.nn.layers.recurrent import GRU, LSTM, RNN
|
||||
from mlx.nn.layers.transformer import (
|
||||
MultiHeadAttention,
|
||||
Transformer,
|
||||
|
287
python/mlx/nn/layers/recurrent.py
Normal file
287
python/mlx/nn/layers/recurrent.py
Normal file
@ -0,0 +1,287 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.activations import tanh
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class RNN(Module):
|
||||
r"""An Elman recurrent layer.
|
||||
|
||||
The input is a sequence of shape ``NLD`` or ``LD`` where:
|
||||
|
||||
* ``N`` is the optional batch dimension
|
||||
* ``L`` is the sequence length
|
||||
* ``D`` is the input's feature dimension
|
||||
|
||||
Concretely, for each element along the sequence length axis, this
|
||||
layer applies the function:
|
||||
|
||||
.. math::
|
||||
|
||||
h_{t + 1} = \text{tanh} (W_{ih}x_t + W_{hh}h_t + b)
|
||||
|
||||
The hidden state :math:`h` has shape ``NH`` or ``H``, depending on
|
||||
whether the input is batched or not. Returns the hidden state at each
|
||||
time step, of shape ``NLH`` or ``LH``.
|
||||
|
||||
Args:
|
||||
input_size (int): Dimension of the input, ``D``.
|
||||
hidden_size (int): Dimension of the hidden state, ``H``.
|
||||
bias (bool, optional): Whether to use a bias. Default: ``True``.
|
||||
nonlinearity (callable, optional): Non-linearity to use. If ``None``,
|
||||
then func:`tanh` is used. Default: ``None``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
bias: bool = True,
|
||||
nonlinearity: Optional[Callable] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.nonlinearity = nonlinearity or tanh
|
||||
if not callable(self.nonlinearity):
|
||||
raise ValueError(
|
||||
f"Nonlinearity must be callable. Current value: {nonlinearity}."
|
||||
)
|
||||
|
||||
scale = 1.0 / math.sqrt(hidden_size)
|
||||
self.hidden_size = hidden_size
|
||||
self.Wxh = mx.random.uniform(
|
||||
low=-scale, high=scale, shape=(input_size, hidden_size)
|
||||
)
|
||||
self.Whh = mx.random.uniform(
|
||||
low=-scale, high=scale, shape=(hidden_size, hidden_size)
|
||||
)
|
||||
self.bias = (
|
||||
mx.random.uniform(low=-scale, high=scale, shape=(hidden_size,))
|
||||
if bias
|
||||
else None
|
||||
)
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"input_dims={self.Wxh.shape[0]}, "
|
||||
f"hidden_size={self.hidden_size}, "
|
||||
f"nonlinearity={self.nonlinearity}, bias={self.bias is not None}"
|
||||
)
|
||||
|
||||
def __call__(self, x, hidden=None):
|
||||
if self.bias is not None:
|
||||
x = mx.addmm(self.bias, x, self.Wxh)
|
||||
else:
|
||||
x = x @ self.Wxh
|
||||
|
||||
all_hidden = []
|
||||
for idx in range(x.shape[-2]):
|
||||
if hidden is not None:
|
||||
hidden = x[..., idx, :] + hidden @ self.Whh
|
||||
else:
|
||||
hidden = x[..., idx, :]
|
||||
hidden = self.nonlinearity(hidden)
|
||||
all_hidden.append(hidden)
|
||||
|
||||
return mx.stack(all_hidden, axis=-2)
|
||||
|
||||
|
||||
class GRU(Module):
|
||||
r"""A gated recurrent unit (GRU) RNN layer.
|
||||
|
||||
The input has shape ``NLD`` or ``LD`` where:
|
||||
|
||||
* ``N`` is the optional batch dimension
|
||||
* ``L`` is the sequence length
|
||||
* ``D`` is the input's feature dimension
|
||||
|
||||
Concretely, for each element of the sequence, this layer computes:
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{align*}
|
||||
r_t &= \sigma (W_{xr}x_t + W_{hr}h_t + b_{r}) \\
|
||||
z_t &= \sigma (W_{xz}x_t + W_{hz}h_t + b_{z}) \\
|
||||
n_t &= \text{tanh}(W_{xn}x_t + b_{n} + r_t \odot (W_{hn}h_t + b_{hn})) \\
|
||||
h_{t + 1} &= (1 - z_t) \odot n_t + z_t \odot h_t
|
||||
\end{align*}
|
||||
|
||||
The hidden state :math:`h` has shape ``NH`` or ``H`` depending on
|
||||
whether the input is batched or not. Returns the hidden state at each
|
||||
time step of shape ``NLH`` or ``LH``.
|
||||
|
||||
Args:
|
||||
input_size (int): Dimension of the input, ``D``.
|
||||
hidden_size (int): Dimension of the hidden state, ``H``.
|
||||
bias (bool): Whether to use biases or not. Default: ``True``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
scale = 1.0 / math.sqrt(hidden_size)
|
||||
self.Wx = mx.random.uniform(
|
||||
low=-scale, high=scale, shape=(input_size, 3 * hidden_size)
|
||||
)
|
||||
self.Wh = mx.random.uniform(
|
||||
low=-scale, high=scale, shape=(hidden_size, 3 * hidden_size)
|
||||
)
|
||||
self.b = (
|
||||
mx.random.uniform(low=-scale, high=scale, shape=(3 * hidden_size,))
|
||||
if bias
|
||||
else None
|
||||
)
|
||||
self.bhn = (
|
||||
mx.random.uniform(low=-scale, high=scale, shape=(hidden_size,))
|
||||
if bias
|
||||
else None
|
||||
)
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"input_dims={self.Wx.shape[0]}, "
|
||||
f"hidden_size={self.hidden_size}, bias={self.b is not None}"
|
||||
)
|
||||
|
||||
def __call__(self, x, hidden=None):
|
||||
if self.b is not None:
|
||||
x = mx.addmm(self.b, x, self.Wx)
|
||||
else:
|
||||
x = x @ self.Wx
|
||||
|
||||
x_rz = x[..., : -self.hidden_size]
|
||||
x_n = x[..., -self.hidden_size :]
|
||||
|
||||
all_hidden = []
|
||||
|
||||
for idx in range(x.shape[-2]):
|
||||
rz = x_rz[..., idx, :]
|
||||
if hidden is not None:
|
||||
h_proj = hidden @ self.Wh
|
||||
h_proj_rz = h_proj[..., : -self.hidden_size]
|
||||
h_proj_n = h_proj[..., -self.hidden_size :]
|
||||
|
||||
if self.bhn is not None:
|
||||
h_proj_n += self.bhn
|
||||
|
||||
rz = rz + h_proj_rz
|
||||
|
||||
rz = mx.sigmoid(rz)
|
||||
|
||||
r, z = mx.split(rz, 2, axis=-1)
|
||||
|
||||
n = x_n[..., idx, :]
|
||||
|
||||
if hidden is not None:
|
||||
n = n + r * h_proj_n
|
||||
n = mx.tanh(n)
|
||||
|
||||
hidden = (1 - z) * n
|
||||
if hidden is not None:
|
||||
hidden = hidden + z * hidden
|
||||
all_hidden.append(hidden)
|
||||
|
||||
return mx.stack(all_hidden, axis=-2)
|
||||
|
||||
|
||||
class LSTM(Module):
|
||||
r"""An LSTM recurrent layer.
|
||||
|
||||
The input has shape ``NLD`` or ``LD`` where:
|
||||
|
||||
* ``N`` is the optional batch dimension
|
||||
* ``L`` is the sequence length
|
||||
* ``D`` is the input's feature dimension
|
||||
|
||||
Concretely, for each element of the sequence, this layer computes:
|
||||
|
||||
.. math::
|
||||
\begin{align*}
|
||||
i_t &= \sigma (W_{xi}x_t + W_{hi}h_t + b_{i}) \\
|
||||
f_t &= \sigma (W_{xf}x_t + W_{hf}h_t + b_{f}) \\
|
||||
g_t &= \text{tanh} (W_{xg}x_t + W_{hg}h_t + b_{g}) \\
|
||||
o_t &= \sigma (W_{xo}x_t + W_{ho}h_t + b_{o}) \\
|
||||
c_{t + 1} &= f_t \odot c_t + i_t \odot g_t \\
|
||||
h_{t + 1} &= o_t \text{tanh}(c_{t + 1})
|
||||
\end{align*}
|
||||
|
||||
The hidden state :math:`h` and cell state :math:`c` have shape ``NH``
|
||||
or ``H``, depending on whether the input is batched or not.
|
||||
|
||||
The layer returns two arrays, the hidden state and the cell state at
|
||||
each time step, both of shape ``NLH`` or ``LH``.
|
||||
|
||||
Args:
|
||||
input_size (int): Dimension of the input, ``D``.
|
||||
hidden_size (int): Dimension of the hidden state, ``H``.
|
||||
bias (bool): Whether to use biases or not. Default: ``True``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
scale = 1.0 / math.sqrt(hidden_size)
|
||||
self.Wx = mx.random.uniform(
|
||||
low=-scale, high=scale, shape=(input_size, 4 * hidden_size)
|
||||
)
|
||||
self.Wh = mx.random.uniform(
|
||||
low=-scale, high=scale, shape=(hidden_size, 4 * hidden_size)
|
||||
)
|
||||
self.bias = (
|
||||
mx.random.uniform(low=-scale, high=scale, shape=(4 * hidden_size,))
|
||||
if bias
|
||||
else None
|
||||
)
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"input_dims={self.Wx.shape[0]}, "
|
||||
f"hidden_size={self.hidden_size}, bias={self.bias is not None}"
|
||||
)
|
||||
|
||||
def __call__(self, x, hidden=None, cell=None):
|
||||
if self.bias is not None:
|
||||
x = mx.addmm(self.bias, x, self.Wx)
|
||||
else:
|
||||
x = x @ self.Wx
|
||||
|
||||
all_hidden = []
|
||||
all_cell = []
|
||||
|
||||
for idx in range(x.shape[-2]):
|
||||
ifgo = x[..., idx, :]
|
||||
if hidden is not None:
|
||||
ifgo = ifgo + hidden @ self.Wh
|
||||
i, f, g, o = mx.split(ifgo, 4, axis=-1)
|
||||
|
||||
i = mx.sigmoid(i)
|
||||
f = mx.sigmoid(f)
|
||||
g = mx.tanh(g)
|
||||
o = mx.sigmoid(o)
|
||||
|
||||
if cell is not None:
|
||||
cell = f * cell + i * g
|
||||
else:
|
||||
cell = i * g
|
||||
hidden = o * mx.tanh(cell)
|
||||
|
||||
all_cell.append(cell)
|
||||
all_hidden.append(hidden)
|
||||
|
||||
return mx.stack(all_hidden, axis=-2), mx.stack(all_cell, axis=-2)
|
@ -1480,6 +1480,72 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
|
||||
)
|
||||
|
||||
def test_rnn(self):
|
||||
layer = nn.RNN(input_size=5, hidden_size=12, bias=True)
|
||||
inp = mx.random.normal((2, 25, 5))
|
||||
|
||||
h_out = layer(inp)
|
||||
self.assertEqual(h_out.shape, (2, 25, 12))
|
||||
|
||||
layer = nn.RNN(
|
||||
5,
|
||||
12,
|
||||
bias=False,
|
||||
nonlinearity=lambda x: mx.maximum(0, x),
|
||||
)
|
||||
|
||||
h_out = layer(inp)
|
||||
self.assertEqual(h_out.shape, (2, 25, 12))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
nn.RNN(5, 12, nonlinearity="tanh")
|
||||
|
||||
inp = mx.random.normal((44, 5))
|
||||
h_out = layer(inp)
|
||||
self.assertEqual(h_out.shape, (44, 12))
|
||||
|
||||
h_out = layer(inp, hidden=h_out[-1, :])
|
||||
self.assertEqual(h_out.shape, (44, 12))
|
||||
|
||||
def test_gru(self):
|
||||
layer = nn.GRU(5, 12, bias=True)
|
||||
inp = mx.random.normal((2, 25, 5))
|
||||
|
||||
h_out = layer(inp)
|
||||
self.assertEqual(h_out.shape, (2, 25, 12))
|
||||
|
||||
h_out = layer(inp, hidden=h_out[:, -1, :])
|
||||
self.assertEqual(h_out.shape, (2, 25, 12))
|
||||
|
||||
inp = mx.random.normal((44, 5))
|
||||
h_out = layer(inp)
|
||||
self.assertEqual(h_out.shape, (44, 12))
|
||||
|
||||
h_out = layer(inp, h_out[-1, :])
|
||||
self.assertEqual(h_out.shape, (44, 12))
|
||||
|
||||
def test_lstm(self):
|
||||
layer = nn.LSTM(5, 12)
|
||||
inp = mx.random.normal((2, 25, 5))
|
||||
|
||||
h_out, c_out = layer(inp)
|
||||
self.assertEqual(h_out.shape, (2, 25, 12))
|
||||
self.assertEqual(c_out.shape, (2, 25, 12))
|
||||
|
||||
h_out, c_out = layer(inp, hidden=h_out[:, -1, :], cell=c_out[:, -1, :])
|
||||
self.assertEqual(h_out.shape, (2, 25, 12))
|
||||
self.assertEqual(c_out.shape, (2, 25, 12))
|
||||
|
||||
inp = mx.random.normal((44, 5))
|
||||
h_out, c_out = layer(inp)
|
||||
self.assertEqual(h_out.shape, (44, 12))
|
||||
self.assertEqual(c_out.shape, (44, 12))
|
||||
|
||||
inp = mx.random.normal((44, 5))
|
||||
h_out, c_out = layer(inp, hidden=h_out[-1, :], cell=c_out[-1, :])
|
||||
self.assertEqual(h_out.shape, (44, 12))
|
||||
self.assertEqual(c_out.shape, (44, 12))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1416,7 +1416,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
# Sliced inputs
|
||||
y = mx.random.uniform(shape=(8, 4))
|
||||
out = mx.softmax(y[:, 0:2], axis=-1)
|
||||
self.assertAlmostEqual(out.sum().item(), 8.0)
|
||||
self.assertAlmostEqual(out.sum().item(), 8.0, 5)
|
||||
|
||||
def test_concatenate(self):
|
||||
a_npy = np.random.randn(32, 32, 32)
|
||||
|
Loading…
Reference in New Issue
Block a user