mlx/python/mlx/nn/layers/recurrent.py
Justin Deschenaux 8e5600022a
Implement RNN, GRU, LSTM (#268)
* RNN base implementation

* Address comments+format

* nits in docs

* add tests for prb

* fix test

* add a couple tests

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-11 21:14:44 -07:00

288 lines
8.4 KiB
Python

# Copyright © 2024 Apple Inc.
import math
from typing import Callable, Optional
import mlx.core as mx
from mlx.nn.layers.activations import tanh
from mlx.nn.layers.base import Module
class RNN(Module):
r"""An Elman recurrent layer.
The input is a sequence of shape ``NLD`` or ``LD`` where:
* ``N`` is the optional batch dimension
* ``L`` is the sequence length
* ``D`` is the input's feature dimension
Concretely, for each element along the sequence length axis, this
layer applies the function:
.. math::
h_{t + 1} = \text{tanh} (W_{ih}x_t + W_{hh}h_t + b)
The hidden state :math:`h` has shape ``NH`` or ``H``, depending on
whether the input is batched or not. Returns the hidden state at each
time step, of shape ``NLH`` or ``LH``.
Args:
input_size (int): Dimension of the input, ``D``.
hidden_size (int): Dimension of the hidden state, ``H``.
bias (bool, optional): Whether to use a bias. Default: ``True``.
nonlinearity (callable, optional): Non-linearity to use. If ``None``,
then func:`tanh` is used. Default: ``None``.
"""
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
nonlinearity: Optional[Callable] = None,
):
super().__init__()
self.nonlinearity = nonlinearity or tanh
if not callable(self.nonlinearity):
raise ValueError(
f"Nonlinearity must be callable. Current value: {nonlinearity}."
)
scale = 1.0 / math.sqrt(hidden_size)
self.hidden_size = hidden_size
self.Wxh = mx.random.uniform(
low=-scale, high=scale, shape=(input_size, hidden_size)
)
self.Whh = mx.random.uniform(
low=-scale, high=scale, shape=(hidden_size, hidden_size)
)
self.bias = (
mx.random.uniform(low=-scale, high=scale, shape=(hidden_size,))
if bias
else None
)
def _extra_repr(self):
return (
f"input_dims={self.Wxh.shape[0]}, "
f"hidden_size={self.hidden_size}, "
f"nonlinearity={self.nonlinearity}, bias={self.bias is not None}"
)
def __call__(self, x, hidden=None):
if self.bias is not None:
x = mx.addmm(self.bias, x, self.Wxh)
else:
x = x @ self.Wxh
all_hidden = []
for idx in range(x.shape[-2]):
if hidden is not None:
hidden = x[..., idx, :] + hidden @ self.Whh
else:
hidden = x[..., idx, :]
hidden = self.nonlinearity(hidden)
all_hidden.append(hidden)
return mx.stack(all_hidden, axis=-2)
class GRU(Module):
r"""A gated recurrent unit (GRU) RNN layer.
The input has shape ``NLD`` or ``LD`` where:
* ``N`` is the optional batch dimension
* ``L`` is the sequence length
* ``D`` is the input's feature dimension
Concretely, for each element of the sequence, this layer computes:
.. math::
\begin{align*}
r_t &= \sigma (W_{xr}x_t + W_{hr}h_t + b_{r}) \\
z_t &= \sigma (W_{xz}x_t + W_{hz}h_t + b_{z}) \\
n_t &= \text{tanh}(W_{xn}x_t + b_{n} + r_t \odot (W_{hn}h_t + b_{hn})) \\
h_{t + 1} &= (1 - z_t) \odot n_t + z_t \odot h_t
\end{align*}
The hidden state :math:`h` has shape ``NH`` or ``H`` depending on
whether the input is batched or not. Returns the hidden state at each
time step of shape ``NLH`` or ``LH``.
Args:
input_size (int): Dimension of the input, ``D``.
hidden_size (int): Dimension of the hidden state, ``H``.
bias (bool): Whether to use biases or not. Default: ``True``.
"""
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
):
super().__init__()
self.hidden_size = hidden_size
scale = 1.0 / math.sqrt(hidden_size)
self.Wx = mx.random.uniform(
low=-scale, high=scale, shape=(input_size, 3 * hidden_size)
)
self.Wh = mx.random.uniform(
low=-scale, high=scale, shape=(hidden_size, 3 * hidden_size)
)
self.b = (
mx.random.uniform(low=-scale, high=scale, shape=(3 * hidden_size,))
if bias
else None
)
self.bhn = (
mx.random.uniform(low=-scale, high=scale, shape=(hidden_size,))
if bias
else None
)
def _extra_repr(self):
return (
f"input_dims={self.Wx.shape[0]}, "
f"hidden_size={self.hidden_size}, bias={self.b is not None}"
)
def __call__(self, x, hidden=None):
if self.b is not None:
x = mx.addmm(self.b, x, self.Wx)
else:
x = x @ self.Wx
x_rz = x[..., : -self.hidden_size]
x_n = x[..., -self.hidden_size :]
all_hidden = []
for idx in range(x.shape[-2]):
rz = x_rz[..., idx, :]
if hidden is not None:
h_proj = hidden @ self.Wh
h_proj_rz = h_proj[..., : -self.hidden_size]
h_proj_n = h_proj[..., -self.hidden_size :]
if self.bhn is not None:
h_proj_n += self.bhn
rz = rz + h_proj_rz
rz = mx.sigmoid(rz)
r, z = mx.split(rz, 2, axis=-1)
n = x_n[..., idx, :]
if hidden is not None:
n = n + r * h_proj_n
n = mx.tanh(n)
hidden = (1 - z) * n
if hidden is not None:
hidden = hidden + z * hidden
all_hidden.append(hidden)
return mx.stack(all_hidden, axis=-2)
class LSTM(Module):
r"""An LSTM recurrent layer.
The input has shape ``NLD`` or ``LD`` where:
* ``N`` is the optional batch dimension
* ``L`` is the sequence length
* ``D`` is the input's feature dimension
Concretely, for each element of the sequence, this layer computes:
.. math::
\begin{align*}
i_t &= \sigma (W_{xi}x_t + W_{hi}h_t + b_{i}) \\
f_t &= \sigma (W_{xf}x_t + W_{hf}h_t + b_{f}) \\
g_t &= \text{tanh} (W_{xg}x_t + W_{hg}h_t + b_{g}) \\
o_t &= \sigma (W_{xo}x_t + W_{ho}h_t + b_{o}) \\
c_{t + 1} &= f_t \odot c_t + i_t \odot g_t \\
h_{t + 1} &= o_t \text{tanh}(c_{t + 1})
\end{align*}
The hidden state :math:`h` and cell state :math:`c` have shape ``NH``
or ``H``, depending on whether the input is batched or not.
The layer returns two arrays, the hidden state and the cell state at
each time step, both of shape ``NLH`` or ``LH``.
Args:
input_size (int): Dimension of the input, ``D``.
hidden_size (int): Dimension of the hidden state, ``H``.
bias (bool): Whether to use biases or not. Default: ``True``.
"""
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
):
super().__init__()
self.hidden_size = hidden_size
scale = 1.0 / math.sqrt(hidden_size)
self.Wx = mx.random.uniform(
low=-scale, high=scale, shape=(input_size, 4 * hidden_size)
)
self.Wh = mx.random.uniform(
low=-scale, high=scale, shape=(hidden_size, 4 * hidden_size)
)
self.bias = (
mx.random.uniform(low=-scale, high=scale, shape=(4 * hidden_size,))
if bias
else None
)
def _extra_repr(self):
return (
f"input_dims={self.Wx.shape[0]}, "
f"hidden_size={self.hidden_size}, bias={self.bias is not None}"
)
def __call__(self, x, hidden=None, cell=None):
if self.bias is not None:
x = mx.addmm(self.bias, x, self.Wx)
else:
x = x @ self.Wx
all_hidden = []
all_cell = []
for idx in range(x.shape[-2]):
ifgo = x[..., idx, :]
if hidden is not None:
ifgo = ifgo + hidden @ self.Wh
i, f, g, o = mx.split(ifgo, 4, axis=-1)
i = mx.sigmoid(i)
f = mx.sigmoid(f)
g = mx.tanh(g)
o = mx.sigmoid(o)
if cell is not None:
cell = f * cell + i * g
else:
cell = i * g
hidden = o * mx.tanh(cell)
all_cell.append(cell)
all_hidden.append(hidden)
return mx.stack(all_hidden, axis=-2), mx.stack(all_cell, axis=-2)