Compare commits

..

11 Commits

Author SHA1 Message Date
Angelos Katharopoulos
ca7970a4f1 Make args references but ensure copy to kernel 2025-10-02 11:39:21 -07:00
Angelos Katharopoulos
214b1c1a06 Remove moves 2025-10-02 11:16:17 -07:00
Angelos Katharopoulos
e42e06046e Fix the check 2025-10-01 21:13:07 -07:00
Angelos Katharopoulos
17432e7885 Add a small column specialization to reduce 2025-10-01 21:11:35 -07:00
Awni Hannun
e88f2d4a8e fix cross entropy axis param (#2641)
* fix cross entropy axis param

* faster grad clipping
2025-10-01 16:49:55 -07:00
Angelos Katharopoulos
9cee557423 Fix status message (#2638) 2025-10-01 16:43:45 -07:00
Awni Hannun
bbf1423953 wait for tasks in cuda (#2636) 2025-09-30 16:08:46 -07:00
Angelos Katharopoulos
eb24267b56 Compile now can attach arbitrary data to an entry (#2634) 2025-09-30 13:33:27 -07:00
Awni Hannun
dc371ae7a5 fix for max block dim (#2631) 2025-09-29 08:59:25 -07:00
AN Long
e76a8dd5c5 Fix incorrect path and typos (#2630) 2025-09-28 06:03:04 -07:00
Cheng
b466dea982 [CUDA] Make CudaEvent work with multi-device (#2614)
* Set current device when creating cuda event

* Separate cuda events by device

* Avoid race condition in pool
2025-09-27 11:27:17 +09:00
26 changed files with 410 additions and 109 deletions

View File

@@ -173,7 +173,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
message(STATUS "Accelerate not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
endif()

View File

@@ -9,7 +9,7 @@
#include "mlx/backend/cpu/simd/base_simd.h"
// There seems to be a bug in sims/base.h
// There seems to be a bug in simd/base_simd.h
// __XROS_2_0 is not defined, the expression evaluates
// to true instead of false setting the SIMD library
// higher than it should be even on macOS < 15

View File

@@ -86,7 +86,7 @@ CudaAllocator::CudaAllocator()
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.8;
memory_limit_ = total * 0.95;
max_pool_size_ = memory_limit_;
}

View File

@@ -332,9 +332,9 @@ void Compiled::eval_gpu(
encoder.set_output_array(out);
}
auto kernel = mod.get_kernel(kernel_name);
auto [kernel, max_block_dims] = mod.get_kernel_and_dims(kernel_name);
auto [num_blocks, block_dims] =
get_launch_args(outputs[0], large, work_per_thread);
get_launch_args(outputs[0], large, work_per_thread, max_block_dims);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}

View File

@@ -14,10 +14,6 @@ namespace mlx::core::cu {
namespace {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
// This should be less than 255
constexpr int default_max_nodes_per_graph = 20;
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
void check_cudnn_error(const char* name, cudnnStatus_t err) {
@@ -68,8 +64,8 @@ Device::~Device() {
void Device::make_current() {
// We need to set/get current CUDA device very frequently, cache it to reduce
// actual calls of CUDA APIs. This function assumes single-thread in host.
static int current = 0;
// actual calls of CUDA APIs.
static thread_local int current = 0;
if (current != device_) {
CHECK_CUDA_ERROR(cudaSetDevice(device_));
current = device_;
@@ -95,6 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
CommandEncoder::CaptureContext::~CaptureContext() {
if (!use_cuda_graphs()) {
enc.node_count_++;
return;
}
@@ -196,6 +193,7 @@ CommandEncoder::CommandEncoder(Device& d)
: device_(d),
stream_(d),
graph_(d),
worker_(d),
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
@@ -220,12 +218,6 @@ void CommandEncoder::set_output_array(const array& arr) {
active_outputs_.push_back(id);
}
void CommandEncoder::maybe_commit() {
if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
commit();
}
}
void CommandEncoder::add_kernel_node(
void* func,
dim3 grid_dim,
@@ -233,6 +225,7 @@ void CommandEncoder::add_kernel_node(
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cudaLaunchKernel(
func, grid_dim, block_dim, params, smem_bytes, stream()));
return;
@@ -253,6 +246,7 @@ void CommandEncoder::add_kernel_node(
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cuLaunchKernel(
func,
grid_dim.x,
@@ -295,6 +289,7 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
void CommandEncoder::add_graph_node(cudaGraph_t child) {
if (!use_cuda_graphs()) {
node_count_++;
CudaGraphExec graph_exec;
graph_exec.instantiate(child);
device_.make_current();
@@ -306,12 +301,16 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
insert_graph_dependencies(GraphNode{node, 'G'});
}
int CommandEncoder::get_num_ops() {
return node_count_;
}
void CommandEncoder::commit() {
nvtx3::scoped_range r("CommandEncoder::commit");
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
if (node_count_ > 0) {
if (use_cuda_graphs() && node_count_ > 0) {
if (!from_nodes_.empty()) {
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
graph_,
@@ -354,7 +353,6 @@ void CommandEncoder::commit() {
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// Reset state
node_count_ = 0;
graph_node_count_ = 0;
empty_node_count_ = 0;
from_nodes_.clear();
@@ -366,6 +364,7 @@ void CommandEncoder::commit() {
// Put completion handlers in a batch.
worker_.commit(stream_);
node_count_ = 0;
}
void CommandEncoder::synchronize() {

View File

@@ -83,7 +83,7 @@ class CommandEncoder {
}
void add_completed_handler(std::function<void()> task);
void maybe_commit();
int get_num_ops();
void commit();
Device& device() {
@@ -140,7 +140,7 @@ class Device {
Device(const Device&) = delete;
Device& operator=(const Device&) = delete;
// Make this device the current cuda device, required by some cuda calls.
// Make this device the current cuda device, this method is thread-safe.
void make_current();
CommandEncoder& get_command_encoder(Stream s);

View File

@@ -5,19 +5,24 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/available.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core::gpu {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
constexpr int default_max_nodes_per_graph = 20;
bool is_available() {
return true;
}
void new_stream(Stream s) {
// Force initalization of CUDA by creating an event, so the CUDA runtime and
// our CUDA event pool get destroyed last.
cu::CudaEvent(cudaEventDefault);
// Force initalization of CUDA, so CUDA runtime get destroyed at last.
cudaFree(nullptr);
// Make sure CUDA event pool get destroyed after device and stream.
cu::CudaEvent::init_pool();
// Ensure the static stream objects get created.
cu::get_command_encoder(s);
}
@@ -35,7 +40,8 @@ void eval(array& arr) {
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
auto& stream = arr.primitive().stream();
auto& encoder = cu::get_command_encoder(stream);
// Keep used buffers alive until kernel finishes running.
for (auto& in : arr.inputs()) {
// Except for the donated one.
@@ -46,7 +52,14 @@ void eval(array& arr) {
for (auto& s : arr.siblings()) {
encoder.add_temporary(s);
}
encoder.maybe_commit();
if (encoder.get_num_ops() >=
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
scheduler::notify_new_task(stream);
encoder.add_completed_handler(
[stream]() { scheduler::notify_task_completion(stream); });
encoder.commit();
}
}
void finalize(Stream s) {

View File

@@ -22,11 +22,15 @@ namespace cu {
namespace {
// Manage cached cudaEvent_t objects.
struct CudaEventPool {
static CudaEventHandle create(int flags) {
auto& cache = cache_for(flags);
class CudaEventPool {
public:
CudaEventHandle create(Device& d, int flags) {
if (!on_creation_thread()) {
return CudaEventHandle(d, flags);
}
auto& cache = cache_for(d, flags);
if (cache.empty()) {
return CudaEventHandle(flags);
return CudaEventHandle(d, flags);
} else {
CudaEventHandle ret = std::move(cache.back());
cache.pop_back();
@@ -34,54 +38,89 @@ struct CudaEventPool {
}
}
static void release(CudaEventHandle event) {
cache_for(event.flags).push_back(std::move(event));
void release(CudaEventHandle event) {
if (!on_creation_thread()) {
// Event will be destroyed directly instead of getting moved to cache.
return;
}
cache_for(event.device, event.flags).push_back(std::move(event));
}
static std::vector<CudaEventHandle>& cache_for(int flags) {
static std::map<int, std::vector<CudaEventHandle>> cache;
return cache[flags];
private:
std::vector<CudaEventHandle>& cache_for(Device& d, int flags) {
return cache_[d.cuda_device()][flags];
}
bool on_creation_thread() {
return std::this_thread::get_id() == thread_id_;
}
// The CudaEvent may be created and destroyed on different threads (for
// example when waiting on GPU work in CPU stream), we don't want to make
// the cache thread-safe as it adds overhead, so we just skip cache when
// using events in worker threads.
std::thread::id thread_id_{std::this_thread::get_id()};
// {device: {flags: [events]}}
std::map<int, std::map<int, std::vector<CudaEventHandle>>> cache_;
};
CudaEventPool& cuda_event_pool() {
static CudaEventPool pool;
return pool;
}
} // namespace
CudaEventHandle::CudaEventHandle(int flags) : flags(flags) {
CudaEventHandle::CudaEventHandle(Device& d, int flags)
: device(d), flags(flags) {
device.make_current();
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
assert(handle_ != nullptr);
}
CudaEvent::CudaEvent(int flags) : event_(CudaEventPool::create(flags)) {}
CudaEvent::CudaEvent(Device& d, int flags)
: event_(cuda_event_pool().create(d, flags)) {}
CudaEvent::~CudaEvent() {
CudaEventPool::release(std::move(event_));
cuda_event_pool().release(std::move(event_));
}
void CudaEvent::wait() {
nvtx3::scoped_range r("cu::CudaEvent::wait");
event_.device.make_current();
cudaEventSynchronize(event_);
}
void CudaEvent::wait(cudaStream_t stream) {
event_.device.make_current();
cudaStreamWaitEvent(stream, event_);
}
void CudaEvent::record(cudaStream_t stream) {
event_.device.make_current();
cudaEventRecord(event_, stream);
}
bool CudaEvent::completed() const {
// Note: cudaEventQuery can be safely called from any device.
return cudaEventQuery(event_) == cudaSuccess;
}
// static
void CudaEvent::init_pool() {
cuda_event_pool();
}
// Wraps CudaEvent with a few features:
// 1. The class can be copied.
// 2. Make wait/record work with CPU streams.
// 3. Add checks for waiting on un-recorded event.
class CopyableCudaEvent {
public:
CopyableCudaEvent()
explicit CopyableCudaEvent(Device& d)
: event_(std::make_shared<CudaEvent>(
d,
cudaEventDisableTiming | cudaEventBlockingSync)) {}
void wait() {
@@ -245,7 +284,7 @@ struct EventImpl {
nvtx3::mark("Using slow AtomicEvent");
atomic = std::make_unique<cu::AtomicEvent>();
} else {
cuda = std::make_unique<cu::CopyableCudaEvent>();
cuda = std::make_unique<cu::CopyableCudaEvent>(cu::device(s.device));
}
}
};

View File

@@ -13,9 +13,12 @@
namespace mlx::core::cu {
class Device;
// RAII-managed move-only wrapper of cudaEvent_t.
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
CudaEventHandle(int flags);
CudaEventHandle(Device& d, int flags);
Device& device;
int flags;
};
@@ -23,7 +26,7 @@ struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
// on GPU stream in CPU stream, but can not wait on CPU stream.
class CudaEvent {
public:
explicit CudaEvent(int flags);
CudaEvent(Device& d, int flags);
~CudaEvent();
CudaEvent(CudaEvent&&) = default;
@@ -40,6 +43,9 @@ class CudaEvent {
// returns true if record() has not been called.
bool completed() const;
// Internal: make sure event pool is initialized.
static void init_pool();
private:
CudaEventHandle event_;
};

View File

@@ -297,7 +297,8 @@ void load_module(
const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
CUmodule& module_,
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>>&
kernels) {
// Load module.
char jit_log[4089] = {};
CUjit_option options[] = {
@@ -314,7 +315,7 @@ void load_module(
for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels[name] = std::make_pair(kernel, false);
kernels[name] = std::make_tuple(kernel, false, 0);
}
}
@@ -358,7 +359,7 @@ JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_));
}
CUfunction JitModule::get_kernel(
std::pair<CUfunction, uint> JitModule::get_kernel_and_dims(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
auto it = kernels_.find(kernel_name);
@@ -369,14 +370,22 @@ CUfunction JitModule::get_kernel(
// If it is the first time we run this kernel then configure it. Do it only
// once!
if (!it->second.second) {
auto kernel = std::get<0>(it->second);
if (!std::get<1>(it->second)) {
if (configure_kernel) {
configure_kernel(it->second.first);
configure_kernel(kernel);
}
it->second.second = true;
std::get<1>(it->second) = true;
std::get<2>(it->second) = max_occupancy_block_dim(kernel);
}
return it->second.first;
return {kernel, std::get<2>(it->second)};
}
CUfunction JitModule::get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
return get_kernel_and_dims(kernel_name, std::move(configure_kernel)).first;
}
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {

View File

@@ -99,10 +99,13 @@ class JitModule {
CUfunction get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
std::pair<CUfunction, uint> get_kernel_and_dims(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
private:
CUmodule module_{nullptr};
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>> kernels_;
};
std::unordered_map<std::string, JitModule>& get_jit_module_cache();

View File

@@ -35,12 +35,10 @@ std::tuple<dim3, uint> get_launch_args(
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread) {
int work_per_thread /* = 1 */,
uint max_block_dim /* = 1024 */) {
size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = 1024;
if (block_dim > nthreads) {
block_dim = nthreads;
}
uint block_dim = max_block_dim < nthreads ? max_block_dim : nthreads;
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);

View File

@@ -120,19 +120,28 @@ dim3 get_2d_grid_dims(
size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
// assuming each thread handles |work_per_thread| elements of |arr|.
// Get the num_blocks and block_dims assuming each thread handles
// |work_per_thread| elements of |arr|.
std::tuple<dim3, uint> get_launch_args(
size_t size,
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread = 1);
int work_per_thread = 1,
uint max_block_dim = 1024);
inline std::tuple<dim3, uint>
get_launch_args(const array& arr, bool large, int work_per_thread = 1) {
inline std::tuple<dim3, uint> get_launch_args(
const array& arr,
bool large,
int work_per_thread = 1,
uint max_block_dim = 1024) {
return get_launch_args(
arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
arr.size(),
arr.shape(),
arr.strides(),
large,
work_per_thread,
max_block_dim);
}
} // namespace mlx::core

View File

@@ -181,6 +181,47 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
}
}
template <typename T, typename U, typename Op, int N_READS = 4>
__global__ void col_reduce_small(
const T* in,
U* out,
const __grid_constant__ ColReduceArgs args,
size_t total) {
Op op;
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
const auto idx = grid.thread_rank() * N_READS;
const auto before_axis = idx / args.reduction_stride;
const auto after_axis = idx % args.reduction_stride;
const auto offset =
before_axis * args.reduction_stride * args.reduction_size + after_axis;
if (idx >= total) {
return;
}
in += offset;
out += idx;
AlignedVector<U, N_READS> accumulator;
for (int i = 0; i < N_READS; i++) {
accumulator[i] = ReduceInit<Op, T>::value();
}
for (int i = 0; i < args.reduction_size; i++) {
auto values = load_vector<N_READS>(in, 0);
for (int j = 0; j < N_READS; j++) {
accumulator[j] = op(accumulator[j], cast_to<U>(values[j]));
}
in += args.reduction_stride;
}
store_vector(out, 0, accumulator);
}
} // namespace cu
inline auto output_grid_for_col_reduce(
@@ -206,7 +247,7 @@ void col_reduce_looped(
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::ColReduceArgs args) {
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
@@ -230,12 +271,55 @@ void col_reduce_looped(
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node(
kernel, grid, blocks, 0, indata, out.data<U>(), args);
kernel,
grid,
blocks,
0,
indata,
out.data<U>(),
static_cast<cu::ColReduceArgs>(args));
});
});
});
}
void col_reduce_small(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 16 / sizeof(T);
auto tmp_grid = get_2d_grid_dims(out.shape(), out.strides());
auto [grid, block] = get_grid_and_block(tmp_grid.x, tmp_grid.y, 1);
auto kernel = cu::col_reduce_small<T, U, OP, N_READS>;
encoder.add_kernel_node(
kernel,
grid,
block,
0,
in.data<T>(),
out.data<U>(),
static_cast<cu::ColReduceArgs>(args),
out.size());
});
});
}
void col_reduce(
cu::CommandEncoder& encoder,
const array& in,
@@ -258,6 +342,13 @@ void col_reduce(
// Make the args struct to help route to the best kernel
cu::ColReduceArgs args(in, plan, axes);
// Small col reduce with a single or contiguous reduction axis
if (args.non_col_reductions == 1 && args.reduction_size <= 32 &&
args.reduction_stride % (16 / in.itemsize()) == 0) {
col_reduce_small(encoder, in, out, reduce_type, axes, plan, args);
return;
}
// Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
}

View File

@@ -12,6 +12,7 @@ namespace mlx::core {
namespace cu {
class Device;
}
struct Dtype;
@@ -86,4 +87,17 @@ class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
explicit CudaStream(cu::Device& device);
};
template <typename T>
inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim;
if constexpr (std::is_same_v<T, CUfunction>) {
CHECK_CUDA_ERROR(
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
} else {
CHECK_CUDA_ERROR(
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
}
return block_dim;
}
} // namespace mlx::core

View File

@@ -5,9 +5,9 @@
namespace mlx::core::cu {
Worker::Worker()
: signal_stream_(device(mlx::core::Device::gpu)),
signal_event_(cudaEventDisableTiming | cudaEventBlockingSync),
Worker::Worker(Device& d)
: signal_stream_(d),
signal_event_(d, cudaEventDisableTiming | cudaEventBlockingSync),
worker_(&Worker::thread_fn, this) {}
Worker::~Worker() {

View File

@@ -15,7 +15,7 @@ namespace mlx::core::cu {
// Run tasks in worker thread, synchronized with cuda stream.
class Worker {
public:
Worker();
explicit Worker(Device& d);
~Worker();
Worker(const Worker&) = delete;

View File

@@ -378,7 +378,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
}
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
// Need placeholders so Metal doesn't complain
int shape_ = 0;
int64_t stride_ = 0;
compute_encoder.set_bytes(shape_, 3);
@@ -393,7 +393,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
// Need placeholders so Metal doesn't complain
int shape_ = 0;
int64_t stride_ = 0;
compute_encoder.set_bytes(shape_, 7);

View File

@@ -296,6 +296,7 @@ class CompilerCache {
std::vector<array> tape;
bool empty{true};
std::vector<uint64_t> constants;
std::shared_ptr<void> extra;
};
// Returns a reference to a CacheEntry which can be updated
@@ -376,8 +377,9 @@ CompilerCache& compiler_cache() {
return compiler_cache_;
}
std::pair<std::vector<array>, std::vector<array>> compile_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
compile_trace(
const ArrayFnWithExtra& fun,
const std::vector<array>& inputs,
bool shapeless) {
// Set the global tracing flag.
@@ -391,7 +393,9 @@ std::pair<std::vector<array>, std::vector<array>> compile_trace(
in.set_tracer(true);
tracer_inputs.push_back(std::move(in));
}
return {tracer_inputs, fun(tracer_inputs)};
auto output = fun(tracer_inputs);
return {tracer_inputs, output.first, output.second};
}
// Traverses the graph to build a tape and a map of array ids to their parents
@@ -932,8 +936,8 @@ bool skip_compile() {
!(compile_available_for_device(default_device()));
}
std::function<std::vector<array>(const std::vector<array>&)> compile(
std::function<std::vector<array>(const std::vector<array>&)> fun,
ArrayFnWithExtra compile(
ArrayFnWithExtra fun,
std::uintptr_t fun_id,
bool shapeless /* = false */,
std::vector<uint64_t> constants /* = {} */) {
@@ -966,7 +970,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
// Set the constants
entry.constants = std::move(constants);
// Trace to build the graph
std::tie(entry.inputs, entry.outputs) =
std::tie(entry.inputs, entry.outputs, entry.extra) =
compile_trace(fun, inputs, shapeless);
// DFS the graph and get a tape, and a map of array id to (parent,
@@ -991,8 +995,37 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
// At this point we must have a tape, now replace the placeholders
// with real arrays that can be evaluated
return compile_replace(
entry.tape, entry.inputs, entry.outputs, inputs, shapeless);
return ArraysAndExtra{
compile_replace(
entry.tape, entry.inputs, entry.outputs, inputs, shapeless),
entry.extra};
};
}
std::function<std::vector<array>(const std::vector<array>&)> compile(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::uintptr_t fun_id,
bool shapeless /* = false */,
std::vector<uint64_t> constants /* = {} */) {
if (skip_compile()) {
return fun;
}
if (!fun) {
throw std::invalid_argument(
"[compile] Cannot compile a function without a target.");
}
ArrayFnWithExtra fun_with_extra =
[fun = std::move(fun)](const std::vector<array>& inputs) {
return ArraysAndExtra{fun(inputs), nullptr};
};
auto compiled_fun = compile(
std::move(fun_with_extra), fun_id, shapeless, std::move(constants));
return [compiled_fun =
std::move(compiled_fun)](const std::vector<array>& inputs) {
return compiled_fun(inputs).first;
};
}

View File

@@ -8,6 +8,10 @@
namespace mlx::core::detail {
using ArraysAndExtra = std::pair<std::vector<array>, std::shared_ptr<void>>;
using ArrayFnWithExtra =
std::function<ArraysAndExtra(const std::vector<array>&)>;
// This is not part of the general C++ API as calling with a bad id is a bad
// idea.
std::function<std::vector<array>(const std::vector<array>&)> compile(
@@ -16,6 +20,12 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
bool shapeless = false,
std::vector<uint64_t> constants = {});
ArrayFnWithExtra compile(
ArrayFnWithExtra fun,
std::uintptr_t fun_id,
bool shapeless,
std::vector<uint64_t> constants);
// Erase cached compile functions
void compile_erase(std::uintptr_t fun_id);
@@ -25,8 +35,9 @@ void compile_clear_cache();
bool compile_available_for_device(const Device& device);
std::pair<std::vector<array>, std::vector<array>> compile_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
compile_trace(
const ArrayFnWithExtra& fun,
const std::vector<array>& inputs,
bool shapeless);

View File

@@ -579,11 +579,11 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
for (auto& k : kwarg_keys) {
kwargs.insert({k, *it++});
}
return fun(args, kwargs);
return detail::ArraysAndExtra{fun(args, kwargs), nullptr};
};
// Trace to build the graph
auto [trace_inputs, trace_outputs] =
auto [trace_inputs, trace_outputs, extra] =
detail::compile_trace(flat_fun, inputs, ftable->shapeless);
// DFS the graph and get the tape

View File

@@ -86,7 +86,9 @@ def cross_entropy(
if targets_as_probs:
score = mx.sum(logits * targets, axis=axis)
else:
score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
score = mx.take_along_axis(logits, mx.expand_dims(targets, axis), axis).squeeze(
axis
)
logsumexp_logits = mx.logsumexp(logits, axis=axis)
if label_smoothing > 0:

View File

@@ -971,10 +971,6 @@ def clip_grad_norm(grads, max_norm):
"""
norm_squared = tree_reduce(lambda acc, g: acc + g.square().sum(), grads, 0.0)
total_norm = mx.sqrt(norm_squared)
normalizer = max_norm / (total_norm + 1e-6)
def clipper(g):
return mx.where(total_norm < max_norm, g, g * normalizer)
clipped_grads = tree_map(clipper, grads)
normalizer = mx.minimum(max_norm / (total_norm + 1e-6), 1.0)
clipped_grads = tree_map(lambda g: g * normalizer, grads)
return clipped_grads, total_norm

View File

@@ -389,19 +389,22 @@ auto py_vmap(
};
}
std::unordered_map<std::uintptr_t, nb::object>& tree_cache() {
// This map is used to Cache the tree structure of the outputs
static std::unordered_map<std::uintptr_t, nb::object> tree_cache_;
return tree_cache_;
}
struct PyCompiledFun {
nb::callable fun;
std::uintptr_t fun_id;
nb::object captured_inputs;
nb::object captured_outputs;
bool shapeless;
mutable size_t num_outputs{0};
// Data to attach to the compiled function that contains the python output
// structure and the number of arrays in said structure.
struct AttachedData {
nb::object output_structure;
int num_outputs;
AttachedData(nb::object output_structure_, int num_outputs_)
: output_structure(output_structure_), num_outputs(num_outputs_) {}
};
PyCompiledFun(
const nb::callable& fun,
@@ -424,7 +427,6 @@ struct PyCompiledFun {
captured_inputs = std::move(other.captured_inputs);
captured_outputs = std::move(other.captured_outputs);
shapeless = other.shapeless;
num_outputs = other.num_outputs;
};
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
@@ -508,9 +510,9 @@ struct PyCompiledFun {
auto [outputs, py_outputs] =
tree_flatten_with_structure(std::move(tree_outputs), false);
tree_cache().insert({fun_id, py_outputs});
std::shared_ptr<void> extra_data =
std::make_shared<AttachedData>(py_outputs, outputs.size());
num_outputs = outputs.size();
if (!captured_outputs.is_none()) {
auto flat_out_captures = tree_flatten(captured_outputs, false);
outputs.insert(
@@ -523,7 +525,7 @@ struct PyCompiledFun {
if (!captured_inputs.is_none()) {
tree_replace(captured_inputs, trace_captures, flat_in_captures);
}
return outputs;
return mx::detail::ArraysAndExtra{outputs, extra_data};
};
if (!captured_inputs.is_none()) {
@@ -535,8 +537,14 @@ struct PyCompiledFun {
}
// Compile and call
auto outputs =
auto [outputs, extra_data] =
mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
int num_outputs =
reinterpret_cast<AttachedData*>(extra_data.get())->num_outputs;
nb::object py_outputs =
reinterpret_cast<AttachedData*>(extra_data.get())->output_structure;
if (!captured_outputs.is_none()) {
std::vector<mx::array> captures(
std::make_move_iterator(outputs.begin() + num_outputs),
@@ -545,8 +553,7 @@ struct PyCompiledFun {
}
// Put the outputs back in the container
nb::object py_outputs = tree_cache().at(fun_id);
return tree_unflatten_from_structure(py_outputs, outputs);
return tree_unflatten_from_structure(std::move(py_outputs), outputs);
}
nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {
@@ -556,7 +563,6 @@ struct PyCompiledFun {
~PyCompiledFun() {
nb::gil_scoped_acquire gil;
tree_cache().erase(fun_id);
mx::detail::compile_erase(fun_id);
fun.reset();
captured_inputs.reset();
@@ -1479,8 +1485,6 @@ void init_transforms(nb::module_& m) {
// Register static Python object cleanup before the interpreter exits
auto atexit = nb::module_::import_("atexit");
atexit.attr("register")(nb::cpp_function([]() {
tree_cache().clear();
mx::detail::compile_clear_cache();
}));
atexit.attr("register")(
nb::cpp_function([]() { mx::detail::compile_clear_cache(); }));
}

View File

@@ -828,6 +828,19 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(arrs)
self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0])))
inputs = [mx.arange(16384).astype(mx.float16) for _ in range(8)]
def fun(inputs):
a = inputs[0] + inputs[1]
b = inputs[2] + inputs[3]
c = inputs[4] + inputs[5]
d = inputs[6] + inputs[7]
return a * b * c * d
out = mx.compile(fun)(inputs)
expected = fun(inputs)
self.assertTrue(mx.allclose(out, expected))
def test_compile_many_outputs(self):
@mx.compile
@@ -1051,6 +1064,57 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(mx.array(1.0), mx.array(2.0))
self.assertEqual(out.item(), 3.0)
def test_compile_changing_outputs(self):
@mx.compile
def fun(x, y):
if y is None:
return 2 * x
elif (
isinstance(x, mx.array)
and isinstance(y, mx.array)
and x.dtype == y.dtype == mx.float32
):
return [x + y]
elif y.dtype == mx.bool_:
return {"a": x, "b": y * x}
else:
return None
a = fun(mx.array(1.0), mx.array(2.0))
self.assertTrue(isinstance(a, list))
self.assertEqual(a[0].item(), 3.0)
b = fun(mx.array(1.0), mx.array(True))
self.assertTrue(isinstance(b, dict))
self.assertEqual(b["a"].item(), 1.0)
self.assertEqual(b["b"].item(), 1.0)
c = fun(mx.array(1.0), None)
self.assertTrue(isinstance(c, mx.array))
self.assertEqual(c.item(), 2.0)
d = fun(False, mx.array(1.0))
self.assertTrue(d is None)
def test_compile_changing_outputs_with_state(self):
state = [mx.array(1.0)]
@partial(mx.compile, inputs=state, outputs=state)
def fun(y):
x = state[0]
if y.dtype == mx.float32:
state[0] = 2 * y
return [x, y, x + y]
elif y.dtype == mx.int32:
state[0] *= 2
return x + y
for i in range(10):
fun(mx.array(1.0))
fun(mx.array(1))
self.assertEqual(state[0].item(), 4)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()

View File

@@ -60,9 +60,19 @@ class TestLosses(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(loss, expected))
probs = mx.array([[1.0, 0.0], [0.0, 1.0]])
# Test a different axis
logits = mx.random.normal((4, 8))
targets = mx.array([1, 2, 3, 0])
loss = nn.losses.cross_entropy(
logits, probs, weights=weights, label_smoothing=0.3, reduction="none"
logits.T,
targets,
axis=0,
)
targets = mx.array([1, 2, 3, 0])
expected = nn.losses.cross_entropy(
logits,
targets,
axis=-1,
)
self.assertTrue(mx.allclose(loss, expected))