Implementation of mlx.random.multivariate_normal (#502) (#877)

* Implementation of mlx.random.multivariate_normal (#502)

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Updated typo in docstring

* Restricted multivariate_normal to  float32

* Generic mean and variance shapes

* Review edits

* Update mlx/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Test for ndim of mean and cov

* nits

* smaller size for test

* fix broadcasted sampling

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Luca Arnaboldi
2024-04-09 22:50:12 +02:00
committed by GitHub
parent a1a31eed27
commit fffe072028
6 changed files with 270 additions and 1 deletions

View File

@@ -1,8 +1,9 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cmath>
#include <sstream>
#include "mlx/linalg.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/random.h"
@@ -192,6 +193,74 @@ array normal(
return samples;
}
array multivariate_normal(
const array& mean,
const array& cov,
const std::vector<int>& shape,
Dtype dtype,
const std::optional<array>& key /* = nullopt */,
StreamOrDevice s) {
auto stream = to_stream(s);
if (dtype != float32) {
throw std::invalid_argument("[multivariate_normal] dtype must be float32.");
}
if (mean.ndim() < 1) {
throw std::invalid_argument(
"[multivariate_normal] mean must have at least one dimension.");
}
if (cov.ndim() < 2) {
throw std::invalid_argument(
"[multivariate_normal] cov must have at least two dimensions.");
}
auto n = mean.shape(-1);
// Check shapes comatibility of mean and cov
if (cov.shape(-1) != cov.shape(-2)) {
throw std::invalid_argument(
"[multivariate_normal] last two dimensions of cov must be equal.");
}
if (n != cov.shape(-1)) {
throw std::invalid_argument(
"[multivariate_normal] mean and cov must have compatible shapes.");
}
// Compute output shape
std::vector<int> truncated_output_shape;
auto truncated_mean_shape =
std::vector<int>(mean.shape().begin(), mean.shape().end() - 1);
auto truncated_cov_shape =
std::vector<int>(cov.shape().begin(), cov.shape().end() - 2);
auto output_shape =
broadcast_shapes(truncated_cov_shape, truncated_mean_shape);
output_shape = broadcast_shapes(output_shape, shape);
output_shape.push_back(n);
// Compute the square-root of the covariance matrix, using the SVD
auto covariance = astype(cov, float32, stream);
auto SVD = linalg::svd(covariance, stream);
auto std = astype(
matmul(
multiply(
SVD[0], expand_dims(sqrt(SVD[1], stream), -2, stream), stream),
SVD[2],
stream),
dtype,
stream);
// Generate standard the samples
auto standard_normal = normal(output_shape, dtype, 0.0, 1.0, key, stream);
auto scaled_out = squeeze(
matmul(expand_dims(standard_normal, -2, stream), std, stream),
-2,
stream);
return add(mean, scaled_out, stream);
}
array randint(
const array& low,
const array& high,

View File

@@ -121,6 +121,15 @@ inline array normal(
return normal(shape, float32, 0.0, 1.0, key, s);
}
/** Generate samples from a multivariate normal distribution. **/
array multivariate_normal(
const array& mean,
const array& cov,
const std::vector<int>& shape,
Dtype dtype,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
/** Generate integer samples uniformly at random */
array randint(
const array& low,