mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Implement sampling from laplace distribution. (#1279)
This commit is contained in:
		| @@ -64,43 +64,50 @@ class TestRandom(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         self.assertEqual(mx.random.uniform().dtype, mx.random.uniform(dtype=None).dtype) | ||||
|  | ||||
|     def test_normal(self): | ||||
|         key = mx.random.key(0) | ||||
|         a = mx.random.normal(key=key) | ||||
|         self.assertEqual(a.shape, ()) | ||||
|         self.assertEqual(a.dtype, mx.float32) | ||||
|     def test_normal_and_laplace(self): | ||||
|         # Same tests for normal and laplace. | ||||
|         for distribution_sampler in [mx.random.normal, mx.random.laplace]: | ||||
|             key = mx.random.key(0) | ||||
|             a = distribution_sampler(key=key) | ||||
|             self.assertEqual(a.shape, ()) | ||||
|             self.assertEqual(a.dtype, mx.float32) | ||||
|  | ||||
|         b = mx.random.normal(key=key) | ||||
|         self.assertEqual(a.item(), b.item()) | ||||
|             b = distribution_sampler(key=key) | ||||
|             self.assertEqual(a.item(), b.item()) | ||||
|  | ||||
|         a = mx.random.normal(shape=(2, 3)) | ||||
|         self.assertEqual(a.shape, (2, 3)) | ||||
|             a = distribution_sampler(shape=(2, 3)) | ||||
|             self.assertEqual(a.shape, (2, 3)) | ||||
|  | ||||
|         ## Generate in float16 or bfloat16 | ||||
|         for t in [mx.float16, mx.bfloat16]: | ||||
|             a = mx.random.normal(dtype=t) | ||||
|             self.assertEqual(a.dtype, t) | ||||
|             ## Generate in float16 or bfloat16 | ||||
|             for t in [mx.float16, mx.bfloat16]: | ||||
|                 a = distribution_sampler(dtype=t) | ||||
|                 self.assertEqual(a.dtype, t) | ||||
|  | ||||
|         # Generate with a given mean and standard deviation | ||||
|         loc = 1.0 | ||||
|         scale = 2.0 | ||||
|             # Generate with a given mean and standard deviation | ||||
|             loc = 1.0 | ||||
|             scale = 2.0 | ||||
|  | ||||
|         a = mx.random.normal(shape=(3, 2), loc=loc, scale=scale, key=key) | ||||
|         b = scale * mx.random.normal(shape=(3, 2), key=key) + loc | ||||
|         self.assertTrue(mx.allclose(a, b)) | ||||
|             a = distribution_sampler(shape=(3, 2), loc=loc, scale=scale, key=key) | ||||
|             b = scale * distribution_sampler(shape=(3, 2), key=key) + loc | ||||
|             self.assertTrue(mx.allclose(a, b)) | ||||
|  | ||||
|         a = mx.random.normal( | ||||
|             shape=(3, 2), loc=loc, scale=scale, dtype=mx.float16, key=key | ||||
|         ) | ||||
|         b = scale * mx.random.normal(shape=(3, 2), dtype=mx.float16, key=key) + loc | ||||
|         self.assertTrue(mx.allclose(a, b)) | ||||
|             a = distribution_sampler( | ||||
|                 shape=(3, 2), loc=loc, scale=scale, dtype=mx.float16, key=key | ||||
|             ) | ||||
|             b = ( | ||||
|                 scale * distribution_sampler(shape=(3, 2), dtype=mx.float16, key=key) | ||||
|                 + loc | ||||
|             ) | ||||
|             self.assertTrue(mx.allclose(a, b)) | ||||
|  | ||||
|         self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype) | ||||
|             self.assertEqual( | ||||
|                 distribution_sampler().dtype, distribution_sampler(dtype=None).dtype | ||||
|             ) | ||||
|  | ||||
|         # Test not getting -inf or inf with half precison | ||||
|         for hp in [mx.float16, mx.bfloat16]: | ||||
|             a = abs(mx.random.normal(shape=(10000,), loc=0, scale=1, dtype=hp)) | ||||
|             self.assertTrue(mx.all(a < mx.inf)) | ||||
|             # Test not getting -inf or inf with half precison | ||||
|             for hp in [mx.float16, mx.bfloat16]: | ||||
|                 a = abs(distribution_sampler(shape=(10000,), loc=0, scale=1, dtype=hp)) | ||||
|                 self.assertTrue(mx.all(a < mx.inf)) | ||||
|  | ||||
|     def test_multivariate_normal(self): | ||||
|         key = mx.random.key(0) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 fgranqvist
					fgranqvist