Implementation of mlx.random.multivariate_normal (#502) (#877)

* Implementation of mlx.random.multivariate_normal (#502)

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Updated typo in docstring

* Restricted multivariate_normal to  float32

* Generic mean and variance shapes

* Review edits

* Update mlx/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Test for ndim of mean and cov

* nits

* smaller size for test

* fix broadcasted sampling

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Luca Arnaboldi
2024-04-09 22:50:12 +02:00
committed by GitHub
parent a1a31eed27
commit fffe072028
6 changed files with 270 additions and 1 deletions

View File

@@ -179,6 +179,48 @@ void init_random(nb::module_& parent_module) {
array: The output array of random values.
)pbdoc");
m.def(
"multivariate_normal",
[](const array& mean,
const array& cov,
const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return multivariate_normal(
mean, cov, shape, type.value_or(float32), key, s);
},
"mean"_a,
"cov"_a,
"shape"_a = std::vector<int>{},
"dtype"_a.none() = float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def multivariate_normal(mean: array, cov: array, shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate jointly-normal random samples given a mean and covariance.
The matrix ``cov`` must be positive semi-definite. The behavior is
undefined if it is not. The only supported ``dtype`` is ``float32``.
Args:
mean (array): array of shape ``(..., n)``, the mean of the
distribution.
cov (array): array of shape ``(..., n, n)``, the covariance
matrix of the distribution. The batch shape ``...`` must be
broadcast-compatible with that of ``mean``.
shape (list(int), optional): The output shape must be
broadcast-compatible with ``mean.shape[:-1]`` and ``cov.shape[:-2]``.
If empty, the result shape is determined by broadcasting the batch
shapes of ``mean`` and ``cov``. Default: ``[]``.
dtype (Dtype, optional): The output type. Default: ``float32``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The output array of random values.
)pbdoc");
m.def(
"randint",
[](const ScalarOrArray& low,
const ScalarOrArray& high,

View File

@@ -1,5 +1,6 @@
# Copyright © 2023 Apple Inc.
import math
import unittest
import mlx.core as mx
@@ -101,6 +102,96 @@ class TestRandom(mlx_tests.MLXTestCase):
a = abs(mx.random.normal(shape=(10000,), loc=0, scale=1, dtype=hp))
self.assertTrue(mx.all(a < mx.inf))
def test_multivariate_normal(self):
key = mx.random.key(0)
mean = mx.array([0, 0])
cov = mx.array([[1, 0], [0, 1]])
a = mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu)
self.assertEqual(a.shape, (2,))
## Check dtypes
for t in [mx.float32]:
a = mx.random.multivariate_normal(
mean, cov, dtype=t, key=key, stream=mx.cpu
)
self.assertEqual(a.dtype, t)
for t in [
mx.int8,
mx.int32,
mx.int64,
mx.uint8,
mx.uint32,
mx.uint64,
mx.float16,
mx.bfloat16,
]:
with self.assertRaises(ValueError):
mx.random.multivariate_normal(
mean, cov, dtype=t, key=key, stream=mx.cpu
)
## Check incompatible shapes
with self.assertRaises(ValueError):
mean = mx.zeros((2, 2))
cov = mx.zeros((2, 2))
mx.random.multivariate_normal(mean, cov, shape=(3,), key=key, stream=mx.cpu)
with self.assertRaises(ValueError):
mean = mx.zeros((2))
cov = mx.zeros((2, 2, 2))
mx.random.multivariate_normal(mean, cov, shape=(3,), key=key, stream=mx.cpu)
with self.assertRaises(ValueError):
mean = mx.zeros((3,))
cov = mx.zeros((2, 2))
mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu)
with self.assertRaises(ValueError):
mean = mx.zeros((2,))
cov = mx.zeros((2, 3))
mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu)
## Different shape of mean and cov
mean = mx.array([[0, 7], [1, 2], [3, 4]])
cov = mx.array([[1, 0.5], [0.5, 1]])
a = mx.random.multivariate_normal(mean, cov, shape=(4, 3), stream=mx.cpu)
self.assertEqual(a.shape, (4, 3, 2))
## Check correcteness of the mean and covariance
n_test = int(1e5)
def check_jointly_gaussian(data, mean, cov):
empirical_mean = mx.mean(data, axis=0)
empirical_cov = (
(data - empirical_mean).T @ (data - empirical_mean) / data.shape[0]
)
N = data.shape[1]
self.assertTrue(
mx.allclose(
empirical_mean, mean, rtol=0.0, atol=10 * N**2 / math.sqrt(n_test)
)
)
self.assertTrue(
mx.allclose(
empirical_cov, cov, rtol=0.0, atol=10 * N**2 / math.sqrt(n_test)
)
)
mean = mx.array([4.0, 7.0])
cov = mx.array([[2, 0.5], [0.5, 1]])
data = mx.random.multivariate_normal(
mean, cov, shape=(n_test,), key=key, stream=mx.cpu
)
check_jointly_gaussian(data, mean, cov)
mean = mx.arange(3)
cov = mx.array([[1, -1, 0.5], [-1, 1, -0.5], [0.5, -0.5, 1]])
data = mx.random.multivariate_normal(
mean, cov, shape=(n_test,), key=key, stream=mx.cpu
)
check_jointly_gaussian(data, mean, cov)
def test_randint(self):
a = mx.random.randint(0, 1, [])
self.assertEqual(a.shape, ())