mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Use tf32 for conv
This commit is contained in:
@@ -150,11 +150,8 @@ cudnn_frontend::EngineConfigList get_engine_configs(
|
||||
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
||||
return true;
|
||||
}
|
||||
// In PyTorch, tf32 seems to be always turned off for convolution even with
|
||||
// "torch.backends.cudnn.allow_tf32 = True", so we are disabling tf32 too to
|
||||
// keep results same.
|
||||
if (dtype == float32 &&
|
||||
cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c)) {
|
||||
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
||||
dtype == float32 && !env::enable_tf32()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@@ -286,8 +283,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
get_alignment(in),
|
||||
get_alignment(wt),
|
||||
get_alignment(out)};
|
||||
auto it = conv_cache().find(cache_key);
|
||||
if (it != conv_cache().end()) {
|
||||
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||
if (!execute_plan(encoder, it->second, in, wt, out)) {
|
||||
throw std::runtime_error("Cached convolution plan failed to execute.");
|
||||
}
|
||||
|
Reference in New Issue
Block a user