mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-02 17:28:12 +08:00
Implement RNN, GRU, LSTM (#268)
* RNN base implementation * Address comments+format * nits in docs * add tests for prb * fix test * add a couple tests --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
committed by
GitHub
parent
0e95b64942
commit
8e5600022a
@@ -1416,7 +1416,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
# Sliced inputs
|
||||
y = mx.random.uniform(shape=(8, 4))
|
||||
out = mx.softmax(y[:, 0:2], axis=-1)
|
||||
self.assertAlmostEqual(out.sum().item(), 8.0)
|
||||
self.assertAlmostEqual(out.sum().item(), 8.0, 5)
|
||||
|
||||
def test_concatenate(self):
|
||||
a_npy = np.random.randn(32, 32, 32)
|
||||
|
||||
Reference in New Issue
Block a user