mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
lower memory uniform
This commit is contained in:
@@ -150,18 +150,18 @@ array uniform(
|
||||
case float16:
|
||||
return std::make_pair(
|
||||
array(below_one<float16_t>(), float16),
|
||||
array(std::numeric_limits<uint16_t>::max(), float32));
|
||||
array(std::numeric_limits<uint32_t>::max(), float32));
|
||||
case bfloat16:
|
||||
return std::make_pair(
|
||||
array(below_one<bfloat16_t>(), bfloat16),
|
||||
array(std::numeric_limits<uint16_t>::max(), float32));
|
||||
array(std::numeric_limits<uint16_t>::max(), bfloat16));
|
||||
default:
|
||||
throw std::runtime_error("[uniform] Unsupported type.");
|
||||
}
|
||||
};
|
||||
|
||||
auto [upper, maxval] = get_limits();
|
||||
auto out = bits(shape, size_of(dtype), key, stream);
|
||||
auto out = bits(shape, size_of(maxval.dtype()), key, stream);
|
||||
out = astype(divide(out, maxval, stream), dtype, stream);
|
||||
out = minimum(out, upper, stream);
|
||||
return add(multiply(range, out, stream), lo, stream);
|
||||
|
||||
Reference in New Issue
Block a user