mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Cast loc and scale to requested dtype
This commit is contained in:
parent
488537e80b
commit
4a380d5a80
@ -215,11 +215,12 @@ array normal(
|
||||
auto samples = uniform(low, high, shape, dtype, key, stream);
|
||||
auto applied_scale = array(std::sqrt(2.0), dtype);
|
||||
if (scale.has_value()) {
|
||||
applied_scale = multiply(applied_scale, *scale, stream);
|
||||
applied_scale =
|
||||
multiply(applied_scale, astype(*scale, dtype, stream), stream);
|
||||
}
|
||||
samples = multiply(applied_scale, erfinv(samples, stream), stream);
|
||||
if (loc.has_value()) {
|
||||
samples = add(*loc, samples, stream);
|
||||
samples = add(astype(*loc, dtype, stream), samples, stream);
|
||||
}
|
||||
return samples;
|
||||
}
|
||||
|
@ -381,6 +381,12 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
b = mx.random.normal((10,))
|
||||
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
|
||||
|
||||
b = mx.random.normal((3, 1, 2))
|
||||
sample = mx.random.normal((3, 4, 2), dtype=mx.float16, loc=b, scale=b)
|
||||
mx.eval(sample)
|
||||
self.assertEqual(sample.shape, (3, 4, 2))
|
||||
self.assertEqual(sample.dtype, mx.float16)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user