mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Add loc and scale to random.normal (#638)
* Add loc and scale to random.normal * Add tests for loc and scale for random.normal * Run pre-commit hooks * Fix code review
This commit is contained in:
		| @@ -80,6 +80,20 @@ class TestRandom(mlx_tests.MLXTestCase): | ||||
|             a = mx.random.normal(dtype=t) | ||||
|             self.assertEqual(a.dtype, t) | ||||
|  | ||||
|         # Generate with a given mean and standard deviation | ||||
|         loc = 1.0 | ||||
|         scale = 2.0 | ||||
|  | ||||
|         a = mx.random.normal(shape=(3, 2), loc=loc, scale=scale, key=key) | ||||
|         b = scale * mx.random.normal(shape=(3, 2), key=key) + loc | ||||
|         self.assertTrue(mx.allclose(a, b)) | ||||
|  | ||||
|         a = mx.random.normal( | ||||
|             shape=(3, 2), loc=loc, scale=scale, dtype=mx.float16, key=key | ||||
|         ) | ||||
|         b = scale * mx.random.normal(shape=(3, 2), dtype=mx.float16, key=key) + loc | ||||
|         self.assertTrue(mx.allclose(a, b)) | ||||
|  | ||||
|         self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype) | ||||
|  | ||||
|     def test_randint(self): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Noah Farr
					Noah Farr