mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
* Implementation of mlx.random.multivariate_normal (#502) * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Updated typo in docstring * Restricted multivariate_normal to float32 * Generic mean and variance shapes * Review edits * Update mlx/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Test for ndim of mean and cov * nits * smaller size for test * fix broadcasted sampling --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cmath>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/random.h"
|
||||
@@ -192,6 +193,74 @@ array normal(
|
||||
return samples;
|
||||
}
|
||||
|
||||
array multivariate_normal(
|
||||
const array& mean,
|
||||
const array& cov,
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
const std::optional<array>& key /* = nullopt */,
|
||||
StreamOrDevice s) {
|
||||
auto stream = to_stream(s);
|
||||
|
||||
if (dtype != float32) {
|
||||
throw std::invalid_argument("[multivariate_normal] dtype must be float32.");
|
||||
}
|
||||
|
||||
if (mean.ndim() < 1) {
|
||||
throw std::invalid_argument(
|
||||
"[multivariate_normal] mean must have at least one dimension.");
|
||||
}
|
||||
|
||||
if (cov.ndim() < 2) {
|
||||
throw std::invalid_argument(
|
||||
"[multivariate_normal] cov must have at least two dimensions.");
|
||||
}
|
||||
|
||||
auto n = mean.shape(-1);
|
||||
|
||||
// Check shapes comatibility of mean and cov
|
||||
if (cov.shape(-1) != cov.shape(-2)) {
|
||||
throw std::invalid_argument(
|
||||
"[multivariate_normal] last two dimensions of cov must be equal.");
|
||||
}
|
||||
if (n != cov.shape(-1)) {
|
||||
throw std::invalid_argument(
|
||||
"[multivariate_normal] mean and cov must have compatible shapes.");
|
||||
}
|
||||
|
||||
// Compute output shape
|
||||
std::vector<int> truncated_output_shape;
|
||||
|
||||
auto truncated_mean_shape =
|
||||
std::vector<int>(mean.shape().begin(), mean.shape().end() - 1);
|
||||
auto truncated_cov_shape =
|
||||
std::vector<int>(cov.shape().begin(), cov.shape().end() - 2);
|
||||
auto output_shape =
|
||||
broadcast_shapes(truncated_cov_shape, truncated_mean_shape);
|
||||
output_shape = broadcast_shapes(output_shape, shape);
|
||||
output_shape.push_back(n);
|
||||
|
||||
// Compute the square-root of the covariance matrix, using the SVD
|
||||
auto covariance = astype(cov, float32, stream);
|
||||
auto SVD = linalg::svd(covariance, stream);
|
||||
auto std = astype(
|
||||
matmul(
|
||||
multiply(
|
||||
SVD[0], expand_dims(sqrt(SVD[1], stream), -2, stream), stream),
|
||||
SVD[2],
|
||||
stream),
|
||||
dtype,
|
||||
stream);
|
||||
|
||||
// Generate standard the samples
|
||||
auto standard_normal = normal(output_shape, dtype, 0.0, 1.0, key, stream);
|
||||
auto scaled_out = squeeze(
|
||||
matmul(expand_dims(standard_normal, -2, stream), std, stream),
|
||||
-2,
|
||||
stream);
|
||||
return add(mean, scaled_out, stream);
|
||||
}
|
||||
|
||||
array randint(
|
||||
const array& low,
|
||||
const array& high,
|
||||
|
@@ -121,6 +121,15 @@ inline array normal(
|
||||
return normal(shape, float32, 0.0, 1.0, key, s);
|
||||
}
|
||||
|
||||
/** Generate samples from a multivariate normal distribution. **/
|
||||
array multivariate_normal(
|
||||
const array& mean,
|
||||
const array& cov,
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Generate integer samples uniformly at random */
|
||||
array randint(
|
||||
const array& low,
|
||||
|
Reference in New Issue
Block a user