mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Generalize gpu backend (#2138)
* generalize gpu backend * fix no_gpu build * fix no_gpu build * generalize gpu backend
This commit is contained in:
@@ -10,7 +10,7 @@
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/backend/cpu/eval.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/gpu/eval.h"
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/memory.h"
|
||||
#include "mlx/ops.h"
|
||||
@@ -218,7 +218,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
}
|
||||
|
||||
if (arr.primitive().device() == Device::gpu) {
|
||||
metal::eval(arr);
|
||||
gpu::eval(arr);
|
||||
} else {
|
||||
cpu::eval(arr);
|
||||
}
|
||||
@@ -229,7 +229,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
// Commit any open streams
|
||||
for (auto& [_, e] : events) {
|
||||
if (e.stream().device == Device::gpu) {
|
||||
metal::finalize(e.stream());
|
||||
gpu::finalize(e.stream());
|
||||
}
|
||||
}
|
||||
scheduler::wait_for_one();
|
||||
@@ -267,7 +267,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
auto s = e.stream();
|
||||
e.signal(s);
|
||||
if (s.device == Device::gpu) {
|
||||
metal::finalize(s);
|
||||
gpu::finalize(s);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user