mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Common neural network initializers nn.initializers (#456)
				
					
				
			* initial commit: constant, normal, uniform * identity, glorot and he initializers * docstrings * rm file * nits * nits * nits * testing suite * docs * nits in docs * more docs * remove unused template * rename packakge to nn.innit * docs, receptive field * more docs --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		
							
								
								
									
										94
									
								
								python/tests/test_init.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								python/tests/test_init.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn.init as init | ||||
| import mlx_tests | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| class TestInit(mlx_tests.MLXTestCase): | ||||
|     def test_constant(self): | ||||
|         value = 5.0 | ||||
|  | ||||
|         for dtype in [mx.float32, mx.float16]: | ||||
|             initializer = init.constant(value, dtype) | ||||
|             for shape in [[3], [3, 3], [3, 3, 3]]: | ||||
|                 result = initializer(mx.array(mx.zeros(shape))) | ||||
|                 with self.subTest(shape=shape): | ||||
|                     self.assertEqual(result.shape, shape) | ||||
|                     self.assertEqual(result.dtype, dtype) | ||||
|  | ||||
|     def test_normal(self): | ||||
|         mean = 0.0 | ||||
|         std = 1.0 | ||||
|         for dtype in [mx.float32, mx.float16]: | ||||
|             initializer = init.normal(mean, std, dtype=dtype) | ||||
|             for shape in [[3], [3, 3], [3, 3, 3]]: | ||||
|                 result = initializer(mx.array(np.empty(shape))) | ||||
|                 with self.subTest(shape=shape): | ||||
|                     self.assertEqual(result.shape, shape) | ||||
|                     self.assertEqual(result.dtype, dtype) | ||||
|  | ||||
|     def test_uniform(self): | ||||
|         low = -1.0 | ||||
|         high = 1.0 | ||||
|  | ||||
|         for dtype in [mx.float32, mx.float16]: | ||||
|             initializer = init.uniform(low, high, dtype) | ||||
|             for shape in [[3], [3, 3], [3, 3, 3]]: | ||||
|                 result = initializer(mx.array(np.empty(shape))) | ||||
|                 with self.subTest(shape=shape): | ||||
|                     self.assertEqual(result.shape, shape) | ||||
|                     self.assertEqual(result.dtype, dtype) | ||||
|                     self.assertTrue(mx.all(result >= low) and mx.all(result <= high)) | ||||
|  | ||||
|     def test_identity(self): | ||||
|         for dtype in [mx.float32, mx.float16]: | ||||
|             initializer = init.identity(dtype) | ||||
|             for shape in [[3], [3, 3], [3, 3, 3]]: | ||||
|                 result = initializer(mx.zeros((3, 3))) | ||||
|                 self.assertTrue(mx.array_equal(result, mx.eye(3))) | ||||
|                 self.assertEqual(result.dtype, dtype) | ||||
|                 with self.assertRaises(ValueError): | ||||
|                     result = initializer(mx.zeros((3, 2))) | ||||
|  | ||||
|     def test_glorot_normal(self): | ||||
|         for dtype in [mx.float32, mx.float16]: | ||||
|             initializer = init.glorot_normal(dtype) | ||||
|             for shape in [[3, 3], [3, 3, 3]]: | ||||
|                 result = initializer(mx.array(np.empty(shape))) | ||||
|                 with self.subTest(shape=shape): | ||||
|                     self.assertEqual(result.shape, shape) | ||||
|                     self.assertEqual(result.dtype, dtype) | ||||
|  | ||||
|     def test_glorot_uniform(self): | ||||
|         for dtype in [mx.float32, mx.float16]: | ||||
|             initializer = init.glorot_uniform(dtype) | ||||
|             for shape in [[3, 3], [3, 3, 3]]: | ||||
|                 result = initializer(mx.array(np.empty(shape))) | ||||
|                 with self.subTest(shape=shape): | ||||
|                     self.assertEqual(result.shape, shape) | ||||
|                     self.assertEqual(result.dtype, dtype) | ||||
|  | ||||
|     def test_he_normal(self): | ||||
|         for dtype in [mx.float32, mx.float16]: | ||||
|             initializer = init.he_normal(dtype) | ||||
|             for shape in [[3, 3], [3, 3, 3]]: | ||||
|                 result = initializer(mx.array(np.empty(shape))) | ||||
|                 with self.subTest(shape=shape): | ||||
|                     self.assertEqual(result.shape, shape) | ||||
|                     self.assertEqual(result.dtype, dtype) | ||||
|  | ||||
|     def test_he_uniform(self): | ||||
|         for dtype in [mx.float32, mx.float16]: | ||||
|             initializer = init.he_uniform(dtype) | ||||
|             for shape in [[3, 3], [3, 3, 3]]: | ||||
|                 result = initializer(mx.array(np.empty(shape))) | ||||
|                 with self.subTest(shape=shape): | ||||
|                     self.assertEqual(result.shape, shape) | ||||
|                     self.assertEqual(result.dtype, dtype) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
		Reference in New Issue
	
	Block a user
	 LeonEricsson
					LeonEricsson