mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
8 Commits
v0.26.0
...
c6a20b427a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c6a20b427a | ||
|
|
a5ac9244c4 | ||
|
|
c763fe1be0 | ||
|
|
52dc8c8cd5 | ||
|
|
aede70e81d | ||
|
|
85a8beb5e4 | ||
|
|
0bb89e9e5f | ||
|
|
5685ceb3c7 |
@@ -1,8 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -79,55 +78,6 @@ std::string get_type_string(Dtype d) {
|
||||
}
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
||||
NodeNamer namer;
|
||||
std::ostringstream os;
|
||||
std::ostringstream constant_hasher;
|
||||
|
||||
// Fill the input names. This is not really necessary, I just like having A,
|
||||
// B, C, ... as the inputs.
|
||||
for (auto& x : inputs) {
|
||||
namer.get_name(x);
|
||||
}
|
||||
|
||||
// The primitives describing the tape. For unary and binary primitives this
|
||||
// must be enough to describe the full computation.
|
||||
for (auto& a : tape) {
|
||||
// name and type of output
|
||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||
// computation performed
|
||||
a.primitive().print(os);
|
||||
// name of inputs to the function
|
||||
for (auto& inp : a.inputs()) {
|
||||
os << namer.get_name(inp);
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
os << "C";
|
||||
print_constant(constant_hasher, x);
|
||||
} else {
|
||||
os << (is_scalar(x) ? "S" : "V");
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
continue;
|
||||
}
|
||||
os << kindof(x.dtype()) << x.itemsize();
|
||||
}
|
||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const Shape& shape) {
|
||||
@@ -159,8 +109,7 @@ bool compiled_check_contiguity(
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
@@ -175,8 +124,7 @@ void compiled_allocate_outputs(
|
||||
// - Donatable
|
||||
// - Not a constant
|
||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
in.is_donatable() && is_constant(i)) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
@@ -204,7 +152,7 @@ void compiled_allocate_outputs(
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
is_constant(i)) {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
@@ -216,4 +164,74 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
const std::vector<array>& inputs,
|
||||
const array& out,
|
||||
const std::function<bool(size_t)>& is_constant) {
|
||||
const Shape& shape = out.shape();
|
||||
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||
if (contiguous) {
|
||||
return {true, shape, {}};
|
||||
}
|
||||
|
||||
std::vector<Strides> strides_vec{out.strides()};
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants.
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip scalar inputs.
|
||||
const auto& x = inputs[i];
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
Strides xstrides;
|
||||
size_t j = 0;
|
||||
for (; j < shape.size() - x.ndim(); ++j) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
strides_vec.push_back(std::move(xstrides));
|
||||
}
|
||||
|
||||
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
|
||||
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
|
||||
}
|
||||
|
||||
bool compiled_use_large_index(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
bool contiguous) {
|
||||
if (contiguous) {
|
||||
size_t max_size = 0;
|
||||
for (const auto& in : inputs) {
|
||||
max_size = std::max(max_size, in.data_size());
|
||||
}
|
||||
return max_size > UINT32_MAX;
|
||||
} else {
|
||||
size_t max_size = 0;
|
||||
for (const auto& o : outputs) {
|
||||
max_size = std::max(max_size, o.size());
|
||||
}
|
||||
return max_size > UINT32_MAX;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -14,12 +13,6 @@ inline bool is_static_cast(const Primitive& p) {
|
||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids);
|
||||
|
||||
std::string get_type_string(Dtype d);
|
||||
|
||||
template <typename T>
|
||||
@@ -60,8 +53,19 @@ bool compiled_check_contiguity(
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous);
|
||||
|
||||
// Collapse contiguous dims ignoring scalars and constants.
|
||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
const std::vector<array>& inputs,
|
||||
const array& out,
|
||||
const std::function<bool(size_t)>& is_constant);
|
||||
|
||||
// Return whether the kernel should use large index.
|
||||
bool compiled_use_large_index(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
bool contiguous);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -146,18 +146,9 @@ inline void build_kernel(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous,
|
||||
int ndim) {
|
||||
// All outputs should have the exact same shape and will be row contiguous
|
||||
auto output_shape = outputs[0].shape();
|
||||
auto output_strides = outputs[0].strides();
|
||||
|
||||
// Constants are scalars that are captured by value and cannot change
|
||||
auto is_constant = [&constant_ids](const array& x) {
|
||||
return constant_ids.find(x.id()) != constant_ids.end();
|
||||
};
|
||||
|
||||
NodeNamer namer;
|
||||
|
||||
#ifdef _MSC_VER
|
||||
@@ -170,14 +161,15 @@ inline void build_kernel(
|
||||
|
||||
// Add the input arguments
|
||||
int cnt = 0;
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants from the input list
|
||||
if (is_constant(x)) {
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
auto tstr = get_type_string(x.dtype());
|
||||
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
||||
<< "];" << std::endl;
|
||||
@@ -211,10 +203,11 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Read the inputs in tmps
|
||||
for (auto& x : inputs) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(x)) {
|
||||
if (is_constant(i)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||
print_constant(os, x);
|
||||
os << ";" << std::endl;
|
||||
@@ -264,8 +257,9 @@ inline void build_kernel(
|
||||
} else {
|
||||
for (int d = ndim - 1; d >= 0; --d) {
|
||||
// Update pointers
|
||||
for (auto& x : inputs) {
|
||||
if (is_constant(x) || is_scalar(x)) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
if (is_constant(i) || is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
auto& xname = namer.get_name(x);
|
||||
@@ -287,65 +281,37 @@ inline void build_kernel(
|
||||
void Compiled::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
if (kernel_lib_.empty()) {
|
||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
||||
}
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& shape = outputs[0].shape();
|
||||
auto contiguous = compiled_check_contiguity(inputs, shape);
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
|
||||
// Handle all broadcasting and collect function input arguments
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
auto [contiguous, shape, strides] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
|
||||
// Collect function input arguments.
|
||||
std::vector<void*> args;
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
// Skip constants.
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
const auto& x = inputs[i];
|
||||
encoder.set_input_array(x);
|
||||
args.push_back((void*)x.data<void>());
|
||||
|
||||
if (contiguous || is_scalar(x)) {
|
||||
continue;
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
args.push_back(strides[strides_index++].data());
|
||||
}
|
||||
|
||||
// Broadcast the input to the output shape.
|
||||
std::vector<size_t> xstrides;
|
||||
int j = 0;
|
||||
for (; j < shape.size() - x.ndim(); j++) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
strides.push_back(std::move(xstrides));
|
||||
args.push_back(strides.back().data());
|
||||
}
|
||||
|
||||
// Get the kernel name from the lib
|
||||
int ndim = shape.size();
|
||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||
if (!contiguous) {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
kernel_name += std::to_string(ndim);
|
||||
}
|
||||
|
||||
// Get the function
|
||||
auto fn_ptr = compile(kernel_name, [&]() {
|
||||
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
|
||||
std::ostringstream kernel;
|
||||
kernel << get_kernel_preamble() << std::endl;
|
||||
kernel << "extern \"C\" {" << std::endl;
|
||||
@@ -355,7 +321,7 @@ void Compiled::eval_cpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
contiguous,
|
||||
ndim);
|
||||
// Close extern "C"
|
||||
@@ -363,26 +329,22 @@ void Compiled::eval_cpu(
|
||||
return kernel.str();
|
||||
});
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||
|
||||
for (auto& x : outputs) {
|
||||
args.push_back(x.data<void>());
|
||||
encoder.set_output_array(x);
|
||||
}
|
||||
Shape out_shape;
|
||||
if (!contiguous) {
|
||||
out_shape = outputs[0].shape();
|
||||
args.push_back((void*)out_shape.data());
|
||||
args.push_back((void*)shape.data());
|
||||
} else {
|
||||
args.push_back((void*)outputs[0].data_size());
|
||||
}
|
||||
auto fun = (void (*)(void**))fn_ptr;
|
||||
encoder.dispatch(
|
||||
[fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
|
||||
encoder.dispatch([fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
shape = std::move(shape)]() mutable { fun(args.data()); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -10,13 +10,15 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||
|
||||
# Enable defining device lambda functions.
|
||||
target_compile_options(mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||
|
||||
@@ -18,7 +18,10 @@ CudaAllocator::CudaAllocator()
|
||||
: buffer_cache_(
|
||||
getpagesize(),
|
||||
[](CudaBuffer* buf) { return buf->size; },
|
||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||
[this](CudaBuffer* buf) {
|
||||
cuda_free(buf->data);
|
||||
delete buf;
|
||||
}) {
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
@@ -70,7 +73,8 @@ void CudaAllocator::free(Buffer buffer) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
lock.unlock();
|
||||
cuda_free(buf);
|
||||
cuda_free(buf->data);
|
||||
delete buf;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,6 +91,25 @@ void CudaAllocator::register_this_thread() {
|
||||
allowed_threads_.insert(std::this_thread::get_id());
|
||||
}
|
||||
|
||||
void CudaAllocator::cuda_free(void* buf) {
|
||||
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
||||
// worker.
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||
if (!worker_) {
|
||||
worker_.reset(new Worker);
|
||||
}
|
||||
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
||||
worker_->end_batch();
|
||||
worker_->commit();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
cudaFree(buf);
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_active_memory() const {
|
||||
return active_memory_;
|
||||
}
|
||||
@@ -125,26 +148,6 @@ void CudaAllocator::clear_cache() {
|
||||
buffer_cache_.clear();
|
||||
}
|
||||
|
||||
void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
||||
// worker.
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||
if (!worker_) {
|
||||
worker_.reset(new Worker);
|
||||
}
|
||||
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
||||
worker_->end_batch();
|
||||
worker_->commit();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
cudaFree(buf->data);
|
||||
delete buf;
|
||||
}
|
||||
|
||||
CudaAllocator& allocator() {
|
||||
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
||||
// will not be called on exit and buffers in the cache will be leaked. This
|
||||
|
||||
@@ -34,6 +34,9 @@ class CudaAllocator : public allocator::Allocator {
|
||||
// buffers there would result in dead lock.
|
||||
void register_this_thread();
|
||||
|
||||
// Call cudaFree in the safe thread.
|
||||
void cuda_free(void* buf);
|
||||
|
||||
size_t get_active_memory() const;
|
||||
size_t get_peak_memory() const;
|
||||
void reset_peak_memory();
|
||||
@@ -47,8 +50,6 @@ class CudaAllocator : public allocator::Allocator {
|
||||
CudaAllocator();
|
||||
friend CudaAllocator& allocator();
|
||||
|
||||
void cuda_free(CudaBuffer* buf);
|
||||
|
||||
std::mutex worker_mutex_;
|
||||
std::unique_ptr<Worker> worker_;
|
||||
std::set<std::thread::id> allowed_threads_;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
@@ -111,12 +112,12 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
|
||||
SharedEvent::SharedEvent() {
|
||||
// Allocate cuda::atomic on managed memory.
|
||||
allocator::Buffer buffer = allocator::malloc(sizeof(Atomic));
|
||||
Atomic* ac = static_cast<Atomic*>(buffer.raw_ptr());
|
||||
Atomic* ac;
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
|
||||
new (ac) Atomic(0);
|
||||
ac_ = std::shared_ptr<Atomic>(ac, [buffer](Atomic* ptr) {
|
||||
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
|
||||
ptr->~Atomic();
|
||||
allocator::free(buffer);
|
||||
allocator().cuda_free(ptr);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -155,7 +156,10 @@ void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
||||
void SharedEvent::signal(Stream s, uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this, value]() mutable { signal(value); });
|
||||
// Signal through a GPU stream so the atomic is updated in GPU - updating
|
||||
// the atomic in CPU sometimes does not get GPU notified.
|
||||
static CudaStream stream(device(mlx::core::Device::gpu));
|
||||
scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); });
|
||||
} else {
|
||||
auto& encoder = get_command_encoder(s);
|
||||
encoder.launch_kernel(
|
||||
|
||||
29
mlx/backend/cuda/fence.cpp
Normal file
29
mlx/backend/cuda/fence.cpp
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
struct FenceImpl {
|
||||
uint32_t count;
|
||||
cu::SharedEvent event;
|
||||
};
|
||||
|
||||
Fence::Fence(Stream s) {
|
||||
fence_ = std::shared_ptr<void>(
|
||||
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
|
||||
}
|
||||
|
||||
void Fence::wait(Stream s, const array&) {
|
||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||
fence->event.wait(fence->count);
|
||||
}
|
||||
|
||||
void Fence::update(Stream s, const array&) {
|
||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||
fence->count++;
|
||||
fence->event.signal(s, fence->count);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,70 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
__host__ __device__ void busy_wait(cuda::atomic<uint64_t>* ac, uint64_t value) {
|
||||
while (true) {
|
||||
// In theory the atomic_thread_fence is not needed, but for CUDA 11 without
|
||||
// it the load() may never return new value.
|
||||
cuda::atomic_thread_fence(cuda::memory_order_seq_cst);
|
||||
uint64_t current = ac->load();
|
||||
if (current >= value) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void busy_wait_kernel(cuda::atomic<uint64_t>* ac, uint64_t value) {
|
||||
busy_wait(ac, value);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
struct FenceImpl {
|
||||
uint32_t count;
|
||||
cu::SharedEvent event;
|
||||
};
|
||||
|
||||
Fence::Fence(Stream s) {
|
||||
fence_ = std::shared_ptr<void>(
|
||||
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
|
||||
}
|
||||
|
||||
void Fence::wait(Stream s, const array&) {
|
||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||
// We can't use SharedEvent::wait because it could hang in CUDA 11, see also:
|
||||
// https://github.com/ml-explore/mlx/issues/2137
|
||||
const auto& ac = fence->event.atomic();
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [ac, count = fence->count]() {
|
||||
nvtx3::scoped_range r("Fence::wait()");
|
||||
busy_wait(ac.get(), count);
|
||||
});
|
||||
} else {
|
||||
nvtx3::scoped_range r("Fence::wait(s)");
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.launch_kernel(
|
||||
encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) {
|
||||
busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count);
|
||||
});
|
||||
encoder.add_completed_handler([ac]() {});
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
void Fence::update(Stream s, const array&) {
|
||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||
fence->count++;
|
||||
fence->event.signal(s, fence->count);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -5,9 +5,17 @@
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
|
||||
#if defined(MLX_USE_CUDA)
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#endif
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#if defined(MLX_USE_CUDA)
|
||||
#define MLX_PROFILER_RANGE(message) nvtx3::scoped_range r(message)
|
||||
#else
|
||||
#define MLX_PROFILER_RANGE(message)
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
|
||||
@@ -31,13 +31,13 @@ std::string get_kernel_name(
|
||||
kname = "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname = (large ? "sv2" : "sv");
|
||||
kname = "sv";
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname = (large ? "vs2" : "vs");
|
||||
kname = "vs";
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname = (large ? "vv2" : "vv");
|
||||
kname = "vv";
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname = "g";
|
||||
@@ -51,6 +51,13 @@ std::string get_kernel_name(
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (bopt != BinaryOpType::General && bopt != BinaryOpType::ScalarScalar) {
|
||||
if (large) {
|
||||
kname += "2";
|
||||
} else if (work_per_thread > 1) {
|
||||
kname += "n";
|
||||
}
|
||||
}
|
||||
concatenate(kname, "_", op, type_to_name(a));
|
||||
return kname;
|
||||
}
|
||||
@@ -90,7 +97,7 @@ void binary_op_gpu_inplace(
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
large = out.data_size() > UINT32_MAX;
|
||||
work_per_thread = get_work_per_thread(a.dtype());
|
||||
work_per_thread = get_work_per_thread(a.dtype(), out.data_size());
|
||||
}
|
||||
std::string kernel_name =
|
||||
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
|
||||
|
||||
@@ -11,8 +11,6 @@
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
using namespace fmt::literals;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline void build_kernel(
|
||||
@@ -21,21 +19,12 @@ inline void build_kernel(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous,
|
||||
int ndim,
|
||||
bool dynamic_dims,
|
||||
bool use_big_index = false,
|
||||
int work_per_thread = 1) {
|
||||
// All outputs should have the exact same shape and will be row contiguous
|
||||
auto output_shape = outputs[0].shape();
|
||||
auto output_strides = outputs[0].strides();
|
||||
|
||||
// Constants are scalars that are captured by value and cannot change
|
||||
auto is_constant = [&constant_ids](const array& x) {
|
||||
return constant_ids.find(x.id()) != constant_ids.end();
|
||||
};
|
||||
|
||||
NodeNamer namer;
|
||||
bool add_indices = false;
|
||||
int cnt = 0;
|
||||
@@ -45,14 +34,15 @@ inline void build_kernel(
|
||||
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
|
||||
|
||||
// Add the input arguments
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants from the input list
|
||||
if (is_constant(x)) {
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
// Scalars and contiguous need no strides
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
add_indices = true;
|
||||
@@ -80,8 +70,6 @@ inline void build_kernel(
|
||||
}
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os += fmt::format(
|
||||
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
|
||||
os += fmt::format(
|
||||
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
||||
} else {
|
||||
@@ -125,7 +113,7 @@ inline void build_kernel(
|
||||
auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(x)) {
|
||||
if (is_constant(i)) {
|
||||
auto type_str = get_type_string(x.dtype());
|
||||
std::ostringstream ss;
|
||||
print_constant(ss, x);
|
||||
@@ -271,11 +259,6 @@ inline void build_kernel(
|
||||
void Compiled::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Make the name for the kernel library
|
||||
if (kernel_lib_.empty()) {
|
||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
||||
}
|
||||
|
||||
// Get the kernel if someone else built it already
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@@ -290,19 +273,33 @@ void Compiled::eval_gpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread = */ work_per_thread);
|
||||
/* work_per_thread = */ 1);
|
||||
if (work_per_thread > 1) {
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous_n",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
is_constant_,
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread = */ work_per_thread);
|
||||
}
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous_large",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false,
|
||||
@@ -315,7 +312,7 @@ void Compiled::eval_gpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ i,
|
||||
/* dynamic_dims = */ false,
|
||||
@@ -328,7 +325,7 @@ void Compiled::eval_gpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ i,
|
||||
/* dynamic_dims = */ false,
|
||||
@@ -342,7 +339,7 @@ void Compiled::eval_gpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true,
|
||||
@@ -354,7 +351,7 @@ void Compiled::eval_gpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true,
|
||||
@@ -363,81 +360,32 @@ void Compiled::eval_gpu(
|
||||
return kernel;
|
||||
});
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& output_shape = outputs[0].shape();
|
||||
auto contiguous = compiled_check_contiguity(inputs, output_shape);
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
std::vector<Strides> initial_strides;
|
||||
initial_strides.push_back(outputs[0].strides());
|
||||
Shape shape;
|
||||
std::vector<Strides> strides;
|
||||
if (!contiguous) {
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
// Skip constants.
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
auto [contiguous, shape, strides] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
|
||||
// Skip scalar inputs.
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
Strides xstrides;
|
||||
int j = 0;
|
||||
for (; j < output_shape.size() - x.ndim(); j++) {
|
||||
if (output_shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (output_shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
initial_strides.push_back(std::move(xstrides));
|
||||
}
|
||||
std::tie(shape, strides) =
|
||||
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
|
||||
}
|
||||
|
||||
bool large;
|
||||
if (contiguous) {
|
||||
size_t max_size = 0;
|
||||
for (auto& in : inputs) {
|
||||
max_size = std::max(max_size, in.data_size());
|
||||
}
|
||||
large = (max_size > UINT32_MAX);
|
||||
} else {
|
||||
size_t max_size = 0;
|
||||
for (auto& o : outputs) {
|
||||
max_size = std::max(max_size, o.size());
|
||||
}
|
||||
large = (max_size > UINT32_MAX);
|
||||
}
|
||||
// Whether to use large index.
|
||||
bool large = compiled_use_large_index(inputs, outputs, contiguous);
|
||||
|
||||
// Get the kernel from the lib
|
||||
int ndim = shape.size();
|
||||
bool dynamic = ndim >= 8;
|
||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||
int work_per_thread = 1;
|
||||
if (!contiguous) {
|
||||
if (dynamic) {
|
||||
kernel_name += "dynamic";
|
||||
} else {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
}
|
||||
work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
|
||||
} else {
|
||||
work_per_thread =
|
||||
get_work_per_thread(outputs[0].dtype(), outputs[0].data_size());
|
||||
if (work_per_thread > 1 && !large) {
|
||||
kernel_name += "_n";
|
||||
}
|
||||
}
|
||||
if (large) {
|
||||
kernel_name += "_large";
|
||||
@@ -451,7 +399,7 @@ void Compiled::eval_gpu(
|
||||
int stride_idx = 1; // idx 0 is the output strides
|
||||
Strides in_strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
@@ -468,8 +416,7 @@ void Compiled::eval_gpu(
|
||||
compute_encoder.set_vector_bytes(in_strides, cnt++);
|
||||
}
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||
|
||||
// Put the outputs in
|
||||
for (auto& x : outputs) {
|
||||
@@ -478,7 +425,6 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Put the output shape and strides in
|
||||
if (!contiguous) {
|
||||
compute_encoder.set_vector_bytes(strides[0], cnt++);
|
||||
compute_encoder.set_vector_bytes(shape, cnt++);
|
||||
} else {
|
||||
auto size = outputs[0].data_size();
|
||||
@@ -496,7 +442,6 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Launch the kernel
|
||||
if (contiguous) {
|
||||
int work_per_thread = get_work_per_thread(outputs[0].dtype());
|
||||
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
@@ -509,7 +454,6 @@ void Compiled::eval_gpu(
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = outputs[0].size() / (dim0 * dim1);
|
||||
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
int pow2;
|
||||
|
||||
@@ -55,10 +55,10 @@ void copy_gpu_inplace(
|
||||
std::string kernel_name;
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kernel_name = (large ? "s2" : "s");
|
||||
kernel_name = large ? "s2" : "s";
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kernel_name = (large ? "v2" : "v");
|
||||
kernel_name = large ? "v2" : "v";
|
||||
break;
|
||||
case CopyType::General:
|
||||
kernel_name = "g";
|
||||
@@ -85,7 +85,10 @@ void copy_gpu_inplace(
|
||||
}
|
||||
}
|
||||
} else {
|
||||
work_per_thread = get_work_per_thread(in.dtype());
|
||||
work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
|
||||
if (work_per_thread > 1) {
|
||||
kernel_name += "n";
|
||||
}
|
||||
}
|
||||
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
||||
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
|
||||
@@ -170,9 +173,10 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
}
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
bool large = out.data_size() > UINT32_MAX;
|
||||
int work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
|
||||
auto& d = metal::device(s.device);
|
||||
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
|
||||
type_to_name(val) + type_to_name(out);
|
||||
std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s");
|
||||
concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out));
|
||||
auto kernel = get_copy_kernel(d, kernel_name, val, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
@@ -180,7 +184,6 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
compute_encoder.set_input_array(val, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
int work_per_thread = get_work_per_thread(val.dtype());
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||
if (thread_group_size > nthreads) {
|
||||
|
||||
@@ -41,7 +41,11 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::unary_ops(), metal::unary());
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op);
|
||||
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1);
|
||||
if (get_work_per_thread(in_type) > 1) {
|
||||
kernel_source +=
|
||||
get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op);
|
||||
}
|
||||
kernel_source +=
|
||||
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||
kernel_source += get_template_definition(
|
||||
@@ -59,11 +63,8 @@ void append_binary_kernels(
|
||||
Dtype out_type,
|
||||
const std::string op,
|
||||
std::string& kernel_source) {
|
||||
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
|
||||
const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
|
||||
{"ss", "binary_ss"},
|
||||
{"vs", "binary_vs"},
|
||||
{"sv", "binary_sv"},
|
||||
{"vv", "binary_vv"},
|
||||
{"vs2", "binary_vs2"},
|
||||
{"sv2", "binary_sv2"},
|
||||
{"vv2", "binary_vv2"},
|
||||
@@ -78,6 +79,22 @@ void append_binary_kernels(
|
||||
kernel_source +=
|
||||
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
|
||||
}
|
||||
kernel_source += get_template_definition(
|
||||
"vs_" + lib_name, "binary_vs", in_t, out_t, op, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"sv_" + lib_name, "binary_sv", in_t, out_t, op, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"vv_" + lib_name, "binary_vv", in_t, out_t, op, 1);
|
||||
|
||||
if (get_work_per_thread(in_type) > 1) {
|
||||
kernel_source += get_template_definition(
|
||||
"vsn_" + lib_name, "binary_vs", in_t, out_t, op);
|
||||
kernel_source += get_template_definition(
|
||||
"svn_" + lib_name, "binary_sv", in_t, out_t, op);
|
||||
kernel_source += get_template_definition(
|
||||
"vvn_" + lib_name, "binary_vv", in_t, out_t, op);
|
||||
}
|
||||
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
@@ -133,8 +150,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
auto t_str = get_type_string(type);
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
|
||||
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
|
||||
{"v", "ternary_v"},
|
||||
const std::array<std::pair<std::string, std::string>, 4> kernel_types = {{
|
||||
{"v2", "ternary_v2"},
|
||||
{"g1large", "ternary_g_nd1"},
|
||||
{"g2large", "ternary_g_nd2"},
|
||||
@@ -144,6 +160,13 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
kernel_source +=
|
||||
get_template_definition(name + "_" + lib_name, func, t_str, op);
|
||||
}
|
||||
if (get_work_per_thread(type) > 1) {
|
||||
kernel_source +=
|
||||
get_template_definition("vn_" + lib_name, "ternary_v", t_str, op);
|
||||
}
|
||||
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
@@ -170,15 +193,22 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
kernel_source += metal::copy();
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
kernel_source +=
|
||||
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"s_" + lib_name, "copy_s", in_type, out_type, 1);
|
||||
kernel_source +=
|
||||
get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type);
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"v_" + lib_name, "copy_v", in_type, out_type, 1);
|
||||
kernel_source +=
|
||||
get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type);
|
||||
|
||||
if (get_work_per_thread(out.dtype()) > 1) {
|
||||
kernel_source += get_template_definition(
|
||||
"sn_" + lib_name, "copy_s", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"vn_" + lib_name, "copy_v", in_type, out_type);
|
||||
}
|
||||
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
|
||||
@@ -17,8 +17,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[0], b[index + i]);
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
c[index + i] = Op()(a[0], b[index + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[index + i] = Op()(a[0], b[index + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,8 +36,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[0]);
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[0]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,8 +55,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[index + i]);
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[index + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[index + i] = Op()(a[index + i], b[index + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,8 +75,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[0], b[offset + i]);
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
c[offset + i] = Op()(a[0], b[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[offset + i] = Op()(a[0], b[offset + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,8 +95,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[0]);
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[0]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,8 +115,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[offset + i]);
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c[offset + i] = Op()(a[offset + i], b[offset + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,11 +9,16 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
#define instantiate_binary_work_per_thread(op, tname, itype, otype) \
|
||||
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) \
|
||||
|
||||
#define instantiate_binary_base(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
@@ -26,15 +31,19 @@
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
||||
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
|
||||
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
||||
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
||||
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
||||
instantiate_binary_all(op, int64, int64_t, int64_t)
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_binary_base(op, tname, itype, otype) \
|
||||
instantiate_binary_work_per_thread(op, tname, itype, otype)
|
||||
|
||||
#define instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
||||
instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
|
||||
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
||||
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
||||
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
||||
instantiate_binary_base(op, int64, int64_t, int64_t)
|
||||
|
||||
#define instantiate_binary_float(op) \
|
||||
instantiate_binary_all(op, float16, half, half) \
|
||||
@@ -44,7 +53,7 @@
|
||||
#define instantiate_binary_types(op) \
|
||||
instantiate_binary_all(op, bool_, bool, bool) \
|
||||
instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
|
||||
instantiate_binary_base(op, complex64, complex64_t, complex64_t)\
|
||||
instantiate_binary_float(op)
|
||||
|
||||
#define instantiate_binary_types_bool(op) \
|
||||
@@ -52,15 +61,15 @@
|
||||
instantiate_binary_all(op, uint8, uint8_t, bool) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, bool) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, bool) \
|
||||
instantiate_binary_all(op, uint64, uint64_t, bool) \
|
||||
instantiate_binary_base(op, uint64, uint64_t, bool) \
|
||||
instantiate_binary_all(op, int8, int8_t, bool) \
|
||||
instantiate_binary_all(op, int16, int16_t, bool) \
|
||||
instantiate_binary_all(op, int32, int32_t, bool) \
|
||||
instantiate_binary_all(op, int64, int64_t, bool) \
|
||||
instantiate_binary_base(op, int64, int64_t, bool) \
|
||||
instantiate_binary_all(op, float16, half, bool) \
|
||||
instantiate_binary_all(op, float32, float, bool) \
|
||||
instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \
|
||||
instantiate_binary_all(op, complex64, complex64_t, bool)
|
||||
instantiate_binary_base(op, complex64, complex64_t, bool)
|
||||
|
||||
instantiate_binary_types(Add)
|
||||
instantiate_binary_types(Divide)
|
||||
@@ -71,7 +80,7 @@ instantiate_binary_types_bool(Less)
|
||||
instantiate_binary_types_bool(LessEqual)
|
||||
instantiate_binary_types_bool(NotEqual)
|
||||
instantiate_binary_float(LogAddExp)
|
||||
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
|
||||
instantiate_binary_base(LogAddExp, complex64, complex64_t, complex64_t)
|
||||
instantiate_binary_types(Maximum)
|
||||
instantiate_binary_types(Minimum)
|
||||
instantiate_binary_types(Multiply)
|
||||
@@ -84,7 +93,7 @@ instantiate_binary_float(ArcTan2)
|
||||
instantiate_binary_all(NaNEqual, float16, half, bool)
|
||||
instantiate_binary_all(NaNEqual, float32, float, bool)
|
||||
instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool)
|
||||
instantiate_binary_all(NaNEqual, complex64, complex64_t, bool)
|
||||
instantiate_binary_base(NaNEqual, complex64, complex64_t, bool)
|
||||
|
||||
instantiate_binary_all(LogicalOr, bool_, bool, bool)
|
||||
instantiate_binary_all(LogicalAnd, bool_, bool, bool)
|
||||
|
||||
@@ -21,10 +21,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
auto out = Op()(a[0], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
auto out = Op()(a[0], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
auto out = Op()(a[0], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,10 +45,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
auto out = Op()(a[index + i], b[0]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
auto out = Op()(a[index + i], b[0]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
auto out = Op()(a[index + i], b[0]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,10 +69,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
auto out = Op()(a[index + i], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
auto out = Op()(a[index + i], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
auto out = Op()(a[index + i], b[index + i]);
|
||||
c[index + i] = out[0];
|
||||
d[index + i] = out[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,11 +93,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
auto out = Op()(a[0], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
auto out = Op()(a[0], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
auto out = Op()(a[0], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,11 +118,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
auto out = Op()(a[offset + i], b[0]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
auto out = Op()(a[offset + i], b[0]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
auto out = Op()(a[offset + i], b[0]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,11 +143,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
auto out = Op()(a[offset + i], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
auto out = Op()(a[offset + i], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
auto out = Op()(a[offset + i], b[offset + i]);
|
||||
c[offset + i] = out[0];
|
||||
d[offset + i] = out[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,11 +7,16 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary_two.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
#define instantiate_binary_work_per_thread(op, tname, itype, otype) \
|
||||
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_base(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
@@ -24,22 +29,26 @@
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_binary_base(op, tname, itype, otype) \
|
||||
instantiate_binary_work_per_thread(op, tname, itype, otype)
|
||||
|
||||
#define instantiate_binary_float(op) \
|
||||
instantiate_binary_all(op, float16, half, half) \
|
||||
instantiate_binary_all(op, float32, float, float) \
|
||||
instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)
|
||||
|
||||
#define instantiate_binary_types(op) \
|
||||
instantiate_binary_all(op, bool_, bool, bool) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
||||
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
|
||||
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
||||
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
||||
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
||||
instantiate_binary_all(op, int64, int64_t, int64_t) \
|
||||
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
|
||||
#define instantiate_binary_types(op) \
|
||||
instantiate_binary_all(op, bool_, bool, bool) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
||||
instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
|
||||
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
||||
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
||||
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
||||
instantiate_binary_base(op, int64, int64_t, int64_t) \
|
||||
instantiate_binary_base(op, complex64, complex64_t, complex64_t) \
|
||||
instantiate_binary_float(op)
|
||||
|
||||
instantiate_binary_types(DivMod) // clang-format on
|
||||
|
||||
@@ -1,52 +1,76 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
||||
template <typename T, typename U, int N = WorkPerThread<U>::n>
|
||||
[[kernel]] void copy_s(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
dst[index + i] = static_cast<U>(src[0]);
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
dst[index + i] = static_cast<U>(src[0]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
dst[index + i] = static_cast<U>(src[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
||||
template <typename T, typename U, int N = WorkPerThread<U>::n>
|
||||
[[kernel]] void copy_v(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
dst[index + i] = static_cast<U>(src[index + i]);
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
dst[index + i] = static_cast<U>(src[index + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
dst[index + i] = static_cast<U>(src[index + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
||||
template <typename T, typename U, int N = WorkPerThread<U>::n>
|
||||
[[kernel]] void copy_s2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
dst[offset + i] = static_cast<U>(src[0]);
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
dst[offset + i] = static_cast<U>(src[0]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
dst[offset + i] = static_cast<U>(src[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
||||
template <typename T, typename U, int N = WorkPerThread<U>::n>
|
||||
[[kernel]] void copy_v2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
dst[offset + i] = static_cast<U>(src[offset + i]);
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
dst[offset + i] = static_cast<U>(src[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
dst[offset + i] = static_cast<U>(src[offset + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,9 +4,13 @@
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/copy.h"
|
||||
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
|
||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
||||
#define instantiate_copy_work_per_thread(tname, itype, otype) \
|
||||
instantiate_kernel("sn_copy" #tname, copy_s, itype, otype) \
|
||||
instantiate_kernel("vn_copy" #tname, copy_v, itype, otype)
|
||||
|
||||
#define instantiate_copy_base(tname, itype, otype) \
|
||||
instantiate_kernel("s_copy" #tname, copy_s, itype, otype, 1) \
|
||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype, 1) \
|
||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
|
||||
@@ -18,6 +22,10 @@
|
||||
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
||||
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4)
|
||||
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_copy_base(tname, itype, otype) \
|
||||
instantiate_copy_work_per_thread(tname, itype, otype)
|
||||
|
||||
#define instantiate_copy_same(tname, type) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
|
||||
@@ -42,15 +50,15 @@
|
||||
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
||||
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
|
||||
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
|
||||
instantiate_copy_all(itname ##uint64, itype, uint64_t) \
|
||||
instantiate_copy_base(itname ##uint64, itype, uint64_t) \
|
||||
instantiate_copy_all(itname ##int8, itype, int8_t) \
|
||||
instantiate_copy_all(itname ##int16, itype, int16_t) \
|
||||
instantiate_copy_all(itname ##int32, itype, int32_t) \
|
||||
instantiate_copy_all(itname ##int64, itype, int64_t) \
|
||||
instantiate_copy_base(itname ##int64, itype, int64_t) \
|
||||
instantiate_copy_all(itname ##float16, itype, half) \
|
||||
instantiate_copy_all(itname ##float32, itype, float) \
|
||||
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
|
||||
instantiate_copy_all(itname ##complex64, itype, complex64_t)
|
||||
instantiate_copy_base(itname ##complex64, itype, complex64_t)
|
||||
|
||||
instantiate_copy_itype(bool_, bool)
|
||||
instantiate_copy_itype(uint8, uint8_t)
|
||||
|
||||
@@ -9,8 +9,14 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,9 +29,15 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
#include "mlx/backend/metal/kernels/ternary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
|
||||
#define instantiate_ternary_all(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||
#define instantiate_ternary_base(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op, 1) \
|
||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
|
||||
@@ -20,19 +20,23 @@
|
||||
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
|
||||
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
|
||||
|
||||
#define instantiate_ternary_all(op, tname, type) \
|
||||
instantiate_kernel("vn_" #op #tname, ternary_v, type, op) \
|
||||
instantiate_ternary_base(op, tname, type)
|
||||
|
||||
#define instantiate_ternary_types(op) \
|
||||
instantiate_ternary_all(op, bool_, bool) \
|
||||
instantiate_ternary_all(op, uint8, uint8_t) \
|
||||
instantiate_ternary_all(op, uint16, uint16_t) \
|
||||
instantiate_ternary_all(op, uint32, uint32_t) \
|
||||
instantiate_ternary_all(op, uint64, uint64_t) \
|
||||
instantiate_ternary_base(op, uint64, uint64_t) \
|
||||
instantiate_ternary_all(op, int8, int8_t) \
|
||||
instantiate_ternary_all(op, int16, int16_t) \
|
||||
instantiate_ternary_all(op, int32, int32_t) \
|
||||
instantiate_ternary_all(op, int64, int64_t) \
|
||||
instantiate_ternary_base(op, int64, int64_t) \
|
||||
instantiate_ternary_all(op, float16, half) \
|
||||
instantiate_ternary_all(op, float32, float) \
|
||||
instantiate_ternary_all(op, bfloat16, bfloat16_t) \
|
||||
instantiate_ternary_all(op, complex64, complex64_t) // clang-format on
|
||||
instantiate_ternary_base(op, complex64, complex64_t) // clang-format on
|
||||
|
||||
instantiate_ternary_types(Select)
|
||||
|
||||
@@ -7,8 +7,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant uint& size,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
index *= N;
|
||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||
out[index + i] = Op()(in[index + i]);
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
out[index + i] = Op()(in[index + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
out[index + i] = Op()(in[index + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,9 +25,15 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
constant int64_t& size,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||
out[offset + i] = Op()(in[offset + i]);
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
out[offset + i] = Op()(in[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
out[offset + i] = Op()(in[offset + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,31 +5,41 @@
|
||||
#include "mlx/backend/metal/kernels/unary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
||||
|
||||
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
|
||||
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
|
||||
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
|
||||
instantiate_kernel( \
|
||||
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
|
||||
instantiate_kernel( \
|
||||
#define instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) \
|
||||
instantiate_kernel("vn_" #op #in_tname #out_tname, unary_v, in_type, out_type, op)
|
||||
|
||||
#define instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \
|
||||
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op, 1) \
|
||||
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
|
||||
instantiate_kernel( \
|
||||
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
|
||||
instantiate_kernel( \
|
||||
"gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
||||
|
||||
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
|
||||
instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \
|
||||
instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type)
|
||||
|
||||
#define instantiate_unary_all_same(op, tname, type) \
|
||||
instantiate_unary_all(op, tname, tname, type, type)
|
||||
|
||||
#define instantiate_unary_base_same(op, tname, type) \
|
||||
instantiate_unary_base(op, tname, tname, type, type)
|
||||
|
||||
#define instantiate_unary_float(op) \
|
||||
instantiate_unary_all_same(op, float16, half) \
|
||||
instantiate_unary_all_same(op, float32, float) \
|
||||
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
|
||||
|
||||
#define instantiate_unary_int(op) \
|
||||
instantiate_unary_all_same(op, uint8, uint8_t) \
|
||||
instantiate_unary_all_same(op, uint16, uint16_t) \
|
||||
instantiate_unary_all_same(op, uint32, uint32_t) \
|
||||
instantiate_unary_all_same(op, uint64, uint64_t) \
|
||||
instantiate_unary_all_same(op, int8, int8_t) \
|
||||
instantiate_unary_all_same(op, int16, int16_t) \
|
||||
instantiate_unary_all_same(op, int32, int32_t) \
|
||||
instantiate_unary_all_same(op, int64, int64_t)
|
||||
#define instantiate_unary_int(op) \
|
||||
instantiate_unary_all_same(op, uint8, uint8_t) \
|
||||
instantiate_unary_all_same(op, uint16, uint16_t) \
|
||||
instantiate_unary_all_same(op, uint32, uint32_t) \
|
||||
instantiate_unary_base_same(op, uint64, uint64_t) \
|
||||
instantiate_unary_all_same(op, int8, int8_t) \
|
||||
instantiate_unary_all_same(op, int16, int16_t) \
|
||||
instantiate_unary_all_same(op, int32, int32_t) \
|
||||
instantiate_unary_base_same(op, int64, int64_t)
|
||||
|
||||
#define instantiate_unary_types(op) \
|
||||
instantiate_unary_all_same(op, bool_, bool) \
|
||||
@@ -68,29 +78,29 @@ instantiate_unary_float(Tanh)
|
||||
instantiate_unary_float(Round)
|
||||
instantiate_unary_int(BitwiseInvert)
|
||||
|
||||
instantiate_unary_all_same(Abs, complex64, complex64_t)
|
||||
instantiate_unary_all_same(ArcCos, complex64, complex64_t)
|
||||
instantiate_unary_all_same(ArcSin, complex64, complex64_t)
|
||||
instantiate_unary_all_same(ArcTan, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cos, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Exp, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log1p, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log2, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log10, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Negative, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sign, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sin, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sinh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Square, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sqrt, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Rsqrt, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Tan, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Tanh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Round, complex64, complex64_t)
|
||||
instantiate_unary_all(Real, complex64, float32, complex64_t, float)
|
||||
instantiate_unary_all(Imag, complex64, float32, complex64_t, float)
|
||||
instantiate_unary_base_same(Abs, complex64, complex64_t)
|
||||
instantiate_unary_base_same(ArcCos, complex64, complex64_t)
|
||||
instantiate_unary_base_same(ArcSin, complex64, complex64_t)
|
||||
instantiate_unary_base_same(ArcTan, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Conjugate, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Cos, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Exp, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Log, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Log1p, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Log2, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Log10, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Negative, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Sign, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Sin, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Sinh, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Square, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Sqrt, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Rsqrt, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Tan, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Tanh, complex64, complex64_t)
|
||||
instantiate_unary_base_same(Round, complex64, complex64_t)
|
||||
instantiate_unary_base(Real, complex64, float32, complex64_t, float)
|
||||
instantiate_unary_base(Imag, complex64, float32, complex64_t, float)
|
||||
|
||||
instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on
|
||||
|
||||
@@ -45,7 +45,7 @@ void ternary_op_gpu_inplace(
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
large = out.data_size() > INT32_MAX;
|
||||
work_per_thread = get_work_per_thread(b.dtype());
|
||||
work_per_thread = get_work_per_thread(b.dtype(), out.data_size());
|
||||
}
|
||||
std::string kernel_name;
|
||||
if (topt == TernaryOpType::General) {
|
||||
@@ -60,6 +60,8 @@ void ternary_op_gpu_inplace(
|
||||
}
|
||||
} else if (large) {
|
||||
kernel_name = "v2";
|
||||
} else if (work_per_thread > 1) {
|
||||
kernel_name = "vn";
|
||||
} else {
|
||||
kernel_name = "v";
|
||||
}
|
||||
|
||||
@@ -43,8 +43,8 @@ void unary_op_gpu_inplace(
|
||||
int work_per_thread;
|
||||
std::string kernel_name;
|
||||
if (contig) {
|
||||
work_per_thread = get_work_per_thread(in.dtype());
|
||||
kernel_name = (large ? "v2" : "v");
|
||||
work_per_thread = get_work_per_thread(in.dtype(), in.data_size());
|
||||
kernel_name = (large ? "v2" : (work_per_thread > 1 ? "vn" : "v"));
|
||||
} else {
|
||||
work_per_thread = large ? 4 : 1;
|
||||
kernel_name = "gn" + std::to_string(work_per_thread);
|
||||
|
||||
@@ -72,6 +72,10 @@ void concatenate(std::string& acc, T first, Args... args) {
|
||||
inline int get_work_per_thread(Dtype dtype) {
|
||||
return std::max(1, 8 / dtype.size());
|
||||
}
|
||||
inline int get_work_per_thread(Dtype dtype, size_t size) {
|
||||
constexpr size_t wpt_threshold = 1 << 16;
|
||||
return size < wpt_threshold ? 1 : std::max(1, 8 / dtype.size());
|
||||
}
|
||||
|
||||
inline size_t ceildiv(size_t n, size_t m) {
|
||||
return (n + m - 1) / m;
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <cstdlib>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/compile.h"
|
||||
#include "mlx/compile_impl.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -82,7 +86,54 @@ Compiled::Compiled(
|
||||
inputs_(std::move(inputs)),
|
||||
outputs_(std::move(outputs)),
|
||||
tape_(std::move(tape)),
|
||||
constant_ids_(std::move(constant_ids)) {}
|
||||
constant_ids_(std::move(constant_ids)),
|
||||
is_constant_([this](size_t i) {
|
||||
return constant_ids_.find(inputs_[i].id()) != constant_ids_.end();
|
||||
}) {
|
||||
// Build the kernel name.
|
||||
NodeNamer namer;
|
||||
std::ostringstream os;
|
||||
std::ostringstream constant_hasher;
|
||||
|
||||
// Fill the input names. This is not really necessary, I just like having A,
|
||||
// B, C, ... as the inputs.
|
||||
for (const auto& x : inputs_) {
|
||||
namer.get_name(x);
|
||||
}
|
||||
|
||||
// The primitives describing the tape. For unary and binary primitives this
|
||||
// must be enough to describe the full computation.
|
||||
for (const auto& a : tape_) {
|
||||
// name and type of output
|
||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||
// computation performed
|
||||
a.primitive().print(os);
|
||||
// name of inputs to the function
|
||||
for (auto& inp : a.inputs()) {
|
||||
os << namer.get_name(inp);
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
|
||||
for (const auto& x : inputs_) {
|
||||
if (constant_ids_.find(x.id()) != constant_ids_.end()) {
|
||||
os << "C";
|
||||
print_constant(constant_hasher, x);
|
||||
} else {
|
||||
os << (is_scalar(x) ? "S" : "V");
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
for (const auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
continue;
|
||||
}
|
||||
os << kindof(x.dtype()) << x.itemsize();
|
||||
}
|
||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
||||
|
||||
kernel_lib_ = os.str();
|
||||
}
|
||||
|
||||
std::vector<array> Compiled::vjp(
|
||||
const std::vector<array>&,
|
||||
|
||||
@@ -627,6 +627,7 @@ class Compiled : public Primitive {
|
||||
const std::vector<array> outputs_;
|
||||
const std::vector<array> tape_;
|
||||
const std::unordered_set<uintptr_t> constant_ids_;
|
||||
const std::function<bool(size_t)> is_constant_;
|
||||
|
||||
std::string kernel_lib_;
|
||||
};
|
||||
|
||||
@@ -208,7 +208,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
// output arrays stream
|
||||
fences[it->second].wait(stream, in);
|
||||
} else if (in.event().valid()) {
|
||||
if (in.event().stream() != stream) {
|
||||
if (in.event().is_signaled()) {
|
||||
in.detach_event();
|
||||
} else if (in.event().stream() != stream) {
|
||||
// Use event to wait across async eval
|
||||
in.event().wait(stream);
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
#define MLX_VERSION_MAJOR 0
|
||||
#define MLX_VERSION_MINOR 26
|
||||
#define MLX_VERSION_PATCH 0
|
||||
#define MLX_VERSION_PATCH 1
|
||||
#define MLX_VERSION_NUMERIC \
|
||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ class Module(dict):
|
||||
)
|
||||
|
||||
if len(weights) != 0:
|
||||
self.update(tree_unflatten(weights))
|
||||
self.update(tree_unflatten(weights), strict=False)
|
||||
return self
|
||||
|
||||
def save_weights(self, file: str):
|
||||
@@ -291,7 +291,7 @@ class Module(dict):
|
||||
|
||||
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
|
||||
|
||||
def update(self, parameters: dict) -> Module:
|
||||
def update(self, parameters: dict, strict: bool = True) -> Module:
|
||||
"""Replace the parameters of this Module with the provided ones in the
|
||||
dict of dicts and lists.
|
||||
|
||||
@@ -305,7 +305,9 @@ class Module(dict):
|
||||
|
||||
Args:
|
||||
parameters (dict): A complete or partial dictionary of the modules
|
||||
parameters.
|
||||
parameters.
|
||||
strict (bool): If ``True`` checks that ``parameters`` is a
|
||||
subset of the module's parameters. Default: ``True``.
|
||||
Returns:
|
||||
The module instance after updating the parameters.
|
||||
"""
|
||||
@@ -317,21 +319,29 @@ class Module(dict):
|
||||
current_value = dst[k]
|
||||
new_value = parameters[k]
|
||||
if isinstance(current_value, mx.array):
|
||||
if strict and not isinstance(new_value, mx.array):
|
||||
raise ValueError(
|
||||
f"Received invalid type: {type(new_value).__name__}."
|
||||
)
|
||||
dst[k] = new_value
|
||||
elif isinstance(current_value, Module):
|
||||
current_value.update(new_value)
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
else:
|
||||
apply(current_value, new_value)
|
||||
elif strict:
|
||||
raise ValueError(f'Module does not have parameter named "{k}".')
|
||||
elif isinstance(parameters, list):
|
||||
for i in range(len(parameters)):
|
||||
current_value = dst[i]
|
||||
new_value = parameters[i]
|
||||
if isinstance(current_value, mx.array):
|
||||
if strict and not isinstance(new_value, mx.array):
|
||||
raise ValueError(
|
||||
f"Received invalid type: {type(new_value).__name__}."
|
||||
)
|
||||
dst[i] = new_value
|
||||
elif isinstance(current_value, Module):
|
||||
current_value.update(new_value)
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
else:
|
||||
apply(current_value, new_value)
|
||||
elif strict:
|
||||
raise ValueError(f"Received invalid type: {type(parameters).__name__}.")
|
||||
|
||||
apply(self, parameters)
|
||||
return self
|
||||
@@ -359,7 +369,7 @@ class Module(dict):
|
||||
self.update(self.filter_and_map(filter_fn, map_fn))
|
||||
return self
|
||||
|
||||
def update_modules(self, modules: dict) -> Module:
|
||||
def update_modules(self, modules: dict, strict: bool = True) -> Module:
|
||||
"""Replace the child modules of this :class:`Module` instance with the
|
||||
provided ones in the dict of dicts and lists.
|
||||
|
||||
@@ -368,12 +378,14 @@ class Module(dict):
|
||||
programmatically swapping layers.
|
||||
|
||||
The passed in parameters dictionary need not be a full dictionary
|
||||
similar to :meth:`parameters`. Only the provided locations will be
|
||||
similar to :meth:`modules`. Only the provided locations will be
|
||||
updated.
|
||||
|
||||
Args:
|
||||
modules (dict): A complete or partial dictionary of the modules
|
||||
modules (dict): A complete or partial dictionary of the module's
|
||||
submodules.
|
||||
strict (bool): If ``True`` checks that ``modules`` is a
|
||||
subset of the child modules of this instance. Default: ``True``.
|
||||
Returns:
|
||||
The module instance after updating the submodules.
|
||||
"""
|
||||
@@ -388,6 +400,14 @@ class Module(dict):
|
||||
dst[k] = new_value
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
elif strict:
|
||||
raise ValueError(
|
||||
f"Received invalid type: {type(new_value).__name__}."
|
||||
)
|
||||
elif strict:
|
||||
raise ValueError(
|
||||
f'Module does not have sub-module named "{k}".'
|
||||
)
|
||||
elif isinstance(modules, list):
|
||||
for i in range(len(dst)):
|
||||
current_value = dst[i]
|
||||
@@ -396,6 +416,12 @@ class Module(dict):
|
||||
dst[i] = new_value
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
elif strict:
|
||||
raise ValueError(
|
||||
f"Received invalid type: {type(new_value).__name__}."
|
||||
)
|
||||
elif strict:
|
||||
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
|
||||
|
||||
apply(self, modules)
|
||||
return self
|
||||
|
||||
@@ -54,5 +54,9 @@ target_link_libraries(core PRIVATE mlx)
|
||||
target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib)
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib)
|
||||
else()
|
||||
target_link_options(core PRIVATE -Wl,-rpath,\$ORIGIN/lib)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -219,6 +219,46 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
x = mx.zeros((3,))
|
||||
mx.grad(loss_fn)(model)
|
||||
|
||||
def test_update(self):
|
||||
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
|
||||
|
||||
# Updating non-existent parameters
|
||||
with self.assertRaises(ValueError):
|
||||
updates = {"layers": [{"value": 0}]}
|
||||
m.update(updates)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
updates = {"layers": ["hello"]}
|
||||
m.update(updates)
|
||||
|
||||
# Wronge type
|
||||
with self.assertRaises(ValueError):
|
||||
updates = {"layers": [{"weight": "hi"}]}
|
||||
m.update(updates)
|
||||
|
||||
def test_update_modules(self):
|
||||
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
|
||||
|
||||
# Updating non-existent modules should not be allowed by default
|
||||
with self.assertRaises(ValueError):
|
||||
m = m.update_modules({"values": [0, 1]})
|
||||
|
||||
# Update wrong types
|
||||
with self.assertRaises(ValueError):
|
||||
m = m.update_modules({"layers": [0, 1]})
|
||||
|
||||
class MyModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.test = mx.array(1.0)
|
||||
self.list = [mx.array(1.0), mx.array(2.0)]
|
||||
|
||||
m = MyModule()
|
||||
with self.assertRaises(ValueError):
|
||||
m = m.update_modules({"test": "hi"})
|
||||
with self.assertRaises(ValueError):
|
||||
m = m.update_modules({"list": ["hi"]})
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
|
||||
Reference in New Issue
Block a user