Implementation of mlx.random.multivariate_normal (#502) (#877)

* Implementation of mlx.random.multivariate_normal (#502)

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Updated typo in docstring

* Restricted multivariate_normal to  float32

* Generic mean and variance shapes

* Review edits

* Update mlx/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Test for ndim of mean and cov

* nits

* smaller size for test

* fix broadcasted sampling

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Luca Arnaboldi
2024-04-09 22:50:12 +02:00
committed by GitHub
parent a1a31eed27
commit fffe072028
6 changed files with 270 additions and 1 deletions

View File

@@ -1,5 +1,6 @@
# Copyright © 2023 Apple Inc.
import math
import unittest
import mlx.core as mx
@@ -101,6 +102,96 @@ class TestRandom(mlx_tests.MLXTestCase):
a = abs(mx.random.normal(shape=(10000,), loc=0, scale=1, dtype=hp))
self.assertTrue(mx.all(a < mx.inf))
def test_multivariate_normal(self):
key = mx.random.key(0)
mean = mx.array([0, 0])
cov = mx.array([[1, 0], [0, 1]])
a = mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu)
self.assertEqual(a.shape, (2,))
## Check dtypes
for t in [mx.float32]:
a = mx.random.multivariate_normal(
mean, cov, dtype=t, key=key, stream=mx.cpu
)
self.assertEqual(a.dtype, t)
for t in [
mx.int8,
mx.int32,
mx.int64,
mx.uint8,
mx.uint32,
mx.uint64,
mx.float16,
mx.bfloat16,
]:
with self.assertRaises(ValueError):
mx.random.multivariate_normal(
mean, cov, dtype=t, key=key, stream=mx.cpu
)
## Check incompatible shapes
with self.assertRaises(ValueError):
mean = mx.zeros((2, 2))
cov = mx.zeros((2, 2))
mx.random.multivariate_normal(mean, cov, shape=(3,), key=key, stream=mx.cpu)
with self.assertRaises(ValueError):
mean = mx.zeros((2))
cov = mx.zeros((2, 2, 2))
mx.random.multivariate_normal(mean, cov, shape=(3,), key=key, stream=mx.cpu)
with self.assertRaises(ValueError):
mean = mx.zeros((3,))
cov = mx.zeros((2, 2))
mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu)
with self.assertRaises(ValueError):
mean = mx.zeros((2,))
cov = mx.zeros((2, 3))
mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu)
## Different shape of mean and cov
mean = mx.array([[0, 7], [1, 2], [3, 4]])
cov = mx.array([[1, 0.5], [0.5, 1]])
a = mx.random.multivariate_normal(mean, cov, shape=(4, 3), stream=mx.cpu)
self.assertEqual(a.shape, (4, 3, 2))
## Check correcteness of the mean and covariance
n_test = int(1e5)
def check_jointly_gaussian(data, mean, cov):
empirical_mean = mx.mean(data, axis=0)
empirical_cov = (
(data - empirical_mean).T @ (data - empirical_mean) / data.shape[0]
)
N = data.shape[1]
self.assertTrue(
mx.allclose(
empirical_mean, mean, rtol=0.0, atol=10 * N**2 / math.sqrt(n_test)
)
)
self.assertTrue(
mx.allclose(
empirical_cov, cov, rtol=0.0, atol=10 * N**2 / math.sqrt(n_test)
)
)
mean = mx.array([4.0, 7.0])
cov = mx.array([[2, 0.5], [0.5, 1]])
data = mx.random.multivariate_normal(
mean, cov, shape=(n_test,), key=key, stream=mx.cpu
)
check_jointly_gaussian(data, mean, cov)
mean = mx.arange(3)
cov = mx.array([[1, -1, 0.5], [-1, 1, -0.5], [0.5, -0.5, 1]])
data = mx.random.multivariate_normal(
mean, cov, shape=(n_test,), key=key, stream=mx.cpu
)
check_jointly_gaussian(data, mean, cov)
def test_randint(self):
a = mx.random.randint(0, 1, [])
self.assertEqual(a.shape, ())