Generalize gpu backend (#2138)

* generalize gpu backend

* fix no_gpu build

* fix no_gpu build

* generalize gpu backend
This commit is contained in:
Awni Hannun
2025-04-30 09:08:17 -07:00
committed by GitHub
parent 87720a8908
commit f1606486d2
33 changed files with 275 additions and 200 deletions

View File

@@ -10,7 +10,7 @@
#include <unordered_set>
#include "mlx/backend/cpu/eval.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/gpu/eval.h"
#include "mlx/fence.h"
#include "mlx/memory.h"
#include "mlx/ops.h"
@@ -218,7 +218,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
if (arr.primitive().device() == Device::gpu) {
metal::eval(arr);
gpu::eval(arr);
} else {
cpu::eval(arr);
}
@@ -229,7 +229,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
// Commit any open streams
for (auto& [_, e] : events) {
if (e.stream().device == Device::gpu) {
metal::finalize(e.stream());
gpu::finalize(e.stream());
}
}
scheduler::wait_for_one();
@@ -267,7 +267,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
auto s = e.stream();
e.signal(s);
if (s.device == Device::gpu) {
metal::finalize(s);
gpu::finalize(s);
}
}