[Feature] Added Sparse Initialization (#1498)

Co-authored-by: Saanidhyavats <saanidhyavats@gmail.com>
This commit is contained in:
Venkata Naga Aditya Datta Chivukula 2024-10-24 12:31:24 -07:00 committed by GitHub
parent 3d17077187
commit 430ffef58a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 64 additions and 0 deletions

View File

@ -348,3 +348,50 @@ def he_uniform(
return mx.random.uniform(-limit, limit, a.shape, dtype=dtype) return mx.random.uniform(-limit, limit, a.shape, dtype=dtype)
return initializer return initializer
def sparse(
sparsity: float,
mean: float = 0.0,
std: float = 1.0,
dtype: mx.Dtype = mx.float32,
) -> Callable[[mx.array], mx.array]:
r"""An initializer that returns a sparse matrix.
Args:
sparsity (float): The fraction of elements in each column to be set to
zero.
mean (float, optional): Mean of the normal distribution. Default:
``0.0``.
std (float, optional): Standard deviation of the normal distribution.
Default: ``1.0``.
dtype (Dtype, optional): The data type of the array. Default:
``float32``.
Returns:
Callable[[array], array]: An initializer that returns an array with the
same shape as the input, filled with samples from a normal distribution.
Example:
>>> init_fn = nn.init.sparse(sparsity=0.5)
>>> init_fn(mx.zeros((2, 2)))
array([[-1.91187, -0.117483],
[0, 0]], dtype=float32)
"""
def initializer(a: mx.array) -> mx.array:
if a.ndim != 2:
raise ValueError("Only tensors with 2 dimensions are supported")
rows, cols = a.shape
num_zeros = int(mx.ceil(sparsity * cols))
order = mx.argsort(mx.random.uniform(shape=a.shape), axis=1)
a = mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype)
a[mx.arange(rows).reshape(rows, 1), order[:, :num_zeros]] = 0
return a
return initializer

View File

@ -89,6 +89,23 @@ class TestInit(mlx_tests.MLXTestCase):
self.assertEqual(result.shape, shape) self.assertEqual(result.shape, shape)
self.assertEqual(result.dtype, dtype) self.assertEqual(result.dtype, dtype)
def test_sparse(self):
mean = 0.0
std = 1.0
sparsity = 0.5
for dtype in [mx.float32, mx.float16]:
initializer = init.sparse(sparsity, mean, std, dtype=dtype)
for shape in [(3, 2), (2, 2), (4, 3)]:
result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape):
self.assertEqual(result.shape, shape)
self.assertEqual(result.dtype, dtype)
self.assertEqual(
(mx.sum(result == 0) >= 0.5 * shape[0] * shape[1]), True
)
with self.assertRaises(ValueError):
result = initializer(mx.zeros((1,)))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()