From ddd132ca26d680a13df36890cb305342f6d04989 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 11 Jul 2025 13:39:37 -0700 Subject: [PATCH] lower memory uniform --- mlx/random.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlx/random.cpp b/mlx/random.cpp index 6c6d1eb95..5367e7ca9 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -150,18 +150,18 @@ array uniform( case float16: return std::make_pair( array(below_one(), float16), - array(std::numeric_limits::max(), float32)); + array(std::numeric_limits::max(), float32)); case bfloat16: return std::make_pair( array(below_one(), bfloat16), - array(std::numeric_limits::max(), float32)); + array(std::numeric_limits::max(), bfloat16)); default: throw std::runtime_error("[uniform] Unsupported type."); } }; auto [upper, maxval] = get_limits(); - auto out = bits(shape, size_of(dtype), key, stream); + auto out = bits(shape, size_of(maxval.dtype()), key, stream); out = astype(divide(out, maxval, stream), dtype, stream); out = minimum(out, upper, stream); return add(multiply(range, out, stream), lo, stream);