mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Turn off tf32
This commit is contained in:
@@ -80,7 +80,8 @@ inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cudnn_frontend::EngineConfigList get_engine_configs(
|
cudnn_frontend::EngineConfigList get_engine_configs(
|
||||||
const cudnnBackendDescriptorType_t& backend_type,
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
Dtype dtype,
|
||||||
cudnn_frontend::OperationGraph& op_graph,
|
cudnn_frontend::OperationGraph& op_graph,
|
||||||
bool use_fallback = false) {
|
bool use_fallback = false) {
|
||||||
cudnn_frontend::GeneratorSource source;
|
cudnn_frontend::GeneratorSource source;
|
||||||
@@ -103,7 +104,24 @@ cudnn_frontend::EngineConfigList get_engine_configs(
|
|||||||
}
|
}
|
||||||
|
|
||||||
cudnn_frontend::EngineConfigGenerator generator(1, &source);
|
cudnn_frontend::EngineConfigGenerator generator(1, &source);
|
||||||
return generator.generate_engine_config(op_graph);
|
auto configs = generator.generate_engine_config(op_graph);
|
||||||
|
|
||||||
|
cudnn_frontend::EngineConfigList filtered_configs;
|
||||||
|
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
||||||
|
if (cudnn_frontend::hasNumericalNote<
|
||||||
|
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// In PyTorch, tf32 seems to be always turned off for convolution even with
|
||||||
|
// "torch.backends.cudnn.allow_tf32 = True", so we are disabling tf32 too to
|
||||||
|
// keep results same.
|
||||||
|
if (dtype == float32 &&
|
||||||
|
cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
return filtered_configs;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool execute_plan(
|
bool execute_plan(
|
||||||
@@ -238,12 +256,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
.build();
|
.build();
|
||||||
|
|
||||||
// Try to run plans based on heuristics.
|
// Try to run plans based on heuristics.
|
||||||
auto configs = get_engine_configs(backend_type, op_graph);
|
auto configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
||||||
if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) {
|
if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// Then try fallback plans.
|
// Then try fallback plans.
|
||||||
configs = get_engine_configs(backend_type, op_graph, /* use_fallback */ true);
|
configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
||||||
if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) {
|
if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,16 +15,11 @@ cuda_skip = {
|
|||||||
"TestOps.test_hadamard_grad_vmap",
|
"TestOps.test_hadamard_grad_vmap",
|
||||||
# Convolutions NYI
|
# Convolutions NYI
|
||||||
"TestConv.test_1d_conv_with_2d",
|
"TestConv.test_1d_conv_with_2d",
|
||||||
"TestConv.test_asymmetric_padding",
|
|
||||||
"TestConv.test_conv_1d_groups_flipped",
|
"TestConv.test_conv_1d_groups_flipped",
|
||||||
"TestConv.test_conv_general_flip_grad",
|
"TestConv.test_conv_general_flip_grad",
|
||||||
"TestConv.test_conv_groups_grad",
|
"TestConv.test_conv_groups_grad",
|
||||||
"TestConv.test_numpy_conv",
|
|
||||||
"TestConv.test_torch_conv_1D",
|
|
||||||
"TestConv.test_torch_conv_1D_grad",
|
"TestConv.test_torch_conv_1D_grad",
|
||||||
"TestConv.test_torch_conv_2D",
|
|
||||||
"TestConv.test_torch_conv_2D_grad",
|
"TestConv.test_torch_conv_2D_grad",
|
||||||
"TestConv.test_torch_conv_3D",
|
|
||||||
"TestConv.test_torch_conv_3D_grad",
|
"TestConv.test_torch_conv_3D_grad",
|
||||||
"TestConv.test_torch_conv_depthwise",
|
"TestConv.test_torch_conv_depthwise",
|
||||||
"TestConv.test_torch_conv_general",
|
"TestConv.test_torch_conv_general",
|
||||||
|
|||||||
Reference in New Issue
Block a user