mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 00:36:32 +08:00
[CUDA] Update cudaMemAdvise and cudaGraphAddDependencies for CUDA 13
These functions' signatures changed in CUDA 13, so we differentiate between CUDA 13 and preceding releases at compile time.
This commit is contained in:
parent
9392fc3f88
commit
9c3259cb5c
@ -30,8 +30,15 @@ SmallSizePool::SmallSizePool() {
|
||||
next_free_ = buffer_;
|
||||
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
|
||||
#if CUDART_VERSION >= 13000
|
||||
cudaMemLocation loc;
|
||||
loc.type = cudaMemLocationTypeDevice;
|
||||
loc.id = 0;
|
||||
#else
|
||||
int loc = 0;
|
||||
#endif // CUDART_VERSION >= 13000
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
|
||||
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, loc));
|
||||
|
||||
auto curr = next_free_;
|
||||
for (size_t i = 1; i < num_blocks; ++i) {
|
||||
|
@ -269,7 +269,13 @@ void CommandEncoder::commit() {
|
||||
if (node_count_ > 0) {
|
||||
if (!from_nodes_.empty()) {
|
||||
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
|
||||
graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
|
||||
graph_,
|
||||
from_nodes_.data(),
|
||||
to_nodes_.data(),
|
||||
#if CUDART_VERSION >= 13000
|
||||
nullptr, // edgeData
|
||||
#endif // CUDART_VERSION >= 13000
|
||||
from_nodes_.size()));
|
||||
}
|
||||
|
||||
graph_key_ += ".";
|
||||
|
Loading…
Reference in New Issue
Block a user