mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Make shape a tuple (#591)
* shape tuple * also remove simplify from docs * rebase
This commit is contained in:
@@ -13,7 +13,7 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
|
||||
for dtype in [mx.float32, mx.float16]:
|
||||
initializer = init.constant(value, dtype)
|
||||
for shape in [[3], [3, 3], [3, 3, 3]]:
|
||||
for shape in [(3,), (3, 3), (3, 3, 3)]:
|
||||
result = initializer(mx.array(mx.zeros(shape)))
|
||||
with self.subTest(shape=shape):
|
||||
self.assertEqual(result.shape, shape)
|
||||
@@ -24,7 +24,7 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
std = 1.0
|
||||
for dtype in [mx.float32, mx.float16]:
|
||||
initializer = init.normal(mean, std, dtype=dtype)
|
||||
for shape in [[3], [3, 3], [3, 3, 3]]:
|
||||
for shape in [(3,), (3, 3), (3, 3, 3)]:
|
||||
result = initializer(mx.array(np.empty(shape)))
|
||||
with self.subTest(shape=shape):
|
||||
self.assertEqual(result.shape, shape)
|
||||
@@ -36,7 +36,7 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
|
||||
for dtype in [mx.float32, mx.float16]:
|
||||
initializer = init.uniform(low, high, dtype)
|
||||
for shape in [[3], [3, 3], [3, 3, 3]]:
|
||||
for shape in [(3,), (3, 3), (3, 3, 3)]:
|
||||
result = initializer(mx.array(np.empty(shape)))
|
||||
with self.subTest(shape=shape):
|
||||
self.assertEqual(result.shape, shape)
|
||||
@@ -46,7 +46,7 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
for dtype in [mx.float32, mx.float16]:
|
||||
initializer = init.identity(dtype)
|
||||
for shape in [[3], [3, 3], [3, 3, 3]]:
|
||||
for shape in [(3,), (3, 3), (3, 3, 3)]:
|
||||
result = initializer(mx.zeros((3, 3)))
|
||||
self.assertTrue(mx.array_equal(result, mx.eye(3)))
|
||||
self.assertEqual(result.dtype, dtype)
|
||||
@@ -56,7 +56,7 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
def test_glorot_normal(self):
|
||||
for dtype in [mx.float32, mx.float16]:
|
||||
initializer = init.glorot_normal(dtype)
|
||||
for shape in [[3, 3], [3, 3, 3]]:
|
||||
for shape in [(3, 3), (3, 3, 3)]:
|
||||
result = initializer(mx.array(np.empty(shape)))
|
||||
with self.subTest(shape=shape):
|
||||
self.assertEqual(result.shape, shape)
|
||||
@@ -65,7 +65,7 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
def test_glorot_uniform(self):
|
||||
for dtype in [mx.float32, mx.float16]:
|
||||
initializer = init.glorot_uniform(dtype)
|
||||
for shape in [[3, 3], [3, 3, 3]]:
|
||||
for shape in [(3, 3), (3, 3, 3)]:
|
||||
result = initializer(mx.array(np.empty(shape)))
|
||||
with self.subTest(shape=shape):
|
||||
self.assertEqual(result.shape, shape)
|
||||
@@ -74,7 +74,7 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
def test_he_normal(self):
|
||||
for dtype in [mx.float32, mx.float16]:
|
||||
initializer = init.he_normal(dtype)
|
||||
for shape in [[3, 3], [3, 3, 3]]:
|
||||
for shape in [(3, 3), (3, 3, 3)]:
|
||||
result = initializer(mx.array(np.empty(shape)))
|
||||
with self.subTest(shape=shape):
|
||||
self.assertEqual(result.shape, shape)
|
||||
@@ -83,7 +83,7 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
def test_he_uniform(self):
|
||||
for dtype in [mx.float32, mx.float16]:
|
||||
initializer = init.he_uniform(dtype)
|
||||
for shape in [[3, 3], [3, 3, 3]]:
|
||||
for shape in [(3, 3), (3, 3, 3)]:
|
||||
result = initializer(mx.array(np.empty(shape)))
|
||||
with self.subTest(shape=shape):
|
||||
self.assertEqual(result.shape, shape)
|
||||
|
Reference in New Issue
Block a user