CUDA backend: backbone (#2075)

This commit is contained in:
Cheng
2025-05-07 13:26:46 +09:00
committed by GitHub
parent 5a1a5d5ed1
commit 0cae0bdac8
22 changed files with 1582 additions and 2 deletions

36
mlx/backend/cuda/utils.h Normal file
View File

@@ -0,0 +1,36 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuda_runtime.h>
namespace mlx::core {
namespace cu {
class Device;
}
// Cuda stream managed with RAII.
class CudaStream {
public:
explicit CudaStream(cu::Device& device);
~CudaStream();
CudaStream(const CudaStream&) = delete;
CudaStream& operator=(const CudaStream&) = delete;
operator cudaStream_t() const {
return stream_;
}
private:
cudaStream_t stream_;
};
// Throw exception if the cuda API does not succeed.
void check_cuda_error(const char* name, cudaError_t err);
// The macro version that prints the command that failed.
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
} // namespace mlx::core