mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 07:18:29 +08:00

committed by
GitHub

parent
db487e6b1a
commit
7546fdb100
@@ -14,9 +14,9 @@ class TestVersion(mlx_tests.MLXTestCase):
|
||||
def test_version(self):
|
||||
v = mx.__version__
|
||||
vnums = v.split(".")
|
||||
self.assertEqual(len(vnums), 3)
|
||||
v = ".".join(str(int(vn)) for vn in vnums)
|
||||
self.assertEqual(v, mx.__version__)
|
||||
self.assertGreaterEqual(len(vnums), 3)
|
||||
v = ".".join(str(int(vn)) for vn in vnums[:3])
|
||||
self.assertEqual(v, mx.__version__[: len(v)])
|
||||
|
||||
|
||||
class TestDtypes(mlx_tests.MLXTestCase):
|
||||
@@ -905,7 +905,6 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
)
|
||||
|
||||
def test_slice_negative_step(self):
|
||||
|
||||
a_np = np.arange(20)
|
||||
a_mx = mx.array(a_np)
|
||||
|
||||
|
@@ -1084,7 +1084,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a[-1] = 0.0
|
||||
a = mx.softmax(mx.array(a))
|
||||
self.assertFalse(np.any(np.isnan(a)))
|
||||
self.assertTrue((a[:-1] == 0).all())
|
||||
self.assertTrue((a[:-1] < 1e-9).all())
|
||||
self.assertEqual(a[-1], 1)
|
||||
|
||||
def test_concatenate(self):
|
||||
|
Reference in New Issue
Block a user