mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Add loc and scale to random.normal (#638)
* Add loc and scale to random.normal * Add tests for loc and scale for random.normal * Run pre-commit hooks * Fix code review
This commit is contained in:
@@ -153,14 +153,23 @@ array uniform(
|
||||
array normal(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
const float loc /* = 0.0 */,
|
||||
const float scale /* = 1.0 */,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto stream = to_stream(s);
|
||||
auto low = array(std::nextafter(-1.0f, 0.0f), dtype);
|
||||
auto high = array(1.0f, dtype);
|
||||
auto samples = uniform(low, high, shape, dtype, key, stream);
|
||||
return multiply(
|
||||
array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream);
|
||||
samples =
|
||||
multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream);
|
||||
if (scale != 1.0) {
|
||||
samples = multiply(array(scale, dtype), samples, stream);
|
||||
}
|
||||
if (loc != 0.0) {
|
||||
samples = add(array(loc, dtype), samples, stream);
|
||||
}
|
||||
return samples;
|
||||
}
|
||||
|
||||
array randint(
|
||||
|
19
mlx/random.h
19
mlx/random.h
@@ -95,13 +95,30 @@ inline array uniform(
|
||||
array normal(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
inline array normal(
|
||||
const std::vector<int>& shape,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, float32, key, s);
|
||||
return normal(shape, float32, loc, scale, key, s);
|
||||
}
|
||||
inline array normal(
|
||||
const std::vector<int>& shape,
|
||||
const Dtype dtype,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, dtype, 0.0, 1.0, key, s);
|
||||
}
|
||||
inline array normal(
|
||||
const std::vector<int>& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, float32, 0.0, 1.0, key, s);
|
||||
}
|
||||
|
||||
/** Generate integer samples uniformly at random */
|
||||
|
Reference in New Issue
Block a user