Compare commits

..

8 Commits

Author SHA1 Message Date
Awni Hannun
c6a20b427a Improve metal elementwise kernels (#2247)
* improve metal elementwise kernels

* compile and copy

* fix jit
2025-06-06 11:37:40 -07:00
Awni Hannun
a5ac9244c4 fix linux linking error (#2248) 2025-06-06 10:41:51 -07:00
Awni Hannun
c763fe1be0 default strict mode for module update and update_modules (#2239) 2025-06-05 15:27:02 -07:00
Cheng
52dc8c8cd5 Add profiler annotations in common primitives for CUDA backend (#2244) 2025-06-04 19:55:12 -07:00
Angelos Katharopoulos
aede70e81d Perf regression fix (#2243) 2025-06-03 17:55:12 -07:00
Cheng
85a8beb5e4 Avoid atomic updates across CPU/GPU in CUDA event (#2231) 2025-06-03 16:49:06 -07:00
Cheng
0bb89e9e5f Share more common code in Compiled (#2240)
* Share more common code in Compiled

* Remove build_lib_name
2025-06-03 16:48:50 -07:00
Cheng
5685ceb3c7 Avoid invoking allocator::malloc when creating CUDA event (#2232) 2025-06-03 16:48:40 -07:00
34 changed files with 769 additions and 522 deletions

View File

@@ -1,8 +1,7 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/graph_utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
@@ -79,55 +78,6 @@ std::string get_type_string(Dtype d) {
} }
} }
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids) {
NodeNamer namer;
std::ostringstream os;
std::ostringstream constant_hasher;
// Fill the input names. This is not really necessary, I just like having A,
// B, C, ... as the inputs.
for (auto& x : inputs) {
namer.get_name(x);
}
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (auto& a : tape) {
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed
a.primitive().print(os);
// name of inputs to the function
for (auto& inp : a.inputs()) {
os << namer.get_name(inp);
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << (is_scalar(x) ? "S" : "V");
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
return os.str();
}
bool compiled_check_contiguity( bool compiled_check_contiguity(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const Shape& shape) { const Shape& shape) {
@@ -159,8 +109,7 @@ bool compiled_check_contiguity(
void compiled_allocate_outputs( void compiled_allocate_outputs(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::vector<array>& inputs_, const std::function<bool(size_t)>& is_constant,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous) { bool contiguous) {
if (contiguous) { if (contiguous) {
int o = 0; int o = 0;
@@ -175,8 +124,7 @@ void compiled_allocate_outputs(
// - Donatable // - Donatable
// - Not a constant // - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() && in.is_donatable() && is_constant(i)) {
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in); outputs[o++].copy_shared_buffer(in);
} }
// Get representative input flags to properly set non-donated outputs // Get representative input flags to properly set non-donated outputs
@@ -204,7 +152,7 @@ void compiled_allocate_outputs(
// - Not a constant // - Not a constant
if (in.flags().row_contiguous && in.size() == outputs[o].size() && if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { is_constant(i)) {
outputs[o].copy_shared_buffer( outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size()); in, outputs[o].strides(), in.flags(), in.data_size());
o++; o++;
@@ -216,4 +164,74 @@ void compiled_allocate_outputs(
} }
} }
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
const std::vector<array>& inputs,
const array& out,
const std::function<bool(size_t)>& is_constant) {
const Shape& shape = out.shape();
bool contiguous = compiled_check_contiguity(inputs, shape);
if (contiguous) {
return {true, shape, {}};
}
std::vector<Strides> strides_vec{out.strides()};
for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants.
if (is_constant(i)) {
continue;
}
// Skip scalar inputs.
const auto& x = inputs[i];
if (is_scalar(x)) {
continue;
}
// Broadcast the inputs to the output shape.
Strides xstrides;
size_t j = 0;
for (; j < shape.size() - x.ndim(); ++j) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides_vec.push_back(std::move(xstrides));
}
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
}
bool compiled_use_large_index(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
bool contiguous) {
if (contiguous) {
size_t max_size = 0;
for (const auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
return max_size > UINT32_MAX;
} else {
size_t max_size = 0;
for (const auto& o : outputs) {
max_size = std::max(max_size, o.size());
}
return max_size > UINT32_MAX;
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,9 +1,8 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
#include <functional>
#include <iomanip> #include <iomanip>
#include <sstream>
#include <unordered_set>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -14,12 +13,6 @@ inline bool is_static_cast(const Primitive& p) {
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType)); return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
} }
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids);
std::string get_type_string(Dtype d); std::string get_type_string(Dtype d);
template <typename T> template <typename T>
@@ -60,8 +53,19 @@ bool compiled_check_contiguity(
void compiled_allocate_outputs( void compiled_allocate_outputs(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::vector<array>& inputs_, const std::function<bool(size_t)>& is_constant,
const std::unordered_set<uintptr_t>& constant_ids_, bool contiguous);
// Collapse contiguous dims ignoring scalars and constants.
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
const std::vector<array>& inputs,
const array& out,
const std::function<bool(size_t)>& is_constant);
// Return whether the kernel should use large index.
bool compiled_use_large_index(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
bool contiguous); bool contiguous);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -146,18 +146,9 @@ inline void build_kernel(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<array>& outputs, const std::vector<array>& outputs,
const std::vector<array>& tape, const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids, const std::function<bool(size_t)>& is_constant,
bool contiguous, bool contiguous,
int ndim) { int ndim) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
NodeNamer namer; NodeNamer namer;
#ifdef _MSC_VER #ifdef _MSC_VER
@@ -170,14 +161,15 @@ inline void build_kernel(
// Add the input arguments // Add the input arguments
int cnt = 0; int cnt = 0;
for (auto& x : inputs) { for (size_t i = 0; i < inputs.size(); ++i) {
auto& xname = namer.get_name(x);
// Skip constants from the input list // Skip constants from the input list
if (is_constant(x)) { if (is_constant(i)) {
continue; continue;
} }
const auto& x = inputs[i];
auto& xname = namer.get_name(x);
auto tstr = get_type_string(x.dtype()); auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++ os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
<< "];" << std::endl; << "];" << std::endl;
@@ -211,10 +203,11 @@ inline void build_kernel(
} }
// Read the inputs in tmps // Read the inputs in tmps
for (auto& x : inputs) { for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
if (is_constant(x)) { if (is_constant(i)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "; os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
print_constant(os, x); print_constant(os, x);
os << ";" << std::endl; os << ";" << std::endl;
@@ -264,8 +257,9 @@ inline void build_kernel(
} else { } else {
for (int d = ndim - 1; d >= 0; --d) { for (int d = ndim - 1; d >= 0; --d) {
// Update pointers // Update pointers
for (auto& x : inputs) { for (size_t i = 0; i < inputs.size(); ++i) {
if (is_constant(x) || is_scalar(x)) { const auto& x = inputs[i];
if (is_constant(i) || is_scalar(x)) {
continue; continue;
} }
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
@@ -287,65 +281,37 @@ inline void build_kernel(
void Compiled::eval_cpu( void Compiled::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, shape);
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
// Handle all broadcasting and collect function input arguments // Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
auto [contiguous, shape, strides] =
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
// Collect function input arguments.
std::vector<void*> args; std::vector<void*> args;
std::vector<std::vector<size_t>> strides; int strides_index = 1;
for (int i = 0; i < inputs.size(); i++) { for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants. if (is_constant_(i)) {
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue; continue;
} }
auto& x = inputs[i]; const auto& x = inputs[i];
encoder.set_input_array(x); encoder.set_input_array(x);
args.push_back((void*)x.data<void>()); args.push_back((void*)x.data<void>());
if (!contiguous && !is_scalar(x)) {
if (contiguous || is_scalar(x)) { args.push_back(strides[strides_index++].data());
continue;
} }
// Broadcast the input to the output shape.
std::vector<size_t> xstrides;
int j = 0;
for (; j < shape.size() - x.ndim(); j++) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides.push_back(std::move(xstrides));
args.push_back(strides.back().data());
} }
// Get the kernel name from the lib // Get the kernel name from the lib
int ndim = shape.size(); int ndim = shape.size();
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
if (!contiguous) { if (!contiguous) {
kernel_name += std::to_string(shape.size()); kernel_name += std::to_string(ndim);
} }
// Get the function // Get the function
auto fn_ptr = compile(kernel_name, [&]() { auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
std::ostringstream kernel; std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl; kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl; kernel << "extern \"C\" {" << std::endl;
@@ -355,7 +321,7 @@ void Compiled::eval_cpu(
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
constant_ids_, is_constant_,
contiguous, contiguous,
ndim); ndim);
// Close extern "C" // Close extern "C"
@@ -363,26 +329,22 @@ void Compiled::eval_cpu(
return kernel.str(); return kernel.str();
}); });
compiled_allocate_outputs( compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
inputs, outputs, inputs_, constant_ids_, contiguous);
for (auto& x : outputs) { for (auto& x : outputs) {
args.push_back(x.data<void>()); args.push_back(x.data<void>());
encoder.set_output_array(x); encoder.set_output_array(x);
} }
Shape out_shape;
if (!contiguous) { if (!contiguous) {
out_shape = outputs[0].shape(); args.push_back((void*)shape.data());
args.push_back((void*)out_shape.data());
} else { } else {
args.push_back((void*)outputs[0].data_size()); args.push_back((void*)outputs[0].data_size());
} }
auto fun = (void (*)(void**))fn_ptr; auto fun = (void (*)(void**))fn_ptr;
encoder.dispatch( encoder.dispatch([fun,
[fun, args = std::move(args),
args = std::move(args), strides = std::move(strides),
strides = std::move(strides), shape = std::move(shape)]() mutable { fun(args.data()); });
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -10,13 +10,15 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
# Enable defining device lambda functions. # Enable defining device lambda functions.
target_compile_options(mlx target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>") PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")

View File

@@ -18,7 +18,10 @@ CudaAllocator::CudaAllocator()
: buffer_cache_( : buffer_cache_(
getpagesize(), getpagesize(),
[](CudaBuffer* buf) { return buf->size; }, [](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) { [this](CudaBuffer* buf) {
cuda_free(buf->data);
delete buf;
}) {
// TODO: Set memory limit for multi-device. // TODO: Set memory limit for multi-device.
size_t free, total; size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
@@ -70,7 +73,8 @@ void CudaAllocator::free(Buffer buffer) {
buffer_cache_.recycle_to_cache(buf); buffer_cache_.recycle_to_cache(buf);
} else { } else {
lock.unlock(); lock.unlock();
cuda_free(buf); cuda_free(buf->data);
delete buf;
} }
} }
@@ -87,6 +91,25 @@ void CudaAllocator::register_this_thread() {
allowed_threads_.insert(std::this_thread::get_id()); allowed_threads_.insert(std::this_thread::get_id());
} }
void CudaAllocator::cuda_free(void* buf) {
// If cuda_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
}
cudaFree(buf);
}
size_t CudaAllocator::get_active_memory() const { size_t CudaAllocator::get_active_memory() const {
return active_memory_; return active_memory_;
} }
@@ -125,26 +148,6 @@ void CudaAllocator::clear_cache() {
buffer_cache_.clear(); buffer_cache_.clear();
} }
void CudaAllocator::cuda_free(CudaBuffer* buf) {
// If cuda_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
}
cudaFree(buf->data);
delete buf;
}
CudaAllocator& allocator() { CudaAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of CudaAllocator // By creating the |allocator_| on heap, the destructor of CudaAllocator
// will not be called on exit and buffers in the cache will be leaked. This // will not be called on exit and buffers in the cache will be leaked. This

View File

@@ -34,6 +34,9 @@ class CudaAllocator : public allocator::Allocator {
// buffers there would result in dead lock. // buffers there would result in dead lock.
void register_this_thread(); void register_this_thread();
// Call cudaFree in the safe thread.
void cuda_free(void* buf);
size_t get_active_memory() const; size_t get_active_memory() const;
size_t get_peak_memory() const; size_t get_peak_memory() const;
void reset_peak_memory(); void reset_peak_memory();
@@ -47,8 +50,6 @@ class CudaAllocator : public allocator::Allocator {
CudaAllocator(); CudaAllocator();
friend CudaAllocator& allocator(); friend CudaAllocator& allocator();
void cuda_free(CudaBuffer* buf);
std::mutex worker_mutex_; std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_; std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_; std::set<std::thread::id> allowed_threads_;

View File

@@ -1,5 +1,6 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
@@ -111,12 +112,12 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
SharedEvent::SharedEvent() { SharedEvent::SharedEvent() {
// Allocate cuda::atomic on managed memory. // Allocate cuda::atomic on managed memory.
allocator::Buffer buffer = allocator::malloc(sizeof(Atomic)); Atomic* ac;
Atomic* ac = static_cast<Atomic*>(buffer.raw_ptr()); CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
new (ac) Atomic(0); new (ac) Atomic(0);
ac_ = std::shared_ptr<Atomic>(ac, [buffer](Atomic* ptr) { ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
ptr->~Atomic(); ptr->~Atomic();
allocator::free(buffer); allocator().cuda_free(ptr);
}); });
} }
@@ -155,7 +156,10 @@ void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
void SharedEvent::signal(Stream s, uint64_t value) { void SharedEvent::signal(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal(s)"); nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
if (s.device == mlx::core::Device::cpu) { if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { signal(value); }); // Signal through a GPU stream so the atomic is updated in GPU - updating
// the atomic in CPU sometimes does not get GPU notified.
static CudaStream stream(device(mlx::core::Device::gpu));
scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); });
} else { } else {
auto& encoder = get_command_encoder(s); auto& encoder = get_command_encoder(s);
encoder.launch_kernel( encoder.launch_kernel(

View File

@@ -0,0 +1,29 @@
// Copyright © 2025 Apple Inc.
#include "mlx/fence.h"
#include "mlx/backend/cuda/event.h"
namespace mlx::core {
struct FenceImpl {
uint32_t count;
cu::SharedEvent event;
};
Fence::Fence(Stream s) {
fence_ = std::shared_ptr<void>(
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
}
void Fence::wait(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
fence->event.wait(fence->count);
}
void Fence::update(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
fence->count++;
fence->event.signal(s, fence->count);
}
} // namespace mlx::core

View File

@@ -1,70 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
#include "mlx/fence.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace {
__host__ __device__ void busy_wait(cuda::atomic<uint64_t>* ac, uint64_t value) {
while (true) {
// In theory the atomic_thread_fence is not needed, but for CUDA 11 without
// it the load() may never return new value.
cuda::atomic_thread_fence(cuda::memory_order_seq_cst);
uint64_t current = ac->load();
if (current >= value) {
break;
}
}
}
__global__ void busy_wait_kernel(cuda::atomic<uint64_t>* ac, uint64_t value) {
busy_wait(ac, value);
}
} // namespace
struct FenceImpl {
uint32_t count;
cu::SharedEvent event;
};
Fence::Fence(Stream s) {
fence_ = std::shared_ptr<void>(
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
}
void Fence::wait(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
// We can't use SharedEvent::wait because it could hang in CUDA 11, see also:
// https://github.com/ml-explore/mlx/issues/2137
const auto& ac = fence->event.atomic();
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [ac, count = fence->count]() {
nvtx3::scoped_range r("Fence::wait()");
busy_wait(ac.get(), count);
});
} else {
nvtx3::scoped_range r("Fence::wait(s)");
auto& encoder = cu::get_command_encoder(s);
encoder.launch_kernel(
encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) {
busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count);
});
encoder.add_completed_handler([ac]() {});
encoder.end_encoding();
}
}
void Fence::update(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
fence->count++;
fence->event.signal(s, fence->count);
}
} // namespace mlx::core

View File

@@ -5,9 +5,17 @@
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/gpu/slicing.h"
#if defined(MLX_USE_CUDA)
#include <nvtx3/nvtx3.hpp>
#endif
#include <cassert> #include <cassert>
#if defined(MLX_USE_CUDA)
#define MLX_PROFILER_RANGE(message) nvtx3::scoped_range r(message)
#else
#define MLX_PROFILER_RANGE(message) #define MLX_PROFILER_RANGE(message)
#endif
namespace mlx::core { namespace mlx::core {

View File

@@ -31,13 +31,13 @@ std::string get_kernel_name(
kname = "ss"; kname = "ss";
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
kname = (large ? "sv2" : "sv"); kname = "sv";
break; break;
case BinaryOpType::VectorScalar: case BinaryOpType::VectorScalar:
kname = (large ? "vs2" : "vs"); kname = "vs";
break; break;
case BinaryOpType::VectorVector: case BinaryOpType::VectorVector:
kname = (large ? "vv2" : "vv"); kname = "vv";
break; break;
case BinaryOpType::General: case BinaryOpType::General:
kname = "g"; kname = "g";
@@ -51,6 +51,13 @@ std::string get_kernel_name(
} }
break; break;
} }
if (bopt != BinaryOpType::General && bopt != BinaryOpType::ScalarScalar) {
if (large) {
kname += "2";
} else if (work_per_thread > 1) {
kname += "n";
}
}
concatenate(kname, "_", op, type_to_name(a)); concatenate(kname, "_", op, type_to_name(a));
return kname; return kname;
} }
@@ -90,7 +97,7 @@ void binary_op_gpu_inplace(
work_per_thread = large ? 4 : 2; work_per_thread = large ? 4 : 2;
} else { } else {
large = out.data_size() > UINT32_MAX; large = out.data_size() > UINT32_MAX;
work_per_thread = get_work_per_thread(a.dtype()); work_per_thread = get_work_per_thread(a.dtype(), out.data_size());
} }
std::string kernel_name = std::string kernel_name =
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread); get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);

View File

@@ -11,8 +11,6 @@
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
using namespace fmt::literals;
namespace mlx::core { namespace mlx::core {
inline void build_kernel( inline void build_kernel(
@@ -21,21 +19,12 @@ inline void build_kernel(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<array>& outputs, const std::vector<array>& outputs,
const std::vector<array>& tape, const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids, const std::function<bool(size_t)>& is_constant,
bool contiguous, bool contiguous,
int ndim, int ndim,
bool dynamic_dims, bool dynamic_dims,
bool use_big_index = false, bool use_big_index = false,
int work_per_thread = 1) { int work_per_thread = 1) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
NodeNamer namer; NodeNamer namer;
bool add_indices = false; bool add_indices = false;
int cnt = 0; int cnt = 0;
@@ -45,14 +34,15 @@ inline void build_kernel(
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name); "[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
// Add the input arguments // Add the input arguments
for (auto& x : inputs) { for (size_t i = 0; i < inputs.size(); ++i) {
auto& xname = namer.get_name(x);
// Skip constants from the input list // Skip constants from the input list
if (is_constant(x)) { if (is_constant(i)) {
continue; continue;
} }
const auto& x = inputs[i];
auto& xname = namer.get_name(x);
// Scalars and contiguous need no strides // Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) { if (!is_scalar(x) && !contiguous) {
add_indices = true; add_indices = true;
@@ -80,8 +70,6 @@ inline void build_kernel(
} }
// Add output strides and shape to extract the indices. // Add output strides and shape to extract the indices.
if (!contiguous) { if (!contiguous) {
os += fmt::format(
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
os += fmt::format( os += fmt::format(
" constant const int* output_shape [[buffer({0})]],\n", cnt++); " constant const int* output_shape [[buffer({0})]],\n", cnt++);
} else { } else {
@@ -125,7 +113,7 @@ inline void build_kernel(
auto& x = inputs[i]; auto& x = inputs[i];
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
if (is_constant(x)) { if (is_constant(i)) {
auto type_str = get_type_string(x.dtype()); auto type_str = get_type_string(x.dtype());
std::ostringstream ss; std::ostringstream ss;
print_constant(ss, x); print_constant(ss, x);
@@ -271,11 +259,6 @@ inline void build_kernel(
void Compiled::eval_gpu( void Compiled::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
// Make the name for the kernel library
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Get the kernel if someone else built it already // Get the kernel if someone else built it already
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
@@ -290,19 +273,33 @@ void Compiled::eval_gpu(
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
constant_ids_, is_constant_,
/* contiguous = */ true, /* contiguous = */ true,
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ false, /* dynamic_dims = */ false,
/* use_big_index = */ false, /* use_big_index = */ false,
/* work_per_thread = */ work_per_thread); /* work_per_thread = */ 1);
if (work_per_thread > 1) {
build_kernel(
kernel,
kernel_lib_ + "_contiguous_n",
inputs_,
outputs_,
tape_,
is_constant_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ work_per_thread);
}
build_kernel( build_kernel(
kernel, kernel,
kernel_lib_ + "_contiguous_large", kernel_lib_ + "_contiguous_large",
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
constant_ids_, is_constant_,
/* contiguous = */ true, /* contiguous = */ true,
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ false, /* dynamic_dims = */ false,
@@ -315,7 +312,7 @@ void Compiled::eval_gpu(
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
constant_ids_, is_constant_,
/* contiguous = */ false, /* contiguous = */ false,
/* ndim = */ i, /* ndim = */ i,
/* dynamic_dims = */ false, /* dynamic_dims = */ false,
@@ -328,7 +325,7 @@ void Compiled::eval_gpu(
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
constant_ids_, is_constant_,
/* contiguous = */ false, /* contiguous = */ false,
/* ndim = */ i, /* ndim = */ i,
/* dynamic_dims = */ false, /* dynamic_dims = */ false,
@@ -342,7 +339,7 @@ void Compiled::eval_gpu(
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
constant_ids_, is_constant_,
/* contiguous = */ false, /* contiguous = */ false,
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ true, /* dynamic_dims = */ true,
@@ -354,7 +351,7 @@ void Compiled::eval_gpu(
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
constant_ids_, is_constant_,
/* contiguous = */ false, /* contiguous = */ false,
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ true, /* dynamic_dims = */ true,
@@ -363,81 +360,32 @@ void Compiled::eval_gpu(
return kernel; return kernel;
}); });
// Figure out which kernel we are using
auto& output_shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, output_shape);
// Collapse contiguous dims to route to a faster kernel if possible. Also // Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting. // handle all broadcasting.
std::vector<Strides> initial_strides; auto [contiguous, shape, strides] =
initial_strides.push_back(outputs[0].strides()); compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
Shape shape;
std::vector<Strides> strides;
if (!contiguous) {
for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
}
auto& x = inputs[i];
// Skip scalar inputs. // Whether to use large index.
if (is_scalar(x)) { bool large = compiled_use_large_index(inputs, outputs, contiguous);
continue;
}
// Broadcast the inputs to the output shape.
Strides xstrides;
int j = 0;
for (; j < output_shape.size() - x.ndim(); j++) {
if (output_shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (output_shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
initial_strides.push_back(std::move(xstrides));
}
std::tie(shape, strides) =
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
}
bool large;
if (contiguous) {
size_t max_size = 0;
for (auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
large = (max_size > UINT32_MAX);
} else {
size_t max_size = 0;
for (auto& o : outputs) {
max_size = std::max(max_size, o.size());
}
large = (max_size > UINT32_MAX);
}
// Get the kernel from the lib // Get the kernel from the lib
int ndim = shape.size(); int ndim = shape.size();
bool dynamic = ndim >= 8; bool dynamic = ndim >= 8;
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
int work_per_thread = 1;
if (!contiguous) { if (!contiguous) {
if (dynamic) { if (dynamic) {
kernel_name += "dynamic"; kernel_name += "dynamic";
} else { } else {
kernel_name += std::to_string(shape.size()); kernel_name += std::to_string(shape.size());
} }
work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
} else {
work_per_thread =
get_work_per_thread(outputs[0].dtype(), outputs[0].data_size());
if (work_per_thread > 1 && !large) {
kernel_name += "_n";
}
} }
if (large) { if (large) {
kernel_name += "_large"; kernel_name += "_large";
@@ -451,7 +399,7 @@ void Compiled::eval_gpu(
int stride_idx = 1; // idx 0 is the output strides int stride_idx = 1; // idx 0 is the output strides
Strides in_strides; Strides in_strides;
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { if (is_constant_(i)) {
continue; continue;
} }
auto& x = inputs[i]; auto& x = inputs[i];
@@ -468,8 +416,7 @@ void Compiled::eval_gpu(
compute_encoder.set_vector_bytes(in_strides, cnt++); compute_encoder.set_vector_bytes(in_strides, cnt++);
} }
compiled_allocate_outputs( compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
inputs, outputs, inputs_, constant_ids_, contiguous);
// Put the outputs in // Put the outputs in
for (auto& x : outputs) { for (auto& x : outputs) {
@@ -478,7 +425,6 @@ void Compiled::eval_gpu(
// Put the output shape and strides in // Put the output shape and strides in
if (!contiguous) { if (!contiguous) {
compute_encoder.set_vector_bytes(strides[0], cnt++);
compute_encoder.set_vector_bytes(shape, cnt++); compute_encoder.set_vector_bytes(shape, cnt++);
} else { } else {
auto size = outputs[0].data_size(); auto size = outputs[0].data_size();
@@ -496,7 +442,6 @@ void Compiled::eval_gpu(
// Launch the kernel // Launch the kernel
if (contiguous) { if (contiguous) {
int work_per_thread = get_work_per_thread(outputs[0].dtype());
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread); size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
MTL::Size group_dims( MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
@@ -509,7 +454,6 @@ void Compiled::eval_gpu(
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = outputs[0].size() / (dim0 * dim1); size_t rest = outputs[0].size() / (dim0 * dim1);
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
int pow2; int pow2;

View File

@@ -55,10 +55,10 @@ void copy_gpu_inplace(
std::string kernel_name; std::string kernel_name;
switch (ctype) { switch (ctype) {
case CopyType::Scalar: case CopyType::Scalar:
kernel_name = (large ? "s2" : "s"); kernel_name = large ? "s2" : "s";
break; break;
case CopyType::Vector: case CopyType::Vector:
kernel_name = (large ? "v2" : "v"); kernel_name = large ? "v2" : "v";
break; break;
case CopyType::General: case CopyType::General:
kernel_name = "g"; kernel_name = "g";
@@ -85,7 +85,10 @@ void copy_gpu_inplace(
} }
} }
} else { } else {
work_per_thread = get_work_per_thread(in.dtype()); work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
if (work_per_thread > 1) {
kernel_name += "n";
}
} }
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out)); concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out) auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
@@ -170,9 +173,10 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
} }
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
bool large = out.data_size() > UINT32_MAX; bool large = out.data_size() > UINT32_MAX;
int work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" + std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s");
type_to_name(val) + type_to_name(out); concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out));
auto kernel = get_copy_kernel(d, kernel_name, val, out); auto kernel = get_copy_kernel(d, kernel_name, val, out);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
@@ -180,7 +184,6 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
compute_encoder.set_input_array(val, 0); compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
int work_per_thread = get_work_per_thread(val.dtype());
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
size_t nthreads = ceildiv(out.data_size(), work_per_thread); size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {

View File

@@ -41,7 +41,11 @@ MTL::ComputePipelineState* get_unary_kernel(
std::string kernel_source = metal::utils(); std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::unary_ops(), metal::unary()); concatenate(kernel_source, metal::unary_ops(), metal::unary());
kernel_source += kernel_source +=
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op); get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1);
if (get_work_per_thread(in_type) > 1) {
kernel_source +=
get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op);
}
kernel_source += kernel_source +=
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op); get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source += get_template_definition( kernel_source += get_template_definition(
@@ -59,11 +63,8 @@ void append_binary_kernels(
Dtype out_type, Dtype out_type,
const std::string op, const std::string op,
std::string& kernel_source) { std::string& kernel_source) {
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{ const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
{"ss", "binary_ss"}, {"ss", "binary_ss"},
{"vs", "binary_vs"},
{"sv", "binary_sv"},
{"vv", "binary_vv"},
{"vs2", "binary_vs2"}, {"vs2", "binary_vs2"},
{"sv2", "binary_sv2"}, {"sv2", "binary_sv2"},
{"vv2", "binary_vv2"}, {"vv2", "binary_vv2"},
@@ -78,6 +79,22 @@ void append_binary_kernels(
kernel_source += kernel_source +=
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op); get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
} }
kernel_source += get_template_definition(
"vs_" + lib_name, "binary_vs", in_t, out_t, op, 1);
kernel_source += get_template_definition(
"sv_" + lib_name, "binary_sv", in_t, out_t, op, 1);
kernel_source += get_template_definition(
"vv_" + lib_name, "binary_vv", in_t, out_t, op, 1);
if (get_work_per_thread(in_type) > 1) {
kernel_source += get_template_definition(
"vsn_" + lib_name, "binary_vs", in_t, out_t, op);
kernel_source += get_template_definition(
"svn_" + lib_name, "binary_sv", in_t, out_t, op);
kernel_source += get_template_definition(
"vvn_" + lib_name, "binary_vv", in_t, out_t, op);
}
kernel_source += get_template_definition( kernel_source += get_template_definition(
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int"); "g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
kernel_source += get_template_definition( kernel_source += get_template_definition(
@@ -133,8 +150,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
auto t_str = get_type_string(type); auto t_str = get_type_string(type);
std::string kernel_source = metal::utils(); std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::ternary_ops(), metal::ternary()); concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{ const std::array<std::pair<std::string, std::string>, 4> kernel_types = {{
{"v", "ternary_v"},
{"v2", "ternary_v2"}, {"v2", "ternary_v2"},
{"g1large", "ternary_g_nd1"}, {"g1large", "ternary_g_nd1"},
{"g2large", "ternary_g_nd2"}, {"g2large", "ternary_g_nd2"},
@@ -144,6 +160,13 @@ MTL::ComputePipelineState* get_ternary_kernel(
kernel_source += kernel_source +=
get_template_definition(name + "_" + lib_name, func, t_str, op); get_template_definition(name + "_" + lib_name, func, t_str, op);
} }
if (get_work_per_thread(type) > 1) {
kernel_source +=
get_template_definition("vn_" + lib_name, "ternary_v", t_str, op);
}
kernel_source +=
get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1);
kernel_source += get_template_definition( kernel_source += get_template_definition(
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int"); "g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
kernel_source += get_template_definition( kernel_source += get_template_definition(
@@ -170,15 +193,22 @@ MTL::ComputePipelineState* get_copy_kernel(
kernel_source += metal::copy(); kernel_source += metal::copy();
auto in_type = get_type_string(in.dtype()); auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype()); auto out_type = get_type_string(out.dtype());
kernel_source += kernel_source += get_template_definition(
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type); "s_" + lib_name, "copy_s", in_type, out_type, 1);
kernel_source += kernel_source +=
get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type); get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type);
kernel_source += kernel_source += get_template_definition(
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type); "v_" + lib_name, "copy_v", in_type, out_type, 1);
kernel_source += kernel_source +=
get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type); get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type);
if (get_work_per_thread(out.dtype()) > 1) {
kernel_source += get_template_definition(
"sn_" + lib_name, "copy_s", in_type, out_type);
kernel_source += get_template_definition(
"vn_" + lib_name, "copy_v", in_type, out_type);
}
kernel_source += get_template_definition( kernel_source += get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int"); "g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
kernel_source += get_template_definition( kernel_source += get_template_definition(

View File

@@ -17,8 +17,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
c[index + i] = Op()(a[0], b[index + i]); for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[0], b[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[0], b[index + i]);
}
} }
} }
@@ -30,8 +36,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
c[index + i] = Op()(a[index + i], b[0]); for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[index + i], b[0]);
}
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[index + i], b[0]);
}
} }
} }
@@ -43,8 +55,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
c[index + i] = Op()(a[index + i], b[index + i]); for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]);
}
} }
} }
@@ -57,8 +75,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
c[offset + i] = Op()(a[0], b[offset + i]); for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]);
}
} }
} }
@@ -71,8 +95,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
c[offset + i] = Op()(a[offset + i], b[0]); for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]);
}
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]);
}
} }
} }
@@ -85,8 +115,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
c[offset + i] = Op()(a[offset + i], b[offset + i]); for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]);
}
} }
} }

View File

@@ -9,11 +9,16 @@
#include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h" #include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary_all(op, tname, itype, otype) \ #define instantiate_binary_work_per_thread(op, tname, itype, otype) \
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) \
#define instantiate_binary_base(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
@@ -26,15 +31,19 @@
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \ instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_integer(op) \ #define instantiate_binary_all(op, tname, itype, otype) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ instantiate_binary_base(op, tname, itype, otype) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ instantiate_binary_work_per_thread(op, tname, itype, otype)
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \ #define instantiate_binary_integer(op) \
instantiate_binary_all(op, int8, int8_t, int8_t) \ instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \ instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \ instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_all(op, int64, int64_t, int64_t) instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_base(op, int64, int64_t, int64_t)
#define instantiate_binary_float(op) \ #define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \ instantiate_binary_all(op, float16, half, half) \
@@ -44,7 +53,7 @@
#define instantiate_binary_types(op) \ #define instantiate_binary_types(op) \
instantiate_binary_all(op, bool_, bool, bool) \ instantiate_binary_all(op, bool_, bool, bool) \
instantiate_binary_integer(op) \ instantiate_binary_integer(op) \
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \ instantiate_binary_base(op, complex64, complex64_t, complex64_t)\
instantiate_binary_float(op) instantiate_binary_float(op)
#define instantiate_binary_types_bool(op) \ #define instantiate_binary_types_bool(op) \
@@ -52,15 +61,15 @@
instantiate_binary_all(op, uint8, uint8_t, bool) \ instantiate_binary_all(op, uint8, uint8_t, bool) \
instantiate_binary_all(op, uint16, uint16_t, bool) \ instantiate_binary_all(op, uint16, uint16_t, bool) \
instantiate_binary_all(op, uint32, uint32_t, bool) \ instantiate_binary_all(op, uint32, uint32_t, bool) \
instantiate_binary_all(op, uint64, uint64_t, bool) \ instantiate_binary_base(op, uint64, uint64_t, bool) \
instantiate_binary_all(op, int8, int8_t, bool) \ instantiate_binary_all(op, int8, int8_t, bool) \
instantiate_binary_all(op, int16, int16_t, bool) \ instantiate_binary_all(op, int16, int16_t, bool) \
instantiate_binary_all(op, int32, int32_t, bool) \ instantiate_binary_all(op, int32, int32_t, bool) \
instantiate_binary_all(op, int64, int64_t, bool) \ instantiate_binary_base(op, int64, int64_t, bool) \
instantiate_binary_all(op, float16, half, bool) \ instantiate_binary_all(op, float16, half, bool) \
instantiate_binary_all(op, float32, float, bool) \ instantiate_binary_all(op, float32, float, bool) \
instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \ instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \
instantiate_binary_all(op, complex64, complex64_t, bool) instantiate_binary_base(op, complex64, complex64_t, bool)
instantiate_binary_types(Add) instantiate_binary_types(Add)
instantiate_binary_types(Divide) instantiate_binary_types(Divide)
@@ -71,7 +80,7 @@ instantiate_binary_types_bool(Less)
instantiate_binary_types_bool(LessEqual) instantiate_binary_types_bool(LessEqual)
instantiate_binary_types_bool(NotEqual) instantiate_binary_types_bool(NotEqual)
instantiate_binary_float(LogAddExp) instantiate_binary_float(LogAddExp)
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t) instantiate_binary_base(LogAddExp, complex64, complex64_t, complex64_t)
instantiate_binary_types(Maximum) instantiate_binary_types(Maximum)
instantiate_binary_types(Minimum) instantiate_binary_types(Minimum)
instantiate_binary_types(Multiply) instantiate_binary_types(Multiply)
@@ -84,7 +93,7 @@ instantiate_binary_float(ArcTan2)
instantiate_binary_all(NaNEqual, float16, half, bool) instantiate_binary_all(NaNEqual, float16, half, bool)
instantiate_binary_all(NaNEqual, float32, float, bool) instantiate_binary_all(NaNEqual, float32, float, bool)
instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool) instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool)
instantiate_binary_all(NaNEqual, complex64, complex64_t, bool) instantiate_binary_base(NaNEqual, complex64, complex64_t, bool)
instantiate_binary_all(LogicalOr, bool_, bool, bool) instantiate_binary_all(LogicalOr, bool_, bool, bool)
instantiate_binary_all(LogicalAnd, bool_, bool, bool) instantiate_binary_all(LogicalAnd, bool_, bool, bool)

View File

@@ -21,10 +21,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
auto out = Op()(a[0], b[index + i]); for (int i = 0; index + i < size; ++i) {
c[index + i] = out[0]; auto out = Op()(a[0], b[index + i]);
d[index + i] = out[1]; c[index + i] = out[0];
d[index + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[0], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} }
} }
@@ -37,10 +45,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
auto out = Op()(a[index + i], b[0]); for (int i = 0; index + i < size; ++i) {
c[index + i] = out[0]; auto out = Op()(a[index + i], b[0]);
d[index + i] = out[1]; c[index + i] = out[0];
d[index + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[index + i], b[0]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} }
} }
@@ -53,10 +69,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
auto out = Op()(a[index + i], b[index + i]); for (int i = 0; index + i < size; ++i) {
c[index + i] = out[0]; auto out = Op()(a[index + i], b[index + i]);
d[index + i] = out[1]; c[index + i] = out[0];
d[index + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[index + i], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} }
} }
@@ -69,11 +93,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
auto out = Op()(a[0], b[offset + i]); for (int i = 0; offset + i < size; ++i) {
c[offset + i] = out[0]; auto out = Op()(a[0], b[offset + i]);
d[offset + i] = out[1]; c[offset + i] = out[0];
d[offset + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[0], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
} }
} }
@@ -86,11 +118,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
auto out = Op()(a[offset + i], b[0]); for (int i = 0; offset + i < size; ++i) {
c[offset + i] = out[0]; auto out = Op()(a[offset + i], b[0]);
d[offset + i] = out[1]; c[offset + i] = out[0];
d[offset + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[offset + i], b[0]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
} }
} }
@@ -103,11 +143,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
auto out = Op()(a[offset + i], b[offset + i]); for (int i = 0; offset + i < size; ++i) {
c[offset + i] = out[0]; auto out = Op()(a[offset + i], b[offset + i]);
d[offset + i] = out[1]; c[offset + i] = out[0];
d[offset + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[offset + i], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
} }
} }

View File

@@ -7,11 +7,16 @@
#include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h" #include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary_all(op, tname, itype, otype) \ #define instantiate_binary_work_per_thread(op, tname, itype, otype) \
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op)
#define instantiate_binary_base(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
@@ -24,22 +29,26 @@
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_binary_base(op, tname, itype, otype) \
instantiate_binary_work_per_thread(op, tname, itype, otype)
#define instantiate_binary_float(op) \ #define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \ instantiate_binary_all(op, float16, half, half) \
instantiate_binary_all(op, float32, float, float) \ instantiate_binary_all(op, float32, float, float) \
instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t) instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)
#define instantiate_binary_types(op) \ #define instantiate_binary_types(op) \
instantiate_binary_all(op, bool_, bool, bool) \ instantiate_binary_all(op, bool_, bool, bool) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \ instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \ instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \ instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \ instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_all(op, int64, int64_t, int64_t) \ instantiate_binary_base(op, int64, int64_t, int64_t) \
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \ instantiate_binary_base(op, complex64, complex64_t, complex64_t) \
instantiate_binary_float(op) instantiate_binary_float(op)
instantiate_binary_types(DivMod) // clang-format on instantiate_binary_types(DivMod) // clang-format on

View File

@@ -1,52 +1,76 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
template <typename T, typename U, int N = WorkPerThread<T>::n> template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_s( [[kernel]] void copy_s(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
dst[index + i] = static_cast<U>(src[0]); for (int i = 0; index + i < size; ++i) {
dst[index + i] = static_cast<U>(src[0]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[index + i] = static_cast<U>(src[0]);
}
} }
} }
template <typename T, typename U, int N = WorkPerThread<T>::n> template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_v( [[kernel]] void copy_v(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
dst[index + i] = static_cast<U>(src[index + i]); for (int i = 0; index + i < size; ++i) {
dst[index + i] = static_cast<U>(src[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[index + i] = static_cast<U>(src[index + i]);
}
} }
} }
template <typename T, typename U, int N = WorkPerThread<T>::n> template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_s2( [[kernel]] void copy_s2(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
dst[offset + i] = static_cast<U>(src[0]); for (int i = 0; offset + i < size; ++i) {
dst[offset + i] = static_cast<U>(src[0]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[offset + i] = static_cast<U>(src[0]);
}
} }
} }
template <typename T, typename U, int N = WorkPerThread<T>::n> template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_v2( [[kernel]] void copy_v2(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
dst[offset + i] = static_cast<U>(src[offset + i]); for (int i = 0; offset + i < size; ++i) {
dst[offset + i] = static_cast<U>(src[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[offset + i] = static_cast<U>(src[offset + i]);
}
} }
} }

View File

@@ -4,9 +4,13 @@
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/copy.h" #include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy_all(tname, itype, otype) \ #define instantiate_copy_work_per_thread(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \ instantiate_kernel("sn_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \ instantiate_kernel("vn_copy" #tname, copy_v, itype, otype)
#define instantiate_copy_base(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype, 1) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype, 1) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \ instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \ instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \ instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
@@ -18,6 +22,10 @@
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \ instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4)
#define instantiate_copy_all(tname, itype, otype) \
instantiate_copy_base(tname, itype, otype) \
instantiate_copy_work_per_thread(tname, itype, otype)
#define instantiate_copy_same(tname, type) \ #define instantiate_copy_same(tname, type) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \ instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \ instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
@@ -42,15 +50,15 @@
instantiate_copy_all(itname ##uint8, itype, uint8_t) \ instantiate_copy_all(itname ##uint8, itype, uint8_t) \
instantiate_copy_all(itname ##uint16, itype, uint16_t) \ instantiate_copy_all(itname ##uint16, itype, uint16_t) \
instantiate_copy_all(itname ##uint32, itype, uint32_t) \ instantiate_copy_all(itname ##uint32, itype, uint32_t) \
instantiate_copy_all(itname ##uint64, itype, uint64_t) \ instantiate_copy_base(itname ##uint64, itype, uint64_t) \
instantiate_copy_all(itname ##int8, itype, int8_t) \ instantiate_copy_all(itname ##int8, itype, int8_t) \
instantiate_copy_all(itname ##int16, itype, int16_t) \ instantiate_copy_all(itname ##int16, itype, int16_t) \
instantiate_copy_all(itname ##int32, itype, int32_t) \ instantiate_copy_all(itname ##int32, itype, int32_t) \
instantiate_copy_all(itname ##int64, itype, int64_t) \ instantiate_copy_base(itname ##int64, itype, int64_t) \
instantiate_copy_all(itname ##float16, itype, half) \ instantiate_copy_all(itname ##float16, itype, half) \
instantiate_copy_all(itname ##float32, itype, float) \ instantiate_copy_all(itname ##float32, itype, float) \
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \ instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
instantiate_copy_all(itname ##complex64, itype, complex64_t) instantiate_copy_base(itname ##complex64, itype, complex64_t)
instantiate_copy_itype(bool_, bool) instantiate_copy_itype(bool_, bool)
instantiate_copy_itype(uint8, uint8_t) instantiate_copy_itype(uint8, uint8_t)

View File

@@ -9,8 +9,14 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); for (int i = 0; index + i < size; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
}
} }
} }
@@ -23,9 +29,15 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); for (int i = 0; offset + i < size; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
}
} }
} }

View File

@@ -8,8 +8,8 @@
#include "mlx/backend/metal/kernels/ternary_ops.h" #include "mlx/backend/metal/kernels/ternary_ops.h"
#include "mlx/backend/metal/kernels/ternary.h" #include "mlx/backend/metal/kernels/ternary.h"
#define instantiate_ternary_all(op, tname, type) \ #define instantiate_ternary_base(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ instantiate_kernel("v_" #op #tname, ternary_v, type, op, 1) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \ instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \ instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
@@ -20,19 +20,23 @@
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \ instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \ instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("vn_" #op #tname, ternary_v, type, op) \
instantiate_ternary_base(op, tname, type)
#define instantiate_ternary_types(op) \ #define instantiate_ternary_types(op) \
instantiate_ternary_all(op, bool_, bool) \ instantiate_ternary_all(op, bool_, bool) \
instantiate_ternary_all(op, uint8, uint8_t) \ instantiate_ternary_all(op, uint8, uint8_t) \
instantiate_ternary_all(op, uint16, uint16_t) \ instantiate_ternary_all(op, uint16, uint16_t) \
instantiate_ternary_all(op, uint32, uint32_t) \ instantiate_ternary_all(op, uint32, uint32_t) \
instantiate_ternary_all(op, uint64, uint64_t) \ instantiate_ternary_base(op, uint64, uint64_t) \
instantiate_ternary_all(op, int8, int8_t) \ instantiate_ternary_all(op, int8, int8_t) \
instantiate_ternary_all(op, int16, int16_t) \ instantiate_ternary_all(op, int16, int16_t) \
instantiate_ternary_all(op, int32, int32_t) \ instantiate_ternary_all(op, int32, int32_t) \
instantiate_ternary_all(op, int64, int64_t) \ instantiate_ternary_base(op, int64, int64_t) \
instantiate_ternary_all(op, float16, half) \ instantiate_ternary_all(op, float16, half) \
instantiate_ternary_all(op, float32, float) \ instantiate_ternary_all(op, float32, float) \
instantiate_ternary_all(op, bfloat16, bfloat16_t) \ instantiate_ternary_all(op, bfloat16, bfloat16_t) \
instantiate_ternary_all(op, complex64, complex64_t) // clang-format on instantiate_ternary_base(op, complex64, complex64_t) // clang-format on
instantiate_ternary_types(Select) instantiate_ternary_types(Select)

View File

@@ -7,8 +7,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
out[index + i] = Op()(in[index + i]); for (int i = 0; index + i < size; ++i) {
out[index + i] = Op()(in[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
out[index + i] = Op()(in[index + i]);
}
} }
} }
@@ -19,9 +25,15 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
out[offset + i] = Op()(in[offset + i]); for (int i = 0; offset + i < size; ++i) {
out[offset + i] = Op()(in[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
out[offset + i] = Op()(in[offset + i]);
}
} }
} }

View File

@@ -5,31 +5,41 @@
#include "mlx/backend/metal/kernels/unary_ops.h" #include "mlx/backend/metal/kernels/unary_ops.h"
#include "mlx/backend/metal/kernels/unary.h" #include "mlx/backend/metal/kernels/unary.h"
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \ #define instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \ instantiate_kernel("vn_" #op #in_tname #out_tname, unary_v, in_type, out_type, op)
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel( \ #define instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \ instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op, 1) \
instantiate_kernel( \ instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel( \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
instantiate_kernel( \
"gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4) "gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \
instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type)
#define instantiate_unary_all_same(op, tname, type) \ #define instantiate_unary_all_same(op, tname, type) \
instantiate_unary_all(op, tname, tname, type, type) instantiate_unary_all(op, tname, tname, type, type)
#define instantiate_unary_base_same(op, tname, type) \
instantiate_unary_base(op, tname, tname, type, type)
#define instantiate_unary_float(op) \ #define instantiate_unary_float(op) \
instantiate_unary_all_same(op, float16, half) \ instantiate_unary_all_same(op, float16, half) \
instantiate_unary_all_same(op, float32, float) \ instantiate_unary_all_same(op, float32, float) \
instantiate_unary_all_same(op, bfloat16, bfloat16_t) instantiate_unary_all_same(op, bfloat16, bfloat16_t)
#define instantiate_unary_int(op) \ #define instantiate_unary_int(op) \
instantiate_unary_all_same(op, uint8, uint8_t) \ instantiate_unary_all_same(op, uint8, uint8_t) \
instantiate_unary_all_same(op, uint16, uint16_t) \ instantiate_unary_all_same(op, uint16, uint16_t) \
instantiate_unary_all_same(op, uint32, uint32_t) \ instantiate_unary_all_same(op, uint32, uint32_t) \
instantiate_unary_all_same(op, uint64, uint64_t) \ instantiate_unary_base_same(op, uint64, uint64_t) \
instantiate_unary_all_same(op, int8, int8_t) \ instantiate_unary_all_same(op, int8, int8_t) \
instantiate_unary_all_same(op, int16, int16_t) \ instantiate_unary_all_same(op, int16, int16_t) \
instantiate_unary_all_same(op, int32, int32_t) \ instantiate_unary_all_same(op, int32, int32_t) \
instantiate_unary_all_same(op, int64, int64_t) instantiate_unary_base_same(op, int64, int64_t)
#define instantiate_unary_types(op) \ #define instantiate_unary_types(op) \
instantiate_unary_all_same(op, bool_, bool) \ instantiate_unary_all_same(op, bool_, bool) \
@@ -68,29 +78,29 @@ instantiate_unary_float(Tanh)
instantiate_unary_float(Round) instantiate_unary_float(Round)
instantiate_unary_int(BitwiseInvert) instantiate_unary_int(BitwiseInvert)
instantiate_unary_all_same(Abs, complex64, complex64_t) instantiate_unary_base_same(Abs, complex64, complex64_t)
instantiate_unary_all_same(ArcCos, complex64, complex64_t) instantiate_unary_base_same(ArcCos, complex64, complex64_t)
instantiate_unary_all_same(ArcSin, complex64, complex64_t) instantiate_unary_base_same(ArcSin, complex64, complex64_t)
instantiate_unary_all_same(ArcTan, complex64, complex64_t) instantiate_unary_base_same(ArcTan, complex64, complex64_t)
instantiate_unary_all_same(Conjugate, complex64, complex64_t) instantiate_unary_base_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_base_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t) instantiate_unary_base_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t) instantiate_unary_base_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Log, complex64, complex64_t) instantiate_unary_base_same(Log, complex64, complex64_t)
instantiate_unary_all_same(Log1p, complex64, complex64_t) instantiate_unary_base_same(Log1p, complex64, complex64_t)
instantiate_unary_all_same(Log2, complex64, complex64_t) instantiate_unary_base_same(Log2, complex64, complex64_t)
instantiate_unary_all_same(Log10, complex64, complex64_t) instantiate_unary_base_same(Log10, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t) instantiate_unary_base_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t) instantiate_unary_base_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t) instantiate_unary_base_same(Sin, complex64, complex64_t)
instantiate_unary_all_same(Sinh, complex64, complex64_t) instantiate_unary_base_same(Sinh, complex64, complex64_t)
instantiate_unary_all_same(Square, complex64, complex64_t) instantiate_unary_base_same(Square, complex64, complex64_t)
instantiate_unary_all_same(Sqrt, complex64, complex64_t) instantiate_unary_base_same(Sqrt, complex64, complex64_t)
instantiate_unary_all_same(Rsqrt, complex64, complex64_t) instantiate_unary_base_same(Rsqrt, complex64, complex64_t)
instantiate_unary_all_same(Tan, complex64, complex64_t) instantiate_unary_base_same(Tan, complex64, complex64_t)
instantiate_unary_all_same(Tanh, complex64, complex64_t) instantiate_unary_base_same(Tanh, complex64, complex64_t)
instantiate_unary_all_same(Round, complex64, complex64_t) instantiate_unary_base_same(Round, complex64, complex64_t)
instantiate_unary_all(Real, complex64, float32, complex64_t, float) instantiate_unary_base(Real, complex64, float32, complex64_t, float)
instantiate_unary_all(Imag, complex64, float32, complex64_t, float) instantiate_unary_base(Imag, complex64, float32, complex64_t, float)
instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on

View File

@@ -45,7 +45,7 @@ void ternary_op_gpu_inplace(
work_per_thread = large ? 4 : 2; work_per_thread = large ? 4 : 2;
} else { } else {
large = out.data_size() > INT32_MAX; large = out.data_size() > INT32_MAX;
work_per_thread = get_work_per_thread(b.dtype()); work_per_thread = get_work_per_thread(b.dtype(), out.data_size());
} }
std::string kernel_name; std::string kernel_name;
if (topt == TernaryOpType::General) { if (topt == TernaryOpType::General) {
@@ -60,6 +60,8 @@ void ternary_op_gpu_inplace(
} }
} else if (large) { } else if (large) {
kernel_name = "v2"; kernel_name = "v2";
} else if (work_per_thread > 1) {
kernel_name = "vn";
} else { } else {
kernel_name = "v"; kernel_name = "v";
} }

View File

@@ -43,8 +43,8 @@ void unary_op_gpu_inplace(
int work_per_thread; int work_per_thread;
std::string kernel_name; std::string kernel_name;
if (contig) { if (contig) {
work_per_thread = get_work_per_thread(in.dtype()); work_per_thread = get_work_per_thread(in.dtype(), in.data_size());
kernel_name = (large ? "v2" : "v"); kernel_name = (large ? "v2" : (work_per_thread > 1 ? "vn" : "v"));
} else { } else {
work_per_thread = large ? 4 : 1; work_per_thread = large ? 4 : 1;
kernel_name = "gn" + std::to_string(work_per_thread); kernel_name = "gn" + std::to_string(work_per_thread);

View File

@@ -72,6 +72,10 @@ void concatenate(std::string& acc, T first, Args... args) {
inline int get_work_per_thread(Dtype dtype) { inline int get_work_per_thread(Dtype dtype) {
return std::max(1, 8 / dtype.size()); return std::max(1, 8 / dtype.size());
} }
inline int get_work_per_thread(Dtype dtype, size_t size) {
constexpr size_t wpt_threshold = 1 << 16;
return size < wpt_threshold ? 1 : std::max(1, 8 / dtype.size());
}
inline size_t ceildiv(size_t n, size_t m) { inline size_t ceildiv(size_t n, size_t m) {
return (n + m - 1) / m; return (n + m - 1) / m;

View File

@@ -1,16 +1,20 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cstdlib> #include <cstdlib>
#include <map> #include <map>
#include <sstream>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/compiled.h"
#include "mlx/compile.h" #include "mlx/compile.h"
#include "mlx/compile_impl.h" #include "mlx/compile_impl.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/transforms.h" #include "mlx/transforms.h"
#include "mlx/transforms_impl.h" #include "mlx/transforms_impl.h"
#include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
@@ -82,7 +86,54 @@ Compiled::Compiled(
inputs_(std::move(inputs)), inputs_(std::move(inputs)),
outputs_(std::move(outputs)), outputs_(std::move(outputs)),
tape_(std::move(tape)), tape_(std::move(tape)),
constant_ids_(std::move(constant_ids)) {} constant_ids_(std::move(constant_ids)),
is_constant_([this](size_t i) {
return constant_ids_.find(inputs_[i].id()) != constant_ids_.end();
}) {
// Build the kernel name.
NodeNamer namer;
std::ostringstream os;
std::ostringstream constant_hasher;
// Fill the input names. This is not really necessary, I just like having A,
// B, C, ... as the inputs.
for (const auto& x : inputs_) {
namer.get_name(x);
}
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (const auto& a : tape_) {
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed
a.primitive().print(os);
// name of inputs to the function
for (auto& inp : a.inputs()) {
os << namer.get_name(inp);
}
}
os << "_";
for (const auto& x : inputs_) {
if (constant_ids_.find(x.id()) != constant_ids_.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << (is_scalar(x) ? "S" : "V");
}
}
os << "_";
for (const auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
kernel_lib_ = os.str();
}
std::vector<array> Compiled::vjp( std::vector<array> Compiled::vjp(
const std::vector<array>&, const std::vector<array>&,

View File

@@ -627,6 +627,7 @@ class Compiled : public Primitive {
const std::vector<array> outputs_; const std::vector<array> outputs_;
const std::vector<array> tape_; const std::vector<array> tape_;
const std::unordered_set<uintptr_t> constant_ids_; const std::unordered_set<uintptr_t> constant_ids_;
const std::function<bool(size_t)> is_constant_;
std::string kernel_lib_; std::string kernel_lib_;
}; };

View File

@@ -208,7 +208,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
// output arrays stream // output arrays stream
fences[it->second].wait(stream, in); fences[it->second].wait(stream, in);
} else if (in.event().valid()) { } else if (in.event().valid()) {
if (in.event().stream() != stream) { if (in.event().is_signaled()) {
in.detach_event();
} else if (in.event().stream() != stream) {
// Use event to wait across async eval // Use event to wait across async eval
in.event().wait(stream); in.event().wait(stream);
} }

View File

@@ -4,7 +4,7 @@
#define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 26 #define MLX_VERSION_MINOR 26
#define MLX_VERSION_PATCH 0 #define MLX_VERSION_PATCH 1
#define MLX_VERSION_NUMERIC \ #define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@@ -193,7 +193,7 @@ class Module(dict):
) )
if len(weights) != 0: if len(weights) != 0:
self.update(tree_unflatten(weights)) self.update(tree_unflatten(weights), strict=False)
return self return self
def save_weights(self, file: str): def save_weights(self, file: str):
@@ -291,7 +291,7 @@ class Module(dict):
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module) return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict) -> Module: def update(self, parameters: dict, strict: bool = True) -> Module:
"""Replace the parameters of this Module with the provided ones in the """Replace the parameters of this Module with the provided ones in the
dict of dicts and lists. dict of dicts and lists.
@@ -305,7 +305,9 @@ class Module(dict):
Args: Args:
parameters (dict): A complete or partial dictionary of the modules parameters (dict): A complete or partial dictionary of the modules
parameters. parameters.
strict (bool): If ``True`` checks that ``parameters`` is a
subset of the module's parameters. Default: ``True``.
Returns: Returns:
The module instance after updating the parameters. The module instance after updating the parameters.
""" """
@@ -317,21 +319,29 @@ class Module(dict):
current_value = dst[k] current_value = dst[k]
new_value = parameters[k] new_value = parameters[k]
if isinstance(current_value, mx.array): if isinstance(current_value, mx.array):
if strict and not isinstance(new_value, mx.array):
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
dst[k] = new_value dst[k] = new_value
elif isinstance(current_value, Module): else:
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict:
raise ValueError(f'Module does not have parameter named "{k}".')
elif isinstance(parameters, list): elif isinstance(parameters, list):
for i in range(len(parameters)): for i in range(len(parameters)):
current_value = dst[i] current_value = dst[i]
new_value = parameters[i] new_value = parameters[i]
if isinstance(current_value, mx.array): if isinstance(current_value, mx.array):
if strict and not isinstance(new_value, mx.array):
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
dst[i] = new_value dst[i] = new_value
elif isinstance(current_value, Module): else:
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict:
raise ValueError(f"Received invalid type: {type(parameters).__name__}.")
apply(self, parameters) apply(self, parameters)
return self return self
@@ -359,7 +369,7 @@ class Module(dict):
self.update(self.filter_and_map(filter_fn, map_fn)) self.update(self.filter_and_map(filter_fn, map_fn))
return self return self
def update_modules(self, modules: dict) -> Module: def update_modules(self, modules: dict, strict: bool = True) -> Module:
"""Replace the child modules of this :class:`Module` instance with the """Replace the child modules of this :class:`Module` instance with the
provided ones in the dict of dicts and lists. provided ones in the dict of dicts and lists.
@@ -368,12 +378,14 @@ class Module(dict):
programmatically swapping layers. programmatically swapping layers.
The passed in parameters dictionary need not be a full dictionary The passed in parameters dictionary need not be a full dictionary
similar to :meth:`parameters`. Only the provided locations will be similar to :meth:`modules`. Only the provided locations will be
updated. updated.
Args: Args:
modules (dict): A complete or partial dictionary of the modules modules (dict): A complete or partial dictionary of the module's
submodules. submodules.
strict (bool): If ``True`` checks that ``modules`` is a
subset of the child modules of this instance. Default: ``True``.
Returns: Returns:
The module instance after updating the submodules. The module instance after updating the submodules.
""" """
@@ -388,6 +400,14 @@ class Module(dict):
dst[k] = new_value dst[k] = new_value
elif isinstance(current_value, (dict, list)): elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(
f'Module does not have sub-module named "{k}".'
)
elif isinstance(modules, list): elif isinstance(modules, list):
for i in range(len(dst)): for i in range(len(dst)):
current_value = dst[i] current_value = dst[i]
@@ -396,6 +416,12 @@ class Module(dict):
dst[i] = new_value dst[i] = new_value
elif isinstance(current_value, (dict, list)): elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
apply(self, modules) apply(self, modules)
return self return self

View File

@@ -54,5 +54,9 @@ target_link_libraries(core PRIVATE mlx)
target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION}) target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})
if(BUILD_SHARED_LIBS) if(BUILD_SHARED_LIBS)
target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib) if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib)
else()
target_link_options(core PRIVATE -Wl,-rpath,\$ORIGIN/lib)
endif()
endif() endif()

View File

@@ -219,6 +219,46 @@ class TestBase(mlx_tests.MLXTestCase):
x = mx.zeros((3,)) x = mx.zeros((3,))
mx.grad(loss_fn)(model) mx.grad(loss_fn)(model)
def test_update(self):
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
# Updating non-existent parameters
with self.assertRaises(ValueError):
updates = {"layers": [{"value": 0}]}
m.update(updates)
with self.assertRaises(ValueError):
updates = {"layers": ["hello"]}
m.update(updates)
# Wronge type
with self.assertRaises(ValueError):
updates = {"layers": [{"weight": "hi"}]}
m.update(updates)
def test_update_modules(self):
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
# Updating non-existent modules should not be allowed by default
with self.assertRaises(ValueError):
m = m.update_modules({"values": [0, 1]})
# Update wrong types
with self.assertRaises(ValueError):
m = m.update_modules({"layers": [0, 1]})
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.test = mx.array(1.0)
self.list = [mx.array(1.0), mx.array(2.0)]
m = MyModule()
with self.assertRaises(ValueError):
m = m.update_modules({"test": "hi"})
with self.assertRaises(ValueError):
m = m.update_modules({"list": ["hi"]})
class TestLayers(mlx_tests.MLXTestCase): class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self): def test_identity(self):