[Feature] Added Sparse Initialization (#1498)

Co-authored-by: Saanidhyavats <saanidhyavats@gmail.com>
This commit is contained in:
Venkata Naga Aditya Datta Chivukula
2024-10-24 12:31:24 -07:00
committed by GitHub
parent 3d17077187
commit 430ffef58a
2 changed files with 64 additions and 0 deletions

View File

@@ -89,6 +89,23 @@ class TestInit(mlx_tests.MLXTestCase):
self.assertEqual(result.shape, shape)
self.assertEqual(result.dtype, dtype)
def test_sparse(self):
mean = 0.0
std = 1.0
sparsity = 0.5
for dtype in [mx.float32, mx.float16]:
initializer = init.sparse(sparsity, mean, std, dtype=dtype)
for shape in [(3, 2), (2, 2), (4, 3)]:
result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape):
self.assertEqual(result.shape, shape)
self.assertEqual(result.dtype, dtype)
self.assertEqual(
(mx.sum(result == 0) >= 0.5 * shape[0] * shape[1]), True
)
with self.assertRaises(ValueError):
result = initializer(mx.zeros((1,)))
if __name__ == "__main__":
unittest.main()