From 4a380d5a803a641e9088d223cb4508da94e6f552 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 13 May 2025 21:13:44 -0700 Subject: [PATCH] Cast loc and scale to requested dtype --- mlx/random.cpp | 5 +++-- python/tests/test_random.py | 6 ++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/mlx/random.cpp b/mlx/random.cpp index 207a0a5e1..6c6d1eb95 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -215,11 +215,12 @@ array normal( auto samples = uniform(low, high, shape, dtype, key, stream); auto applied_scale = array(std::sqrt(2.0), dtype); if (scale.has_value()) { - applied_scale = multiply(applied_scale, *scale, stream); + applied_scale = + multiply(applied_scale, astype(*scale, dtype, stream), stream); } samples = multiply(applied_scale, erfinv(samples, stream), stream); if (loc.has_value()) { - samples = add(*loc, samples, stream); + samples = add(astype(*loc, dtype, stream), samples, stream); } return samples; } diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 96a62de87..2fc768651 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -381,6 +381,12 @@ class TestRandom(mlx_tests.MLXTestCase): b = mx.random.normal((10,)) sample = mx.random.normal((2, 10, 2), loc=b, scale=b) + b = mx.random.normal((3, 1, 2)) + sample = mx.random.normal((3, 4, 2), dtype=mx.float16, loc=b, scale=b) + mx.eval(sample) + self.assertEqual(sample.shape, (3, 4, 2)) + self.assertEqual(sample.dtype, mx.float16) + if __name__ == "__main__": unittest.main()