Use tf32 for conv

This commit is contained in:
Cheng
2025-07-22 18:16:24 -07:00
parent f189face9d
commit ada7d518da

View File

@@ -150,11 +150,8 @@ cudnn_frontend::EngineConfigList get_engine_configs(
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
// In PyTorch, tf32 seems to be always turned off for convolution even with
// "torch.backends.cudnn.allow_tf32 = True", so we are disabling tf32 too to
// keep results same.
if (dtype == float32 &&
cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c)) {
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
dtype == float32 && !env::enable_tf32()) {
return true;
}
return false;
@@ -286,8 +283,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
get_alignment(in),
get_alignment(wt),
get_alignment(out)};
auto it = conv_cache().find(cache_key);
if (it != conv_cache().end()) {
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
if (!execute_plan(encoder, it->second, in, wt, out)) {
throw std::runtime_error("Cached convolution plan failed to execute.");
}