mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Implement RNN, GRU, LSTM (#268)
* RNN base implementation * Address comments+format * nits in docs * add tests for prb * fix test * add a couple tests --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:

committed by
GitHub

parent
0e95b64942
commit
8e5600022a
@@ -1480,6 +1480,72 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
|
||||
)
|
||||
|
||||
def test_rnn(self):
|
||||
layer = nn.RNN(input_size=5, hidden_size=12, bias=True)
|
||||
inp = mx.random.normal((2, 25, 5))
|
||||
|
||||
h_out = layer(inp)
|
||||
self.assertEqual(h_out.shape, (2, 25, 12))
|
||||
|
||||
layer = nn.RNN(
|
||||
5,
|
||||
12,
|
||||
bias=False,
|
||||
nonlinearity=lambda x: mx.maximum(0, x),
|
||||
)
|
||||
|
||||
h_out = layer(inp)
|
||||
self.assertEqual(h_out.shape, (2, 25, 12))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
nn.RNN(5, 12, nonlinearity="tanh")
|
||||
|
||||
inp = mx.random.normal((44, 5))
|
||||
h_out = layer(inp)
|
||||
self.assertEqual(h_out.shape, (44, 12))
|
||||
|
||||
h_out = layer(inp, hidden=h_out[-1, :])
|
||||
self.assertEqual(h_out.shape, (44, 12))
|
||||
|
||||
def test_gru(self):
|
||||
layer = nn.GRU(5, 12, bias=True)
|
||||
inp = mx.random.normal((2, 25, 5))
|
||||
|
||||
h_out = layer(inp)
|
||||
self.assertEqual(h_out.shape, (2, 25, 12))
|
||||
|
||||
h_out = layer(inp, hidden=h_out[:, -1, :])
|
||||
self.assertEqual(h_out.shape, (2, 25, 12))
|
||||
|
||||
inp = mx.random.normal((44, 5))
|
||||
h_out = layer(inp)
|
||||
self.assertEqual(h_out.shape, (44, 12))
|
||||
|
||||
h_out = layer(inp, h_out[-1, :])
|
||||
self.assertEqual(h_out.shape, (44, 12))
|
||||
|
||||
def test_lstm(self):
|
||||
layer = nn.LSTM(5, 12)
|
||||
inp = mx.random.normal((2, 25, 5))
|
||||
|
||||
h_out, c_out = layer(inp)
|
||||
self.assertEqual(h_out.shape, (2, 25, 12))
|
||||
self.assertEqual(c_out.shape, (2, 25, 12))
|
||||
|
||||
h_out, c_out = layer(inp, hidden=h_out[:, -1, :], cell=c_out[:, -1, :])
|
||||
self.assertEqual(h_out.shape, (2, 25, 12))
|
||||
self.assertEqual(c_out.shape, (2, 25, 12))
|
||||
|
||||
inp = mx.random.normal((44, 5))
|
||||
h_out, c_out = layer(inp)
|
||||
self.assertEqual(h_out.shape, (44, 12))
|
||||
self.assertEqual(c_out.shape, (44, 12))
|
||||
|
||||
inp = mx.random.normal((44, 5))
|
||||
h_out, c_out = layer(inp, hidden=h_out[-1, :], cell=c_out[-1, :])
|
||||
self.assertEqual(h_out.shape, (44, 12))
|
||||
self.assertEqual(c_out.shape, (44, 12))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user