mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
Add random normal distribution for complex numbers (#2182)
This commit is contained in:

committed by
GitHub

parent
0751263dec
commit
130df35e1b
@@ -176,24 +176,51 @@ array uniform(
|
||||
array(0.0, dtype), array(1.0, dtype), shape, dtype, key, to_stream(s));
|
||||
}
|
||||
|
||||
inline array complex_normal(
|
||||
Shape shape,
|
||||
const std::optional<array>& loc,
|
||||
const std::optional<array>& scale,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) {
|
||||
auto stream = to_stream(s);
|
||||
auto low = above_minus_one_with_default(float32);
|
||||
auto high = array(1.0f, float32);
|
||||
shape.push_back(2);
|
||||
auto samples =
|
||||
erfinv(uniform(low, high, shape, float32, key, stream), stream);
|
||||
samples = squeeze(view(samples, complex64, stream), -1, stream);
|
||||
if (scale.has_value()) {
|
||||
samples = multiply(*scale, samples, stream);
|
||||
}
|
||||
if (loc.has_value()) {
|
||||
samples = add(*loc, samples, stream);
|
||||
}
|
||||
return samples;
|
||||
}
|
||||
|
||||
array normal(
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const float loc /* = 0.0 */,
|
||||
const float scale /* = 1.0 */,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
const std::optional<array>& loc,
|
||||
const std::optional<array>& scale,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (dtype == complex64) {
|
||||
return complex_normal(shape, loc, scale, key, s);
|
||||
}
|
||||
|
||||
auto stream = to_stream(s);
|
||||
auto low = above_minus_one_with_default(dtype);
|
||||
auto high = array(1.0f, dtype);
|
||||
auto samples = uniform(low, high, shape, dtype, key, stream);
|
||||
samples =
|
||||
multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream);
|
||||
if (scale != 1.0) {
|
||||
samples = multiply(array(scale, dtype), samples, stream);
|
||||
auto applied_scale = array(std::sqrt(2.0), dtype);
|
||||
if (scale.has_value()) {
|
||||
applied_scale =
|
||||
multiply(applied_scale, astype(*scale, dtype, stream), stream);
|
||||
}
|
||||
if (loc != 0.0) {
|
||||
samples = add(array(loc, dtype), samples, stream);
|
||||
samples = multiply(applied_scale, erfinv(samples, stream), stream);
|
||||
if (loc.has_value()) {
|
||||
samples = add(astype(*loc, dtype, stream), samples, stream);
|
||||
}
|
||||
return samples;
|
||||
}
|
||||
|
18
mlx/random.h
18
mlx/random.h
@@ -94,12 +94,24 @@ inline array uniform(
|
||||
|
||||
/** Generate samples from the standard normal distribution. */
|
||||
array normal(
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const std::optional<array>& loc,
|
||||
const std::optional<array>& scale,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s = {});
|
||||
inline array normal(
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
StreamOrDevice s = {}) {
|
||||
auto loc_ = loc == 0 ? std::nullopt : std::make_optional(array(loc, dtype));
|
||||
auto scale_ =
|
||||
scale == 1 ? std::nullopt : std::make_optional(array(scale, dtype));
|
||||
return normal(shape, dtype, loc_, scale_, key, s);
|
||||
}
|
||||
inline array normal(
|
||||
const Shape& shape,
|
||||
const float loc,
|
||||
@@ -113,13 +125,13 @@ inline array normal(
|
||||
const Dtype dtype,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, dtype, 0.0, 1.0, key, s);
|
||||
return normal(shape, dtype, std::nullopt, std::nullopt, key, s);
|
||||
}
|
||||
inline array normal(
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, float32, 0.0, 1.0, key, s);
|
||||
return normal(shape, float32, std::nullopt, std::nullopt, key, s);
|
||||
}
|
||||
|
||||
/** Generate samples from a multivariate normal distribution. **/
|
||||
|
Reference in New Issue
Block a user