mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Cast loc and scale to requested dtype
This commit is contained in:
parent
488537e80b
commit
4a380d5a80
@ -215,11 +215,12 @@ array normal(
|
|||||||
auto samples = uniform(low, high, shape, dtype, key, stream);
|
auto samples = uniform(low, high, shape, dtype, key, stream);
|
||||||
auto applied_scale = array(std::sqrt(2.0), dtype);
|
auto applied_scale = array(std::sqrt(2.0), dtype);
|
||||||
if (scale.has_value()) {
|
if (scale.has_value()) {
|
||||||
applied_scale = multiply(applied_scale, *scale, stream);
|
applied_scale =
|
||||||
|
multiply(applied_scale, astype(*scale, dtype, stream), stream);
|
||||||
}
|
}
|
||||||
samples = multiply(applied_scale, erfinv(samples, stream), stream);
|
samples = multiply(applied_scale, erfinv(samples, stream), stream);
|
||||||
if (loc.has_value()) {
|
if (loc.has_value()) {
|
||||||
samples = add(*loc, samples, stream);
|
samples = add(astype(*loc, dtype, stream), samples, stream);
|
||||||
}
|
}
|
||||||
return samples;
|
return samples;
|
||||||
}
|
}
|
||||||
|
@ -381,6 +381,12 @@ class TestRandom(mlx_tests.MLXTestCase):
|
|||||||
b = mx.random.normal((10,))
|
b = mx.random.normal((10,))
|
||||||
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
|
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
|
||||||
|
|
||||||
|
b = mx.random.normal((3, 1, 2))
|
||||||
|
sample = mx.random.normal((3, 4, 2), dtype=mx.float16, loc=b, scale=b)
|
||||||
|
mx.eval(sample)
|
||||||
|
self.assertEqual(sample.shape, (3, 4, 2))
|
||||||
|
self.assertEqual(sample.dtype, mx.float16)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user