mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add loc and scale to random.normal (#638)
* Add loc and scale to random.normal * Add tests for loc and scale for random.normal * Run pre-commit hooks * Fix code review
This commit is contained in:
19
mlx/random.h
19
mlx/random.h
@@ -95,13 +95,30 @@ inline array uniform(
|
||||
array normal(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
inline array normal(
|
||||
const std::vector<int>& shape,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, float32, key, s);
|
||||
return normal(shape, float32, loc, scale, key, s);
|
||||
}
|
||||
inline array normal(
|
||||
const std::vector<int>& shape,
|
||||
const Dtype dtype,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, dtype, 0.0, 1.0, key, s);
|
||||
}
|
||||
inline array normal(
|
||||
const std::vector<int>& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, float32, 0.0, 1.0, key, s);
|
||||
}
|
||||
|
||||
/** Generate integer samples uniformly at random */
|
||||
|
||||
Reference in New Issue
Block a user