diff --git a/mlx/random.cpp b/mlx/random.cpp index 6c6d1eb95..5367e7ca9 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -150,18 +150,18 @@ array uniform( case float16: return std::make_pair( array(below_one(), float16), - array(std::numeric_limits::max(), float32)); + array(std::numeric_limits::max(), float32)); case bfloat16: return std::make_pair( array(below_one(), bfloat16), - array(std::numeric_limits::max(), float32)); + array(std::numeric_limits::max(), bfloat16)); default: throw std::runtime_error("[uniform] Unsupported type."); } }; auto [upper, maxval] = get_limits(); - auto out = bits(shape, size_of(dtype), key, stream); + auto out = bits(shape, size_of(maxval.dtype()), key, stream); out = astype(divide(out, maxval, stream), dtype, stream); out = minimum(out, upper, stream); return add(multiply(range, out, stream), lo, stream);