diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 53ef6d06c..578b520a0 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -150,11 +150,8 @@ cudnn_frontend::EngineConfigList get_engine_configs( CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) { return true; } - // In PyTorch, tf32 seems to be always turned off for convolution even with - // "torch.backends.cudnn.allow_tf32 = True", so we are disabling tf32 too to - // keep results same. - if (dtype == float32 && - cudnn_frontend::hasNumericalNote(c)) { + if (cudnn_frontend::hasNumericalNote(c) && + dtype == float32 && !env::enable_tf32()) { return true; } return false; @@ -286,8 +283,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { get_alignment(in), get_alignment(wt), get_alignment(out)}; - auto it = conv_cache().find(cache_key); - if (it != conv_cache().end()) { + if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) { if (!execute_plan(encoder, it->second, in, wt, out)) { throw std::runtime_error("Cached convolution plan failed to execute."); }