diff --git a/docs/src/python/random.rst b/docs/src/python/random.rst index 706378f9d..d08d5a7df 100644 --- a/docs/src/python/random.rst +++ b/docs/src/python/random.rst @@ -38,6 +38,7 @@ we use a splittable version of Threefry, which is a counter-based PRNG. gumbel key normal + multivariate_normal randint seed split diff --git a/mlx/random.cpp b/mlx/random.cpp index fae2e592c..05405acbb 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -1,8 +1,9 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include +#include "mlx/linalg.h" #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/random.h" @@ -192,6 +193,74 @@ array normal( return samples; } +array multivariate_normal( + const array& mean, + const array& cov, + const std::vector& shape, + Dtype dtype, + const std::optional& key /* = nullopt */, + StreamOrDevice s) { + auto stream = to_stream(s); + + if (dtype != float32) { + throw std::invalid_argument("[multivariate_normal] dtype must be float32."); + } + + if (mean.ndim() < 1) { + throw std::invalid_argument( + "[multivariate_normal] mean must have at least one dimension."); + } + + if (cov.ndim() < 2) { + throw std::invalid_argument( + "[multivariate_normal] cov must have at least two dimensions."); + } + + auto n = mean.shape(-1); + + // Check shapes comatibility of mean and cov + if (cov.shape(-1) != cov.shape(-2)) { + throw std::invalid_argument( + "[multivariate_normal] last two dimensions of cov must be equal."); + } + if (n != cov.shape(-1)) { + throw std::invalid_argument( + "[multivariate_normal] mean and cov must have compatible shapes."); + } + + // Compute output shape + std::vector truncated_output_shape; + + auto truncated_mean_shape = + std::vector(mean.shape().begin(), mean.shape().end() - 1); + auto truncated_cov_shape = + std::vector(cov.shape().begin(), cov.shape().end() - 2); + auto output_shape = + broadcast_shapes(truncated_cov_shape, truncated_mean_shape); + output_shape = broadcast_shapes(output_shape, shape); + output_shape.push_back(n); + + // Compute the square-root of the covariance matrix, using the SVD + auto covariance = astype(cov, float32, stream); + auto SVD = linalg::svd(covariance, stream); + auto std = astype( + matmul( + multiply( + SVD[0], expand_dims(sqrt(SVD[1], stream), -2, stream), stream), + SVD[2], + stream), + dtype, + stream); + + // Generate standard the samples + auto standard_normal = normal(output_shape, dtype, 0.0, 1.0, key, stream); + auto scaled_out = squeeze( + matmul(expand_dims(standard_normal, -2, stream), std, stream), + -2, + stream); + return add(mean, scaled_out, stream); +} + array randint( const array& low, const array& high, diff --git a/mlx/random.h b/mlx/random.h index 1397b32d7..5f6b9e0d6 100644 --- a/mlx/random.h +++ b/mlx/random.h @@ -121,6 +121,15 @@ inline array normal( return normal(shape, float32, 0.0, 1.0, key, s); } +/** Generate samples from a multivariate normal distribution. **/ +array multivariate_normal( + const array& mean, + const array& cov, + const std::vector& shape, + Dtype dtype, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + /** Generate integer samples uniformly at random */ array randint( const array& low, diff --git a/python/src/random.cpp b/python/src/random.cpp index dde8469d4..3f082e15d 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -179,6 +179,48 @@ void init_random(nb::module_& parent_module) { array: The output array of random values. )pbdoc"); m.def( + "multivariate_normal", + [](const array& mean, + const array& cov, + const std::vector& shape, + std::optional type, + const std::optional& key_, + StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); + return multivariate_normal( + mean, cov, shape, type.value_or(float32), key, s); + }, + "mean"_a, + "cov"_a, + "shape"_a = std::vector{}, + "dtype"_a.none() = float32, + "key"_a = nb::none(), + "stream"_a = nb::none(), + nb::sig( + "def multivariate_normal(mean: array, cov: array, shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Generate jointly-normal random samples given a mean and covariance. + + The matrix ``cov`` must be positive semi-definite. The behavior is + undefined if it is not. The only supported ``dtype`` is ``float32``. + + Args: + mean (array): array of shape ``(..., n)``, the mean of the + distribution. + cov (array): array of shape ``(..., n, n)``, the covariance + matrix of the distribution. The batch shape ``...`` must be + broadcast-compatible with that of ``mean``. + shape (list(int), optional): The output shape must be + broadcast-compatible with ``mean.shape[:-1]`` and ``cov.shape[:-2]``. + If empty, the result shape is determined by broadcasting the batch + shapes of ``mean`` and ``cov``. Default: ``[]``. + dtype (Dtype, optional): The output type. Default: ``float32``. + key (array, optional): A PRNG key. Default: ``None``. + + Returns: + array: The output array of random values. + )pbdoc"); + m.def( "randint", [](const ScalarOrArray& low, const ScalarOrArray& high, diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 7515cf468..4ddef837b 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -1,5 +1,6 @@ # Copyright © 2023 Apple Inc. +import math import unittest import mlx.core as mx @@ -101,6 +102,96 @@ class TestRandom(mlx_tests.MLXTestCase): a = abs(mx.random.normal(shape=(10000,), loc=0, scale=1, dtype=hp)) self.assertTrue(mx.all(a < mx.inf)) + def test_multivariate_normal(self): + key = mx.random.key(0) + mean = mx.array([0, 0]) + cov = mx.array([[1, 0], [0, 1]]) + + a = mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu) + self.assertEqual(a.shape, (2,)) + + ## Check dtypes + for t in [mx.float32]: + a = mx.random.multivariate_normal( + mean, cov, dtype=t, key=key, stream=mx.cpu + ) + self.assertEqual(a.dtype, t) + for t in [ + mx.int8, + mx.int32, + mx.int64, + mx.uint8, + mx.uint32, + mx.uint64, + mx.float16, + mx.bfloat16, + ]: + with self.assertRaises(ValueError): + mx.random.multivariate_normal( + mean, cov, dtype=t, key=key, stream=mx.cpu + ) + + ## Check incompatible shapes + with self.assertRaises(ValueError): + mean = mx.zeros((2, 2)) + cov = mx.zeros((2, 2)) + mx.random.multivariate_normal(mean, cov, shape=(3,), key=key, stream=mx.cpu) + + with self.assertRaises(ValueError): + mean = mx.zeros((2)) + cov = mx.zeros((2, 2, 2)) + mx.random.multivariate_normal(mean, cov, shape=(3,), key=key, stream=mx.cpu) + + with self.assertRaises(ValueError): + mean = mx.zeros((3,)) + cov = mx.zeros((2, 2)) + mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu) + + with self.assertRaises(ValueError): + mean = mx.zeros((2,)) + cov = mx.zeros((2, 3)) + mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu) + + ## Different shape of mean and cov + mean = mx.array([[0, 7], [1, 2], [3, 4]]) + cov = mx.array([[1, 0.5], [0.5, 1]]) + a = mx.random.multivariate_normal(mean, cov, shape=(4, 3), stream=mx.cpu) + self.assertEqual(a.shape, (4, 3, 2)) + + ## Check correcteness of the mean and covariance + n_test = int(1e5) + + def check_jointly_gaussian(data, mean, cov): + empirical_mean = mx.mean(data, axis=0) + empirical_cov = ( + (data - empirical_mean).T @ (data - empirical_mean) / data.shape[0] + ) + N = data.shape[1] + self.assertTrue( + mx.allclose( + empirical_mean, mean, rtol=0.0, atol=10 * N**2 / math.sqrt(n_test) + ) + ) + self.assertTrue( + mx.allclose( + empirical_cov, cov, rtol=0.0, atol=10 * N**2 / math.sqrt(n_test) + ) + ) + + mean = mx.array([4.0, 7.0]) + cov = mx.array([[2, 0.5], [0.5, 1]]) + data = mx.random.multivariate_normal( + mean, cov, shape=(n_test,), key=key, stream=mx.cpu + ) + check_jointly_gaussian(data, mean, cov) + + mean = mx.arange(3) + cov = mx.array([[1, -1, 0.5], [-1, 1, -0.5], [0.5, -0.5, 1]]) + data = mx.random.multivariate_normal( + mean, cov, shape=(n_test,), key=key, stream=mx.cpu + ) + check_jointly_gaussian(data, mean, cov) + def test_randint(self): a = mx.random.randint(0, 1, []) self.assertEqual(a.shape, ()) diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index 7ce057319..42259e065 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -420,6 +420,63 @@ TEST_CASE("test random normal") { } } +TEST_CASE("test random multivariate_normal") { + { + auto mean = zeros({3}); + auto cov = eye(3); + auto x = random::multivariate_normal(mean, cov, {1000}, float32); + CHECK_EQ(x.shape(), std::vector({1000, 3})); + CHECK_EQ(x.dtype(), float32); + } + + // Limit case + { + auto mean = array({0, 0}); + auto cov = array({1., -1, -.1, 1.}); + cov = reshape(cov, {2, 2}); + auto x = random::multivariate_normal(mean, cov, {1}, float32); + CHECK_EQ(x.shape(), std::vector({1, 2})); + CHECK_EQ(x.dtype(), float32); + } + + // Check wrong shapes + { + auto mean = zeros({3, 1}); + auto cov = eye(3); + CHECK_THROWS_AS( + random::multivariate_normal( + mean, + cov, + { + 1000, + }, + float32), + std::invalid_argument); + } + { + auto mean = zeros({3}); + auto cov = zeros({1, 2, 3, 3}); + auto x = random::multivariate_normal(mean, cov, {1000, 2}, float32); + CHECK_EQ(x.shape(), std::vector({1000, 2, 3})); + } + { + auto mean = zeros({3}); + auto cov = eye(4); + CHECK_THROWS_AS( + random::multivariate_normal(mean, cov, {1000, 3}, float32), + std::invalid_argument); + } + + // Check wrong type + { + auto mean = zeros({3}); + auto cov = eye(3); + CHECK_THROWS_AS( + random::multivariate_normal(mean, cov, {1000, 3}, float16), + std::invalid_argument); + } +} + TEST_CASE("test random randint") { CHECK_THROWS_AS( random::randint(array(3), array(5), {1}, float32), std::invalid_argument);