Implement RNN, GRU, LSTM (#268)

* RNN base implementation

* Address comments+format

* nits in docs

* add tests for prb

* fix test

* add a couple tests

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Justin Deschenaux
2024-03-12 05:14:44 +01:00
committed by GitHub
parent 0e95b64942
commit 8e5600022a
6 changed files with 361 additions and 1 deletions

View File

@@ -1416,7 +1416,7 @@ class TestOps(mlx_tests.MLXTestCase):
# Sliced inputs
y = mx.random.uniform(shape=(8, 4))
out = mx.softmax(y[:, 0:2], axis=-1)
self.assertAlmostEqual(out.sum().item(), 8.0)
self.assertAlmostEqual(out.sum().item(), 8.0, 5)
def test_concatenate(self):
a_npy = np.random.randn(32, 32, 32)