lower memory uniform

This commit is contained in:
Awni Hannun
2025-07-11 13:39:37 -07:00
parent e7d2ebadd2
commit ddd132ca26

View File

@@ -150,18 +150,18 @@ array uniform(
case float16: case float16:
return std::make_pair( return std::make_pair(
array(below_one<float16_t>(), float16), array(below_one<float16_t>(), float16),
array(std::numeric_limits<uint16_t>::max(), float32)); array(std::numeric_limits<uint32_t>::max(), float32));
case bfloat16: case bfloat16:
return std::make_pair( return std::make_pair(
array(below_one<bfloat16_t>(), bfloat16), array(below_one<bfloat16_t>(), bfloat16),
array(std::numeric_limits<uint16_t>::max(), float32)); array(std::numeric_limits<uint16_t>::max(), bfloat16));
default: default:
throw std::runtime_error("[uniform] Unsupported type."); throw std::runtime_error("[uniform] Unsupported type.");
} }
}; };
auto [upper, maxval] = get_limits(); auto [upper, maxval] = get_limits();
auto out = bits(shape, size_of(dtype), key, stream); auto out = bits(shape, size_of(maxval.dtype()), key, stream);
out = astype(divide(out, maxval, stream), dtype, stream); out = astype(divide(out, maxval, stream), dtype, stream);
out = minimum(out, upper, stream); out = minimum(out, upper, stream);
return add(multiply(range, out, stream), lo, stream); return add(multiply(range, out, stream), lo, stream);