Add loc and scale to random.normal (#638)

* Add loc and scale to random.normal

* Add tests for loc and scale for random.normal

* Run pre-commit hooks

* Fix code review
This commit is contained in:
Noah Farr
2024-02-07 20:49:59 +01:00
committed by GitHub
parent ef73393a19
commit 5fd11c347d
4 changed files with 50 additions and 4 deletions

View File

@@ -153,14 +153,23 @@ array uniform(
array normal(
const std::vector<int>& shape,
Dtype dtype,
const float loc /* = 0.0 */,
const float scale /* = 1.0 */,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
auto stream = to_stream(s);
auto low = array(std::nextafter(-1.0f, 0.0f), dtype);
auto high = array(1.0f, dtype);
auto samples = uniform(low, high, shape, dtype, key, stream);
return multiply(
array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream);
samples =
multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream);
if (scale != 1.0) {
samples = multiply(array(scale, dtype), samples, stream);
}
if (loc != 0.0) {
samples = add(array(loc, dtype), samples, stream);
}
return samples;
}
array randint(

View File

@@ -95,13 +95,30 @@ inline array uniform(
array normal(
const std::vector<int>& shape,
Dtype dtype,
const float loc,
const float scale,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
inline array normal(
const std::vector<int>& shape,
const float loc,
const float scale,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return normal(shape, float32, key, s);
return normal(shape, float32, loc, scale, key, s);
}
inline array normal(
const std::vector<int>& shape,
const Dtype dtype,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return normal(shape, dtype, 0.0, 1.0, key, s);
}
inline array normal(
const std::vector<int>& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return normal(shape, float32, 0.0, 1.0, key, s);
}
/** Generate integer samples uniformly at random */