mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	[Feature] Added Sparse Initialization (#1498)
Co-authored-by: Saanidhyavats <saanidhyavats@gmail.com>
This commit is contained in:
		 Venkata Naga Aditya Datta Chivukula
					Venkata Naga Aditya Datta Chivukula
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							3d17077187
						
					
				
				
					commit
					430ffef58a
				
			| @@ -348,3 +348,50 @@ def he_uniform( | ||||
|         return mx.random.uniform(-limit, limit, a.shape, dtype=dtype) | ||||
|  | ||||
|     return initializer | ||||
|  | ||||
|  | ||||
| def sparse( | ||||
|     sparsity: float, | ||||
|     mean: float = 0.0, | ||||
|     std: float = 1.0, | ||||
|     dtype: mx.Dtype = mx.float32, | ||||
| ) -> Callable[[mx.array], mx.array]: | ||||
|     r"""An initializer that returns a sparse matrix. | ||||
|  | ||||
|     Args: | ||||
|         sparsity (float): The fraction of elements in each column to be set to | ||||
|         zero. | ||||
|         mean (float, optional): Mean of the normal distribution. Default: | ||||
|           ``0.0``. | ||||
|         std (float, optional): Standard deviation of the normal distribution. | ||||
|           Default: ``1.0``. | ||||
|         dtype (Dtype, optional): The data type of the array. Default: | ||||
|           ``float32``. | ||||
|  | ||||
|     Returns: | ||||
|         Callable[[array], array]: An initializer that returns an array with the | ||||
|         same shape as the input, filled with samples from a normal distribution. | ||||
|  | ||||
|     Example: | ||||
|  | ||||
|         >>> init_fn = nn.init.sparse(sparsity=0.5) | ||||
|         >>> init_fn(mx.zeros((2, 2))) | ||||
|         array([[-1.91187, -0.117483], | ||||
|        [0, 0]], dtype=float32) | ||||
|     """ | ||||
|  | ||||
|     def initializer(a: mx.array) -> mx.array: | ||||
|         if a.ndim != 2: | ||||
|             raise ValueError("Only tensors with 2 dimensions are supported") | ||||
|  | ||||
|         rows, cols = a.shape | ||||
|         num_zeros = int(mx.ceil(sparsity * cols)) | ||||
|  | ||||
|         order = mx.argsort(mx.random.uniform(shape=a.shape), axis=1) | ||||
|         a = mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype) | ||||
|  | ||||
|         a[mx.arange(rows).reshape(rows, 1), order[:, :num_zeros]] = 0 | ||||
|  | ||||
|         return a | ||||
|  | ||||
|     return initializer | ||||
|   | ||||
| @@ -89,6 +89,23 @@ class TestInit(mlx_tests.MLXTestCase): | ||||
|                     self.assertEqual(result.shape, shape) | ||||
|                     self.assertEqual(result.dtype, dtype) | ||||
|  | ||||
|     def test_sparse(self): | ||||
|         mean = 0.0 | ||||
|         std = 1.0 | ||||
|         sparsity = 0.5 | ||||
|         for dtype in [mx.float32, mx.float16]: | ||||
|             initializer = init.sparse(sparsity, mean, std, dtype=dtype) | ||||
|             for shape in [(3, 2), (2, 2), (4, 3)]: | ||||
|                 result = initializer(mx.array(np.empty(shape))) | ||||
|                 with self.subTest(shape=shape): | ||||
|                     self.assertEqual(result.shape, shape) | ||||
|                     self.assertEqual(result.dtype, dtype) | ||||
|                     self.assertEqual( | ||||
|                         (mx.sum(result == 0) >= 0.5 * shape[0] * shape[1]), True | ||||
|                     ) | ||||
|             with self.assertRaises(ValueError): | ||||
|                 result = initializer(mx.zeros((1,))) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user