Make shape a tuple (#591)

* shape tuple

* also remove simplify from docs

* rebase
This commit is contained in:
Awni Hannun
2024-01-30 13:11:01 -08:00
committed by GitHub
parent d3a9005454
commit 09b9275027
13 changed files with 141 additions and 140 deletions

View File

@@ -13,7 +13,7 @@ class TestInit(mlx_tests.MLXTestCase):
for dtype in [mx.float32, mx.float16]:
initializer = init.constant(value, dtype)
for shape in [[3], [3, 3], [3, 3, 3]]:
for shape in [(3,), (3, 3), (3, 3, 3)]:
result = initializer(mx.array(mx.zeros(shape)))
with self.subTest(shape=shape):
self.assertEqual(result.shape, shape)
@@ -24,7 +24,7 @@ class TestInit(mlx_tests.MLXTestCase):
std = 1.0
for dtype in [mx.float32, mx.float16]:
initializer = init.normal(mean, std, dtype=dtype)
for shape in [[3], [3, 3], [3, 3, 3]]:
for shape in [(3,), (3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape):
self.assertEqual(result.shape, shape)
@@ -36,7 +36,7 @@ class TestInit(mlx_tests.MLXTestCase):
for dtype in [mx.float32, mx.float16]:
initializer = init.uniform(low, high, dtype)
for shape in [[3], [3, 3], [3, 3, 3]]:
for shape in [(3,), (3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape):
self.assertEqual(result.shape, shape)
@@ -46,7 +46,7 @@ class TestInit(mlx_tests.MLXTestCase):
def test_identity(self):
for dtype in [mx.float32, mx.float16]:
initializer = init.identity(dtype)
for shape in [[3], [3, 3], [3, 3, 3]]:
for shape in [(3,), (3, 3), (3, 3, 3)]:
result = initializer(mx.zeros((3, 3)))
self.assertTrue(mx.array_equal(result, mx.eye(3)))
self.assertEqual(result.dtype, dtype)
@@ -56,7 +56,7 @@ class TestInit(mlx_tests.MLXTestCase):
def test_glorot_normal(self):
for dtype in [mx.float32, mx.float16]:
initializer = init.glorot_normal(dtype)
for shape in [[3, 3], [3, 3, 3]]:
for shape in [(3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape):
self.assertEqual(result.shape, shape)
@@ -65,7 +65,7 @@ class TestInit(mlx_tests.MLXTestCase):
def test_glorot_uniform(self):
for dtype in [mx.float32, mx.float16]:
initializer = init.glorot_uniform(dtype)
for shape in [[3, 3], [3, 3, 3]]:
for shape in [(3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape):
self.assertEqual(result.shape, shape)
@@ -74,7 +74,7 @@ class TestInit(mlx_tests.MLXTestCase):
def test_he_normal(self):
for dtype in [mx.float32, mx.float16]:
initializer = init.he_normal(dtype)
for shape in [[3, 3], [3, 3, 3]]:
for shape in [(3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape):
self.assertEqual(result.shape, shape)
@@ -83,7 +83,7 @@ class TestInit(mlx_tests.MLXTestCase):
def test_he_uniform(self):
for dtype in [mx.float32, mx.float16]:
initializer = init.he_uniform(dtype)
for shape in [[3, 3], [3, 3, 3]]:
for shape in [(3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape):
self.assertEqual(result.shape, shape)