mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
refactor for regular cuda malloc
This commit is contained in:
@@ -64,7 +64,7 @@ array array::unsafe_weak_copy(const array& other) {
|
|||||||
other.strides(),
|
other.strides(),
|
||||||
other.flags(),
|
other.flags(),
|
||||||
[](auto) {});
|
[](auto) {});
|
||||||
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
|
cpy.array_desc_->offset = other.array_desc_->offset;
|
||||||
return cpy;
|
return cpy;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,7 +141,7 @@ bool array::is_tracer() const {
|
|||||||
|
|
||||||
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||||
array_desc_->data_ptr = buffer.raw_ptr();
|
array_desc_->offset = 0;
|
||||||
array_desc_->data_size = size();
|
array_desc_->data_size = size();
|
||||||
array_desc_->flags.contiguous = true;
|
array_desc_->flags.contiguous = true;
|
||||||
array_desc_->flags.row_contiguous = true;
|
array_desc_->flags.row_contiguous = true;
|
||||||
@@ -156,7 +156,7 @@ void array::set_data(
|
|||||||
Flags flags,
|
Flags flags,
|
||||||
Deleter d) {
|
Deleter d) {
|
||||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||||
array_desc_->data_ptr = buffer.raw_ptr();
|
array_desc_->offset = 0;
|
||||||
array_desc_->data_size = data_size;
|
array_desc_->data_size = data_size;
|
||||||
array_desc_->strides = std::move(strides);
|
array_desc_->strides = std::move(strides);
|
||||||
array_desc_->flags = flags;
|
array_desc_->flags = flags;
|
||||||
@@ -172,9 +172,8 @@ void array::copy_shared_buffer(
|
|||||||
array_desc_->strides = strides;
|
array_desc_->strides = strides;
|
||||||
array_desc_->flags = flags;
|
array_desc_->flags = flags;
|
||||||
array_desc_->data_size = data_size;
|
array_desc_->data_size = data_size;
|
||||||
auto char_offset = sizeof(char) * itemsize() * offset;
|
array_desc_->offset =
|
||||||
array_desc_->data_ptr = static_cast<void*>(
|
sizeof(char) * itemsize() * offset + other.array_desc_->offset;
|
||||||
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::copy_shared_buffer(const array& other) {
|
void array::copy_shared_buffer(const array& other) {
|
||||||
|
|||||||
@@ -352,12 +352,13 @@ class array {
|
|||||||
// Return a raw pointer to the arrays data
|
// Return a raw pointer to the arrays data
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T* data() {
|
T* data() {
|
||||||
return static_cast<T*>(array_desc_->data_ptr);
|
return reinterpret_cast<T*>(
|
||||||
|
(static_cast<char*>(buffer().raw_ptr()) + array_desc_->offset));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
const T* data() const {
|
const T* data() const {
|
||||||
return static_cast<T*>(array_desc_->data_ptr);
|
return const_cast<array&>(*this).data<T>();
|
||||||
}
|
}
|
||||||
|
|
||||||
enum Status {
|
enum Status {
|
||||||
@@ -461,8 +462,8 @@ class array {
|
|||||||
// can share the underlying data buffer.
|
// can share the underlying data buffer.
|
||||||
std::shared_ptr<Data> data;
|
std::shared_ptr<Data> data;
|
||||||
|
|
||||||
// Properly offset data pointer
|
// Offset from beginning of data pointer
|
||||||
void* data_ptr{nullptr};
|
int64_t offset{0};
|
||||||
|
|
||||||
// The size in elements of the data buffer the array accesses
|
// The size in elements of the data buffer the array accesses
|
||||||
size_t data_size;
|
size_t data_size;
|
||||||
|
|||||||
@@ -38,20 +38,20 @@ inline void set_binary_op_output_data(
|
|||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
array& out,
|
array& out,
|
||||||
BinaryOpType bopt) {
|
BinaryOpType bopt,
|
||||||
|
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
||||||
bool b_donatable = is_donatable(b, out);
|
bool b_donatable = is_donatable(b, out);
|
||||||
bool a_donatable = is_donatable(a, out);
|
bool a_donatable = is_donatable(a, out);
|
||||||
switch (bopt) {
|
switch (bopt) {
|
||||||
case BinaryOpType::ScalarScalar:
|
case BinaryOpType::ScalarScalar:
|
||||||
out.set_data(
|
out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags());
|
||||||
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
|
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::ScalarVector:
|
case BinaryOpType::ScalarVector:
|
||||||
if (b_donatable) {
|
if (b_donatable) {
|
||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(b.data_size() * out.itemsize()),
|
mallocfn(b.data_size() * out.itemsize()),
|
||||||
b.data_size(),
|
b.data_size(),
|
||||||
b.strides(),
|
b.strides(),
|
||||||
b.flags());
|
b.flags());
|
||||||
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
|
|||||||
out.copy_shared_buffer(a);
|
out.copy_shared_buffer(a);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(a.data_size() * out.itemsize()),
|
mallocfn(a.data_size() * out.itemsize()),
|
||||||
a.data_size(),
|
a.data_size(),
|
||||||
a.strides(),
|
a.strides(),
|
||||||
a.flags());
|
a.flags());
|
||||||
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
|
|||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(a.data_size() * out.itemsize()),
|
mallocfn(a.data_size() * out.itemsize()),
|
||||||
a.data_size(),
|
a.data_size(),
|
||||||
a.strides(),
|
a.strides(),
|
||||||
a.flags());
|
a.flags());
|
||||||
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
|
|||||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(mallocfn(out.nbytes()));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -114,7 +114,9 @@ void compiled_allocate_outputs(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::function<bool(size_t)>& is_constant,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
bool contiguous) {
|
bool contiguous,
|
||||||
|
const std::function<allocator::Buffer(size_t)>&
|
||||||
|
mallocfn /* = allocator::malloc */) {
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
int o = 0;
|
int o = 0;
|
||||||
Strides strides;
|
Strides strides;
|
||||||
@@ -140,7 +142,7 @@ void compiled_allocate_outputs(
|
|||||||
}
|
}
|
||||||
for (; o < outputs.size(); ++o) {
|
for (; o < outputs.size(); ++o) {
|
||||||
outputs[o].set_data(
|
outputs[o].set_data(
|
||||||
allocator::malloc(data_size * outputs[o].itemsize()),
|
mallocfn(data_size * outputs[o].itemsize()),
|
||||||
data_size,
|
data_size,
|
||||||
strides,
|
strides,
|
||||||
flags);
|
flags);
|
||||||
@@ -163,7 +165,7 @@ void compiled_allocate_outputs(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (; o < outputs.size(); ++o) {
|
for (; o < outputs.size(); ++o) {
|
||||||
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
|
outputs[o].set_data(mallocfn(outputs[o].nbytes()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,7 +58,9 @@ void compiled_allocate_outputs(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::function<bool(size_t)>& is_constant,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
bool contiguous);
|
bool contiguous,
|
||||||
|
const std::function<allocator::Buffer(size_t)>& mallocfn =
|
||||||
|
allocator::malloc);
|
||||||
|
|
||||||
// Collapse contiguous dims ignoring scalars and constants.
|
// Collapse contiguous dims ignoring scalars and constants.
|
||||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||||
|
|||||||
@@ -22,7 +22,11 @@ enum class CopyType {
|
|||||||
GeneralGeneral
|
GeneralGeneral
|
||||||
};
|
};
|
||||||
|
|
||||||
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
inline bool set_copy_output_data(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
CopyType ctype,
|
||||||
|
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
||||||
if (ctype == CopyType::Vector) {
|
if (ctype == CopyType::Vector) {
|
||||||
// If the input is donateable, we are doing a vector copy and the types
|
// If the input is donateable, we are doing a vector copy and the types
|
||||||
// have the same size, then the input buffer can hold the output.
|
// have the same size, then the input buffer can hold the output.
|
||||||
@@ -31,14 +35,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
|||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(in.data_size() * out.itemsize()),
|
mallocfn(in.data_size() * out.itemsize()),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.strides(),
|
in.strides(),
|
||||||
in.flags());
|
in.flags());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(mallocfn(out.nbytes()));
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,7 +46,8 @@ inline void set_ternary_op_output_data(
|
|||||||
const array& b,
|
const array& b,
|
||||||
const array& c,
|
const array& c,
|
||||||
array& out,
|
array& out,
|
||||||
TernaryOpType topt) {
|
TernaryOpType topt,
|
||||||
|
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
||||||
auto maybe_donate = [&out](const array& x) {
|
auto maybe_donate = [&out](const array& x) {
|
||||||
if (is_donatable(x, out)) {
|
if (is_donatable(x, out)) {
|
||||||
out.copy_shared_buffer(x);
|
out.copy_shared_buffer(x);
|
||||||
@@ -57,13 +58,12 @@ inline void set_ternary_op_output_data(
|
|||||||
|
|
||||||
switch (topt) {
|
switch (topt) {
|
||||||
case TernaryOpType::ScalarScalarScalar:
|
case TernaryOpType::ScalarScalarScalar:
|
||||||
out.set_data(
|
out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags());
|
||||||
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
|
|
||||||
break;
|
break;
|
||||||
case TernaryOpType::VectorVectorVector:
|
case TernaryOpType::VectorVectorVector:
|
||||||
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(out.itemsize() * b.data_size()),
|
mallocfn(out.itemsize() * b.data_size()),
|
||||||
b.data_size(),
|
b.data_size(),
|
||||||
b.strides(),
|
b.strides(),
|
||||||
b.flags());
|
b.flags());
|
||||||
@@ -76,7 +76,7 @@ inline void set_ternary_op_output_data(
|
|||||||
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
||||||
(b.flags().row_contiguous && maybe_donate(b)) ||
|
(b.flags().row_contiguous && maybe_donate(b)) ||
|
||||||
(c.flags().row_contiguous && maybe_donate(c)))) {
|
(c.flags().row_contiguous && maybe_donate(c)))) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(mallocfn(out.nbytes()));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,19 +7,22 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
inline void set_unary_output_data(const array& in, array& out) {
|
inline void set_unary_output_data(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
||||||
if (in.flags().contiguous) {
|
if (in.flags().contiguous) {
|
||||||
if (is_donatable(in, out)) {
|
if (is_donatable(in, out)) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(in.data_size() * out.itemsize()),
|
mallocfn(in.data_size() * out.itemsize()),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.strides(),
|
in.strides(),
|
||||||
in.flags());
|
in.flags());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(mallocfn(out.nbytes()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ CudaBuffer* SmallSizePool::malloc() {
|
|||||||
next_free_ = next_free_->next;
|
next_free_ = next_free_->next;
|
||||||
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
|
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
|
||||||
b->buf.size = small_block_size;
|
b->buf.size = small_block_size;
|
||||||
|
b->buf.managed = true;
|
||||||
return &b->buf;
|
return &b->buf;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,16 +95,13 @@ CudaAllocator::CudaAllocator()
|
|||||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||||
memory_limit_ = total * 0.95;
|
memory_limit_ = total * 0.95;
|
||||||
max_pool_size_ = memory_limit_;
|
max_pool_size_ = memory_limit_;
|
||||||
#if CUDART_VERSION >= 13000
|
int loc = 0;
|
||||||
cudaMemLocation loc;
|
cudaDeviceGetDefaultMemPool(&cuda_pool_, loc);
|
||||||
loc.id = 0;
|
|
||||||
loc.type = cudaMemLocationTypeNone;
|
|
||||||
cudaMemGetDefaultMemPool(&cuda_pool_, &loc, cudaMemAllocationTypeManaged);
|
|
||||||
// TODO need a strategy for that
|
// TODO need a strategy for that
|
||||||
uint64_t threshold = UINT64_MAX;
|
uint64_t threshold = UINT64_MAX;
|
||||||
cudaMemPoolSetAttribute(
|
cudaMemPoolSetAttribute(
|
||||||
cuda_pool_, cudaMemPoolAttrReleaseThreshold, &threshold);
|
cuda_pool_, cudaMemPoolAttrReleaseThreshold, &threshold);
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
||||||
@@ -133,12 +131,13 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
|||||||
}
|
}
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
buf = new CudaBuffer{nullptr, size};
|
bool managed = stream == nullptr;
|
||||||
|
buf = new CudaBuffer{nullptr, size, managed};
|
||||||
cudaError_t err;
|
cudaError_t err;
|
||||||
if (stream != nullptr && cuda_pool_ != nullptr) {
|
if (managed) {
|
||||||
err = cudaMallocFromPoolAsync(&buf->data, size, cuda_pool_, stream);
|
|
||||||
} else {
|
|
||||||
err = cudaMallocManaged(&buf->data, size);
|
err = cudaMallocManaged(&buf->data, size);
|
||||||
|
} else {
|
||||||
|
err = cudaMallocFromPoolAsync(&buf->data, size, cuda_pool_, stream);
|
||||||
}
|
}
|
||||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(fmt::format(
|
||||||
@@ -266,7 +265,17 @@ void* Buffer::raw_ptr() {
|
|||||||
if (!ptr_) {
|
if (!ptr_) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return static_cast<cu::CudaBuffer*>(ptr_)->data;
|
auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
|
||||||
|
if (!cbuf.managed) {
|
||||||
|
void* new_data;
|
||||||
|
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size));
|
||||||
|
cbuf.managed = true;
|
||||||
|
CHECK_CUDA_ERROR(
|
||||||
|
cudaMemcpy(new_data, cbuf.data, cbuf.size, cudaMemcpyDefault));
|
||||||
|
CHECK_CUDA_ERROR(cudaFree(cbuf.data));
|
||||||
|
cbuf.data = new_data;
|
||||||
|
}
|
||||||
|
return cbuf.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace allocator
|
} // namespace allocator
|
||||||
|
|||||||
@@ -18,8 +18,14 @@ using allocator::Buffer;
|
|||||||
struct CudaBuffer {
|
struct CudaBuffer {
|
||||||
void* data;
|
void* data;
|
||||||
size_t size;
|
size_t size;
|
||||||
|
bool managed;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T* gpu_ptr(Buffer buf) {
|
||||||
|
return static_cast<T*>(static_cast<cu::CudaBuffer*>(buf.ptr())->data);
|
||||||
|
}
|
||||||
|
|
||||||
class SmallSizePool {
|
class SmallSizePool {
|
||||||
private:
|
private:
|
||||||
union Block {
|
union Block {
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
out.data<OutType>(),
|
gpu_ptr<OutType>(out),
|
||||||
out.data_size(),
|
out.data_size(),
|
||||||
static_cast<CTYPE>(start_),
|
static_cast<CTYPE>(start_),
|
||||||
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));
|
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));
|
||||||
|
|||||||
@@ -173,8 +173,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
0,
|
0,
|
||||||
in.data<T>(),
|
gpu_ptr<T>(in),
|
||||||
out.data<uint32_t>(),
|
gpu_ptr<uint32_t>(out),
|
||||||
out.size(),
|
out.size(),
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(in_strides),
|
const_param(in_strides),
|
||||||
|
|||||||
@@ -292,9 +292,9 @@ void binary_op_gpu_inplace(
|
|||||||
{num_blocks_x, num_blocks_y},
|
{num_blocks_x, num_blocks_y},
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
a.data<InType>(),
|
gpu_ptr<InType>(a),
|
||||||
b.data<InType>(),
|
gpu_ptr<InType>(b),
|
||||||
out.data<OutType>(),
|
gpu_ptr<OutType>(out),
|
||||||
rest,
|
rest,
|
||||||
const_param<dims_constant()>(shape),
|
const_param<dims_constant()>(shape),
|
||||||
const_param<dims_constant()>(a_strides),
|
const_param<dims_constant()>(a_strides),
|
||||||
@@ -310,9 +310,9 @@ void binary_op_gpu_inplace(
|
|||||||
{num_blocks_x, num_blocks_y},
|
{num_blocks_x, num_blocks_y},
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
a.data<InType>(),
|
gpu_ptr<InType>(a),
|
||||||
b.data<InType>(),
|
gpu_ptr<InType>(b),
|
||||||
out.data<OutType>(),
|
gpu_ptr<OutType>(out),
|
||||||
rest,
|
rest,
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(a_strides),
|
const_param(a_strides),
|
||||||
@@ -339,9 +339,9 @@ void binary_op_gpu_inplace(
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
a.data<InType>(),
|
gpu_ptr<InType>(a),
|
||||||
b.data<InType>(),
|
gpu_ptr<InType>(b),
|
||||||
out.data<OutType>(),
|
gpu_ptr<OutType>(out),
|
||||||
out.data_size());
|
out.data_size());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -365,7 +365,11 @@ void binary_op_gpu(
|
|||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
auto bopt = get_binary_op_type(a, b);
|
auto bopt = get_binary_op_type(a, b);
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
set_binary_op_output_data(a, b, out, bopt, [&](auto n) {
|
||||||
|
return cu::malloc_async(n, encoder.stream());
|
||||||
|
});
|
||||||
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -245,14 +245,18 @@ void binary_two_op_gpu_inplace(
|
|||||||
auto& out_a = outputs[0];
|
auto& out_a = outputs[0];
|
||||||
auto& out_b = outputs[1];
|
auto& out_b = outputs[1];
|
||||||
auto bopt = get_binary_op_type(a, b);
|
auto bopt = get_binary_op_type(a, b);
|
||||||
set_binary_op_output_data(a, b, out_a, bopt);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
set_binary_op_output_data(a, b, out_b, bopt);
|
set_binary_op_output_data(a, b, out_a, bopt, [&](auto n) {
|
||||||
|
return cu::malloc_async(n, encoder.stream());
|
||||||
|
});
|
||||||
|
set_binary_op_output_data(a, b, out_b, bopt, [&](auto n) {
|
||||||
|
return cu::malloc_async(n, encoder.stream());
|
||||||
|
});
|
||||||
|
|
||||||
if (out_a.size() == 0) {
|
if (out_a.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.set_input_array(a);
|
encoder.set_input_array(a);
|
||||||
encoder.set_input_array(b);
|
encoder.set_input_array(b);
|
||||||
encoder.set_output_array(out_a);
|
encoder.set_output_array(out_a);
|
||||||
@@ -313,10 +317,10 @@ void binary_two_op_gpu_inplace(
|
|||||||
{num_blocks_x, num_blocks_y},
|
{num_blocks_x, num_blocks_y},
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
a.data<InType>(),
|
gpu_ptr<InType>(a),
|
||||||
b.data<InType>(),
|
gpu_ptr<InType>(b),
|
||||||
out_a.data<OutType>(),
|
gpu_ptr<OutType>(out_a),
|
||||||
out_b.data<OutType>(),
|
gpu_ptr<OutType>(out_b),
|
||||||
rest,
|
rest,
|
||||||
const_param<dims_constant()>(shape),
|
const_param<dims_constant()>(shape),
|
||||||
const_param<dims_constant()>(a_strides),
|
const_param<dims_constant()>(a_strides),
|
||||||
@@ -332,10 +336,10 @@ void binary_two_op_gpu_inplace(
|
|||||||
{num_blocks_x, num_blocks_y},
|
{num_blocks_x, num_blocks_y},
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
a.data<InType>(),
|
gpu_ptr<InType>(a),
|
||||||
b.data<InType>(),
|
gpu_ptr<InType>(b),
|
||||||
out_a.data<OutType>(),
|
gpu_ptr<OutType>(out_a),
|
||||||
out_b.data<OutType>(),
|
gpu_ptr<OutType>(out_b),
|
||||||
rest,
|
rest,
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(a_strides),
|
const_param(a_strides),
|
||||||
@@ -366,10 +370,10 @@ void binary_two_op_gpu_inplace(
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
a.data<InType>(),
|
gpu_ptr<InType>(a),
|
||||||
b.data<InType>(),
|
gpu_ptr<InType>(b),
|
||||||
out_a.data<OutType>(),
|
gpu_ptr<OutType>(out_a),
|
||||||
out_b.data<OutType>(),
|
gpu_ptr<OutType>(out_b),
|
||||||
out_a.data_size());
|
out_a.data_size());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -270,17 +270,16 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
if (out_.size() == 0) {
|
if (out_.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
array in = inputs[0];
|
array in = inputs[0];
|
||||||
array wt = inputs[1];
|
array wt = inputs[1];
|
||||||
array out = out_;
|
array out = out_;
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
Dtype dtype = out.dtype();
|
Dtype dtype = out.dtype();
|
||||||
|
|
||||||
auto& s = stream();
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
|
|
||||||
// Search cache.
|
// Search cache.
|
||||||
ConvCacheKey cache_key{
|
ConvCacheKey cache_key{
|
||||||
encoder.device().cuda_device(),
|
encoder.device().cuda_device(),
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ array unfold_inputs_nd(
|
|||||||
int mat_N,
|
int mat_N,
|
||||||
ConvParams<NDIM>& params) {
|
ConvParams<NDIM>& params) {
|
||||||
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
|
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
|
||||||
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
|
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
|
||||||
encoder.add_temporary(unfolded);
|
encoder.add_temporary(unfolded);
|
||||||
|
|
||||||
int filter_size = params.C;
|
int filter_size = params.C;
|
||||||
@@ -118,8 +118,8 @@ array unfold_inputs_nd(
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
in.data<DataType>(),
|
gpu_ptr<DataType>(in),
|
||||||
unfolded.data<DataType>(),
|
gpu_ptr<DataType>(unfolded),
|
||||||
filter_size,
|
filter_size,
|
||||||
out_pixels,
|
out_pixels,
|
||||||
params);
|
params);
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ array grouped_unfold_transpose_inputs_nd(
|
|||||||
int mat_N,
|
int mat_N,
|
||||||
ConvParams<NDIM>& params) {
|
ConvParams<NDIM>& params) {
|
||||||
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
|
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
|
||||||
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
|
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
|
||||||
encoder.add_temporary(unfolded);
|
encoder.add_temporary(unfolded);
|
||||||
|
|
||||||
int filter_size = params.C;
|
int filter_size = params.C;
|
||||||
@@ -121,8 +121,8 @@ array grouped_unfold_transpose_inputs_nd(
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
in.data<DataType>(),
|
gpu_ptr<DataType>(in),
|
||||||
unfolded.data<DataType>(),
|
gpu_ptr<DataType>(unfolded),
|
||||||
filter_size,
|
filter_size,
|
||||||
out_pixels,
|
out_pixels,
|
||||||
params);
|
params);
|
||||||
|
|||||||
@@ -5,6 +5,22 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
bool donated = set_copy_output_data(in, out, ctype, [&](auto n) {
|
||||||
|
return cu::malloc_async(n, encoder.stream());
|
||||||
|
});
|
||||||
|
if (donated && in.dtype() == out.dtype()) {
|
||||||
|
// If the output has the same type as the input then there is nothing to
|
||||||
|
// copy, just use the buffer.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (ctype == CopyType::GeneralGeneral) {
|
||||||
|
ctype = CopyType::General;
|
||||||
|
}
|
||||||
|
copy_gpu_inplace(in, out, ctype, s);
|
||||||
|
}
|
||||||
|
|
||||||
void copy_gpu_inplace(
|
void copy_gpu_inplace(
|
||||||
const array& in,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
|
|||||||
@@ -77,8 +77,8 @@ void copy_contiguous(
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
in.data<InType>() + in_offset,
|
gpu_ptr<InType>(in) + in_offset,
|
||||||
out.data<OutType>() + out_offset,
|
gpu_ptr<OutType>(out) + out_offset,
|
||||||
out.data_size());
|
out.data_size());
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -106,8 +106,8 @@ void copy_general(
|
|||||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
||||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
size_t data_size = 1;
|
size_t data_size = 1;
|
||||||
for (auto& s : shape)
|
for (auto& s : shape)
|
||||||
|
|||||||
@@ -69,8 +69,8 @@ void copy_general_dynamic(
|
|||||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
||||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
@@ -90,8 +90,8 @@ void copy_general_dynamic(
|
|||||||
const_param<dims_constant()>(shape),
|
const_param<dims_constant()>(shape),
|
||||||
const_param<dims_constant()>(strides_in),
|
const_param<dims_constant()>(strides_in),
|
||||||
const_param<dims_constant()>(strides_out),
|
const_param<dims_constant()>(strides_out),
|
||||||
dynamic_offset_in.data<int64_t>(),
|
gpu_ptr<int64_t>(dynamic_offset_in),
|
||||||
dynamic_offset_out.data<int64_t>());
|
gpu_ptr<int64_t>(dynamic_offset_out));
|
||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||||
@@ -107,8 +107,8 @@ void copy_general_dynamic(
|
|||||||
const_param(strides_in),
|
const_param(strides_in),
|
||||||
const_param(strides_out),
|
const_param(strides_out),
|
||||||
ndim,
|
ndim,
|
||||||
dynamic_offset_in.data<int64_t>(),
|
gpu_ptr<int64_t>(dynamic_offset_in),
|
||||||
dynamic_offset_out.data<int64_t>());
|
gpu_ptr<int64_t>(dynamic_offset_out));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -92,8 +92,8 @@ void copy_general_input(
|
|||||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
||||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
int work_per_thread = 1;
|
int work_per_thread = 1;
|
||||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||||
|
|||||||
@@ -133,13 +133,13 @@ bool prepare_cudnn_plan(
|
|||||||
F&& execute) {
|
F&& execute) {
|
||||||
int workspace_size = plan.getWorkspaceSize();
|
int workspace_size = plan.getWorkspaceSize();
|
||||||
array workspace(
|
array workspace(
|
||||||
workspace_size > 0 ? allocator::malloc(workspace_size)
|
workspace_size > 0 ? cu::malloc_async(workspace_size, encoder.stream())
|
||||||
: allocator::Buffer(nullptr),
|
: allocator::Buffer(nullptr),
|
||||||
{workspace_size},
|
{workspace_size},
|
||||||
uint8);
|
uint8);
|
||||||
|
|
||||||
auto args = cudnn_frontend::VariantPackBuilder()
|
auto args = cudnn_frontend::VariantPackBuilder()
|
||||||
.setWorkspacePointer(workspace.data<void>())
|
.setWorkspacePointer(gpu_ptr<void>(workspace))
|
||||||
.setDataPointers(num_args, data_ptrs)
|
.setDataPointers(num_args, data_ptrs)
|
||||||
.setUids(num_args, uids)
|
.setUids(num_args, uids)
|
||||||
.build();
|
.build();
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/device/config.h"
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
@@ -23,7 +24,7 @@ class CommandEncoder;
|
|||||||
// Return pointer alignment of |x|'s data.
|
// Return pointer alignment of |x|'s data.
|
||||||
inline uint8_t get_alignment(const array& x) {
|
inline uint8_t get_alignment(const array& x) {
|
||||||
uint8_t alignment = 1;
|
uint8_t alignment = 1;
|
||||||
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
uintptr_t address = reinterpret_cast<uintptr_t>(gpu_ptr<void>(x));
|
||||||
for (; alignment < 32; alignment *= 2) {
|
for (; alignment < 32; alignment *= 2) {
|
||||||
if (address % (alignment * 2)) {
|
if (address % (alignment * 2)) {
|
||||||
return alignment;
|
return alignment;
|
||||||
@@ -56,7 +57,7 @@ inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
|
|||||||
|
|
||||||
// Helpers used by get_data_ptrs to get pointers.
|
// Helpers used by get_data_ptrs to get pointers.
|
||||||
inline void* get_data_ptr(const array& arr) {
|
inline void* get_data_ptr(const array& arr) {
|
||||||
return const_cast<void*>(arr.data<void>());
|
return const_cast<void*>(gpu_ptr<void>(arr));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
|
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
|
||||||
|
|||||||
@@ -279,6 +279,7 @@ void CustomKernel::eval_gpu(
|
|||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
nvtx3::scoped_range r("CustomKernel::eval_gpu");
|
nvtx3::scoped_range r("CustomKernel::eval_gpu");
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
std::vector<array> copies;
|
std::vector<array> copies;
|
||||||
|
|
||||||
@@ -288,7 +289,7 @@ void CustomKernel::eval_gpu(
|
|||||||
copies.emplace_back(init_value_.value(), out.dtype());
|
copies.emplace_back(init_value_.value(), out.dtype());
|
||||||
fill_gpu(copies.back(), out, s);
|
fill_gpu(copies.back(), out, s);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -356,7 +357,6 @@ void CustomKernel::eval_gpu(
|
|||||||
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
|
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
|
||||||
|
|
||||||
// Call the kernel
|
// Call the kernel
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
for (const auto& in : checked_inputs) {
|
for (const auto& in : checked_inputs) {
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,8 +15,10 @@ void AllReduce::eval_gpu(
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(outputs.size() == 1);
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
auto set_input_output =
|
auto& s = stream();
|
||||||
[s = stream()](const array& in, array& out) -> std::pair<array, array> {
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
auto set_input_output = [&](const array& in,
|
||||||
|
array& out) -> std::pair<array, array> {
|
||||||
if (!in.flags().row_contiguous) {
|
if (!in.flags().row_contiguous) {
|
||||||
copy_gpu(in, out, CopyType::General, s);
|
copy_gpu(in, out, CopyType::General, s);
|
||||||
return {out, out};
|
return {out, out};
|
||||||
@@ -24,19 +26,17 @@ void AllReduce::eval_gpu(
|
|||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
return {in, out};
|
return {in, out};
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
return {in, out};
|
return {in, out};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
auto [input, output] = set_input_output(inputs[0], outputs[0]);
|
auto [input, output] = set_input_output(inputs[0], outputs[0]);
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(stream());
|
|
||||||
encoder.set_input_array(input);
|
encoder.set_input_array(input);
|
||||||
encoder.set_output_array(output);
|
encoder.set_output_array(output);
|
||||||
|
|
||||||
auto capture = encoder.capture_context();
|
auto capture = encoder.capture_context();
|
||||||
auto& s = stream();
|
|
||||||
|
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case Sum:
|
case Sum:
|
||||||
|
|||||||
@@ -241,7 +241,7 @@ void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
|
|||||||
CUBLASLT_MATMUL_DESC_EPILOGUE,
|
CUBLASLT_MATMUL_DESC_EPILOGUE,
|
||||||
&epilogue,
|
&epilogue,
|
||||||
sizeof(epilogue)));
|
sizeof(epilogue)));
|
||||||
auto* bias_ptr = bias.data<void>();
|
auto* bias_ptr = gpu_ptr<void>(bias);
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
matmul_desc_,
|
matmul_desc_,
|
||||||
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
|
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
|
||||||
@@ -278,9 +278,9 @@ void CublasGemm::run(
|
|||||||
|
|
||||||
execute(
|
execute(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<void>(),
|
gpu_ptr<void>(out),
|
||||||
a.data<void>(),
|
gpu_ptr<void>(a),
|
||||||
b.data<void>(),
|
gpu_ptr<void>(b),
|
||||||
nullptr,
|
nullptr,
|
||||||
alpha);
|
alpha);
|
||||||
}
|
}
|
||||||
@@ -321,10 +321,10 @@ void CublasGemm::run(
|
|||||||
|
|
||||||
execute(
|
execute(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<void>(),
|
gpu_ptr<void>(out),
|
||||||
a.data<void>(),
|
gpu_ptr<void>(a),
|
||||||
b.data<void>(),
|
gpu_ptr<void>(b),
|
||||||
c.data<void>(),
|
gpu_ptr<void>(c),
|
||||||
alpha,
|
alpha,
|
||||||
beta);
|
beta);
|
||||||
}
|
}
|
||||||
@@ -374,7 +374,7 @@ void CublasGemm::execute(
|
|||||||
{static_cast<int>(heuristic_.workspaceSize)},
|
{static_cast<int>(heuristic_.workspaceSize)},
|
||||||
int8);
|
int8);
|
||||||
encoder.add_temporary(workspace);
|
encoder.add_temporary(workspace);
|
||||||
workspace_ptr = workspace.data<void>();
|
workspace_ptr = gpu_ptr<void>(workspace);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto capture = encoder.capture_context();
|
auto capture = encoder.capture_context();
|
||||||
|
|||||||
@@ -25,9 +25,10 @@ void CublasGemm::run_batched(
|
|||||||
for (size_t i = 0; i < nbatch; ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
execute(
|
execute(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
gpu_ptr<int8_t>(out) +
|
||||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||||
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
|
||||||
|
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
|
||||||
nullptr,
|
nullptr,
|
||||||
alpha);
|
alpha);
|
||||||
a_it.step();
|
a_it.step();
|
||||||
@@ -60,10 +61,11 @@ void CublasGemm::run_batched(
|
|||||||
for (size_t i = 0; i < nbatch; ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
execute(
|
execute(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
gpu_ptr<int8_t>(out) +
|
||||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||||
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
|
||||||
c.data<int8_t>() + c.itemsize() * c_it.loc,
|
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
|
||||||
|
gpu_ptr<int8_t>(c) + c.itemsize() * c_it.loc,
|
||||||
alpha,
|
alpha,
|
||||||
beta);
|
beta);
|
||||||
a_it.step();
|
a_it.step();
|
||||||
|
|||||||
@@ -183,10 +183,10 @@ void CublasGemm::run_batched(
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
pointers.data<int8_t*>(),
|
gpu_ptr<int8_t*>(pointers),
|
||||||
a.data<int8_t>(),
|
gpu_ptr<int8_t>(a),
|
||||||
b.data<int8_t>(),
|
gpu_ptr<int8_t>(b),
|
||||||
out.data<int8_t>(),
|
gpu_ptr<int8_t>(out),
|
||||||
item_size,
|
item_size,
|
||||||
const_param<ndim_constant()>(batch_shape),
|
const_param<ndim_constant()>(batch_shape),
|
||||||
const_param<ndim_constant()>(a_batch_strides),
|
const_param<ndim_constant()>(a_batch_strides),
|
||||||
@@ -200,10 +200,10 @@ void CublasGemm::run_batched(
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
pointers.data<int8_t*>(),
|
gpu_ptr<int8_t*>(pointers),
|
||||||
a.data<int8_t>(),
|
gpu_ptr<int8_t>(a),
|
||||||
b.data<int8_t>(),
|
gpu_ptr<int8_t>(b),
|
||||||
out.data<int8_t>(),
|
gpu_ptr<int8_t>(out),
|
||||||
item_size,
|
item_size,
|
||||||
const_param(batch_shape),
|
const_param(batch_shape),
|
||||||
const_param(a_batch_strides),
|
const_param(a_batch_strides),
|
||||||
@@ -219,7 +219,7 @@ void CublasGemm::run_batched(
|
|||||||
encoder.set_input_array(b);
|
encoder.set_input_array(b);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
auto a_pointers = pointers.data<int8_t*>();
|
auto a_pointers = gpu_ptr<int8_t*>(pointers);
|
||||||
auto b_pointers = a_pointers + batch_count;
|
auto b_pointers = a_pointers + batch_count;
|
||||||
auto out_pointers = b_pointers + batch_count;
|
auto out_pointers = b_pointers + batch_count;
|
||||||
execute(
|
execute(
|
||||||
@@ -271,11 +271,11 @@ void CublasGemm::run_batched(
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
pointers.data<int8_t*>(),
|
gpu_ptr<int8_t*>(pointers),
|
||||||
a.data<int8_t>(),
|
gpu_ptr<int8_t>(a),
|
||||||
b.data<int8_t>(),
|
gpu_ptr<int8_t>(b),
|
||||||
c.data<int8_t>(),
|
gpu_ptr<int8_t>(c),
|
||||||
out.data<int8_t>(),
|
gpu_ptr<int8_t>(out),
|
||||||
item_size,
|
item_size,
|
||||||
const_param<ndim_constant()>(batch_shape),
|
const_param<ndim_constant()>(batch_shape),
|
||||||
const_param<ndim_constant()>(a_batch_strides),
|
const_param<ndim_constant()>(a_batch_strides),
|
||||||
@@ -290,11 +290,11 @@ void CublasGemm::run_batched(
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
pointers.data<int8_t*>(),
|
gpu_ptr<int8_t*>(pointers),
|
||||||
a.data<int8_t>(),
|
gpu_ptr<int8_t>(a),
|
||||||
b.data<int8_t>(),
|
gpu_ptr<int8_t>(b),
|
||||||
c.data<int8_t>(),
|
gpu_ptr<int8_t>(c),
|
||||||
out.data<int8_t>(),
|
gpu_ptr<int8_t>(out),
|
||||||
item_size,
|
item_size,
|
||||||
const_param(batch_shape),
|
const_param(batch_shape),
|
||||||
const_param(a_batch_strides),
|
const_param(a_batch_strides),
|
||||||
@@ -312,7 +312,7 @@ void CublasGemm::run_batched(
|
|||||||
encoder.set_input_array(c);
|
encoder.set_input_array(c);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
auto a_pointers = pointers.data<int8_t*>();
|
auto a_pointers = gpu_ptr<int8_t*>(pointers);
|
||||||
auto b_pointers = a_pointers + batch_count;
|
auto b_pointers = a_pointers + batch_count;
|
||||||
auto c_pointers = b_pointers + batch_count;
|
auto c_pointers = b_pointers + batch_count;
|
||||||
auto out_pointers = c_pointers + batch_count;
|
auto out_pointers = c_pointers + batch_count;
|
||||||
|
|||||||
@@ -149,13 +149,13 @@ void gemv(
|
|||||||
auto vec_strides = const_param(b_batch_strides);
|
auto vec_strides = const_param(b_batch_strides);
|
||||||
|
|
||||||
if (M == 1) {
|
if (M == 1) {
|
||||||
mat = b.data<DataType>();
|
mat = gpu_ptr<DataType>(b);
|
||||||
vec = a.data<DataType>();
|
vec = gpu_ptr<DataType>(a);
|
||||||
rows = N;
|
rows = N;
|
||||||
std::swap(mat_strides, vec_strides);
|
std::swap(mat_strides, vec_strides);
|
||||||
} else {
|
} else {
|
||||||
mat = a.data<DataType>();
|
mat = gpu_ptr<DataType>(a);
|
||||||
vec = b.data<DataType>();
|
vec = gpu_ptr<DataType>(b);
|
||||||
rows = M;
|
rows = M;
|
||||||
}
|
}
|
||||||
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
|
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
|
||||||
@@ -177,7 +177,7 @@ void gemv(
|
|||||||
0,
|
0,
|
||||||
mat,
|
mat,
|
||||||
vec,
|
vec,
|
||||||
out.data<DataType>(),
|
gpu_ptr<DataType>(out),
|
||||||
rows,
|
rows,
|
||||||
cols);
|
cols);
|
||||||
} else {
|
} else {
|
||||||
@@ -189,7 +189,7 @@ void gemv(
|
|||||||
0,
|
0,
|
||||||
mat,
|
mat,
|
||||||
vec,
|
vec,
|
||||||
out.data<DataType>(),
|
gpu_ptr<DataType>(out),
|
||||||
rows,
|
rows,
|
||||||
cols,
|
cols,
|
||||||
const_param(batch_shape),
|
const_param(batch_shape),
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ void append_indices_arg(
|
|||||||
int idx_ndim) {
|
int idx_ndim) {
|
||||||
SmallVector<const void*> indices(nidx);
|
SmallVector<const void*> indices(nidx);
|
||||||
for (int i = 0; i < nidx; ++i) {
|
for (int i = 0; i < nidx; ++i) {
|
||||||
indices[i] = inputs[i + 1].data<void>();
|
indices[i] = gpu_ptr<void>(inputs[i + 1]);
|
||||||
}
|
}
|
||||||
args.append(std::move(indices));
|
args.append(std::move(indices));
|
||||||
SmallVector<int32_t> indices_shape(nidx * idx_ndim);
|
SmallVector<int32_t> indices_shape(nidx * idx_ndim);
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ struct KernelArgs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void append(const array& a) {
|
void append(const array& a) {
|
||||||
append(reinterpret_cast<CUdeviceptr>(a.data<void>()));
|
append(reinterpret_cast<CUdeviceptr>(gpu_ptr<void>(a.buffer())));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
|
|||||||
@@ -280,10 +280,10 @@ void LayerNorm::eval_gpu(
|
|||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
0,
|
0,
|
||||||
x.data<DataType>(),
|
gpu_ptr<DataType>(x),
|
||||||
w.data<DataType>(),
|
gpu_ptr<DataType>(w),
|
||||||
b.data<DataType>(),
|
gpu_ptr<DataType>(b),
|
||||||
out.data<DataType>(),
|
gpu_ptr<DataType>(out),
|
||||||
eps_,
|
eps_,
|
||||||
axis_size,
|
axis_size,
|
||||||
w_stride,
|
w_stride,
|
||||||
@@ -393,11 +393,11 @@ void LayerNormVJP::eval_gpu(
|
|||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
0,
|
0,
|
||||||
x.data<DataType>(),
|
gpu_ptr<DataType>(x),
|
||||||
w.data<DataType>(),
|
gpu_ptr<DataType>(w),
|
||||||
g.data<DataType>(),
|
gpu_ptr<DataType>(g),
|
||||||
gx.data<DataType>(),
|
gpu_ptr<DataType>(gx),
|
||||||
gw_temp.data<DataType>(),
|
gpu_ptr<DataType>(gw_temp),
|
||||||
eps_,
|
eps_,
|
||||||
axis_size,
|
axis_size,
|
||||||
w_stride);
|
w_stride);
|
||||||
|
|||||||
@@ -151,8 +151,8 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
0,
|
0,
|
||||||
in.data<DataType>(),
|
gpu_ptr<DataType>(in),
|
||||||
out.data<DataType>(),
|
gpu_ptr<DataType>(out),
|
||||||
axis_size);
|
axis_size);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -262,10 +262,10 @@ void affine_quantize(
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
w.data<T>(),
|
gpu_ptr<T>(w),
|
||||||
wq.data<uint8_t>(),
|
gpu_ptr<uint8_t>(wq),
|
||||||
scales.data<T>(),
|
gpu_ptr<T>(scales),
|
||||||
biases.data<T>(),
|
gpu_ptr<T>(biases),
|
||||||
w.size());
|
w.size());
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -318,10 +318,10 @@ void affine_dequantize(
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
wq.data<uint8_t>(),
|
gpu_ptr<uint8_t>(wq),
|
||||||
scales.data<T>(),
|
gpu_ptr<T>(scales),
|
||||||
biases.data<T>(),
|
gpu_ptr<T>(biases),
|
||||||
w.data<T>(),
|
gpu_ptr<T>(w),
|
||||||
w.size());
|
w.size());
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -156,9 +156,9 @@ void fp_quantize(
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
w.data<T>(),
|
gpu_ptr<T>(w),
|
||||||
wq.data<uint8_t>(),
|
gpu_ptr<uint8_t>(wq),
|
||||||
scales.data<uint8_t>(),
|
gpu_ptr<uint8_t>(scales),
|
||||||
w.size());
|
w.size());
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
@@ -202,9 +202,9 @@ void fp_dequantize(
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
wq.data<uint8_t>(),
|
gpu_ptr<uint8_t>(wq),
|
||||||
scales.data<T>(),
|
gpu_ptr<uint8_t>(scales),
|
||||||
w.data<T>(),
|
gpu_ptr<T>(w),
|
||||||
w.size());
|
w.size());
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ void fast::Quantize::eval_gpu(
|
|||||||
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
||||||
auto& w = outputs[0];
|
auto& w = outputs[0];
|
||||||
|
|
||||||
w.set_data(allocator::malloc(w.nbytes()));
|
w.set_data(cu::malloc_async(w.nbytes(), enc.stream()));
|
||||||
|
|
||||||
if (mode_ == QuantizationMode::Affine) {
|
if (mode_ == QuantizationMode::Affine) {
|
||||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||||
@@ -72,8 +72,8 @@ void fast::Quantize::eval_gpu(
|
|||||||
auto& wq = outputs[0];
|
auto& wq = outputs[0];
|
||||||
auto& scales = outputs[1];
|
auto& scales = outputs[1];
|
||||||
|
|
||||||
wq.set_data(allocator::malloc(wq.nbytes()));
|
wq.set_data(cu::malloc_async(wq.nbytes(), enc.stream()));
|
||||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
scales.set_data(cu::malloc_async(scales.nbytes(), enc.stream()));
|
||||||
if (mode_ == QuantizationMode::Affine) {
|
if (mode_ == QuantizationMode::Affine) {
|
||||||
auto& biases = outputs[2];
|
auto& biases = outputs[2];
|
||||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||||
|
|||||||
@@ -143,7 +143,9 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
uint32_t elems_per_key = out.size() / num_keys;
|
uint32_t elems_per_key = out.size() / num_keys;
|
||||||
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
|
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -152,8 +154,6 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
uint32_t half_size = out_per_key / 2;
|
uint32_t half_size = out_per_key / 2;
|
||||||
bool odd = out_per_key % 2;
|
bool odd = out_per_key % 2;
|
||||||
|
|
||||||
auto& s = stream();
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.set_input_array(keys);
|
encoder.set_input_array(keys);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
dim3 grid_dims{num_keys, half_size + odd};
|
dim3 grid_dims{num_keys, half_size + odd};
|
||||||
@@ -171,8 +171,8 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
0,
|
0,
|
||||||
keys.data<uint32_t>(),
|
gpu_ptr<uint32_t>(keys),
|
||||||
out.data<uint8_t>(),
|
gpu_ptr<uint8_t>(out),
|
||||||
grid_dims,
|
grid_dims,
|
||||||
odd,
|
odd,
|
||||||
bytes_per_key);
|
bytes_per_key);
|
||||||
@@ -182,8 +182,8 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
0,
|
0,
|
||||||
keys.data<uint32_t>(),
|
gpu_ptr<uint32_t>(keys),
|
||||||
out.data<uint8_t>(),
|
gpu_ptr<uint8_t>(out),
|
||||||
grid_dims,
|
grid_dims,
|
||||||
odd,
|
odd,
|
||||||
bytes_per_key,
|
bytes_per_key,
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ void all_reduce(
|
|||||||
Reduce::ReduceType reduce_type) {
|
Reduce::ReduceType reduce_type) {
|
||||||
constexpr int N_READS = 8;
|
constexpr int N_READS = 8;
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
|
|
||||||
auto get_args = [](size_t size, int N) {
|
auto get_args = [](size_t size, int N) {
|
||||||
int threads = std::min(512UL, (size + N - 1) / N);
|
int threads = std::min(512UL, (size + N - 1) / N);
|
||||||
@@ -100,14 +100,15 @@ void all_reduce(
|
|||||||
Dtype dt = in.dtype();
|
Dtype dt = in.dtype();
|
||||||
|
|
||||||
// Cub doesn't like const pointers for load (sigh).
|
// Cub doesn't like const pointers for load (sigh).
|
||||||
void* indata = const_cast<void*>(in.data<void>());
|
void* indata = const_cast<void*>(gpu_ptr<void>(in));
|
||||||
|
|
||||||
// Large array so allocate an intermediate and accumulate there
|
// Large array so allocate an intermediate and accumulate there
|
||||||
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
if (blocks > 1) {
|
if (blocks > 1) {
|
||||||
array intermediate({blocks}, out.dtype(), nullptr, {});
|
array intermediate({blocks}, out.dtype(), nullptr, {});
|
||||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
intermediate.set_data(
|
||||||
|
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
|
||||||
encoder.add_temporary(intermediate);
|
encoder.add_temporary(intermediate);
|
||||||
encoder.set_output_array(intermediate);
|
encoder.set_output_array(intermediate);
|
||||||
dispatch_all_types(dt, [&](auto type_tag) {
|
dispatch_all_types(dt, [&](auto type_tag) {
|
||||||
@@ -122,14 +123,14 @@ void all_reduce(
|
|||||||
threads,
|
threads,
|
||||||
0,
|
0,
|
||||||
static_cast<T*>(indata),
|
static_cast<T*>(indata),
|
||||||
intermediate.data<U>(),
|
gpu_ptr<U>(intermediate),
|
||||||
block_step,
|
block_step,
|
||||||
insize);
|
insize);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
// Set the input for the next step and recalculate the blocks
|
// Set the input for the next step and recalculate the blocks
|
||||||
indata = intermediate.data<void>();
|
indata = gpu_ptr<void>(intermediate);
|
||||||
dt = intermediate.dtype();
|
dt = intermediate.dtype();
|
||||||
insize = intermediate.size();
|
insize = intermediate.size();
|
||||||
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
||||||
@@ -149,7 +150,7 @@ void all_reduce(
|
|||||||
threads,
|
threads,
|
||||||
0,
|
0,
|
||||||
static_cast<T*>(indata),
|
static_cast<T*>(indata),
|
||||||
out.data<U>(),
|
gpu_ptr<U>(out),
|
||||||
block_step,
|
block_step,
|
||||||
insize);
|
insize);
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -250,7 +250,7 @@ void col_reduce_looped(
|
|||||||
const cu::ColReduceArgs& args) {
|
const cu::ColReduceArgs& args) {
|
||||||
// Allocate data for the output using in's layout to access them as
|
// Allocate data for the output using in's layout to access them as
|
||||||
// contiguously as possible.
|
// contiguously as possible.
|
||||||
allocate_same_layout(out, in, axes);
|
allocate_same_layout(out, in, axes, encoder);
|
||||||
|
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
@@ -261,7 +261,7 @@ void col_reduce_looped(
|
|||||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
using U = typename cu::ReduceResult<OP, T>::type;
|
using U = typename cu::ReduceResult<OP, T>::type;
|
||||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||||
T* indata = const_cast<T*>(in.data<T>());
|
T* indata = const_cast<T*>(gpu_ptr<T>(in));
|
||||||
|
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
constexpr int BM = 32;
|
constexpr int BM = 32;
|
||||||
@@ -276,7 +276,7 @@ void col_reduce_looped(
|
|||||||
blocks,
|
blocks,
|
||||||
0,
|
0,
|
||||||
indata,
|
indata,
|
||||||
out.data<U>(),
|
gpu_ptr<U>(out),
|
||||||
static_cast<cu::ColReduceArgs>(args));
|
static_cast<cu::ColReduceArgs>(args));
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -293,7 +293,7 @@ void col_reduce_small(
|
|||||||
const cu::ColReduceArgs& args) {
|
const cu::ColReduceArgs& args) {
|
||||||
// Allocate data for the output using in's layout to access them as
|
// Allocate data for the output using in's layout to access them as
|
||||||
// contiguously as possible.
|
// contiguously as possible.
|
||||||
allocate_same_layout(out, in, axes);
|
allocate_same_layout(out, in, axes, encoder);
|
||||||
|
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
@@ -312,8 +312,8 @@ void col_reduce_small(
|
|||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
0,
|
0,
|
||||||
in.data<T>(),
|
gpu_ptr<T>(in),
|
||||||
out.data<U>(),
|
gpu_ptr<U>(out),
|
||||||
static_cast<cu::ColReduceArgs>(args),
|
static_cast<cu::ColReduceArgs>(args),
|
||||||
out.size());
|
out.size());
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ void init_reduce(
|
|||||||
Reduce::ReduceType reduce_type) {
|
Reduce::ReduceType reduce_type) {
|
||||||
// Allocate if needed
|
// Allocate if needed
|
||||||
if (out.data_shared_ptr() == nullptr) {
|
if (out.data_shared_ptr() == nullptr) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
}
|
}
|
||||||
|
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
@@ -42,7 +42,7 @@ void init_reduce(
|
|||||||
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
||||||
grid.x = (grid.x + 1023) / 1024;
|
grid.x = (grid.x + 1023) / 1024;
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel, grid, block, 0, out.data<U>(), out.size());
|
kernel, grid, block, 0, gpu_ptr<U>(out), out.size());
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
@@ -92,9 +93,10 @@ block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {
|
|||||||
inline void allocate_same_layout(
|
inline void allocate_same_layout(
|
||||||
array& out,
|
array& out,
|
||||||
const array& in,
|
const array& in,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes,
|
||||||
|
cu::CommandEncoder& encoder) {
|
||||||
if (in.flags().row_contiguous) {
|
if (in.flags().row_contiguous) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,7 +135,7 @@ inline void allocate_same_layout(
|
|||||||
fl.col_contiguous = cc;
|
fl.col_contiguous = cc;
|
||||||
fl.contiguous = true;
|
fl.contiguous = true;
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(out.nbytes()),
|
cu::malloc_async(out.nbytes(), encoder.stream()),
|
||||||
data_size,
|
data_size,
|
||||||
final_strides,
|
final_strides,
|
||||||
fl,
|
fl,
|
||||||
|
|||||||
@@ -238,7 +238,7 @@ void row_reduce_simple(
|
|||||||
const ReductionPlan& plan) {
|
const ReductionPlan& plan) {
|
||||||
// Allocate data for the output using in's layout to avoid elem_to_loc in the
|
// Allocate data for the output using in's layout to avoid elem_to_loc in the
|
||||||
// kernel.
|
// kernel.
|
||||||
allocate_same_layout(out, in, axes);
|
allocate_same_layout(out, in, axes, encoder);
|
||||||
|
|
||||||
// TODO: If out.size() < 1024 which will be a common case then write this in
|
// TODO: If out.size() < 1024 which will be a common case then write this in
|
||||||
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
|
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
|
||||||
@@ -268,10 +268,10 @@ void row_reduce_simple(
|
|||||||
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
|
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
|
||||||
}
|
}
|
||||||
|
|
||||||
T* indata = const_cast<T*>(in.data<T>());
|
T* indata = const_cast<T*>(gpu_ptr<T>(in));
|
||||||
int size = plan.shape.back();
|
int size = plan.shape.back();
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
|
kernel, grid, block, 0, indata, gpu_ptr<U>(out), out.size(), size);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -286,7 +286,7 @@ void row_reduce_looped(
|
|||||||
cu::RowReduceArgs args) {
|
cu::RowReduceArgs args) {
|
||||||
// Allocate data for the output using in's layout to access them as
|
// Allocate data for the output using in's layout to access them as
|
||||||
// contiguously as possible.
|
// contiguously as possible.
|
||||||
allocate_same_layout(out, in, axes);
|
allocate_same_layout(out, in, axes, encoder);
|
||||||
|
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
@@ -315,7 +315,7 @@ void row_reduce_looped(
|
|||||||
});
|
});
|
||||||
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel, grid, block, 0, in.data<T>(), out.data<U>(), args);
|
kernel, grid, block, 0, gpu_ptr<T>(in), gpu_ptr<U>(out), args);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -223,9 +223,9 @@ void RMSNorm::eval_gpu(
|
|||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
0,
|
0,
|
||||||
x.data<DataType>(),
|
gpu_ptr<DataType>(x),
|
||||||
w.data<DataType>(),
|
gpu_ptr<DataType>(w),
|
||||||
out.data<DataType>(),
|
gpu_ptr<DataType>(out),
|
||||||
eps_,
|
eps_,
|
||||||
axis_size,
|
axis_size,
|
||||||
w_stride);
|
w_stride);
|
||||||
@@ -318,11 +318,11 @@ void RMSNormVJP::eval_gpu(
|
|||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
0,
|
0,
|
||||||
x.data<DataType>(),
|
gpu_ptr<DataType>(x),
|
||||||
w.data<DataType>(),
|
gpu_ptr<DataType>(w),
|
||||||
g.data<DataType>(),
|
gpu_ptr<DataType>(g),
|
||||||
gx.data<DataType>(),
|
gpu_ptr<DataType>(gx),
|
||||||
gw_temp.data<DataType>(),
|
gpu_ptr<DataType>(gw_temp),
|
||||||
eps_,
|
eps_,
|
||||||
axis_size,
|
axis_size,
|
||||||
w_stride);
|
w_stride);
|
||||||
|
|||||||
@@ -340,9 +340,9 @@ void RoPE::eval_gpu(
|
|||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
0,
|
0,
|
||||||
(donated ? out : in).data<DataType>(),
|
gpu_ptr<DataType>(donated ? out : in),
|
||||||
out.data<DataType>(),
|
gpu_ptr<DataType>(out),
|
||||||
offset.data<int32_t>(),
|
gpu_ptr<int32_t>(offset),
|
||||||
scale_,
|
scale_,
|
||||||
std::log2(base_),
|
std::log2(base_),
|
||||||
mat_size,
|
mat_size,
|
||||||
@@ -357,10 +357,10 @@ void RoPE::eval_gpu(
|
|||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
0,
|
0,
|
||||||
(donated ? out : in).data<DataType>(),
|
gpu_ptr<DataType>(donated ? out : in),
|
||||||
out.data<DataType>(),
|
gpu_ptr<DataType>(out),
|
||||||
offset.data<int32_t>(),
|
gpu_ptr<int32_t>(offset),
|
||||||
inputs[2].data<float>(),
|
gpu_ptr<float>(inputs[2]),
|
||||||
scale_,
|
scale_,
|
||||||
mat_size,
|
mat_size,
|
||||||
dims,
|
dims,
|
||||||
@@ -381,10 +381,10 @@ void RoPE::eval_gpu(
|
|||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
0,
|
0,
|
||||||
(donated ? out : in).data<DataType>(),
|
gpu_ptr<DataType>(donated ? out : in),
|
||||||
out.data<DataType>(),
|
gpu_ptr<DataType>(out),
|
||||||
offset.data<int32_t>(),
|
gpu_ptr<int32_t>(offset),
|
||||||
inputs[2].data<float>(),
|
gpu_ptr<float>(inputs[2]),
|
||||||
scale_,
|
scale_,
|
||||||
std::log2(base_),
|
std::log2(base_),
|
||||||
strides,
|
strides,
|
||||||
@@ -408,9 +408,9 @@ void RoPE::eval_gpu(
|
|||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
0,
|
0,
|
||||||
(donated ? out : in).data<DataType>(),
|
gpu_ptr<DataType>(donated ? out : in),
|
||||||
out.data<DataType>(),
|
gpu_ptr<DataType>(out),
|
||||||
offset.data<int32_t>(),
|
gpu_ptr<int32_t>(offset),
|
||||||
scale_,
|
scale_,
|
||||||
std::log2(base_),
|
std::log2(base_),
|
||||||
strides,
|
strides,
|
||||||
|
|||||||
@@ -513,11 +513,11 @@ void sdpa_vector_1pass_fallback(
|
|||||||
grid_dim,
|
grid_dim,
|
||||||
block_dim,
|
block_dim,
|
||||||
0,
|
0,
|
||||||
q.data<DataType>(),
|
gpu_ptr<DataType>(q),
|
||||||
k.data<DataType>(),
|
gpu_ptr<DataType>(k),
|
||||||
v.data<DataType>(),
|
gpu_ptr<DataType>(v),
|
||||||
o.data<DataType>(),
|
gpu_ptr<DataType>(o),
|
||||||
sinks ? (*sinks).data<DataType>() : nullptr,
|
sinks ? gpu_ptr<DataType>(*sinks) : nullptr,
|
||||||
params);
|
params);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -602,13 +602,13 @@ void sdpa_vector_2pass_fallback(
|
|||||||
grid_dim,
|
grid_dim,
|
||||||
block_dim,
|
block_dim,
|
||||||
0,
|
0,
|
||||||
q.data<DataType>(),
|
gpu_ptr<DataType>(q),
|
||||||
k.data<DataType>(),
|
gpu_ptr<DataType>(k),
|
||||||
v.data<DataType>(),
|
gpu_ptr<DataType>(v),
|
||||||
sinks ? (*sinks).data<DataType>() : nullptr,
|
sinks ? gpu_ptr<DataType>(*sinks) : nullptr,
|
||||||
intermediate.data<float>(),
|
gpu_ptr<float>(intermediate),
|
||||||
sums.data<float>(),
|
gpu_ptr<float>(sums),
|
||||||
maxs.data<float>(),
|
gpu_ptr<float>(maxs),
|
||||||
params);
|
params);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -629,10 +629,10 @@ void sdpa_vector_2pass_fallback(
|
|||||||
grid_dim,
|
grid_dim,
|
||||||
block_dim,
|
block_dim,
|
||||||
0,
|
0,
|
||||||
intermediate.data<float>(),
|
gpu_ptr<float>(intermediate),
|
||||||
sums.data<float>(),
|
gpu_ptr<float>(sums),
|
||||||
maxs.data<float>(),
|
gpu_ptr<float>(maxs),
|
||||||
o.data<DataType>(),
|
gpu_ptr<DataType>(o),
|
||||||
params);
|
params);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -415,8 +415,8 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in.data_size() / axis_size,
|
in.data_size() / axis_size,
|
||||||
block_dim,
|
block_dim,
|
||||||
0,
|
0,
|
||||||
in.data<T>(),
|
gpu_ptr<T>(in),
|
||||||
out.data<U>(),
|
gpu_ptr<U>(out),
|
||||||
axis_size);
|
axis_size);
|
||||||
} else {
|
} else {
|
||||||
constexpr int BM = WARP_SIZE;
|
constexpr int BM = WARP_SIZE;
|
||||||
@@ -445,8 +445,8 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dim,
|
block_dim,
|
||||||
0,
|
0,
|
||||||
in.data<T>(),
|
gpu_ptr<T>(in),
|
||||||
out.data<U>(),
|
gpu_ptr<U>(out),
|
||||||
axis_size,
|
axis_size,
|
||||||
stride,
|
stride,
|
||||||
stride_blocks);
|
stride_blocks);
|
||||||
|
|||||||
@@ -152,8 +152,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
0,
|
0,
|
||||||
in.data<DataType>(),
|
gpu_ptr<DataType>(in),
|
||||||
out.data<DataType>(),
|
gpu_ptr<DataType>(out),
|
||||||
axis_size);
|
axis_size);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -91,10 +91,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
|
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
|
||||||
nullptr,
|
nullptr,
|
||||||
size,
|
size,
|
||||||
in.data<Type>(),
|
gpu_ptr<Type>(in),
|
||||||
discard.data<Type>(),
|
gpu_ptr<Type>(discard),
|
||||||
indices.data<uint32_t>(),
|
gpu_ptr<uint32_t>(indices),
|
||||||
out.data<uint32_t>(),
|
gpu_ptr<uint32_t>(out),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.data_size() / nsort,
|
in.data_size() / nsort,
|
||||||
offsets,
|
offsets,
|
||||||
@@ -115,16 +115,16 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
cu::thrust_policy(stream),
|
cu::thrust_policy(stream),
|
||||||
thrust::counting_iterator<uint32_t>(0),
|
thrust::counting_iterator<uint32_t>(0),
|
||||||
thrust::counting_iterator<uint32_t>(indices.data_size()),
|
thrust::counting_iterator<uint32_t>(indices.data_size()),
|
||||||
thrust::device_pointer_cast(indices.data<uint32_t>()),
|
thrust::device_pointer_cast(gpu_ptr<uint32_t>(indices)),
|
||||||
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
|
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
|
||||||
|
|
||||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
|
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
|
||||||
temp.data<void>(),
|
gpu_ptr<void>(temp),
|
||||||
size,
|
size,
|
||||||
in.data<Type>(),
|
gpu_ptr<Type>(in),
|
||||||
discard.data<Type>(),
|
gpu_ptr<Type>(discard),
|
||||||
indices.data<uint32_t>(),
|
gpu_ptr<uint32_t>(indices),
|
||||||
out.data<uint32_t>(),
|
gpu_ptr<uint32_t>(out),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.data_size() / nsort,
|
in.data_size() / nsort,
|
||||||
offsets,
|
offsets,
|
||||||
@@ -137,8 +137,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
|
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
|
||||||
nullptr,
|
nullptr,
|
||||||
size,
|
size,
|
||||||
in.data<Type>(),
|
gpu_ptr<Type>(in),
|
||||||
out.data<Type>(),
|
gpu_ptr<Type>(out),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.data_size() / nsort,
|
in.data_size() / nsort,
|
||||||
offsets,
|
offsets,
|
||||||
@@ -156,10 +156,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
// Start capturing after allocations
|
// Start capturing after allocations
|
||||||
auto capture = encoder.capture_context();
|
auto capture = encoder.capture_context();
|
||||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
|
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
|
||||||
temp.data<void>(),
|
gpu_ptr<void>(temp),
|
||||||
size,
|
size,
|
||||||
in.data<Type>(),
|
gpu_ptr<Type>(in),
|
||||||
out.data<Type>(),
|
gpu_ptr<Type>(out),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.data_size() / nsort,
|
in.data_size() / nsort,
|
||||||
offsets,
|
offsets,
|
||||||
|
|||||||
@@ -168,10 +168,10 @@ void ternary_op_gpu_inplace(
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
a.data<bool>(),
|
gpu_ptr<bool>(a),
|
||||||
b.data<DType>(),
|
gpu_ptr<DType>(b),
|
||||||
c.data<DType>(),
|
gpu_ptr<DType>(c),
|
||||||
out.data<DType>(),
|
gpu_ptr<DType>(out),
|
||||||
out.data_size());
|
out.data_size());
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
@@ -211,10 +211,10 @@ void ternary_op_gpu_inplace(
|
|||||||
{num_blocks_x, num_blocks_y},
|
{num_blocks_x, num_blocks_y},
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
a.data<bool>(),
|
gpu_ptr<bool>(a),
|
||||||
b.data<DType>(),
|
gpu_ptr<DType>(b),
|
||||||
c.data<DType>(),
|
gpu_ptr<DType>(c),
|
||||||
out.data<DType>(),
|
gpu_ptr<DType>(out),
|
||||||
rest,
|
rest,
|
||||||
const_param<dims_constant()>(shape),
|
const_param<dims_constant()>(shape),
|
||||||
const_param<dims_constant()>(a_strides),
|
const_param<dims_constant()>(a_strides),
|
||||||
@@ -231,10 +231,10 @@ void ternary_op_gpu_inplace(
|
|||||||
{num_blocks_x, num_blocks_y},
|
{num_blocks_x, num_blocks_y},
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
a.data<bool>(),
|
gpu_ptr<bool>(a),
|
||||||
b.data<DType>(),
|
gpu_ptr<DType>(b),
|
||||||
c.data<DType>(),
|
gpu_ptr<DType>(c),
|
||||||
out.data<DType>(),
|
gpu_ptr<DType>(out),
|
||||||
rest,
|
rest,
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(a_strides),
|
const_param(a_strides),
|
||||||
@@ -256,7 +256,10 @@ void ternary_op_gpu(
|
|||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
auto& c = inputs[2];
|
auto& c = inputs[2];
|
||||||
auto topt = get_ternary_op_type(a, b, c);
|
auto topt = get_ternary_op_type(a, b, c);
|
||||||
set_ternary_op_output_data(a, b, c, out, topt);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
set_ternary_op_output_data(a, b, c, out, topt, [&](auto n) {
|
||||||
|
return cu::malloc_async(n, encoder.stream());
|
||||||
|
});
|
||||||
ternary_op_gpu_inplace<Op>(inputs, out, s);
|
ternary_op_gpu_inplace<Op>(inputs, out, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -158,8 +158,8 @@ void unary_op_gpu_inplace(
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
in.data<InType>(),
|
gpu_ptr<InType>(in),
|
||||||
out.data<OutType>(),
|
gpu_ptr<OutType>(out),
|
||||||
out.data_size());
|
out.data_size());
|
||||||
} else {
|
} else {
|
||||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
@@ -182,8 +182,8 @@ void unary_op_gpu_inplace(
|
|||||||
{num_blocks_x, num_blocks_y},
|
{num_blocks_x, num_blocks_y},
|
||||||
block_dims,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
in.data<InType>(),
|
gpu_ptr<InType>(in),
|
||||||
out.data<OutType>(),
|
gpu_ptr<OutType>(out),
|
||||||
rest,
|
rest,
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(strides),
|
const_param(strides),
|
||||||
@@ -207,7 +207,10 @@ void unary_op_gpu(
|
|||||||
array& out,
|
array& out,
|
||||||
const char* op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
set_unary_output_data(inputs[0], out);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
set_unary_output_data(inputs[0], out, [&](auto n) {
|
||||||
|
return cu::malloc_async(n, encoder.stream());
|
||||||
|
});
|
||||||
unary_op_gpu_inplace<Op>(inputs, out, op, s);
|
unary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,8 @@
|
|||||||
#include <cublasLt.h>
|
#include <cublasLt.h>
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -15,6 +17,16 @@ class Device;
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T* gpu_ptr(array& arr) {
|
||||||
|
return cu::gpu_ptr<T>(arr.buffer());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
const T* gpu_ptr(const array& arr) {
|
||||||
|
return cu::gpu_ptr<T>(arr.buffer());
|
||||||
|
}
|
||||||
|
|
||||||
struct Dtype;
|
struct Dtype;
|
||||||
|
|
||||||
// Throw exception if the cuda API does not succeed.
|
// Throw exception if the cuda API does not succeed.
|
||||||
|
|||||||
@@ -7,18 +7,7 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
|
||||||
bool donated = set_copy_output_data(in, out, ctype);
|
|
||||||
if (donated && in.dtype() == out.dtype()) {
|
|
||||||
// If the output has the same type as the input then there is nothing to
|
|
||||||
// copy, just use the buffer.
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (ctype == CopyType::GeneralGeneral) {
|
|
||||||
ctype = CopyType::General;
|
|
||||||
}
|
|
||||||
copy_gpu_inplace(in, out, ctype, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||||
|
|||||||
@@ -10,6 +10,19 @@ namespace mlx::core {
|
|||||||
|
|
||||||
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
|
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
|
||||||
|
|
||||||
|
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||||
|
bool donated = set_copy_output_data(in, out, ctype);
|
||||||
|
if (donated && in.dtype() == out.dtype()) {
|
||||||
|
// If the output has the same type as the input then there is nothing to
|
||||||
|
// copy, just use the buffer.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (ctype == CopyType::GeneralGeneral) {
|
||||||
|
ctype = CopyType::General;
|
||||||
|
}
|
||||||
|
copy_gpu_inplace(in, out, ctype, s);
|
||||||
|
}
|
||||||
|
|
||||||
void copy_gpu_inplace(
|
void copy_gpu_inplace(
|
||||||
const array& in,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
|
|||||||
Reference in New Issue
Block a user