Turn off tf32

This commit is contained in:
Cheng
2025-07-19 19:49:29 -07:00
parent 6444b29651
commit 0430a6a74a
2 changed files with 22 additions and 9 deletions

View File

@@ -80,7 +80,8 @@ inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
}
cudnn_frontend::EngineConfigList get_engine_configs(
const cudnnBackendDescriptorType_t& backend_type,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph,
bool use_fallback = false) {
cudnn_frontend::GeneratorSource source;
@@ -103,7 +104,24 @@ cudnn_frontend::EngineConfigList get_engine_configs(
}
cudnn_frontend::EngineConfigGenerator generator(1, &source);
return generator.generate_engine_config(op_graph);
auto configs = generator.generate_engine_config(op_graph);
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
if (cudnn_frontend::hasNumericalNote<
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
// In PyTorch, tf32 seems to be always turned off for convolution even with
// "torch.backends.cudnn.allow_tf32 = True", so we are disabling tf32 too to
// keep results same.
if (dtype == float32 &&
cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c)) {
return true;
}
return false;
});
return filtered_configs;
}
bool execute_plan(
@@ -238,12 +256,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
.build();
// Try to run plans based on heuristics.
auto configs = get_engine_configs(backend_type, op_graph);
auto configs = get_engine_configs(backend_type, in.dtype(), op_graph);
if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) {
return;
}
// Then try fallback plans.
configs = get_engine_configs(backend_type, op_graph, /* use_fallback */ true);
configs = get_engine_configs(backend_type, in.dtype(), op_graph);
if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) {
return;
}

View File

@@ -15,16 +15,11 @@ cuda_skip = {
"TestOps.test_hadamard_grad_vmap",
# Convolutions NYI
"TestConv.test_1d_conv_with_2d",
"TestConv.test_asymmetric_padding",
"TestConv.test_conv_1d_groups_flipped",
"TestConv.test_conv_general_flip_grad",
"TestConv.test_conv_groups_grad",
"TestConv.test_numpy_conv",
"TestConv.test_torch_conv_1D",
"TestConv.test_torch_conv_1D_grad",
"TestConv.test_torch_conv_2D",
"TestConv.test_torch_conv_2D_grad",
"TestConv.test_torch_conv_3D",
"TestConv.test_torch_conv_3D_grad",
"TestConv.test_torch_conv_depthwise",
"TestConv.test_torch_conv_general",