mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* Use async cuda malloc managed with cuda 13 * add pool threshold * refactor for regular cuda malloc * load eval gpu for cuda * remove use of cuda pool, use cuda free async * fix * fix * fix * fix * fix + comment
47 lines
1.2 KiB
C++
47 lines
1.2 KiB
C++
// Copyright © 2025 Apple Inc.
|
|
|
|
#include "mlx/distributed/primitives.h"
|
|
#include "mlx/fast_primitives.h"
|
|
#include "mlx/primitives.h"
|
|
|
|
namespace mlx::core {
|
|
|
|
#define NO_GPU_MULTI(func) \
|
|
void func::eval_gpu( \
|
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
|
throw std::runtime_error(#func " has no CUDA implementation."); \
|
|
}
|
|
|
|
#define NO_GPU_USE_FALLBACK(func) \
|
|
bool func::use_fallback(Stream s) { \
|
|
return true; \
|
|
} \
|
|
NO_GPU_MULTI(func)
|
|
|
|
#define NO_GPU(func) \
|
|
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
|
throw std::runtime_error(#func " has no CUDA implementation."); \
|
|
}
|
|
|
|
NO_GPU(BlockMaskedMM)
|
|
NO_GPU(FFT)
|
|
NO_GPU(GatherMM)
|
|
NO_GPU(GatherQMM)
|
|
NO_GPU(Hadamard)
|
|
NO_GPU_MULTI(LUF)
|
|
NO_GPU_MULTI(QRF)
|
|
NO_GPU(QuantizedMatmul)
|
|
NO_GPU(SegmentedMM)
|
|
NO_GPU_MULTI(SVD)
|
|
NO_GPU(Inverse)
|
|
NO_GPU(Cholesky)
|
|
NO_GPU_MULTI(Eig)
|
|
NO_GPU_MULTI(Eigh)
|
|
|
|
namespace distributed {
|
|
NO_GPU_MULTI(Send)
|
|
NO_GPU_MULTI(Recv)
|
|
} // namespace distributed
|
|
|
|
} // namespace mlx::core
|