mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-12 07:14:34 +08:00
Compare commits
4 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
aede70e81d | ||
![]() |
85a8beb5e4 | ||
![]() |
0bb89e9e5f | ||
![]() |
5685ceb3c7 |
@@ -1,8 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/graph_utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/primitives.h"
|
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -79,55 +78,6 @@ std::string get_type_string(Dtype d) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string build_lib_name(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<array>& outputs,
|
|
||||||
const std::vector<array>& tape,
|
|
||||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
|
||||||
NodeNamer namer;
|
|
||||||
std::ostringstream os;
|
|
||||||
std::ostringstream constant_hasher;
|
|
||||||
|
|
||||||
// Fill the input names. This is not really necessary, I just like having A,
|
|
||||||
// B, C, ... as the inputs.
|
|
||||||
for (auto& x : inputs) {
|
|
||||||
namer.get_name(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The primitives describing the tape. For unary and binary primitives this
|
|
||||||
// must be enough to describe the full computation.
|
|
||||||
for (auto& a : tape) {
|
|
||||||
// name and type of output
|
|
||||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
|
||||||
// computation performed
|
|
||||||
a.primitive().print(os);
|
|
||||||
// name of inputs to the function
|
|
||||||
for (auto& inp : a.inputs()) {
|
|
||||||
os << namer.get_name(inp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
os << "_";
|
|
||||||
|
|
||||||
for (auto& x : inputs) {
|
|
||||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
|
||||||
os << "C";
|
|
||||||
print_constant(constant_hasher, x);
|
|
||||||
} else {
|
|
||||||
os << (is_scalar(x) ? "S" : "V");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
os << "_";
|
|
||||||
for (auto& x : inputs) {
|
|
||||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
os << kindof(x.dtype()) << x.itemsize();
|
|
||||||
}
|
|
||||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
|
||||||
|
|
||||||
return os.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool compiled_check_contiguity(
|
bool compiled_check_contiguity(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const Shape& shape) {
|
const Shape& shape) {
|
||||||
@@ -159,8 +109,7 @@ bool compiled_check_contiguity(
|
|||||||
void compiled_allocate_outputs(
|
void compiled_allocate_outputs(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::vector<array>& inputs_,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
|
||||||
bool contiguous) {
|
bool contiguous) {
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
int o = 0;
|
int o = 0;
|
||||||
@@ -175,8 +124,7 @@ void compiled_allocate_outputs(
|
|||||||
// - Donatable
|
// - Donatable
|
||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||||
in.is_donatable() &&
|
in.is_donatable() && is_constant(i)) {
|
||||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
|
||||||
outputs[o++].copy_shared_buffer(in);
|
outputs[o++].copy_shared_buffer(in);
|
||||||
}
|
}
|
||||||
// Get representative input flags to properly set non-donated outputs
|
// Get representative input flags to properly set non-donated outputs
|
||||||
@@ -204,7 +152,7 @@ void compiled_allocate_outputs(
|
|||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
is_constant(i)) {
|
||||||
outputs[o].copy_shared_buffer(
|
outputs[o].copy_shared_buffer(
|
||||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||||
o++;
|
o++;
|
||||||
@@ -216,4 +164,74 @@ void compiled_allocate_outputs(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const array& out,
|
||||||
|
const std::function<bool(size_t)>& is_constant) {
|
||||||
|
const Shape& shape = out.shape();
|
||||||
|
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||||
|
if (contiguous) {
|
||||||
|
return {true, shape, {}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Strides> strides_vec{out.strides()};
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
// Skip constants.
|
||||||
|
if (is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip scalar inputs.
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
if (is_scalar(x)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast the inputs to the output shape.
|
||||||
|
Strides xstrides;
|
||||||
|
size_t j = 0;
|
||||||
|
for (; j < shape.size() - x.ndim(); ++j) {
|
||||||
|
if (shape[j] == 1) {
|
||||||
|
xstrides.push_back(out.strides()[j]);
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
|
||||||
|
if (x.shape(i) == 1) {
|
||||||
|
if (shape[j] == 1) {
|
||||||
|
xstrides.push_back(out.strides()[j]);
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(x.strides()[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
strides_vec.push_back(std::move(xstrides));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
|
||||||
|
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool compiled_use_large_index(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
|
bool contiguous) {
|
||||||
|
if (contiguous) {
|
||||||
|
size_t max_size = 0;
|
||||||
|
for (const auto& in : inputs) {
|
||||||
|
max_size = std::max(max_size, in.data_size());
|
||||||
|
}
|
||||||
|
return max_size > UINT32_MAX;
|
||||||
|
} else {
|
||||||
|
size_t max_size = 0;
|
||||||
|
for (const auto& o : outputs) {
|
||||||
|
max_size = std::max(max_size, o.size());
|
||||||
|
}
|
||||||
|
return max_size > UINT32_MAX;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@@ -1,9 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <sstream>
|
|
||||||
#include <unordered_set>
|
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@@ -14,12 +13,6 @@ inline bool is_static_cast(const Primitive& p) {
|
|||||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string build_lib_name(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<array>& outputs,
|
|
||||||
const std::vector<array>& tape,
|
|
||||||
const std::unordered_set<uintptr_t>& constant_ids);
|
|
||||||
|
|
||||||
std::string get_type_string(Dtype d);
|
std::string get_type_string(Dtype d);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -60,8 +53,19 @@ bool compiled_check_contiguity(
|
|||||||
void compiled_allocate_outputs(
|
void compiled_allocate_outputs(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::vector<array>& inputs_,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
bool contiguous);
|
||||||
|
|
||||||
|
// Collapse contiguous dims ignoring scalars and constants.
|
||||||
|
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const array& out,
|
||||||
|
const std::function<bool(size_t)>& is_constant);
|
||||||
|
|
||||||
|
// Return whether the kernel should use large index.
|
||||||
|
bool compiled_use_large_index(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
bool contiguous);
|
bool contiguous);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@@ -146,18 +146,9 @@ inline void build_kernel(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<array>& outputs,
|
const std::vector<array>& outputs,
|
||||||
const std::vector<array>& tape,
|
const std::vector<array>& tape,
|
||||||
const std::unordered_set<uintptr_t>& constant_ids,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
bool contiguous,
|
bool contiguous,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
// All outputs should have the exact same shape and will be row contiguous
|
|
||||||
auto output_shape = outputs[0].shape();
|
|
||||||
auto output_strides = outputs[0].strides();
|
|
||||||
|
|
||||||
// Constants are scalars that are captured by value and cannot change
|
|
||||||
auto is_constant = [&constant_ids](const array& x) {
|
|
||||||
return constant_ids.find(x.id()) != constant_ids.end();
|
|
||||||
};
|
|
||||||
|
|
||||||
NodeNamer namer;
|
NodeNamer namer;
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
@@ -170,14 +161,15 @@ inline void build_kernel(
|
|||||||
|
|
||||||
// Add the input arguments
|
// Add the input arguments
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
for (auto& x : inputs) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
auto& xname = namer.get_name(x);
|
|
||||||
|
|
||||||
// Skip constants from the input list
|
// Skip constants from the input list
|
||||||
if (is_constant(x)) {
|
if (is_constant(i)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
auto tstr = get_type_string(x.dtype());
|
auto tstr = get_type_string(x.dtype());
|
||||||
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
||||||
<< "];" << std::endl;
|
<< "];" << std::endl;
|
||||||
@@ -211,10 +203,11 @@ inline void build_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Read the inputs in tmps
|
// Read the inputs in tmps
|
||||||
for (auto& x : inputs) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
if (is_constant(x)) {
|
if (is_constant(i)) {
|
||||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||||
print_constant(os, x);
|
print_constant(os, x);
|
||||||
os << ";" << std::endl;
|
os << ";" << std::endl;
|
||||||
@@ -264,8 +257,9 @@ inline void build_kernel(
|
|||||||
} else {
|
} else {
|
||||||
for (int d = ndim - 1; d >= 0; --d) {
|
for (int d = ndim - 1; d >= 0; --d) {
|
||||||
// Update pointers
|
// Update pointers
|
||||||
for (auto& x : inputs) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
if (is_constant(x) || is_scalar(x)) {
|
const auto& x = inputs[i];
|
||||||
|
if (is_constant(i) || is_scalar(x)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
@@ -287,65 +281,37 @@ inline void build_kernel(
|
|||||||
void Compiled::eval_cpu(
|
void Compiled::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
if (kernel_lib_.empty()) {
|
|
||||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Figure out which kernel we are using
|
|
||||||
auto& shape = outputs[0].shape();
|
|
||||||
auto contiguous = compiled_check_contiguity(inputs, shape);
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
|
||||||
// Handle all broadcasting and collect function input arguments
|
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||||
|
// handle all broadcasting.
|
||||||
|
auto [contiguous, shape, strides] =
|
||||||
|
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||||
|
|
||||||
|
// Collect function input arguments.
|
||||||
std::vector<void*> args;
|
std::vector<void*> args;
|
||||||
std::vector<std::vector<size_t>> strides;
|
int strides_index = 1;
|
||||||
for (int i = 0; i < inputs.size(); i++) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
// Skip constants.
|
if (is_constant_(i)) {
|
||||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto& x = inputs[i];
|
const auto& x = inputs[i];
|
||||||
encoder.set_input_array(x);
|
encoder.set_input_array(x);
|
||||||
args.push_back((void*)x.data<void>());
|
args.push_back((void*)x.data<void>());
|
||||||
|
if (!contiguous && !is_scalar(x)) {
|
||||||
if (contiguous || is_scalar(x)) {
|
args.push_back(strides[strides_index++].data());
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Broadcast the input to the output shape.
|
|
||||||
std::vector<size_t> xstrides;
|
|
||||||
int j = 0;
|
|
||||||
for (; j < shape.size() - x.ndim(); j++) {
|
|
||||||
if (shape[j] == 1) {
|
|
||||||
xstrides.push_back(outputs[0].strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
|
||||||
if (x.shape(i) == 1) {
|
|
||||||
if (shape[j] == 1) {
|
|
||||||
xstrides.push_back(outputs[0].strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(x.strides()[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
strides.push_back(std::move(xstrides));
|
|
||||||
args.push_back(strides.back().data());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the kernel name from the lib
|
// Get the kernel name from the lib
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
kernel_name += std::to_string(shape.size());
|
kernel_name += std::to_string(ndim);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the function
|
// Get the function
|
||||||
auto fn_ptr = compile(kernel_name, [&]() {
|
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
|
||||||
std::ostringstream kernel;
|
std::ostringstream kernel;
|
||||||
kernel << get_kernel_preamble() << std::endl;
|
kernel << get_kernel_preamble() << std::endl;
|
||||||
kernel << "extern \"C\" {" << std::endl;
|
kernel << "extern \"C\" {" << std::endl;
|
||||||
@@ -355,7 +321,7 @@ void Compiled::eval_cpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
constant_ids_,
|
is_constant_,
|
||||||
contiguous,
|
contiguous,
|
||||||
ndim);
|
ndim);
|
||||||
// Close extern "C"
|
// Close extern "C"
|
||||||
@@ -363,26 +329,22 @@ void Compiled::eval_cpu(
|
|||||||
return kernel.str();
|
return kernel.str();
|
||||||
});
|
});
|
||||||
|
|
||||||
compiled_allocate_outputs(
|
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
|
||||||
|
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
args.push_back(x.data<void>());
|
args.push_back(x.data<void>());
|
||||||
encoder.set_output_array(x);
|
encoder.set_output_array(x);
|
||||||
}
|
}
|
||||||
Shape out_shape;
|
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
out_shape = outputs[0].shape();
|
args.push_back((void*)shape.data());
|
||||||
args.push_back((void*)out_shape.data());
|
|
||||||
} else {
|
} else {
|
||||||
args.push_back((void*)outputs[0].data_size());
|
args.push_back((void*)outputs[0].data_size());
|
||||||
}
|
}
|
||||||
auto fun = (void (*)(void**))fn_ptr;
|
auto fun = (void (*)(void**))fn_ptr;
|
||||||
encoder.dispatch(
|
encoder.dispatch([fun,
|
||||||
[fun,
|
args = std::move(args),
|
||||||
args = std::move(args),
|
strides = std::move(strides),
|
||||||
strides = std::move(strides),
|
shape = std::move(shape)]() mutable { fun(args.data()); });
|
||||||
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@@ -10,7 +10,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
|
@@ -18,7 +18,10 @@ CudaAllocator::CudaAllocator()
|
|||||||
: buffer_cache_(
|
: buffer_cache_(
|
||||||
getpagesize(),
|
getpagesize(),
|
||||||
[](CudaBuffer* buf) { return buf->size; },
|
[](CudaBuffer* buf) { return buf->size; },
|
||||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
[this](CudaBuffer* buf) {
|
||||||
|
cuda_free(buf->data);
|
||||||
|
delete buf;
|
||||||
|
}) {
|
||||||
// TODO: Set memory limit for multi-device.
|
// TODO: Set memory limit for multi-device.
|
||||||
size_t free, total;
|
size_t free, total;
|
||||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||||
@@ -70,7 +73,8 @@ void CudaAllocator::free(Buffer buffer) {
|
|||||||
buffer_cache_.recycle_to_cache(buf);
|
buffer_cache_.recycle_to_cache(buf);
|
||||||
} else {
|
} else {
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
cuda_free(buf);
|
cuda_free(buf->data);
|
||||||
|
delete buf;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,6 +91,25 @@ void CudaAllocator::register_this_thread() {
|
|||||||
allowed_threads_.insert(std::this_thread::get_id());
|
allowed_threads_.insert(std::this_thread::get_id());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CudaAllocator::cuda_free(void* buf) {
|
||||||
|
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
||||||
|
// worker.
|
||||||
|
{
|
||||||
|
std::lock_guard lock(worker_mutex_);
|
||||||
|
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||||
|
if (!worker_) {
|
||||||
|
worker_.reset(new Worker);
|
||||||
|
}
|
||||||
|
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
||||||
|
worker_->end_batch();
|
||||||
|
worker_->commit();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaFree(buf);
|
||||||
|
}
|
||||||
|
|
||||||
size_t CudaAllocator::get_active_memory() const {
|
size_t CudaAllocator::get_active_memory() const {
|
||||||
return active_memory_;
|
return active_memory_;
|
||||||
}
|
}
|
||||||
@@ -125,26 +148,6 @@ void CudaAllocator::clear_cache() {
|
|||||||
buffer_cache_.clear();
|
buffer_cache_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
|
||||||
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
|
||||||
// worker.
|
|
||||||
{
|
|
||||||
std::lock_guard lock(worker_mutex_);
|
|
||||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
|
||||||
if (!worker_) {
|
|
||||||
worker_.reset(new Worker);
|
|
||||||
}
|
|
||||||
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
|
||||||
worker_->end_batch();
|
|
||||||
worker_->commit();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cudaFree(buf->data);
|
|
||||||
delete buf;
|
|
||||||
}
|
|
||||||
|
|
||||||
CudaAllocator& allocator() {
|
CudaAllocator& allocator() {
|
||||||
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
||||||
// will not be called on exit and buffers in the cache will be leaked. This
|
// will not be called on exit and buffers in the cache will be leaked. This
|
||||||
|
@@ -34,6 +34,9 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
// buffers there would result in dead lock.
|
// buffers there would result in dead lock.
|
||||||
void register_this_thread();
|
void register_this_thread();
|
||||||
|
|
||||||
|
// Call cudaFree in the safe thread.
|
||||||
|
void cuda_free(void* buf);
|
||||||
|
|
||||||
size_t get_active_memory() const;
|
size_t get_active_memory() const;
|
||||||
size_t get_peak_memory() const;
|
size_t get_peak_memory() const;
|
||||||
void reset_peak_memory();
|
void reset_peak_memory();
|
||||||
@@ -47,8 +50,6 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
CudaAllocator();
|
CudaAllocator();
|
||||||
friend CudaAllocator& allocator();
|
friend CudaAllocator& allocator();
|
||||||
|
|
||||||
void cuda_free(CudaBuffer* buf);
|
|
||||||
|
|
||||||
std::mutex worker_mutex_;
|
std::mutex worker_mutex_;
|
||||||
std::unique_ptr<Worker> worker_;
|
std::unique_ptr<Worker> worker_;
|
||||||
std::set<std::thread::id> allowed_threads_;
|
std::set<std::thread::id> allowed_threads_;
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/event.h"
|
#include "mlx/backend/cuda/event.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
@@ -111,12 +112,12 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
|||||||
|
|
||||||
SharedEvent::SharedEvent() {
|
SharedEvent::SharedEvent() {
|
||||||
// Allocate cuda::atomic on managed memory.
|
// Allocate cuda::atomic on managed memory.
|
||||||
allocator::Buffer buffer = allocator::malloc(sizeof(Atomic));
|
Atomic* ac;
|
||||||
Atomic* ac = static_cast<Atomic*>(buffer.raw_ptr());
|
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
|
||||||
new (ac) Atomic(0);
|
new (ac) Atomic(0);
|
||||||
ac_ = std::shared_ptr<Atomic>(ac, [buffer](Atomic* ptr) {
|
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
|
||||||
ptr->~Atomic();
|
ptr->~Atomic();
|
||||||
allocator::free(buffer);
|
allocator().cuda_free(ptr);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,7 +156,10 @@ void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
|||||||
void SharedEvent::signal(Stream s, uint64_t value) {
|
void SharedEvent::signal(Stream s, uint64_t value) {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
|
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
|
||||||
if (s.device == mlx::core::Device::cpu) {
|
if (s.device == mlx::core::Device::cpu) {
|
||||||
scheduler::enqueue(s, [*this, value]() mutable { signal(value); });
|
// Signal through a GPU stream so the atomic is updated in GPU - updating
|
||||||
|
// the atomic in CPU sometimes does not get GPU notified.
|
||||||
|
static CudaStream stream(device(mlx::core::Device::gpu));
|
||||||
|
scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); });
|
||||||
} else {
|
} else {
|
||||||
auto& encoder = get_command_encoder(s);
|
auto& encoder = get_command_encoder(s);
|
||||||
encoder.launch_kernel(
|
encoder.launch_kernel(
|
||||||
|
29
mlx/backend/cuda/fence.cpp
Normal file
29
mlx/backend/cuda/fence.cpp
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/fence.h"
|
||||||
|
#include "mlx/backend/cuda/event.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
struct FenceImpl {
|
||||||
|
uint32_t count;
|
||||||
|
cu::SharedEvent event;
|
||||||
|
};
|
||||||
|
|
||||||
|
Fence::Fence(Stream s) {
|
||||||
|
fence_ = std::shared_ptr<void>(
|
||||||
|
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
|
||||||
|
}
|
||||||
|
|
||||||
|
void Fence::wait(Stream s, const array&) {
|
||||||
|
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||||
|
fence->event.wait(fence->count);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Fence::update(Stream s, const array&) {
|
||||||
|
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||||
|
fence->count++;
|
||||||
|
fence->event.signal(s, fence->count);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@@ -1,70 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/event.h"
|
|
||||||
#include "mlx/fence.h"
|
|
||||||
#include "mlx/scheduler.h"
|
|
||||||
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
__host__ __device__ void busy_wait(cuda::atomic<uint64_t>* ac, uint64_t value) {
|
|
||||||
while (true) {
|
|
||||||
// In theory the atomic_thread_fence is not needed, but for CUDA 11 without
|
|
||||||
// it the load() may never return new value.
|
|
||||||
cuda::atomic_thread_fence(cuda::memory_order_seq_cst);
|
|
||||||
uint64_t current = ac->load();
|
|
||||||
if (current >= value) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void busy_wait_kernel(cuda::atomic<uint64_t>* ac, uint64_t value) {
|
|
||||||
busy_wait(ac, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
struct FenceImpl {
|
|
||||||
uint32_t count;
|
|
||||||
cu::SharedEvent event;
|
|
||||||
};
|
|
||||||
|
|
||||||
Fence::Fence(Stream s) {
|
|
||||||
fence_ = std::shared_ptr<void>(
|
|
||||||
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
|
|
||||||
}
|
|
||||||
|
|
||||||
void Fence::wait(Stream s, const array&) {
|
|
||||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
|
||||||
// We can't use SharedEvent::wait because it could hang in CUDA 11, see also:
|
|
||||||
// https://github.com/ml-explore/mlx/issues/2137
|
|
||||||
const auto& ac = fence->event.atomic();
|
|
||||||
if (s.device == mlx::core::Device::cpu) {
|
|
||||||
scheduler::enqueue(s, [ac, count = fence->count]() {
|
|
||||||
nvtx3::scoped_range r("Fence::wait()");
|
|
||||||
busy_wait(ac.get(), count);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
nvtx3::scoped_range r("Fence::wait(s)");
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.launch_kernel(
|
|
||||||
encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) {
|
|
||||||
busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count);
|
|
||||||
});
|
|
||||||
encoder.add_completed_handler([ac]() {});
|
|
||||||
encoder.end_encoding();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Fence::update(Stream s, const array&) {
|
|
||||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
|
||||||
fence->count++;
|
|
||||||
fence->event.signal(s, fence->count);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
@@ -11,8 +11,6 @@
|
|||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
using namespace fmt::literals;
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
inline void build_kernel(
|
inline void build_kernel(
|
||||||
@@ -21,21 +19,12 @@ inline void build_kernel(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<array>& outputs,
|
const std::vector<array>& outputs,
|
||||||
const std::vector<array>& tape,
|
const std::vector<array>& tape,
|
||||||
const std::unordered_set<uintptr_t>& constant_ids,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
bool contiguous,
|
bool contiguous,
|
||||||
int ndim,
|
int ndim,
|
||||||
bool dynamic_dims,
|
bool dynamic_dims,
|
||||||
bool use_big_index = false,
|
bool use_big_index = false,
|
||||||
int work_per_thread = 1) {
|
int work_per_thread = 1) {
|
||||||
// All outputs should have the exact same shape and will be row contiguous
|
|
||||||
auto output_shape = outputs[0].shape();
|
|
||||||
auto output_strides = outputs[0].strides();
|
|
||||||
|
|
||||||
// Constants are scalars that are captured by value and cannot change
|
|
||||||
auto is_constant = [&constant_ids](const array& x) {
|
|
||||||
return constant_ids.find(x.id()) != constant_ids.end();
|
|
||||||
};
|
|
||||||
|
|
||||||
NodeNamer namer;
|
NodeNamer namer;
|
||||||
bool add_indices = false;
|
bool add_indices = false;
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
@@ -45,14 +34,15 @@ inline void build_kernel(
|
|||||||
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
|
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
|
||||||
|
|
||||||
// Add the input arguments
|
// Add the input arguments
|
||||||
for (auto& x : inputs) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
auto& xname = namer.get_name(x);
|
|
||||||
|
|
||||||
// Skip constants from the input list
|
// Skip constants from the input list
|
||||||
if (is_constant(x)) {
|
if (is_constant(i)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
// Scalars and contiguous need no strides
|
// Scalars and contiguous need no strides
|
||||||
if (!is_scalar(x) && !contiguous) {
|
if (!is_scalar(x) && !contiguous) {
|
||||||
add_indices = true;
|
add_indices = true;
|
||||||
@@ -80,8 +70,6 @@ inline void build_kernel(
|
|||||||
}
|
}
|
||||||
// Add output strides and shape to extract the indices.
|
// Add output strides and shape to extract the indices.
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
os += fmt::format(
|
|
||||||
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
|
|
||||||
os += fmt::format(
|
os += fmt::format(
|
||||||
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
||||||
} else {
|
} else {
|
||||||
@@ -125,7 +113,7 @@ inline void build_kernel(
|
|||||||
auto& x = inputs[i];
|
auto& x = inputs[i];
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
if (is_constant(x)) {
|
if (is_constant(i)) {
|
||||||
auto type_str = get_type_string(x.dtype());
|
auto type_str = get_type_string(x.dtype());
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
print_constant(ss, x);
|
print_constant(ss, x);
|
||||||
@@ -271,11 +259,6 @@ inline void build_kernel(
|
|||||||
void Compiled::eval_gpu(
|
void Compiled::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
// Make the name for the kernel library
|
|
||||||
if (kernel_lib_.empty()) {
|
|
||||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the kernel if someone else built it already
|
// Get the kernel if someone else built it already
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
@@ -290,7 +273,7 @@ void Compiled::eval_gpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
constant_ids_,
|
is_constant_,
|
||||||
/* contiguous = */ true,
|
/* contiguous = */ true,
|
||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ false,
|
/* dynamic_dims = */ false,
|
||||||
@@ -302,7 +285,7 @@ void Compiled::eval_gpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
constant_ids_,
|
is_constant_,
|
||||||
/* contiguous = */ true,
|
/* contiguous = */ true,
|
||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ false,
|
/* dynamic_dims = */ false,
|
||||||
@@ -315,7 +298,7 @@ void Compiled::eval_gpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
constant_ids_,
|
is_constant_,
|
||||||
/* contiguous = */ false,
|
/* contiguous = */ false,
|
||||||
/* ndim = */ i,
|
/* ndim = */ i,
|
||||||
/* dynamic_dims = */ false,
|
/* dynamic_dims = */ false,
|
||||||
@@ -328,7 +311,7 @@ void Compiled::eval_gpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
constant_ids_,
|
is_constant_,
|
||||||
/* contiguous = */ false,
|
/* contiguous = */ false,
|
||||||
/* ndim = */ i,
|
/* ndim = */ i,
|
||||||
/* dynamic_dims = */ false,
|
/* dynamic_dims = */ false,
|
||||||
@@ -342,7 +325,7 @@ void Compiled::eval_gpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
constant_ids_,
|
is_constant_,
|
||||||
/* contiguous = */ false,
|
/* contiguous = */ false,
|
||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ true,
|
/* dynamic_dims = */ true,
|
||||||
@@ -354,7 +337,7 @@ void Compiled::eval_gpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
constant_ids_,
|
is_constant_,
|
||||||
/* contiguous = */ false,
|
/* contiguous = */ false,
|
||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ true,
|
/* dynamic_dims = */ true,
|
||||||
@@ -363,70 +346,13 @@ void Compiled::eval_gpu(
|
|||||||
return kernel;
|
return kernel;
|
||||||
});
|
});
|
||||||
|
|
||||||
// Figure out which kernel we are using
|
|
||||||
auto& output_shape = outputs[0].shape();
|
|
||||||
auto contiguous = compiled_check_contiguity(inputs, output_shape);
|
|
||||||
|
|
||||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||||
// handle all broadcasting.
|
// handle all broadcasting.
|
||||||
std::vector<Strides> initial_strides;
|
auto [contiguous, shape, strides] =
|
||||||
initial_strides.push_back(outputs[0].strides());
|
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||||
Shape shape;
|
|
||||||
std::vector<Strides> strides;
|
|
||||||
if (!contiguous) {
|
|
||||||
for (int i = 0; i < inputs.size(); i++) {
|
|
||||||
// Skip constants.
|
|
||||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto& x = inputs[i];
|
|
||||||
|
|
||||||
// Skip scalar inputs.
|
// Whether to use large index.
|
||||||
if (is_scalar(x)) {
|
bool large = compiled_use_large_index(inputs, outputs, contiguous);
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Broadcast the inputs to the output shape.
|
|
||||||
Strides xstrides;
|
|
||||||
int j = 0;
|
|
||||||
for (; j < output_shape.size() - x.ndim(); j++) {
|
|
||||||
if (output_shape[j] == 1) {
|
|
||||||
xstrides.push_back(outputs[0].strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
|
||||||
if (x.shape(i) == 1) {
|
|
||||||
if (output_shape[j] == 1) {
|
|
||||||
xstrides.push_back(outputs[0].strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(x.strides()[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
initial_strides.push_back(std::move(xstrides));
|
|
||||||
}
|
|
||||||
std::tie(shape, strides) =
|
|
||||||
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool large;
|
|
||||||
if (contiguous) {
|
|
||||||
size_t max_size = 0;
|
|
||||||
for (auto& in : inputs) {
|
|
||||||
max_size = std::max(max_size, in.data_size());
|
|
||||||
}
|
|
||||||
large = (max_size > UINT32_MAX);
|
|
||||||
} else {
|
|
||||||
size_t max_size = 0;
|
|
||||||
for (auto& o : outputs) {
|
|
||||||
max_size = std::max(max_size, o.size());
|
|
||||||
}
|
|
||||||
large = (max_size > UINT32_MAX);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the kernel from the lib
|
// Get the kernel from the lib
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
@@ -451,7 +377,7 @@ void Compiled::eval_gpu(
|
|||||||
int stride_idx = 1; // idx 0 is the output strides
|
int stride_idx = 1; // idx 0 is the output strides
|
||||||
Strides in_strides;
|
Strides in_strides;
|
||||||
for (int i = 0; i < inputs.size(); i++) {
|
for (int i = 0; i < inputs.size(); i++) {
|
||||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
if (is_constant_(i)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto& x = inputs[i];
|
auto& x = inputs[i];
|
||||||
@@ -468,8 +394,7 @@ void Compiled::eval_gpu(
|
|||||||
compute_encoder.set_vector_bytes(in_strides, cnt++);
|
compute_encoder.set_vector_bytes(in_strides, cnt++);
|
||||||
}
|
}
|
||||||
|
|
||||||
compiled_allocate_outputs(
|
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
|
||||||
|
|
||||||
// Put the outputs in
|
// Put the outputs in
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
@@ -478,7 +403,6 @@ void Compiled::eval_gpu(
|
|||||||
|
|
||||||
// Put the output shape and strides in
|
// Put the output shape and strides in
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
compute_encoder.set_vector_bytes(strides[0], cnt++);
|
|
||||||
compute_encoder.set_vector_bytes(shape, cnt++);
|
compute_encoder.set_vector_bytes(shape, cnt++);
|
||||||
} else {
|
} else {
|
||||||
auto size = outputs[0].data_size();
|
auto size = outputs[0].data_size();
|
||||||
|
@@ -1,16 +1,20 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <sstream>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/compile.h"
|
#include "mlx/compile.h"
|
||||||
#include "mlx/compile_impl.h"
|
#include "mlx/compile_impl.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
|
#include "mlx/graph_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
#include "mlx/transforms_impl.h"
|
#include "mlx/transforms_impl.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -82,7 +86,54 @@ Compiled::Compiled(
|
|||||||
inputs_(std::move(inputs)),
|
inputs_(std::move(inputs)),
|
||||||
outputs_(std::move(outputs)),
|
outputs_(std::move(outputs)),
|
||||||
tape_(std::move(tape)),
|
tape_(std::move(tape)),
|
||||||
constant_ids_(std::move(constant_ids)) {}
|
constant_ids_(std::move(constant_ids)),
|
||||||
|
is_constant_([this](size_t i) {
|
||||||
|
return constant_ids_.find(inputs_[i].id()) != constant_ids_.end();
|
||||||
|
}) {
|
||||||
|
// Build the kernel name.
|
||||||
|
NodeNamer namer;
|
||||||
|
std::ostringstream os;
|
||||||
|
std::ostringstream constant_hasher;
|
||||||
|
|
||||||
|
// Fill the input names. This is not really necessary, I just like having A,
|
||||||
|
// B, C, ... as the inputs.
|
||||||
|
for (const auto& x : inputs_) {
|
||||||
|
namer.get_name(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The primitives describing the tape. For unary and binary primitives this
|
||||||
|
// must be enough to describe the full computation.
|
||||||
|
for (const auto& a : tape_) {
|
||||||
|
// name and type of output
|
||||||
|
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||||
|
// computation performed
|
||||||
|
a.primitive().print(os);
|
||||||
|
// name of inputs to the function
|
||||||
|
for (auto& inp : a.inputs()) {
|
||||||
|
os << namer.get_name(inp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os << "_";
|
||||||
|
|
||||||
|
for (const auto& x : inputs_) {
|
||||||
|
if (constant_ids_.find(x.id()) != constant_ids_.end()) {
|
||||||
|
os << "C";
|
||||||
|
print_constant(constant_hasher, x);
|
||||||
|
} else {
|
||||||
|
os << (is_scalar(x) ? "S" : "V");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os << "_";
|
||||||
|
for (const auto& x : inputs) {
|
||||||
|
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os << kindof(x.dtype()) << x.itemsize();
|
||||||
|
}
|
||||||
|
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
||||||
|
|
||||||
|
kernel_lib_ = os.str();
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> Compiled::vjp(
|
std::vector<array> Compiled::vjp(
|
||||||
const std::vector<array>&,
|
const std::vector<array>&,
|
||||||
|
@@ -627,6 +627,7 @@ class Compiled : public Primitive {
|
|||||||
const std::vector<array> outputs_;
|
const std::vector<array> outputs_;
|
||||||
const std::vector<array> tape_;
|
const std::vector<array> tape_;
|
||||||
const std::unordered_set<uintptr_t> constant_ids_;
|
const std::unordered_set<uintptr_t> constant_ids_;
|
||||||
|
const std::function<bool(size_t)> is_constant_;
|
||||||
|
|
||||||
std::string kernel_lib_;
|
std::string kernel_lib_;
|
||||||
};
|
};
|
||||||
|
@@ -208,7 +208,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
// output arrays stream
|
// output arrays stream
|
||||||
fences[it->second].wait(stream, in);
|
fences[it->second].wait(stream, in);
|
||||||
} else if (in.event().valid()) {
|
} else if (in.event().valid()) {
|
||||||
if (in.event().stream() != stream) {
|
if (in.event().is_signaled()) {
|
||||||
|
in.detach_event();
|
||||||
|
} else if (in.event().stream() != stream) {
|
||||||
// Use event to wait across async eval
|
// Use event to wait across async eval
|
||||||
in.event().wait(stream);
|
in.event().wait(stream);
|
||||||
}
|
}
|
||||||
|
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#define MLX_VERSION_MAJOR 0
|
#define MLX_VERSION_MAJOR 0
|
||||||
#define MLX_VERSION_MINOR 26
|
#define MLX_VERSION_MINOR 26
|
||||||
#define MLX_VERSION_PATCH 0
|
#define MLX_VERSION_PATCH 1
|
||||||
#define MLX_VERSION_NUMERIC \
|
#define MLX_VERSION_NUMERIC \
|
||||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user