mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-12 15:24:57 +08:00
Use tf32 for conv
This commit is contained in:
@@ -150,11 +150,8 @@ cudnn_frontend::EngineConfigList get_engine_configs(
|
|||||||
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
// In PyTorch, tf32 seems to be always turned off for convolution even with
|
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
||||||
// "torch.backends.cudnn.allow_tf32 = True", so we are disabling tf32 too to
|
dtype == float32 && !env::enable_tf32()) {
|
||||||
// keep results same.
|
|
||||||
if (dtype == float32 &&
|
|
||||||
cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c)) {
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@@ -286,8 +283,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
get_alignment(in),
|
get_alignment(in),
|
||||||
get_alignment(wt),
|
get_alignment(wt),
|
||||||
get_alignment(out)};
|
get_alignment(out)};
|
||||||
auto it = conv_cache().find(cache_key);
|
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||||
if (it != conv_cache().end()) {
|
|
||||||
if (!execute_plan(encoder, it->second, in, wt, out)) {
|
if (!execute_plan(encoder, it->second, in, wt, out)) {
|
||||||
throw std::runtime_error("Cached convolution plan failed to execute.");
|
throw std::runtime_error("Cached convolution plan failed to execute.");
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user