mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 17:28:10 +08:00
* Implementation of mlx.random.multivariate_normal (#502) * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Updated typo in docstring * Restricted multivariate_normal to float32 * Generic mean and variance shapes * Review edits * Update mlx/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Test for ndim of mean and cov * nits * smaller size for test * fix broadcasted sampling --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -179,6 +179,48 @@ void init_random(nb::module_& parent_module) {
|
||||
array: The output array of random values.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"multivariate_normal",
|
||||
[](const array& mean,
|
||||
const array& cov,
|
||||
const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return multivariate_normal(
|
||||
mean, cov, shape, type.value_or(float32), key, s);
|
||||
},
|
||||
"mean"_a,
|
||||
"cov"_a,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a.none() = float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def multivariate_normal(mean: array, cov: array, shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate jointly-normal random samples given a mean and covariance.
|
||||
|
||||
The matrix ``cov`` must be positive semi-definite. The behavior is
|
||||
undefined if it is not. The only supported ``dtype`` is ``float32``.
|
||||
|
||||
Args:
|
||||
mean (array): array of shape ``(..., n)``, the mean of the
|
||||
distribution.
|
||||
cov (array): array of shape ``(..., n, n)``, the covariance
|
||||
matrix of the distribution. The batch shape ``...`` must be
|
||||
broadcast-compatible with that of ``mean``.
|
||||
shape (list(int), optional): The output shape must be
|
||||
broadcast-compatible with ``mean.shape[:-1]`` and ``cov.shape[:-2]``.
|
||||
If empty, the result shape is determined by broadcasting the batch
|
||||
shapes of ``mean`` and ``cov``. Default: ``[]``.
|
||||
dtype (Dtype, optional): The output type. Default: ``float32``.
|
||||
key (array, optional): A PRNG key. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The output array of random values.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"randint",
|
||||
[](const ScalarOrArray& low,
|
||||
const ScalarOrArray& high,
|
||||
|
Reference in New Issue
Block a user