[Feature] Added Sparse Initialization (#1498)

Co-authored-by: Saanidhyavats <saanidhyavats@gmail.com>
This commit is contained in:
Venkata Naga Aditya Datta Chivukula
2024-10-24 12:31:24 -07:00
committed by GitHub
parent 3d17077187
commit 430ffef58a
2 changed files with 64 additions and 0 deletions

View File

@@ -348,3 +348,50 @@ def he_uniform(
return mx.random.uniform(-limit, limit, a.shape, dtype=dtype)
return initializer
def sparse(
sparsity: float,
mean: float = 0.0,
std: float = 1.0,
dtype: mx.Dtype = mx.float32,
) -> Callable[[mx.array], mx.array]:
r"""An initializer that returns a sparse matrix.
Args:
sparsity (float): The fraction of elements in each column to be set to
zero.
mean (float, optional): Mean of the normal distribution. Default:
``0.0``.
std (float, optional): Standard deviation of the normal distribution.
Default: ``1.0``.
dtype (Dtype, optional): The data type of the array. Default:
``float32``.
Returns:
Callable[[array], array]: An initializer that returns an array with the
same shape as the input, filled with samples from a normal distribution.
Example:
>>> init_fn = nn.init.sparse(sparsity=0.5)
>>> init_fn(mx.zeros((2, 2)))
array([[-1.91187, -0.117483],
[0, 0]], dtype=float32)
"""
def initializer(a: mx.array) -> mx.array:
if a.ndim != 2:
raise ValueError("Only tensors with 2 dimensions are supported")
rows, cols = a.shape
num_zeros = int(mx.ceil(sparsity * cols))
order = mx.argsort(mx.random.uniform(shape=a.shape), axis=1)
a = mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype)
a[mx.arange(rows).reshape(rows, 1), order[:, :num_zeros]] = 0
return a
return initializer