Cast loc and scale to requested dtype

This commit is contained in:
Angelos Katharopoulos 2025-05-13 21:13:44 -07:00
parent 488537e80b
commit 4a380d5a80
2 changed files with 9 additions and 2 deletions

View File

@ -215,11 +215,12 @@ array normal(
auto samples = uniform(low, high, shape, dtype, key, stream);
auto applied_scale = array(std::sqrt(2.0), dtype);
if (scale.has_value()) {
applied_scale = multiply(applied_scale, *scale, stream);
applied_scale =
multiply(applied_scale, astype(*scale, dtype, stream), stream);
}
samples = multiply(applied_scale, erfinv(samples, stream), stream);
if (loc.has_value()) {
samples = add(*loc, samples, stream);
samples = add(astype(*loc, dtype, stream), samples, stream);
}
return samples;
}

View File

@ -381,6 +381,12 @@ class TestRandom(mlx_tests.MLXTestCase):
b = mx.random.normal((10,))
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
b = mx.random.normal((3, 1, 2))
sample = mx.random.normal((3, 4, 2), dtype=mx.float16, loc=b, scale=b)
mx.eval(sample)
self.assertEqual(sample.shape, (3, 4, 2))
self.assertEqual(sample.dtype, mx.float16)
if __name__ == "__main__":
unittest.main()