diff --git a/mlx/random.cpp b/mlx/random.cpp index 207a0a5e1..6c6d1eb95 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -215,11 +215,12 @@ array normal( auto samples = uniform(low, high, shape, dtype, key, stream); auto applied_scale = array(std::sqrt(2.0), dtype); if (scale.has_value()) { - applied_scale = multiply(applied_scale, *scale, stream); + applied_scale = + multiply(applied_scale, astype(*scale, dtype, stream), stream); } samples = multiply(applied_scale, erfinv(samples, stream), stream); if (loc.has_value()) { - samples = add(*loc, samples, stream); + samples = add(astype(*loc, dtype, stream), samples, stream); } return samples; } diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 96a62de87..2fc768651 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -381,6 +381,12 @@ class TestRandom(mlx_tests.MLXTestCase): b = mx.random.normal((10,)) sample = mx.random.normal((2, 10, 2), loc=b, scale=b) + b = mx.random.normal((3, 1, 2)) + sample = mx.random.normal((3, 4, 2), dtype=mx.float16, loc=b, scale=b) + mx.eval(sample) + self.assertEqual(sample.shape, (3, 4, 2)) + self.assertEqual(sample.dtype, mx.float16) + if __name__ == "__main__": unittest.main()