mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add random normal distribution for complex numbers (#2182)
This commit is contained in:
committed by
GitHub
parent
0751263dec
commit
130df35e1b
18
mlx/random.h
18
mlx/random.h
@@ -94,12 +94,24 @@ inline array uniform(
|
||||
|
||||
/** Generate samples from the standard normal distribution. */
|
||||
array normal(
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const std::optional<array>& loc,
|
||||
const std::optional<array>& scale,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s = {});
|
||||
inline array normal(
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
StreamOrDevice s = {}) {
|
||||
auto loc_ = loc == 0 ? std::nullopt : std::make_optional(array(loc, dtype));
|
||||
auto scale_ =
|
||||
scale == 1 ? std::nullopt : std::make_optional(array(scale, dtype));
|
||||
return normal(shape, dtype, loc_, scale_, key, s);
|
||||
}
|
||||
inline array normal(
|
||||
const Shape& shape,
|
||||
const float loc,
|
||||
@@ -113,13 +125,13 @@ inline array normal(
|
||||
const Dtype dtype,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, dtype, 0.0, 1.0, key, s);
|
||||
return normal(shape, dtype, std::nullopt, std::nullopt, key, s);
|
||||
}
|
||||
inline array normal(
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, float32, 0.0, 1.0, key, s);
|
||||
return normal(shape, float32, std::nullopt, std::nullopt, key, s);
|
||||
}
|
||||
|
||||
/** Generate samples from a multivariate normal distribution. **/
|
||||
|
||||
Reference in New Issue
Block a user