mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-01 08:38:12 +08:00
Use async cuda malloc managed with cuda 13
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/allocator.h"
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
@@ -93,9 +94,17 @@ CudaAllocator::CudaAllocator()
|
|||||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||||
memory_limit_ = total * 0.95;
|
memory_limit_ = total * 0.95;
|
||||||
max_pool_size_ = memory_limit_;
|
max_pool_size_ = memory_limit_;
|
||||||
|
#if CUDART_VERSION >= 13000
|
||||||
|
cudaMemLocation loc;
|
||||||
|
loc.id = 0;
|
||||||
|
loc.type = cudaMemLocationTypeNone;
|
||||||
|
cudaMemGetDefaultMemPool(&cuda_pool_, &loc, cudaMemAllocationTypeManaged);
|
||||||
|
// TODO set that.
|
||||||
|
// uint64_t threshold = UINT64_MAX;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer CudaAllocator::malloc(size_t size) {
|
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
||||||
// Find available buffer from cache.
|
// Find available buffer from cache.
|
||||||
auto orig_size = size;
|
auto orig_size = size;
|
||||||
std::unique_lock lock(mutex_);
|
std::unique_lock lock(mutex_);
|
||||||
@@ -123,7 +132,12 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
lock.unlock();
|
lock.unlock();
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
buf = new CudaBuffer{nullptr, size};
|
buf = new CudaBuffer{nullptr, size};
|
||||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
cudaError_t err;
|
||||||
|
if (stream != nullptr && cuda_pool_ != nullptr) {
|
||||||
|
err = cudaMallocFromPoolAsync(&buf->data, size, cuda_pool_, stream);
|
||||||
|
} else {
|
||||||
|
err = cudaMallocManaged(&buf->data, size);
|
||||||
|
}
|
||||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(fmt::format(
|
||||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||||
@@ -141,6 +155,14 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
return Buffer{buf};
|
return Buffer{buf};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) {
|
||||||
|
return malloc_impl(size, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
Buffer CudaAllocator::malloc(size_t size) {
|
||||||
|
return malloc_impl(size, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
void CudaAllocator::free(Buffer buffer) {
|
void CudaAllocator::free(Buffer buffer) {
|
||||||
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
@@ -220,6 +242,16 @@ CudaAllocator& allocator() {
|
|||||||
return *allocator_;
|
return *allocator_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Buffer malloc_async(size_t size, cudaStream_t stream) {
|
||||||
|
auto buffer = allocator().malloc_async(size, stream);
|
||||||
|
if (size && !buffer.ptr()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[malloc_async] Unable to allocate " << size << " bytes.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
|
|
||||||
namespace allocator {
|
namespace allocator {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/buffer_cache.h"
|
#include "mlx/backend/common/buffer_cache.h"
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@@ -45,6 +46,7 @@ class SmallSizePool {
|
|||||||
class CudaAllocator : public allocator::Allocator {
|
class CudaAllocator : public allocator::Allocator {
|
||||||
public:
|
public:
|
||||||
Buffer malloc(size_t size) override;
|
Buffer malloc(size_t size) override;
|
||||||
|
Buffer malloc_async(size_t size, cudaStream_t stream);
|
||||||
void free(Buffer buffer) override;
|
void free(Buffer buffer) override;
|
||||||
size_t size(Buffer buffer) const override;
|
size_t size(Buffer buffer) const override;
|
||||||
|
|
||||||
@@ -58,6 +60,7 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
void clear_cache();
|
void clear_cache();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
Buffer malloc_impl(size_t size, cudaStream_t stream);
|
||||||
void cuda_free(CudaBuffer* buf);
|
void cuda_free(CudaBuffer* buf);
|
||||||
|
|
||||||
CudaAllocator();
|
CudaAllocator();
|
||||||
@@ -70,8 +73,11 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
size_t active_memory_{0};
|
size_t active_memory_{0};
|
||||||
size_t peak_memory_{0};
|
size_t peak_memory_{0};
|
||||||
SmallSizePool scalar_pool_;
|
SmallSizePool scalar_pool_;
|
||||||
|
cudaMemPool_t cuda_pool_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
CudaAllocator& allocator();
|
CudaAllocator& allocator();
|
||||||
|
|
||||||
|
Buffer malloc_async(size_t size, cudaStream_t stream);
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
|||||||
@@ -41,9 +41,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(stream());
|
auto& encoder = cu::get_command_encoder(stream());
|
||||||
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
||||||
|
|||||||
@@ -140,8 +140,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
nvtx3::scoped_range r("ArgReduce::eval_gpu");
|
nvtx3::scoped_range r("ArgReduce::eval_gpu");
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
|
|
||||||
// Prepare the shapes, strides and axis arguments.
|
// Prepare the shapes, strides and axis arguments.
|
||||||
Shape shape = remove_index(in.shape(), axis_);
|
Shape shape = remove_index(in.shape(), axis_);
|
||||||
@@ -154,7 +156,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int32_t ndim = shape.size();
|
int32_t ndim = shape.size();
|
||||||
|
|
||||||
// ArgReduce.
|
// ArgReduce.
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
|
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
|
||||||
|
|||||||
@@ -87,8 +87,8 @@ void fill_gpu(const array& in, array& out, const Stream& s) {
|
|||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
|
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/lru_cache.h"
|
#include "mlx/backend/cuda/lru_cache.h"
|
||||||
#include "mlx/backend/cuda/worker.h"
|
#include "mlx/backend/cuda/worker.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
|||||||
@@ -370,7 +370,7 @@ void CublasGemm::execute(
|
|||||||
// Ensure workspace is 256-byte aligned
|
// Ensure workspace is 256-byte aligned
|
||||||
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
|
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
|
||||||
array workspace(
|
array workspace(
|
||||||
allocator::malloc(nbytes),
|
cu::malloc_async(nbytes, encoder.stream()),
|
||||||
{static_cast<int>(heuristic_.workspaceSize)},
|
{static_cast<int>(heuristic_.workspaceSize)},
|
||||||
int8);
|
int8);
|
||||||
encoder.add_temporary(workspace);
|
encoder.add_temporary(workspace);
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ void CublasGemm::run_batched(
|
|||||||
|
|
||||||
// Launch kernel to set device offsets
|
// Launch kernel to set device offsets
|
||||||
auto pointers = array(
|
auto pointers = array(
|
||||||
allocator::malloc(batch_count * sizeof(void*) * 3),
|
cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()),
|
||||||
{batch_count * 3},
|
{batch_count * 3},
|
||||||
uint64);
|
uint64);
|
||||||
|
|
||||||
@@ -251,7 +251,7 @@ void CublasGemm::run_batched(
|
|||||||
|
|
||||||
// Launch kernel to set device offsets
|
// Launch kernel to set device offsets
|
||||||
auto pointers = array(
|
auto pointers = array(
|
||||||
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
|
cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()),
|
||||||
{batch_count * 4},
|
{batch_count * 4},
|
||||||
uint64);
|
uint64);
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() > 0);
|
assert(inputs.size() > 0);
|
||||||
const auto& src = inputs[0];
|
const auto& src = inputs[0];
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -80,7 +82,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
dtype_to_string(idx_dtype),
|
dtype_to_string(idx_dtype),
|
||||||
nidx);
|
nidx);
|
||||||
|
|
||||||
auto& s = stream();
|
|
||||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||||
std::vector<std::string> kernel_names;
|
std::vector<std::string> kernel_names;
|
||||||
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||||
@@ -121,7 +122,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
idx_ndim,
|
idx_ndim,
|
||||||
large ? "int64_t" : "int32_t");
|
large ? "int64_t" : "int32_t");
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
for (const auto& in : inputs) {
|
for (const auto& in : inputs) {
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
}
|
}
|
||||||
@@ -239,7 +239,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
const auto& src = inputs[0];
|
const auto& src = inputs[0];
|
||||||
const auto& idx = inputs[1];
|
const auto& idx = inputs[1];
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -251,7 +253,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
dtype_to_string(out.dtype()),
|
dtype_to_string(out.dtype()),
|
||||||
dtype_to_string(idx.dtype()));
|
dtype_to_string(idx.dtype()));
|
||||||
|
|
||||||
auto& s = stream();
|
|
||||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||||
std::vector<std::string> kernel_names;
|
std::vector<std::string> kernel_names;
|
||||||
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||||
@@ -312,7 +313,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
idx.flags().row_contiguous,
|
idx.flags().row_contiguous,
|
||||||
large ? "int64_t" : "int32_t");
|
large ? "int64_t" : "int32_t");
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
for (const auto& in : inputs) {
|
for (const auto& in : inputs) {
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -230,9 +230,10 @@ void LayerNorm::eval_gpu(
|
|||||||
nvtx3::scoped_range r("LayerNorm::eval_gpu");
|
nvtx3::scoped_range r("LayerNorm::eval_gpu");
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
// Make sure that the last dimension is contiguous.
|
// Make sure that the last dimension is contiguous.
|
||||||
auto set_output = [&s, &out](const array& x) {
|
auto set_output = [&s, &out, &encoder](const array& x) {
|
||||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||||
if (no_copy && x.ndim() > 1) {
|
if (no_copy && x.ndim() > 1) {
|
||||||
auto s = x.strides()[x.ndim() - 2];
|
auto s = x.strides()[x.ndim() - 2];
|
||||||
@@ -243,7 +244,7 @@ void LayerNorm::eval_gpu(
|
|||||||
out.copy_shared_buffer(x);
|
out.copy_shared_buffer(x);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(x.data_size() * x.itemsize()),
|
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
|
||||||
x.data_size(),
|
x.data_size(),
|
||||||
x.strides(),
|
x.strides(),
|
||||||
x.flags());
|
x.flags());
|
||||||
@@ -265,7 +266,6 @@ void LayerNorm::eval_gpu(
|
|||||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||||
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
|
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.set_input_array(x);
|
encoder.set_input_array(x);
|
||||||
encoder.set_input_array(w);
|
encoder.set_input_array(w);
|
||||||
encoder.set_input_array(b);
|
encoder.set_input_array(b);
|
||||||
@@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu(
|
|||||||
gx.copy_shared_buffer(g);
|
gx.copy_shared_buffer(g);
|
||||||
g_in_gx = true;
|
g_in_gx = true;
|
||||||
} else {
|
} else {
|
||||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
|
||||||
}
|
}
|
||||||
if (g_copied && !g_in_gx) {
|
if (g_copied && !g_in_gx) {
|
||||||
encoder.add_temporary(g);
|
encoder.add_temporary(g);
|
||||||
@@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu(
|
|||||||
g_in_gw = true;
|
g_in_gw = true;
|
||||||
gw_temp.copy_shared_buffer(g);
|
gw_temp.copy_shared_buffer(g);
|
||||||
} else {
|
} else {
|
||||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
|
||||||
encoder.add_temporary(gw_temp);
|
encoder.add_temporary(gw_temp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
auto in = ensure_contiguous(inputs[0]);
|
auto in = ensure_contiguous(inputs[0]);
|
||||||
if (in.flags().row_contiguous) {
|
if (in.flags().row_contiguous) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
} else {
|
} else {
|
||||||
auto n = in.shape(-1);
|
auto n = in.shape(-1);
|
||||||
auto flags = in.flags();
|
auto flags = in.flags();
|
||||||
@@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
flags.col_contiguous = col_contig;
|
flags.col_contiguous = col_contig;
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(in.nbytes() / n),
|
cu::malloc_async(in.nbytes() / n, encoder.stream()),
|
||||||
in.data_size() / n,
|
in.data_size() / n,
|
||||||
std::move(strides),
|
std::move(strides),
|
||||||
flags);
|
flags);
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
|
|
||||||
int M = a_pre.shape(-2);
|
int M = a_pre.shape(-2);
|
||||||
int N = b_pre.shape(-1);
|
int N = b_pre.shape(-1);
|
||||||
@@ -163,7 +163,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
|
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
|
||||||
c.data_size() == out.shape(-1)) {
|
c.data_size() == out.shape(-1)) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
gemm_and_bias(
|
gemm_and_bias(
|
||||||
encoder,
|
encoder,
|
||||||
M,
|
M,
|
||||||
@@ -187,10 +187,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto sty = c.strides()[c.ndim() - 1];
|
auto sty = c.strides()[c.ndim() - 1];
|
||||||
if (sty == 1 && stx == c.shape(-1)) {
|
if (sty == 1 && stx == c.shape(-1)) {
|
||||||
ldc = stx;
|
ldc = stx;
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
} else if (sty == 1 && stx == 0) {
|
} else if (sty == 1 && stx == 0) {
|
||||||
ldc = 0;
|
ldc = 0;
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
} else {
|
} else {
|
||||||
// Copy C into out and set C to out
|
// Copy C into out and set C to out
|
||||||
ldc = c.shape(-1);
|
ldc = c.shape(-1);
|
||||||
|
|||||||
@@ -176,9 +176,10 @@ void RMSNorm::eval_gpu(
|
|||||||
nvtx3::scoped_range r("RMSNorm::eval_gpu");
|
nvtx3::scoped_range r("RMSNorm::eval_gpu");
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
// Make sure that the last dimension is contiguous.
|
// Make sure that the last dimension is contiguous.
|
||||||
auto set_output = [&s, &out](const array& x) {
|
auto set_output = [&s, &out, &encoder](const array& x) {
|
||||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||||
if (no_copy && x.ndim() > 1) {
|
if (no_copy && x.ndim() > 1) {
|
||||||
auto s = x.strides()[x.ndim() - 2];
|
auto s = x.strides()[x.ndim() - 2];
|
||||||
@@ -189,7 +190,7 @@ void RMSNorm::eval_gpu(
|
|||||||
out.copy_shared_buffer(x);
|
out.copy_shared_buffer(x);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(x.data_size() * x.itemsize()),
|
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
|
||||||
x.data_size(),
|
x.data_size(),
|
||||||
x.strides(),
|
x.strides(),
|
||||||
x.flags());
|
x.flags());
|
||||||
@@ -209,7 +210,6 @@ void RMSNorm::eval_gpu(
|
|||||||
int32_t n_rows = x.data_size() / axis_size;
|
int32_t n_rows = x.data_size() / axis_size;
|
||||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.set_input_array(x);
|
encoder.set_input_array(x);
|
||||||
encoder.set_input_array(w);
|
encoder.set_input_array(w);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
@@ -274,7 +274,7 @@ void RMSNormVJP::eval_gpu(
|
|||||||
gx.copy_shared_buffer(g);
|
gx.copy_shared_buffer(g);
|
||||||
g_in_gx = true;
|
g_in_gx = true;
|
||||||
} else {
|
} else {
|
||||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
|
||||||
}
|
}
|
||||||
if (g_copied && !g_in_gx) {
|
if (g_copied && !g_in_gx) {
|
||||||
encoder.add_temporary(g);
|
encoder.add_temporary(g);
|
||||||
@@ -292,7 +292,7 @@ void RMSNormVJP::eval_gpu(
|
|||||||
if (!g_in_gx && donate_g) {
|
if (!g_in_gx && donate_g) {
|
||||||
gw_temp.copy_shared_buffer(g);
|
gw_temp.copy_shared_buffer(g);
|
||||||
} else {
|
} else {
|
||||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
|
||||||
encoder.add_temporary(gw_temp);
|
encoder.add_temporary(gw_temp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -250,6 +250,7 @@ void RoPE::eval_gpu(
|
|||||||
nvtx3::scoped_range r("RoPE::eval_gpu");
|
nvtx3::scoped_range r("RoPE::eval_gpu");
|
||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
auto& offset = inputs[1];
|
auto& offset = inputs[1];
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
@@ -291,14 +292,14 @@ void RoPE::eval_gpu(
|
|||||||
donated = true;
|
donated = true;
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
}
|
}
|
||||||
strides[0] = mat_size;
|
strides[0] = mat_size;
|
||||||
strides[1] = in.strides()[ndim - 2];
|
strides[1] = in.strides()[ndim - 2];
|
||||||
strides[2] = in.strides()[ndim - 1];
|
strides[2] = in.strides()[ndim - 1];
|
||||||
} else if (dispatch_ndim == 3) {
|
} else if (dispatch_ndim == 3) {
|
||||||
// Handle non-contiguous 3D inputs
|
// Handle non-contiguous 3D inputs
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
strides[0] = in.strides()[ndim - 3];
|
strides[0] = in.strides()[ndim - 3];
|
||||||
strides[1] = in.strides()[ndim - 2];
|
strides[1] = in.strides()[ndim - 2];
|
||||||
strides[2] = in.strides()[ndim - 1];
|
strides[2] = in.strides()[ndim - 1];
|
||||||
@@ -319,7 +320,6 @@ void RoPE::eval_gpu(
|
|||||||
bool single = in.flags().row_contiguous && B == 1 && T == 1;
|
bool single = in.flags().row_contiguous && B == 1 && T == 1;
|
||||||
bool with_freqs = inputs.size() == 3;
|
bool with_freqs = inputs.size() == 3;
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.set_input_array(donated ? out : in);
|
encoder.set_input_array(donated ? out : in);
|
||||||
encoder.set_input_array(offset);
|
encoder.set_input_array(offset);
|
||||||
if (with_freqs) {
|
if (with_freqs) {
|
||||||
|
|||||||
@@ -565,9 +565,10 @@ void sdpa_vector_2pass_fallback(
|
|||||||
array sums(intermediate_shape, float32, nullptr, {});
|
array sums(intermediate_shape, float32, nullptr, {});
|
||||||
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
||||||
|
|
||||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
intermediate.set_data(
|
||||||
sums.set_data(allocator::malloc(sums.nbytes()));
|
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
|
||||||
maxs.set_data(allocator::malloc(maxs.nbytes()));
|
sums.set_data(cu::malloc_async(sums.nbytes(), encoder.stream()));
|
||||||
|
maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder.stream()));
|
||||||
|
|
||||||
encoder.add_temporary(intermediate);
|
encoder.add_temporary(intermediate);
|
||||||
encoder.add_temporary(sums);
|
encoder.add_temporary(sums);
|
||||||
@@ -787,7 +788,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
};
|
};
|
||||||
|
|
||||||
o.set_data(
|
o.set_data(
|
||||||
allocator::malloc(o.nbytes()),
|
cu::malloc_async(o.nbytes(), encoder.stream()),
|
||||||
o.size(),
|
o.size(),
|
||||||
{str_oB, str_oH, str_oL, str_oD},
|
{str_oB, str_oH, str_oL, str_oD},
|
||||||
flags);
|
flags);
|
||||||
|
|||||||
@@ -367,13 +367,14 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto in = inputs[0];
|
auto in = inputs[0];
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
if (in.flags().contiguous && in.strides()[axis_] != 0) {
|
if (in.flags().contiguous && in.strides()[axis_] != 0) {
|
||||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(in.data_size() * out.itemsize()),
|
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.strides(),
|
in.strides(),
|
||||||
in.flags());
|
in.flags());
|
||||||
@@ -387,7 +388,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int32_t axis_size = in.shape(axis_);
|
int32_t axis_size = in.shape(axis_);
|
||||||
bool contiguous = in.strides()[axis_] == 1;
|
bool contiguous = in.strides()[axis_] == 1;
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
|||||||
@@ -23,14 +23,15 @@ void concatenate_gpu(
|
|||||||
}
|
}
|
||||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
|
|
||||||
auto strides = out.strides();
|
auto strides = out.strides();
|
||||||
auto flags = out.flags();
|
auto flags = out.flags();
|
||||||
flags.row_contiguous = false;
|
flags.row_contiguous = false;
|
||||||
flags.col_contiguous = false;
|
flags.col_contiguous = false;
|
||||||
flags.contiguous = false;
|
flags.contiguous = false;
|
||||||
auto concurrent = cu::get_command_encoder(s).concurrent_context();
|
auto concurrent = encoder.concurrent_context();
|
||||||
for (int i = 0; i < inputs.size(); i++) {
|
for (int i = 0; i < inputs.size(); i++) {
|
||||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||||
size_t data_offset = strides[axis] * sizes[i];
|
size_t data_offset = strides[axis] * sizes[i];
|
||||||
@@ -80,6 +81,7 @@ array compute_dynamic_offset(
|
|||||||
return std::make_tuple(false, std::move(source), std::vector{kernel_name});
|
return std::make_tuple(false, std::move(source), std::vector{kernel_name});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
// Prepare output.
|
// Prepare output.
|
||||||
array offset({1}, int64, nullptr, {});
|
array offset({1}, int64, nullptr, {});
|
||||||
bool donate = indices.is_donatable() &&
|
bool donate = indices.is_donatable() &&
|
||||||
@@ -87,10 +89,9 @@ array compute_dynamic_offset(
|
|||||||
if (donate) {
|
if (donate) {
|
||||||
offset.copy_shared_buffer(indices);
|
offset.copy_shared_buffer(indices);
|
||||||
} else {
|
} else {
|
||||||
offset.set_data(allocator::malloc(offset.itemsize()));
|
offset.set_data(cu::malloc_async(offset.itemsize(), encoder.stream()));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.add_temporary(offset);
|
encoder.add_temporary(offset);
|
||||||
encoder.set_input_array(indices);
|
encoder.set_input_array(indices);
|
||||||
encoder.set_output_array(offset);
|
encoder.set_output_array(offset);
|
||||||
|
|||||||
@@ -109,15 +109,16 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
nvtx3::scoped_range r("Softmax::eval_gpu");
|
nvtx3::scoped_range r("Softmax::eval_gpu");
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
// Make sure that the last dimension is contiguous.
|
// Make sure that the last dimension is contiguous.
|
||||||
auto set_output = [&s, &out](const array& x) {
|
auto set_output = [&s, &out, &encoder](const array& x) {
|
||||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||||
if (x.is_donatable()) {
|
if (x.is_donatable()) {
|
||||||
out.copy_shared_buffer(x);
|
out.copy_shared_buffer(x);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(x.data_size() * x.itemsize()),
|
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
|
||||||
x.data_size(),
|
x.data_size(),
|
||||||
x.strides(),
|
x.strides(),
|
||||||
x.flags());
|
x.flags());
|
||||||
@@ -136,7 +137,6 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int axis_size = in.shape().back();
|
int axis_size = in.shape().back();
|
||||||
int n_rows = in.data_size() / axis_size;
|
int n_rows = in.data_size() / axis_size;
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
|
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
|
||||||
|
|||||||
@@ -49,11 +49,14 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
array trans = swapaxes_in_eval(in, axis, last_dim);
|
array trans = swapaxes_in_eval(in, axis, last_dim);
|
||||||
in = contiguous_copy_gpu(trans, s);
|
in = contiguous_copy_gpu(trans, s);
|
||||||
encoder.add_temporary(in);
|
encoder.add_temporary(in);
|
||||||
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
out = array(
|
||||||
|
cu::malloc_async(out.nbytes(), encoder.stream()),
|
||||||
|
in.shape(),
|
||||||
|
out.dtype());
|
||||||
encoder.add_temporary(out);
|
encoder.add_temporary(out);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(in.data_size() * out.itemsize()),
|
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.strides(),
|
in.strides(),
|
||||||
in.flags());
|
in.flags());
|
||||||
@@ -70,12 +73,18 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
thrust::make_counting_iterator(0), OffsetTransform{nsort});
|
thrust::make_counting_iterator(0), OffsetTransform{nsort});
|
||||||
if (argsort) {
|
if (argsort) {
|
||||||
// Indices in the sorted dimension.
|
// Indices in the sorted dimension.
|
||||||
array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
array indices(
|
||||||
|
cu::malloc_async(out.nbytes(), encoder.stream()),
|
||||||
|
in.shape(),
|
||||||
|
out.dtype());
|
||||||
encoder.add_temporary(indices);
|
encoder.add_temporary(indices);
|
||||||
|
|
||||||
// In argsort though we don't need the result of sorted values, the
|
// In argsort though we don't need the result of sorted values, the
|
||||||
// API requires us to provide an array to store it.
|
// API requires us to provide an array to store it.
|
||||||
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
|
array discard(
|
||||||
|
cu::malloc_async(in.nbytes(), encoder.stream()),
|
||||||
|
in.shape(),
|
||||||
|
in.dtype());
|
||||||
encoder.add_temporary(discard);
|
encoder.add_temporary(discard);
|
||||||
|
|
||||||
size_t size;
|
size_t size;
|
||||||
@@ -94,7 +103,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
sizeof(Type) * 8,
|
sizeof(Type) * 8,
|
||||||
stream));
|
stream));
|
||||||
|
|
||||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
array temp(
|
||||||
|
cu::malloc_async(size, encoder.stream()),
|
||||||
|
{static_cast<int>(size)},
|
||||||
|
uint8);
|
||||||
encoder.add_temporary(temp);
|
encoder.add_temporary(temp);
|
||||||
|
|
||||||
// Start capturing after allocations
|
// Start capturing after allocations
|
||||||
@@ -135,7 +147,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
sizeof(Type) * 8,
|
sizeof(Type) * 8,
|
||||||
stream));
|
stream));
|
||||||
|
|
||||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
array temp(
|
||||||
|
cu::malloc_async(size, encoder.stream()),
|
||||||
|
{static_cast<int>(size)},
|
||||||
|
uint8);
|
||||||
encoder.add_temporary(temp);
|
encoder.add_temporary(temp);
|
||||||
|
|
||||||
// Start capturing after allocations
|
// Start capturing after allocations
|
||||||
|
|||||||
Reference in New Issue
Block a user