mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
* Implementation of mlx.random.multivariate_normal (#502) * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Updated typo in docstring * Restricted multivariate_normal to float32 * Generic mean and variance shapes * Review edits * Update mlx/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Test for ndim of mean and cov * nits * smaller size for test * fix broadcasted sampling --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
a1a31eed27
commit
fffe072028
@ -38,6 +38,7 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
|
|||||||
gumbel
|
gumbel
|
||||||
key
|
key
|
||||||
normal
|
normal
|
||||||
|
multivariate_normal
|
||||||
randint
|
randint
|
||||||
seed
|
seed
|
||||||
split
|
split
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/linalg.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/random.h"
|
#include "mlx/random.h"
|
||||||
@ -192,6 +193,74 @@ array normal(
|
|||||||
return samples;
|
return samples;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array multivariate_normal(
|
||||||
|
const array& mean,
|
||||||
|
const array& cov,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype,
|
||||||
|
const std::optional<array>& key /* = nullopt */,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
auto stream = to_stream(s);
|
||||||
|
|
||||||
|
if (dtype != float32) {
|
||||||
|
throw std::invalid_argument("[multivariate_normal] dtype must be float32.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mean.ndim() < 1) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[multivariate_normal] mean must have at least one dimension.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cov.ndim() < 2) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[multivariate_normal] cov must have at least two dimensions.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto n = mean.shape(-1);
|
||||||
|
|
||||||
|
// Check shapes comatibility of mean and cov
|
||||||
|
if (cov.shape(-1) != cov.shape(-2)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[multivariate_normal] last two dimensions of cov must be equal.");
|
||||||
|
}
|
||||||
|
if (n != cov.shape(-1)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[multivariate_normal] mean and cov must have compatible shapes.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute output shape
|
||||||
|
std::vector<int> truncated_output_shape;
|
||||||
|
|
||||||
|
auto truncated_mean_shape =
|
||||||
|
std::vector<int>(mean.shape().begin(), mean.shape().end() - 1);
|
||||||
|
auto truncated_cov_shape =
|
||||||
|
std::vector<int>(cov.shape().begin(), cov.shape().end() - 2);
|
||||||
|
auto output_shape =
|
||||||
|
broadcast_shapes(truncated_cov_shape, truncated_mean_shape);
|
||||||
|
output_shape = broadcast_shapes(output_shape, shape);
|
||||||
|
output_shape.push_back(n);
|
||||||
|
|
||||||
|
// Compute the square-root of the covariance matrix, using the SVD
|
||||||
|
auto covariance = astype(cov, float32, stream);
|
||||||
|
auto SVD = linalg::svd(covariance, stream);
|
||||||
|
auto std = astype(
|
||||||
|
matmul(
|
||||||
|
multiply(
|
||||||
|
SVD[0], expand_dims(sqrt(SVD[1], stream), -2, stream), stream),
|
||||||
|
SVD[2],
|
||||||
|
stream),
|
||||||
|
dtype,
|
||||||
|
stream);
|
||||||
|
|
||||||
|
// Generate standard the samples
|
||||||
|
auto standard_normal = normal(output_shape, dtype, 0.0, 1.0, key, stream);
|
||||||
|
auto scaled_out = squeeze(
|
||||||
|
matmul(expand_dims(standard_normal, -2, stream), std, stream),
|
||||||
|
-2,
|
||||||
|
stream);
|
||||||
|
return add(mean, scaled_out, stream);
|
||||||
|
}
|
||||||
|
|
||||||
array randint(
|
array randint(
|
||||||
const array& low,
|
const array& low,
|
||||||
const array& high,
|
const array& high,
|
||||||
|
@ -121,6 +121,15 @@ inline array normal(
|
|||||||
return normal(shape, float32, 0.0, 1.0, key, s);
|
return normal(shape, float32, 0.0, 1.0, key, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Generate samples from a multivariate normal distribution. **/
|
||||||
|
array multivariate_normal(
|
||||||
|
const array& mean,
|
||||||
|
const array& cov,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype,
|
||||||
|
const std::optional<array>& key = std::nullopt,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Generate integer samples uniformly at random */
|
/** Generate integer samples uniformly at random */
|
||||||
array randint(
|
array randint(
|
||||||
const array& low,
|
const array& low,
|
||||||
|
@ -179,6 +179,48 @@ void init_random(nb::module_& parent_module) {
|
|||||||
array: The output array of random values.
|
array: The output array of random values.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
|
"multivariate_normal",
|
||||||
|
[](const array& mean,
|
||||||
|
const array& cov,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
std::optional<Dtype> type,
|
||||||
|
const std::optional<array>& key_,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
auto key = key_ ? key_.value() : default_key().next();
|
||||||
|
return multivariate_normal(
|
||||||
|
mean, cov, shape, type.value_or(float32), key, s);
|
||||||
|
},
|
||||||
|
"mean"_a,
|
||||||
|
"cov"_a,
|
||||||
|
"shape"_a = std::vector<int>{},
|
||||||
|
"dtype"_a.none() = float32,
|
||||||
|
"key"_a = nb::none(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def multivariate_normal(mean: array, cov: array, shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Generate jointly-normal random samples given a mean and covariance.
|
||||||
|
|
||||||
|
The matrix ``cov`` must be positive semi-definite. The behavior is
|
||||||
|
undefined if it is not. The only supported ``dtype`` is ``float32``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mean (array): array of shape ``(..., n)``, the mean of the
|
||||||
|
distribution.
|
||||||
|
cov (array): array of shape ``(..., n, n)``, the covariance
|
||||||
|
matrix of the distribution. The batch shape ``...`` must be
|
||||||
|
broadcast-compatible with that of ``mean``.
|
||||||
|
shape (list(int), optional): The output shape must be
|
||||||
|
broadcast-compatible with ``mean.shape[:-1]`` and ``cov.shape[:-2]``.
|
||||||
|
If empty, the result shape is determined by broadcasting the batch
|
||||||
|
shapes of ``mean`` and ``cov``. Default: ``[]``.
|
||||||
|
dtype (Dtype, optional): The output type. Default: ``float32``.
|
||||||
|
key (array, optional): A PRNG key. Default: ``None``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The output array of random values.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
"randint",
|
"randint",
|
||||||
[](const ScalarOrArray& low,
|
[](const ScalarOrArray& low,
|
||||||
const ScalarOrArray& high,
|
const ScalarOrArray& high,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -101,6 +102,96 @@ class TestRandom(mlx_tests.MLXTestCase):
|
|||||||
a = abs(mx.random.normal(shape=(10000,), loc=0, scale=1, dtype=hp))
|
a = abs(mx.random.normal(shape=(10000,), loc=0, scale=1, dtype=hp))
|
||||||
self.assertTrue(mx.all(a < mx.inf))
|
self.assertTrue(mx.all(a < mx.inf))
|
||||||
|
|
||||||
|
def test_multivariate_normal(self):
|
||||||
|
key = mx.random.key(0)
|
||||||
|
mean = mx.array([0, 0])
|
||||||
|
cov = mx.array([[1, 0], [0, 1]])
|
||||||
|
|
||||||
|
a = mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu)
|
||||||
|
self.assertEqual(a.shape, (2,))
|
||||||
|
|
||||||
|
## Check dtypes
|
||||||
|
for t in [mx.float32]:
|
||||||
|
a = mx.random.multivariate_normal(
|
||||||
|
mean, cov, dtype=t, key=key, stream=mx.cpu
|
||||||
|
)
|
||||||
|
self.assertEqual(a.dtype, t)
|
||||||
|
for t in [
|
||||||
|
mx.int8,
|
||||||
|
mx.int32,
|
||||||
|
mx.int64,
|
||||||
|
mx.uint8,
|
||||||
|
mx.uint32,
|
||||||
|
mx.uint64,
|
||||||
|
mx.float16,
|
||||||
|
mx.bfloat16,
|
||||||
|
]:
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.random.multivariate_normal(
|
||||||
|
mean, cov, dtype=t, key=key, stream=mx.cpu
|
||||||
|
)
|
||||||
|
|
||||||
|
## Check incompatible shapes
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mean = mx.zeros((2, 2))
|
||||||
|
cov = mx.zeros((2, 2))
|
||||||
|
mx.random.multivariate_normal(mean, cov, shape=(3,), key=key, stream=mx.cpu)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mean = mx.zeros((2))
|
||||||
|
cov = mx.zeros((2, 2, 2))
|
||||||
|
mx.random.multivariate_normal(mean, cov, shape=(3,), key=key, stream=mx.cpu)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mean = mx.zeros((3,))
|
||||||
|
cov = mx.zeros((2, 2))
|
||||||
|
mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mean = mx.zeros((2,))
|
||||||
|
cov = mx.zeros((2, 3))
|
||||||
|
mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu)
|
||||||
|
|
||||||
|
## Different shape of mean and cov
|
||||||
|
mean = mx.array([[0, 7], [1, 2], [3, 4]])
|
||||||
|
cov = mx.array([[1, 0.5], [0.5, 1]])
|
||||||
|
a = mx.random.multivariate_normal(mean, cov, shape=(4, 3), stream=mx.cpu)
|
||||||
|
self.assertEqual(a.shape, (4, 3, 2))
|
||||||
|
|
||||||
|
## Check correcteness of the mean and covariance
|
||||||
|
n_test = int(1e5)
|
||||||
|
|
||||||
|
def check_jointly_gaussian(data, mean, cov):
|
||||||
|
empirical_mean = mx.mean(data, axis=0)
|
||||||
|
empirical_cov = (
|
||||||
|
(data - empirical_mean).T @ (data - empirical_mean) / data.shape[0]
|
||||||
|
)
|
||||||
|
N = data.shape[1]
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
empirical_mean, mean, rtol=0.0, atol=10 * N**2 / math.sqrt(n_test)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
empirical_cov, cov, rtol=0.0, atol=10 * N**2 / math.sqrt(n_test)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mean = mx.array([4.0, 7.0])
|
||||||
|
cov = mx.array([[2, 0.5], [0.5, 1]])
|
||||||
|
data = mx.random.multivariate_normal(
|
||||||
|
mean, cov, shape=(n_test,), key=key, stream=mx.cpu
|
||||||
|
)
|
||||||
|
check_jointly_gaussian(data, mean, cov)
|
||||||
|
|
||||||
|
mean = mx.arange(3)
|
||||||
|
cov = mx.array([[1, -1, 0.5], [-1, 1, -0.5], [0.5, -0.5, 1]])
|
||||||
|
data = mx.random.multivariate_normal(
|
||||||
|
mean, cov, shape=(n_test,), key=key, stream=mx.cpu
|
||||||
|
)
|
||||||
|
check_jointly_gaussian(data, mean, cov)
|
||||||
|
|
||||||
def test_randint(self):
|
def test_randint(self):
|
||||||
a = mx.random.randint(0, 1, [])
|
a = mx.random.randint(0, 1, [])
|
||||||
self.assertEqual(a.shape, ())
|
self.assertEqual(a.shape, ())
|
||||||
|
@ -420,6 +420,63 @@ TEST_CASE("test random normal") {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test random multivariate_normal") {
|
||||||
|
{
|
||||||
|
auto mean = zeros({3});
|
||||||
|
auto cov = eye(3);
|
||||||
|
auto x = random::multivariate_normal(mean, cov, {1000}, float32);
|
||||||
|
CHECK_EQ(x.shape(), std::vector<int>({1000, 3}));
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limit case
|
||||||
|
{
|
||||||
|
auto mean = array({0, 0});
|
||||||
|
auto cov = array({1., -1, -.1, 1.});
|
||||||
|
cov = reshape(cov, {2, 2});
|
||||||
|
auto x = random::multivariate_normal(mean, cov, {1}, float32);
|
||||||
|
CHECK_EQ(x.shape(), std::vector<int>({1, 2}));
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check wrong shapes
|
||||||
|
{
|
||||||
|
auto mean = zeros({3, 1});
|
||||||
|
auto cov = eye(3);
|
||||||
|
CHECK_THROWS_AS(
|
||||||
|
random::multivariate_normal(
|
||||||
|
mean,
|
||||||
|
cov,
|
||||||
|
{
|
||||||
|
1000,
|
||||||
|
},
|
||||||
|
float32),
|
||||||
|
std::invalid_argument);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto mean = zeros({3});
|
||||||
|
auto cov = zeros({1, 2, 3, 3});
|
||||||
|
auto x = random::multivariate_normal(mean, cov, {1000, 2}, float32);
|
||||||
|
CHECK_EQ(x.shape(), std::vector<int>({1000, 2, 3}));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto mean = zeros({3});
|
||||||
|
auto cov = eye(4);
|
||||||
|
CHECK_THROWS_AS(
|
||||||
|
random::multivariate_normal(mean, cov, {1000, 3}, float32),
|
||||||
|
std::invalid_argument);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check wrong type
|
||||||
|
{
|
||||||
|
auto mean = zeros({3});
|
||||||
|
auto cov = eye(3);
|
||||||
|
CHECK_THROWS_AS(
|
||||||
|
random::multivariate_normal(mean, cov, {1000, 3}, float16),
|
||||||
|
std::invalid_argument);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE("test random randint") {
|
TEST_CASE("test random randint") {
|
||||||
CHECK_THROWS_AS(
|
CHECK_THROWS_AS(
|
||||||
random::randint(array(3), array(5), {1}, float32), std::invalid_argument);
|
random::randint(array(3), array(5), {1}, float32), std::invalid_argument);
|
||||||
|
Loading…
Reference in New Issue
Block a user