mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
9 Commits
v0.29.2
...
e42e06046e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e42e06046e | ||
|
|
17432e7885 | ||
|
|
e88f2d4a8e | ||
|
|
9cee557423 | ||
|
|
bbf1423953 | ||
|
|
eb24267b56 | ||
|
|
dc371ae7a5 | ||
|
|
e76a8dd5c5 | ||
|
|
b466dea982 |
@@ -173,7 +173,7 @@ if(MLX_BUILD_CPU)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
message(STATUS "Accelerate not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
#include "mlx/backend/cpu/simd/base_simd.h"
|
||||
|
||||
// There seems to be a bug in sims/base.h
|
||||
// There seems to be a bug in simd/base_simd.h
|
||||
// __XROS_2_0 is not defined, the expression evaluates
|
||||
// to true instead of false setting the SIMD library
|
||||
// higher than it should be even on macOS < 15
|
||||
|
||||
@@ -86,7 +86,7 @@ CudaAllocator::CudaAllocator()
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.8;
|
||||
memory_limit_ = total * 0.95;
|
||||
max_pool_size_ = memory_limit_;
|
||||
}
|
||||
|
||||
|
||||
@@ -332,9 +332,9 @@ void Compiled::eval_gpu(
|
||||
encoder.set_output_array(out);
|
||||
}
|
||||
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [kernel, max_block_dims] = mod.get_kernel_and_dims(kernel_name);
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(outputs[0], large, work_per_thread);
|
||||
get_launch_args(outputs[0], large, work_per_thread, max_block_dims);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||
}
|
||||
|
||||
|
||||
@@ -14,10 +14,6 @@ namespace mlx::core::cu {
|
||||
|
||||
namespace {
|
||||
|
||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
||||
// This should be less than 255
|
||||
constexpr int default_max_nodes_per_graph = 20;
|
||||
|
||||
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
||||
|
||||
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||
@@ -68,8 +64,8 @@ Device::~Device() {
|
||||
|
||||
void Device::make_current() {
|
||||
// We need to set/get current CUDA device very frequently, cache it to reduce
|
||||
// actual calls of CUDA APIs. This function assumes single-thread in host.
|
||||
static int current = 0;
|
||||
// actual calls of CUDA APIs.
|
||||
static thread_local int current = 0;
|
||||
if (current != device_) {
|
||||
CHECK_CUDA_ERROR(cudaSetDevice(device_));
|
||||
current = device_;
|
||||
@@ -95,6 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||
|
||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||
if (!use_cuda_graphs()) {
|
||||
enc.node_count_++;
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -196,6 +193,7 @@ CommandEncoder::CommandEncoder(Device& d)
|
||||
: device_(d),
|
||||
stream_(d),
|
||||
graph_(d),
|
||||
worker_(d),
|
||||
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
|
||||
|
||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
@@ -220,12 +218,6 @@ void CommandEncoder::set_output_array(const array& arr) {
|
||||
active_outputs_.push_back(id);
|
||||
}
|
||||
|
||||
void CommandEncoder::maybe_commit() {
|
||||
if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
|
||||
commit();
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(
|
||||
void* func,
|
||||
dim3 grid_dim,
|
||||
@@ -233,6 +225,7 @@ void CommandEncoder::add_kernel_node(
|
||||
uint32_t smem_bytes,
|
||||
void** params) {
|
||||
if (!use_cuda_graphs()) {
|
||||
node_count_++;
|
||||
CHECK_CUDA_ERROR(cudaLaunchKernel(
|
||||
func, grid_dim, block_dim, params, smem_bytes, stream()));
|
||||
return;
|
||||
@@ -253,6 +246,7 @@ void CommandEncoder::add_kernel_node(
|
||||
uint32_t smem_bytes,
|
||||
void** params) {
|
||||
if (!use_cuda_graphs()) {
|
||||
node_count_++;
|
||||
CHECK_CUDA_ERROR(cuLaunchKernel(
|
||||
func,
|
||||
grid_dim.x,
|
||||
@@ -295,6 +289,7 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||
|
||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
if (!use_cuda_graphs()) {
|
||||
node_count_++;
|
||||
CudaGraphExec graph_exec;
|
||||
graph_exec.instantiate(child);
|
||||
device_.make_current();
|
||||
@@ -306,12 +301,16 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||
}
|
||||
|
||||
int CommandEncoder::get_num_ops() {
|
||||
return node_count_;
|
||||
}
|
||||
|
||||
void CommandEncoder::commit() {
|
||||
nvtx3::scoped_range r("CommandEncoder::commit");
|
||||
if (!temporaries_.empty()) {
|
||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||
}
|
||||
if (node_count_ > 0) {
|
||||
if (use_cuda_graphs() && node_count_ > 0) {
|
||||
if (!from_nodes_.empty()) {
|
||||
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
|
||||
graph_,
|
||||
@@ -354,7 +353,6 @@ void CommandEncoder::commit() {
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
|
||||
// Reset state
|
||||
node_count_ = 0;
|
||||
graph_node_count_ = 0;
|
||||
empty_node_count_ = 0;
|
||||
from_nodes_.clear();
|
||||
@@ -366,6 +364,7 @@ void CommandEncoder::commit() {
|
||||
|
||||
// Put completion handlers in a batch.
|
||||
worker_.commit(stream_);
|
||||
node_count_ = 0;
|
||||
}
|
||||
|
||||
void CommandEncoder::synchronize() {
|
||||
|
||||
@@ -83,7 +83,7 @@ class CommandEncoder {
|
||||
}
|
||||
|
||||
void add_completed_handler(std::function<void()> task);
|
||||
void maybe_commit();
|
||||
int get_num_ops();
|
||||
void commit();
|
||||
|
||||
Device& device() {
|
||||
@@ -140,7 +140,7 @@ class Device {
|
||||
Device(const Device&) = delete;
|
||||
Device& operator=(const Device&) = delete;
|
||||
|
||||
// Make this device the current cuda device, required by some cuda calls.
|
||||
// Make this device the current cuda device, this method is thread-safe.
|
||||
void make_current();
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream s);
|
||||
|
||||
@@ -5,19 +5,24 @@
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/gpu/available.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
||||
constexpr int default_max_nodes_per_graph = 20;
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
void new_stream(Stream s) {
|
||||
// Force initalization of CUDA by creating an event, so the CUDA runtime and
|
||||
// our CUDA event pool get destroyed last.
|
||||
cu::CudaEvent(cudaEventDefault);
|
||||
// Force initalization of CUDA, so CUDA runtime get destroyed at last.
|
||||
cudaFree(nullptr);
|
||||
// Make sure CUDA event pool get destroyed after device and stream.
|
||||
cu::CudaEvent::init_pool();
|
||||
// Ensure the static stream objects get created.
|
||||
cu::get_command_encoder(s);
|
||||
}
|
||||
@@ -35,7 +40,8 @@ void eval(array& arr) {
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
||||
auto& stream = arr.primitive().stream();
|
||||
auto& encoder = cu::get_command_encoder(stream);
|
||||
// Keep used buffers alive until kernel finishes running.
|
||||
for (auto& in : arr.inputs()) {
|
||||
// Except for the donated one.
|
||||
@@ -46,7 +52,14 @@ void eval(array& arr) {
|
||||
for (auto& s : arr.siblings()) {
|
||||
encoder.add_temporary(s);
|
||||
}
|
||||
encoder.maybe_commit();
|
||||
|
||||
if (encoder.get_num_ops() >=
|
||||
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
|
||||
scheduler::notify_new_task(stream);
|
||||
encoder.add_completed_handler(
|
||||
[stream]() { scheduler::notify_task_completion(stream); });
|
||||
encoder.commit();
|
||||
}
|
||||
}
|
||||
|
||||
void finalize(Stream s) {
|
||||
|
||||
@@ -22,11 +22,15 @@ namespace cu {
|
||||
namespace {
|
||||
|
||||
// Manage cached cudaEvent_t objects.
|
||||
struct CudaEventPool {
|
||||
static CudaEventHandle create(int flags) {
|
||||
auto& cache = cache_for(flags);
|
||||
class CudaEventPool {
|
||||
public:
|
||||
CudaEventHandle create(Device& d, int flags) {
|
||||
if (!on_creation_thread()) {
|
||||
return CudaEventHandle(d, flags);
|
||||
}
|
||||
auto& cache = cache_for(d, flags);
|
||||
if (cache.empty()) {
|
||||
return CudaEventHandle(flags);
|
||||
return CudaEventHandle(d, flags);
|
||||
} else {
|
||||
CudaEventHandle ret = std::move(cache.back());
|
||||
cache.pop_back();
|
||||
@@ -34,54 +38,89 @@ struct CudaEventPool {
|
||||
}
|
||||
}
|
||||
|
||||
static void release(CudaEventHandle event) {
|
||||
cache_for(event.flags).push_back(std::move(event));
|
||||
void release(CudaEventHandle event) {
|
||||
if (!on_creation_thread()) {
|
||||
// Event will be destroyed directly instead of getting moved to cache.
|
||||
return;
|
||||
}
|
||||
cache_for(event.device, event.flags).push_back(std::move(event));
|
||||
}
|
||||
|
||||
static std::vector<CudaEventHandle>& cache_for(int flags) {
|
||||
static std::map<int, std::vector<CudaEventHandle>> cache;
|
||||
return cache[flags];
|
||||
private:
|
||||
std::vector<CudaEventHandle>& cache_for(Device& d, int flags) {
|
||||
return cache_[d.cuda_device()][flags];
|
||||
}
|
||||
|
||||
bool on_creation_thread() {
|
||||
return std::this_thread::get_id() == thread_id_;
|
||||
}
|
||||
|
||||
// The CudaEvent may be created and destroyed on different threads (for
|
||||
// example when waiting on GPU work in CPU stream), we don't want to make
|
||||
// the cache thread-safe as it adds overhead, so we just skip cache when
|
||||
// using events in worker threads.
|
||||
std::thread::id thread_id_{std::this_thread::get_id()};
|
||||
|
||||
// {device: {flags: [events]}}
|
||||
std::map<int, std::map<int, std::vector<CudaEventHandle>>> cache_;
|
||||
};
|
||||
|
||||
CudaEventPool& cuda_event_pool() {
|
||||
static CudaEventPool pool;
|
||||
return pool;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
CudaEventHandle::CudaEventHandle(int flags) : flags(flags) {
|
||||
CudaEventHandle::CudaEventHandle(Device& d, int flags)
|
||||
: device(d), flags(flags) {
|
||||
device.make_current();
|
||||
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
|
||||
assert(handle_ != nullptr);
|
||||
}
|
||||
|
||||
CudaEvent::CudaEvent(int flags) : event_(CudaEventPool::create(flags)) {}
|
||||
CudaEvent::CudaEvent(Device& d, int flags)
|
||||
: event_(cuda_event_pool().create(d, flags)) {}
|
||||
|
||||
CudaEvent::~CudaEvent() {
|
||||
CudaEventPool::release(std::move(event_));
|
||||
cuda_event_pool().release(std::move(event_));
|
||||
}
|
||||
|
||||
void CudaEvent::wait() {
|
||||
nvtx3::scoped_range r("cu::CudaEvent::wait");
|
||||
event_.device.make_current();
|
||||
cudaEventSynchronize(event_);
|
||||
}
|
||||
|
||||
void CudaEvent::wait(cudaStream_t stream) {
|
||||
event_.device.make_current();
|
||||
cudaStreamWaitEvent(stream, event_);
|
||||
}
|
||||
|
||||
void CudaEvent::record(cudaStream_t stream) {
|
||||
event_.device.make_current();
|
||||
cudaEventRecord(event_, stream);
|
||||
}
|
||||
|
||||
bool CudaEvent::completed() const {
|
||||
// Note: cudaEventQuery can be safely called from any device.
|
||||
return cudaEventQuery(event_) == cudaSuccess;
|
||||
}
|
||||
|
||||
// static
|
||||
void CudaEvent::init_pool() {
|
||||
cuda_event_pool();
|
||||
}
|
||||
|
||||
// Wraps CudaEvent with a few features:
|
||||
// 1. The class can be copied.
|
||||
// 2. Make wait/record work with CPU streams.
|
||||
// 3. Add checks for waiting on un-recorded event.
|
||||
class CopyableCudaEvent {
|
||||
public:
|
||||
CopyableCudaEvent()
|
||||
explicit CopyableCudaEvent(Device& d)
|
||||
: event_(std::make_shared<CudaEvent>(
|
||||
d,
|
||||
cudaEventDisableTiming | cudaEventBlockingSync)) {}
|
||||
|
||||
void wait() {
|
||||
@@ -245,7 +284,7 @@ struct EventImpl {
|
||||
nvtx3::mark("Using slow AtomicEvent");
|
||||
atomic = std::make_unique<cu::AtomicEvent>();
|
||||
} else {
|
||||
cuda = std::make_unique<cu::CopyableCudaEvent>();
|
||||
cuda = std::make_unique<cu::CopyableCudaEvent>(cu::device(s.device));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -13,9 +13,12 @@
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class Device;
|
||||
|
||||
// RAII-managed move-only wrapper of cudaEvent_t.
|
||||
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
|
||||
CudaEventHandle(int flags);
|
||||
CudaEventHandle(Device& d, int flags);
|
||||
Device& device;
|
||||
int flags;
|
||||
};
|
||||
|
||||
@@ -23,7 +26,7 @@ struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
|
||||
// on GPU stream in CPU stream, but can not wait on CPU stream.
|
||||
class CudaEvent {
|
||||
public:
|
||||
explicit CudaEvent(int flags);
|
||||
CudaEvent(Device& d, int flags);
|
||||
~CudaEvent();
|
||||
|
||||
CudaEvent(CudaEvent&&) = default;
|
||||
@@ -40,6 +43,9 @@ class CudaEvent {
|
||||
// returns true if record() has not been called.
|
||||
bool completed() const;
|
||||
|
||||
// Internal: make sure event pool is initialized.
|
||||
static void init_pool();
|
||||
|
||||
private:
|
||||
CudaEventHandle event_;
|
||||
};
|
||||
|
||||
@@ -297,7 +297,8 @@ void load_module(
|
||||
const std::string& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||
CUmodule& module_,
|
||||
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
|
||||
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>>&
|
||||
kernels) {
|
||||
// Load module.
|
||||
char jit_log[4089] = {};
|
||||
CUjit_option options[] = {
|
||||
@@ -314,7 +315,7 @@ void load_module(
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
CUfunction kernel;
|
||||
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
||||
kernels[name] = std::make_pair(kernel, false);
|
||||
kernels[name] = std::make_tuple(kernel, false, 0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -358,7 +359,7 @@ JitModule::~JitModule() {
|
||||
CHECK_CUDA_ERROR(cuModuleUnload(module_));
|
||||
}
|
||||
|
||||
CUfunction JitModule::get_kernel(
|
||||
std::pair<CUfunction, uint> JitModule::get_kernel_and_dims(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel) {
|
||||
auto it = kernels_.find(kernel_name);
|
||||
@@ -369,14 +370,22 @@ CUfunction JitModule::get_kernel(
|
||||
|
||||
// If it is the first time we run this kernel then configure it. Do it only
|
||||
// once!
|
||||
if (!it->second.second) {
|
||||
auto kernel = std::get<0>(it->second);
|
||||
if (!std::get<1>(it->second)) {
|
||||
if (configure_kernel) {
|
||||
configure_kernel(it->second.first);
|
||||
configure_kernel(kernel);
|
||||
}
|
||||
it->second.second = true;
|
||||
std::get<1>(it->second) = true;
|
||||
std::get<2>(it->second) = max_occupancy_block_dim(kernel);
|
||||
}
|
||||
|
||||
return it->second.first;
|
||||
return {kernel, std::get<2>(it->second)};
|
||||
}
|
||||
|
||||
CUfunction JitModule::get_kernel(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel) {
|
||||
return get_kernel_and_dims(kernel_name, std::move(configure_kernel)).first;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
||||
|
||||
@@ -99,10 +99,13 @@ class JitModule {
|
||||
CUfunction get_kernel(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel = nullptr);
|
||||
std::pair<CUfunction, uint> get_kernel_and_dims(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel = nullptr);
|
||||
|
||||
private:
|
||||
CUmodule module_{nullptr};
|
||||
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
|
||||
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>> kernels_;
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
||||
|
||||
@@ -35,12 +35,10 @@ std::tuple<dim3, uint> get_launch_args(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
bool large,
|
||||
int work_per_thread) {
|
||||
int work_per_thread /* = 1 */,
|
||||
uint max_block_dim /* = 1024 */) {
|
||||
size_t nthreads = cuda::ceil_div(size, work_per_thread);
|
||||
uint block_dim = 1024;
|
||||
if (block_dim > nthreads) {
|
||||
block_dim = nthreads;
|
||||
}
|
||||
uint block_dim = max_block_dim < nthreads ? max_block_dim : nthreads;
|
||||
dim3 num_blocks;
|
||||
if (large) {
|
||||
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
|
||||
|
||||
@@ -120,19 +120,28 @@ dim3 get_2d_grid_dims(
|
||||
size_t divisor);
|
||||
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
|
||||
|
||||
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
|
||||
// assuming each thread handles |work_per_thread| elements of |arr|.
|
||||
// Get the num_blocks and block_dims assuming each thread handles
|
||||
// |work_per_thread| elements of |arr|.
|
||||
std::tuple<dim3, uint> get_launch_args(
|
||||
size_t size,
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
bool large,
|
||||
int work_per_thread = 1);
|
||||
int work_per_thread = 1,
|
||||
uint max_block_dim = 1024);
|
||||
|
||||
inline std::tuple<dim3, uint>
|
||||
get_launch_args(const array& arr, bool large, int work_per_thread = 1) {
|
||||
inline std::tuple<dim3, uint> get_launch_args(
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread = 1,
|
||||
uint max_block_dim = 1024) {
|
||||
return get_launch_args(
|
||||
arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
|
||||
arr.size(),
|
||||
arr.shape(),
|
||||
arr.strides(),
|
||||
large,
|
||||
work_per_thread,
|
||||
max_block_dim);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -181,6 +181,47 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = 4>
|
||||
__global__ void col_reduce_small(
|
||||
const T* in,
|
||||
U* out,
|
||||
const __grid_constant__ ColReduceArgs args,
|
||||
size_t total) {
|
||||
Op op;
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
const auto idx = grid.thread_rank() * N_READS;
|
||||
const auto before_axis = idx / args.reduction_stride;
|
||||
const auto after_axis = idx % args.reduction_stride;
|
||||
const auto offset =
|
||||
before_axis * args.reduction_stride * args.reduction_size + after_axis;
|
||||
|
||||
if (idx >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
in += offset;
|
||||
out += idx;
|
||||
|
||||
AlignedVector<U, N_READS> accumulator;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
accumulator[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
for (int i = 0; i < args.reduction_size; i++) {
|
||||
auto values = load_vector<N_READS>(in, 0);
|
||||
|
||||
for (int j = 0; j < N_READS; j++) {
|
||||
accumulator[j] = op(accumulator[j], cast_to<U>(values[j]));
|
||||
}
|
||||
|
||||
in += args.reduction_stride;
|
||||
}
|
||||
|
||||
store_vector(out, 0, accumulator);
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
inline auto output_grid_for_col_reduce(
|
||||
@@ -236,6 +277,43 @@ void col_reduce_looped(
|
||||
});
|
||||
}
|
||||
|
||||
void col_reduce_small(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan,
|
||||
cu::ColReduceArgs args) {
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
|
||||
constexpr int N_READS = 16 / sizeof(T);
|
||||
auto tmp_grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
auto [grid, block] = get_grid_and_block(tmp_grid.x, tmp_grid.y, 1);
|
||||
auto kernel = cu::col_reduce_small<T, U, OP, N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
args,
|
||||
out.size());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void col_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
@@ -258,8 +336,16 @@ void col_reduce(
|
||||
// Make the args struct to help route to the best kernel
|
||||
cu::ColReduceArgs args(in, plan, axes);
|
||||
|
||||
// Small col reduce with a single or contiguous reduction axis
|
||||
if (args.non_col_reductions == 1 && args.reduction_size <= 32 &&
|
||||
args.reduction_stride % (16 / in.itemsize()) == 0) {
|
||||
col_reduce_small(
|
||||
encoder, in, out, reduce_type, axes, plan, std::move(args));
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback col reduce
|
||||
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
||||
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -12,6 +12,7 @@ namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
class Device;
|
||||
|
||||
}
|
||||
|
||||
struct Dtype;
|
||||
@@ -86,4 +87,17 @@ class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
|
||||
explicit CudaStream(cu::Device& device);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline uint max_occupancy_block_dim(T kernel) {
|
||||
int _, block_dim;
|
||||
if constexpr (std::is_same_v<T, CUfunction>) {
|
||||
CHECK_CUDA_ERROR(
|
||||
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
|
||||
} else {
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
|
||||
}
|
||||
return block_dim;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -5,9 +5,9 @@
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
Worker::Worker()
|
||||
: signal_stream_(device(mlx::core::Device::gpu)),
|
||||
signal_event_(cudaEventDisableTiming | cudaEventBlockingSync),
|
||||
Worker::Worker(Device& d)
|
||||
: signal_stream_(d),
|
||||
signal_event_(d, cudaEventDisableTiming | cudaEventBlockingSync),
|
||||
worker_(&Worker::thread_fn, this) {}
|
||||
|
||||
Worker::~Worker() {
|
||||
|
||||
@@ -15,7 +15,7 @@ namespace mlx::core::cu {
|
||||
// Run tasks in worker thread, synchronized with cuda stream.
|
||||
class Worker {
|
||||
public:
|
||||
Worker();
|
||||
explicit Worker(Device& d);
|
||||
~Worker();
|
||||
|
||||
Worker(const Worker&) = delete;
|
||||
|
||||
@@ -378,7 +378,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
// Need placeholders so Metal doesn't complain
|
||||
int shape_ = 0;
|
||||
int64_t stride_ = 0;
|
||||
compute_encoder.set_bytes(shape_, 3);
|
||||
@@ -393,7 +393,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
// Need placeholders so Metal doesn't complain
|
||||
int shape_ = 0;
|
||||
int64_t stride_ = 0;
|
||||
compute_encoder.set_bytes(shape_, 7);
|
||||
|
||||
@@ -296,6 +296,7 @@ class CompilerCache {
|
||||
std::vector<array> tape;
|
||||
bool empty{true};
|
||||
std::vector<uint64_t> constants;
|
||||
std::shared_ptr<void> extra;
|
||||
};
|
||||
|
||||
// Returns a reference to a CacheEntry which can be updated
|
||||
@@ -376,8 +377,9 @@ CompilerCache& compiler_cache() {
|
||||
return compiler_cache_;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
|
||||
compile_trace(
|
||||
const ArrayFnWithExtra& fun,
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless) {
|
||||
// Set the global tracing flag.
|
||||
@@ -391,7 +393,9 @@ std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||
in.set_tracer(true);
|
||||
tracer_inputs.push_back(std::move(in));
|
||||
}
|
||||
return {tracer_inputs, fun(tracer_inputs)};
|
||||
|
||||
auto output = fun(tracer_inputs);
|
||||
return {tracer_inputs, output.first, output.second};
|
||||
}
|
||||
|
||||
// Traverses the graph to build a tape and a map of array ids to their parents
|
||||
@@ -932,8 +936,8 @@ bool skip_compile() {
|
||||
!(compile_available_for_device(default_device()));
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
ArrayFnWithExtra compile(
|
||||
ArrayFnWithExtra fun,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless /* = false */,
|
||||
std::vector<uint64_t> constants /* = {} */) {
|
||||
@@ -966,7 +970,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
// Set the constants
|
||||
entry.constants = std::move(constants);
|
||||
// Trace to build the graph
|
||||
std::tie(entry.inputs, entry.outputs) =
|
||||
std::tie(entry.inputs, entry.outputs, entry.extra) =
|
||||
compile_trace(fun, inputs, shapeless);
|
||||
|
||||
// DFS the graph and get a tape, and a map of array id to (parent,
|
||||
@@ -991,8 +995,37 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
|
||||
// At this point we must have a tape, now replace the placeholders
|
||||
// with real arrays that can be evaluated
|
||||
return compile_replace(
|
||||
entry.tape, entry.inputs, entry.outputs, inputs, shapeless);
|
||||
return ArraysAndExtra{
|
||||
compile_replace(
|
||||
entry.tape, entry.inputs, entry.outputs, inputs, shapeless),
|
||||
entry.extra};
|
||||
};
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless /* = false */,
|
||||
std::vector<uint64_t> constants /* = {} */) {
|
||||
if (skip_compile()) {
|
||||
return fun;
|
||||
}
|
||||
if (!fun) {
|
||||
throw std::invalid_argument(
|
||||
"[compile] Cannot compile a function without a target.");
|
||||
}
|
||||
|
||||
ArrayFnWithExtra fun_with_extra =
|
||||
[fun = std::move(fun)](const std::vector<array>& inputs) {
|
||||
return ArraysAndExtra{fun(inputs), nullptr};
|
||||
};
|
||||
|
||||
auto compiled_fun = compile(
|
||||
std::move(fun_with_extra), fun_id, shapeless, std::move(constants));
|
||||
|
||||
return [compiled_fun =
|
||||
std::move(compiled_fun)](const std::vector<array>& inputs) {
|
||||
return compiled_fun(inputs).first;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,10 @@
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
||||
using ArraysAndExtra = std::pair<std::vector<array>, std::shared_ptr<void>>;
|
||||
using ArrayFnWithExtra =
|
||||
std::function<ArraysAndExtra(const std::vector<array>&)>;
|
||||
|
||||
// This is not part of the general C++ API as calling with a bad id is a bad
|
||||
// idea.
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
@@ -16,6 +20,12 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
bool shapeless = false,
|
||||
std::vector<uint64_t> constants = {});
|
||||
|
||||
ArrayFnWithExtra compile(
|
||||
ArrayFnWithExtra fun,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless,
|
||||
std::vector<uint64_t> constants);
|
||||
|
||||
// Erase cached compile functions
|
||||
void compile_erase(std::uintptr_t fun_id);
|
||||
|
||||
@@ -25,8 +35,9 @@ void compile_clear_cache();
|
||||
|
||||
bool compile_available_for_device(const Device& device);
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
|
||||
compile_trace(
|
||||
const ArrayFnWithExtra& fun,
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless);
|
||||
|
||||
|
||||
@@ -579,11 +579,11 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
||||
for (auto& k : kwarg_keys) {
|
||||
kwargs.insert({k, *it++});
|
||||
}
|
||||
return fun(args, kwargs);
|
||||
return detail::ArraysAndExtra{fun(args, kwargs), nullptr};
|
||||
};
|
||||
|
||||
// Trace to build the graph
|
||||
auto [trace_inputs, trace_outputs] =
|
||||
auto [trace_inputs, trace_outputs, extra] =
|
||||
detail::compile_trace(flat_fun, inputs, ftable->shapeless);
|
||||
|
||||
// DFS the graph and get the tape
|
||||
|
||||
@@ -86,7 +86,9 @@ def cross_entropy(
|
||||
if targets_as_probs:
|
||||
score = mx.sum(logits * targets, axis=axis)
|
||||
else:
|
||||
score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
|
||||
score = mx.take_along_axis(logits, mx.expand_dims(targets, axis), axis).squeeze(
|
||||
axis
|
||||
)
|
||||
|
||||
logsumexp_logits = mx.logsumexp(logits, axis=axis)
|
||||
if label_smoothing > 0:
|
||||
|
||||
@@ -971,10 +971,6 @@ def clip_grad_norm(grads, max_norm):
|
||||
"""
|
||||
norm_squared = tree_reduce(lambda acc, g: acc + g.square().sum(), grads, 0.0)
|
||||
total_norm = mx.sqrt(norm_squared)
|
||||
normalizer = max_norm / (total_norm + 1e-6)
|
||||
|
||||
def clipper(g):
|
||||
return mx.where(total_norm < max_norm, g, g * normalizer)
|
||||
|
||||
clipped_grads = tree_map(clipper, grads)
|
||||
normalizer = mx.minimum(max_norm / (total_norm + 1e-6), 1.0)
|
||||
clipped_grads = tree_map(lambda g: g * normalizer, grads)
|
||||
return clipped_grads, total_norm
|
||||
|
||||
@@ -389,19 +389,22 @@ auto py_vmap(
|
||||
};
|
||||
}
|
||||
|
||||
std::unordered_map<std::uintptr_t, nb::object>& tree_cache() {
|
||||
// This map is used to Cache the tree structure of the outputs
|
||||
static std::unordered_map<std::uintptr_t, nb::object> tree_cache_;
|
||||
return tree_cache_;
|
||||
}
|
||||
|
||||
struct PyCompiledFun {
|
||||
nb::callable fun;
|
||||
std::uintptr_t fun_id;
|
||||
nb::object captured_inputs;
|
||||
nb::object captured_outputs;
|
||||
bool shapeless;
|
||||
mutable size_t num_outputs{0};
|
||||
|
||||
// Data to attach to the compiled function that contains the python output
|
||||
// structure and the number of arrays in said structure.
|
||||
struct AttachedData {
|
||||
nb::object output_structure;
|
||||
int num_outputs;
|
||||
|
||||
AttachedData(nb::object output_structure_, int num_outputs_)
|
||||
: output_structure(output_structure_), num_outputs(num_outputs_) {}
|
||||
};
|
||||
|
||||
PyCompiledFun(
|
||||
const nb::callable& fun,
|
||||
@@ -424,7 +427,6 @@ struct PyCompiledFun {
|
||||
captured_inputs = std::move(other.captured_inputs);
|
||||
captured_outputs = std::move(other.captured_outputs);
|
||||
shapeless = other.shapeless;
|
||||
num_outputs = other.num_outputs;
|
||||
};
|
||||
|
||||
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
|
||||
@@ -508,9 +510,9 @@ struct PyCompiledFun {
|
||||
auto [outputs, py_outputs] =
|
||||
tree_flatten_with_structure(std::move(tree_outputs), false);
|
||||
|
||||
tree_cache().insert({fun_id, py_outputs});
|
||||
std::shared_ptr<void> extra_data =
|
||||
std::make_shared<AttachedData>(py_outputs, outputs.size());
|
||||
|
||||
num_outputs = outputs.size();
|
||||
if (!captured_outputs.is_none()) {
|
||||
auto flat_out_captures = tree_flatten(captured_outputs, false);
|
||||
outputs.insert(
|
||||
@@ -523,7 +525,7 @@ struct PyCompiledFun {
|
||||
if (!captured_inputs.is_none()) {
|
||||
tree_replace(captured_inputs, trace_captures, flat_in_captures);
|
||||
}
|
||||
return outputs;
|
||||
return mx::detail::ArraysAndExtra{outputs, extra_data};
|
||||
};
|
||||
|
||||
if (!captured_inputs.is_none()) {
|
||||
@@ -535,8 +537,14 @@ struct PyCompiledFun {
|
||||
}
|
||||
|
||||
// Compile and call
|
||||
auto outputs =
|
||||
auto [outputs, extra_data] =
|
||||
mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||
|
||||
int num_outputs =
|
||||
reinterpret_cast<AttachedData*>(extra_data.get())->num_outputs;
|
||||
nb::object py_outputs =
|
||||
reinterpret_cast<AttachedData*>(extra_data.get())->output_structure;
|
||||
|
||||
if (!captured_outputs.is_none()) {
|
||||
std::vector<mx::array> captures(
|
||||
std::make_move_iterator(outputs.begin() + num_outputs),
|
||||
@@ -545,8 +553,7 @@ struct PyCompiledFun {
|
||||
}
|
||||
|
||||
// Put the outputs back in the container
|
||||
nb::object py_outputs = tree_cache().at(fun_id);
|
||||
return tree_unflatten_from_structure(py_outputs, outputs);
|
||||
return tree_unflatten_from_structure(std::move(py_outputs), outputs);
|
||||
}
|
||||
|
||||
nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {
|
||||
@@ -556,7 +563,6 @@ struct PyCompiledFun {
|
||||
~PyCompiledFun() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
tree_cache().erase(fun_id);
|
||||
mx::detail::compile_erase(fun_id);
|
||||
fun.reset();
|
||||
captured_inputs.reset();
|
||||
@@ -1479,8 +1485,6 @@ void init_transforms(nb::module_& m) {
|
||||
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = nb::module_::import_("atexit");
|
||||
atexit.attr("register")(nb::cpp_function([]() {
|
||||
tree_cache().clear();
|
||||
mx::detail::compile_clear_cache();
|
||||
}));
|
||||
atexit.attr("register")(
|
||||
nb::cpp_function([]() { mx::detail::compile_clear_cache(); }));
|
||||
}
|
||||
|
||||
@@ -828,6 +828,19 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
out = fun(arrs)
|
||||
self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0])))
|
||||
|
||||
inputs = [mx.arange(16384).astype(mx.float16) for _ in range(8)]
|
||||
|
||||
def fun(inputs):
|
||||
a = inputs[0] + inputs[1]
|
||||
b = inputs[2] + inputs[3]
|
||||
c = inputs[4] + inputs[5]
|
||||
d = inputs[6] + inputs[7]
|
||||
return a * b * c * d
|
||||
|
||||
out = mx.compile(fun)(inputs)
|
||||
expected = fun(inputs)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_compile_many_outputs(self):
|
||||
|
||||
@mx.compile
|
||||
@@ -1051,6 +1064,57 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
out = fun(mx.array(1.0), mx.array(2.0))
|
||||
self.assertEqual(out.item(), 3.0)
|
||||
|
||||
def test_compile_changing_outputs(self):
|
||||
@mx.compile
|
||||
def fun(x, y):
|
||||
if y is None:
|
||||
return 2 * x
|
||||
elif (
|
||||
isinstance(x, mx.array)
|
||||
and isinstance(y, mx.array)
|
||||
and x.dtype == y.dtype == mx.float32
|
||||
):
|
||||
return [x + y]
|
||||
elif y.dtype == mx.bool_:
|
||||
return {"a": x, "b": y * x}
|
||||
else:
|
||||
return None
|
||||
|
||||
a = fun(mx.array(1.0), mx.array(2.0))
|
||||
self.assertTrue(isinstance(a, list))
|
||||
self.assertEqual(a[0].item(), 3.0)
|
||||
|
||||
b = fun(mx.array(1.0), mx.array(True))
|
||||
self.assertTrue(isinstance(b, dict))
|
||||
self.assertEqual(b["a"].item(), 1.0)
|
||||
self.assertEqual(b["b"].item(), 1.0)
|
||||
|
||||
c = fun(mx.array(1.0), None)
|
||||
self.assertTrue(isinstance(c, mx.array))
|
||||
self.assertEqual(c.item(), 2.0)
|
||||
|
||||
d = fun(False, mx.array(1.0))
|
||||
self.assertTrue(d is None)
|
||||
|
||||
def test_compile_changing_outputs_with_state(self):
|
||||
state = [mx.array(1.0)]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def fun(y):
|
||||
x = state[0]
|
||||
if y.dtype == mx.float32:
|
||||
state[0] = 2 * y
|
||||
return [x, y, x + y]
|
||||
elif y.dtype == mx.int32:
|
||||
state[0] *= 2
|
||||
return x + y
|
||||
|
||||
for i in range(10):
|
||||
fun(mx.array(1.0))
|
||||
fun(mx.array(1))
|
||||
|
||||
self.assertEqual(state[0].item(), 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
@@ -60,9 +60,19 @@ class TestLosses(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(loss, expected))
|
||||
|
||||
probs = mx.array([[1.0, 0.0], [0.0, 1.0]])
|
||||
# Test a different axis
|
||||
logits = mx.random.normal((4, 8))
|
||||
targets = mx.array([1, 2, 3, 0])
|
||||
loss = nn.losses.cross_entropy(
|
||||
logits, probs, weights=weights, label_smoothing=0.3, reduction="none"
|
||||
logits.T,
|
||||
targets,
|
||||
axis=0,
|
||||
)
|
||||
targets = mx.array([1, 2, 3, 0])
|
||||
expected = nn.losses.cross_entropy(
|
||||
logits,
|
||||
targets,
|
||||
axis=-1,
|
||||
)
|
||||
self.assertTrue(mx.allclose(loss, expected))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user