mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
lower memory uniform
This commit is contained in:
@@ -150,18 +150,18 @@ array uniform(
|
|||||||
case float16:
|
case float16:
|
||||||
return std::make_pair(
|
return std::make_pair(
|
||||||
array(below_one<float16_t>(), float16),
|
array(below_one<float16_t>(), float16),
|
||||||
array(std::numeric_limits<uint16_t>::max(), float32));
|
array(std::numeric_limits<uint32_t>::max(), float32));
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return std::make_pair(
|
return std::make_pair(
|
||||||
array(below_one<bfloat16_t>(), bfloat16),
|
array(below_one<bfloat16_t>(), bfloat16),
|
||||||
array(std::numeric_limits<uint16_t>::max(), float32));
|
array(std::numeric_limits<uint16_t>::max(), bfloat16));
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("[uniform] Unsupported type.");
|
throw std::runtime_error("[uniform] Unsupported type.");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
auto [upper, maxval] = get_limits();
|
auto [upper, maxval] = get_limits();
|
||||||
auto out = bits(shape, size_of(dtype), key, stream);
|
auto out = bits(shape, size_of(maxval.dtype()), key, stream);
|
||||||
out = astype(divide(out, maxval, stream), dtype, stream);
|
out = astype(divide(out, maxval, stream), dtype, stream);
|
||||||
out = minimum(out, upper, stream);
|
out = minimum(out, upper, stream);
|
||||||
return add(multiply(range, out, stream), lo, stream);
|
return add(multiply(range, out, stream), lo, stream);
|
||||||
|
|||||||
Reference in New Issue
Block a user