Implement RNN, GRU, LSTM (#268)

* RNN base implementation

* Address comments+format

* nits in docs

* add tests for prb

* fix test

* add a couple tests

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Justin Deschenaux
2024-03-12 05:14:44 +01:00
committed by GitHub
parent 0e95b64942
commit 8e5600022a
6 changed files with 361 additions and 1 deletions

View File

@@ -1480,6 +1480,72 @@ class TestLayers(mlx_tests.MLXTestCase):
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
)
def test_rnn(self):
layer = nn.RNN(input_size=5, hidden_size=12, bias=True)
inp = mx.random.normal((2, 25, 5))
h_out = layer(inp)
self.assertEqual(h_out.shape, (2, 25, 12))
layer = nn.RNN(
5,
12,
bias=False,
nonlinearity=lambda x: mx.maximum(0, x),
)
h_out = layer(inp)
self.assertEqual(h_out.shape, (2, 25, 12))
with self.assertRaises(ValueError):
nn.RNN(5, 12, nonlinearity="tanh")
inp = mx.random.normal((44, 5))
h_out = layer(inp)
self.assertEqual(h_out.shape, (44, 12))
h_out = layer(inp, hidden=h_out[-1, :])
self.assertEqual(h_out.shape, (44, 12))
def test_gru(self):
layer = nn.GRU(5, 12, bias=True)
inp = mx.random.normal((2, 25, 5))
h_out = layer(inp)
self.assertEqual(h_out.shape, (2, 25, 12))
h_out = layer(inp, hidden=h_out[:, -1, :])
self.assertEqual(h_out.shape, (2, 25, 12))
inp = mx.random.normal((44, 5))
h_out = layer(inp)
self.assertEqual(h_out.shape, (44, 12))
h_out = layer(inp, h_out[-1, :])
self.assertEqual(h_out.shape, (44, 12))
def test_lstm(self):
layer = nn.LSTM(5, 12)
inp = mx.random.normal((2, 25, 5))
h_out, c_out = layer(inp)
self.assertEqual(h_out.shape, (2, 25, 12))
self.assertEqual(c_out.shape, (2, 25, 12))
h_out, c_out = layer(inp, hidden=h_out[:, -1, :], cell=c_out[:, -1, :])
self.assertEqual(h_out.shape, (2, 25, 12))
self.assertEqual(c_out.shape, (2, 25, 12))
inp = mx.random.normal((44, 5))
h_out, c_out = layer(inp)
self.assertEqual(h_out.shape, (44, 12))
self.assertEqual(c_out.shape, (44, 12))
inp = mx.random.normal((44, 5))
h_out, c_out = layer(inp, hidden=h_out[-1, :], cell=c_out[-1, :])
self.assertEqual(h_out.shape, (44, 12))
self.assertEqual(c_out.shape, (44, 12))
if __name__ == "__main__":
unittest.main()