mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Turn off tf32
This commit is contained in:
@@ -80,7 +80,8 @@ inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
|
||||
}
|
||||
|
||||
cudnn_frontend::EngineConfigList get_engine_configs(
|
||||
const cudnnBackendDescriptorType_t& backend_type,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
cudnn_frontend::OperationGraph& op_graph,
|
||||
bool use_fallback = false) {
|
||||
cudnn_frontend::GeneratorSource source;
|
||||
@@ -103,7 +104,24 @@ cudnn_frontend::EngineConfigList get_engine_configs(
|
||||
}
|
||||
|
||||
cudnn_frontend::EngineConfigGenerator generator(1, &source);
|
||||
return generator.generate_engine_config(op_graph);
|
||||
auto configs = generator.generate_engine_config(op_graph);
|
||||
|
||||
cudnn_frontend::EngineConfigList filtered_configs;
|
||||
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
||||
if (cudnn_frontend::hasNumericalNote<
|
||||
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
||||
return true;
|
||||
}
|
||||
// In PyTorch, tf32 seems to be always turned off for convolution even with
|
||||
// "torch.backends.cudnn.allow_tf32 = True", so we are disabling tf32 too to
|
||||
// keep results same.
|
||||
if (dtype == float32 &&
|
||||
cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
return filtered_configs;
|
||||
}
|
||||
|
||||
bool execute_plan(
|
||||
@@ -238,12 +256,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
.build();
|
||||
|
||||
// Try to run plans based on heuristics.
|
||||
auto configs = get_engine_configs(backend_type, op_graph);
|
||||
auto configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
||||
if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) {
|
||||
return;
|
||||
}
|
||||
// Then try fallback plans.
|
||||
configs = get_engine_configs(backend_type, op_graph, /* use_fallback */ true);
|
||||
configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
||||
if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) {
|
||||
return;
|
||||
}
|
||||
|
@@ -15,16 +15,11 @@ cuda_skip = {
|
||||
"TestOps.test_hadamard_grad_vmap",
|
||||
# Convolutions NYI
|
||||
"TestConv.test_1d_conv_with_2d",
|
||||
"TestConv.test_asymmetric_padding",
|
||||
"TestConv.test_conv_1d_groups_flipped",
|
||||
"TestConv.test_conv_general_flip_grad",
|
||||
"TestConv.test_conv_groups_grad",
|
||||
"TestConv.test_numpy_conv",
|
||||
"TestConv.test_torch_conv_1D",
|
||||
"TestConv.test_torch_conv_1D_grad",
|
||||
"TestConv.test_torch_conv_2D",
|
||||
"TestConv.test_torch_conv_2D_grad",
|
||||
"TestConv.test_torch_conv_3D",
|
||||
"TestConv.test_torch_conv_3D_grad",
|
||||
"TestConv.test_torch_conv_depthwise",
|
||||
"TestConv.test_torch_conv_general",
|
||||
|
Reference in New Issue
Block a user