mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
bug fix (#658)
This commit is contained in:
@@ -1386,6 +1386,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertTrue((a[:-1] < 1e-9).all())
|
||||
self.assertEqual(a[-1], 1)
|
||||
|
||||
# Sliced inputs
|
||||
y = mx.random.uniform(shape=(8, 4))
|
||||
out = mx.softmax(y[:, 0:2], axis=-1)
|
||||
self.assertAlmostEqual(out.sum().item(), 8.0)
|
||||
|
||||
def test_concatenate(self):
|
||||
a_npy = np.random.randn(32, 32, 32)
|
||||
b_npy = np.random.randn(32, 32, 32)
|
||||
|
Reference in New Issue
Block a user