diff --git a/benchmarks/python/layer_norm_bench.py b/benchmarks/python/layer_norm_bench.py index 8b9635315..69263835a 100644 --- a/benchmarks/python/layer_norm_bench.py +++ b/benchmarks/python/layer_norm_bench.py @@ -10,7 +10,12 @@ def layer_norm(x, w, b, eps): x = x.astype(mx.float32) mu = mx.mean(x, -1, keepdims=True) v = mx.var(x, -1, keepdims=True) - return (x - mu) * mx.rsqrt(v + eps) * w + b + y = (x - mu) * mx.rsqrt(v + eps) + if w is not None: + y = y * w + if b is not None: + y = y + b + return y def time_layer_norm(): @@ -36,6 +41,28 @@ def time_layer_norm(): time_fn(layer_norm_loop, mx.compile(g1), x, w, b) time_fn(layer_norm_loop, mx.compile(g2), x, w, b) + f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum() + f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum() + g1 = mx.grad(f1, argnums=(0,)) + g2 = mx.grad(f2, argnums=(0,)) + + x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) + w = mx.random.uniform(shape=(4096,)).astype(mx.float16) + b = mx.random.uniform(shape=(4096,)).astype(mx.float16) + y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) + mx.eval(x, w, b, y) + + def layer_norm_loop(g, x): + gx = x + for _ in range(32): + gx = g(gx, y) + return gx + + time_fn(layer_norm_loop, g1, x) + time_fn(layer_norm_loop, g2, x) + time_fn(layer_norm_loop, mx.compile(g1), x) + time_fn(layer_norm_loop, mx.compile(g2), x) + if __name__ == "__main__": time_layer_norm() diff --git a/benchmarks/python/rms_norm_bench.py b/benchmarks/python/rms_norm_bench.py index a54dfe697..50f3a40b0 100644 --- a/benchmarks/python/rms_norm_bench.py +++ b/benchmarks/python/rms_norm_bench.py @@ -9,7 +9,10 @@ def rms_norm(x, w, eps): ot = x.dtype x = x.astype(mx.float32) n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps) - return (x * n).astype(ot) * w + y = (x * n).astype(ot) + if w is not None: + y = y * w + return y def time_rms_norm(): @@ -34,6 +37,27 @@ def time_rms_norm(): time_fn(rms_norm_loop, mx.compile(g1), x, w) time_fn(rms_norm_loop, mx.compile(g2), x, w) + f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum() + f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum() + g1 = mx.grad(f1, argnums=(0,)) + g2 = mx.grad(f2, argnums=(0,)) + + x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) + w = mx.random.uniform(shape=(4096,)).astype(mx.float16) + y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) + mx.eval(x, w, y) + + def rms_norm_loop(g, x): + gx = x + for _ in range(32): + gx = g(gx, y) + return gx + + time_fn(rms_norm_loop, g1, x) + time_fn(rms_norm_loop, g2, x) + time_fn(rms_norm_loop, mx.compile(g1), x) + time_fn(rms_norm_loop, mx.compile(g2), x) + if __name__ == "__main__": time_rms_norm() diff --git a/mlx/backend/common/slicing.cpp b/mlx/backend/common/slicing.cpp index c446ff948..93b9d480e 100644 --- a/mlx/backend/common/slicing.cpp +++ b/mlx/backend/common/slicing.cpp @@ -14,6 +14,10 @@ std::tuple prepare_slice( data_offset += start_indices[i] * in.strides()[i]; inp_strides[i] = in.strides()[i] * strides[i]; } + // Normalize the offset + if (data_offset < 0) { + data_offset += in.data_size(); + } return std::make_tuple(data_offset, inp_strides); } @@ -54,9 +58,10 @@ void slice( data_end += end_idx * in.strides()[i]; } } - // data_end can be -1 - size_t data_size = - data_end < 0 ? (data_offset - data_end) : (data_end - data_offset); + if (data_end < 0) { + data_end += in.data_size(); + } + size_t data_size = (data_end - data_offset); shared_buffer_slice(in, inp_strides, data_offset, data_size, out); } diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index 38ca9f371..da7788971 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -543,8 +543,8 @@ void quantize( T* scales = scales_.data(); T* biases = biases_.data(); - T n_bins = (1 << bits) - 1; - T eps = 1e-7; + float n_bins = (1 << bits) - 1; + float eps = 1e-7; bool power_of_2_bits = is_power_of_2(bits); int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; // For 3/6 bits we read 3 uint8s at a time instead of 1 uint32 @@ -554,32 +554,30 @@ void quantize( for (size_t i = 0; i < n_groups; ++i) { size_t w_idx = i * group_size; - T w_min = std::numeric_limits::infinity(); - T w_max = -w_min; + float w_min = std::numeric_limits::infinity(); + float w_max = -w_min; for (int j = 0; j < group_size; ++j) { - w_max = std::max(w_max, w[w_idx + j]); - w_min = std::min(w_min, w[w_idx + j]); + w_max = std::max(w_max, (float)w[w_idx + j]); + w_min = std::min(w_min, (float)w[w_idx + j]); } bool mask = std::abs(w_min) > std::abs(w_max); - T scale = std::max(T((w_max - w_min) / n_bins), eps); + float scale = std::max((w_max - w_min) / n_bins, eps); scale = mask ? scale : -scale; - auto edge = mask ? w_min : w_max; - auto q0 = std::rint(edge / scale); - if (q0 == 0) { - scales[i] = scale; - biases[i] = 0; - } else { - scales[i] = edge / q0; - biases[i] = edge; + float edge = mask ? w_min : w_max; + float q0 = std::rint(edge / scale); + float bias = 0; + if (q0 != 0) { + scale = edge / q0; + bias = edge; } size_t out_idx = i * int_per_group; for (int j = 0; j < int_per_group / bytes_per_pack; ++j) { uint32_t out_el = 0; for (int k = 0; k < el_per_int; ++k) { - T w_el = w[w_idx + j * el_per_int + k]; - w_el = std::rint((w_el - biases[i]) / scales[i]); - w_el = std::min(std::max(w_el, T(0)), n_bins); + float w_el = w[w_idx + j * el_per_int + k]; + w_el = std::rint((w_el - bias) / scale); + w_el = std::min(std::max(w_el, 0.0f), n_bins); out_el |= static_cast(w_el) << (k * bits); } if (power_of_2_bits) { @@ -590,6 +588,8 @@ void quantize( out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16; } } + scales[i] = static_cast(scale); + biases[i] = static_cast(bias); } } diff --git a/mlx/backend/cpu/svd.cpp b/mlx/backend/cpu/svd.cpp index 33a30d843..88b127bce 100644 --- a/mlx/backend/cpu/svd.cpp +++ b/mlx/backend/cpu/svd.cpp @@ -8,7 +8,7 @@ namespace mlx::core { template -void svd_impl(const array& a, array& u, array& s, array& vt) { +void svd_impl(const array& a, T* u_data, T* s_data, T* vt_data) { // Lapack uses the column-major convention. To avoid having to transpose // the input and then transpose the outputs, we swap the indices/sizes of the // matrices and take advantage of the following identity (see @@ -35,13 +35,8 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { array in(a.shape(), a.dtype(), nullptr, {}); copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General); - // Allocate outputs. - u.set_data(allocator::malloc_or_wait(u.nbytes())); - s.set_data(allocator::malloc_or_wait(s.nbytes())); - vt.set_data(allocator::malloc_or_wait(vt.nbytes())); - - static constexpr auto job_u = "V"; - static constexpr auto job_vt = "V"; + auto job_u = (u_data && vt_data) ? "V" : "N"; + auto job_vt = (u_data && vt_data) ? "V" : "N"; static constexpr auto range = "A"; // Will contain the number of singular values after the call has returned. @@ -56,6 +51,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { static const int ignored_int = 0; static const T ignored_float = 0; + static T ignored_output = 0; int info; @@ -109,12 +105,12 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { /* il = */ &ignored_int, /* iu = */ &ignored_int, /* ns = */ &ns, - /* s = */ s.data() + K * i, + /* s = */ s_data + K * i, // According to the identity above, lapack will write Vᵀᵀ as U. - /* u = */ vt.data() + N * N * i, + /* u = */ vt_data ? vt_data + N * N * i : &ignored_output, /* ldu = */ &ldu, // According to the identity above, lapack will write Uᵀ as Vᵀ. - /* vt = */ u.data() + M * M * i, + /* vt = */ u_data ? u_data + M * M * i : &ignored_output, /* ldvt = */ &ldvt, /* work = */ static_cast(scratch.buffer.raw_ptr()), /* lwork = */ &lwork, @@ -136,15 +132,36 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { } } +template +void compute_svd(const array& a, bool compute_uv, std::vector& outputs) { + if (compute_uv) { + array& u = outputs[0]; + array& s = outputs[1]; + array& vt = outputs[2]; + + u.set_data(allocator::malloc_or_wait(u.nbytes())); + s.set_data(allocator::malloc_or_wait(s.nbytes())); + vt.set_data(allocator::malloc_or_wait(vt.nbytes())); + + svd_impl(a, u.data(), s.data(), vt.data()); + } else { + array& s = outputs[0]; + + s.set_data(allocator::malloc_or_wait(s.nbytes())); + + svd_impl(a, nullptr, s.data(), nullptr); + } +} + void SVD::eval_cpu( const std::vector& inputs, std::vector& outputs) { switch (inputs[0].dtype()) { case float32: - svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]); + compute_svd(inputs[0], compute_uv_, outputs); break; case float64: - svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]); + compute_svd(inputs[0], compute_uv_, outputs); break; default: throw std::runtime_error( diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 7d34fc815..f2c95be20 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -10,6 +10,9 @@ namespace mlx::core { +constexpr size_t resource_options = + MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeUntracked; + namespace allocator { Allocator& allocator() { @@ -150,15 +153,34 @@ MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), residency_set_(device_), buffer_cache_(device_) { - auto memsize = std::get(device_info()["memory_size"]); + auto pool = metal::new_scoped_memory_pool(); + auto memsize = std::get(device_info().at("memory_size")); auto max_rec_size = - std::get(device_info()["max_recommended_working_set_size"]); - resource_limit_ = std::get(device_info()["resource_limit"]); + std::get(device_info().at("max_recommended_working_set_size")); + resource_limit_ = std::get(device_info().at("resource_limit")); block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize); gc_limit_ = std::min(static_cast(0.95 * max_rec_size), block_limit_); max_pool_size_ = block_limit_; device(mlx::core::Device::gpu) .set_residency_set(residency_set_.mtl_residency_set()); + bool is_vm = std::get(device_info().at("device_name")) == + "Apple Paravirtual device"; + if (is_vm) { + return; + } + auto heap_desc = MTL::HeapDescriptor::alloc()->init(); + heap_desc->setResourceOptions(resource_options); + heap_desc->setSize(heap_size_); + heap_ = device_->newHeap(heap_desc); + heap_desc->release(); + residency_set_.insert(heap_); +} + +MetalAllocator::~MetalAllocator() { + auto pool = metal::new_scoped_memory_pool(); + if (heap_) { + heap_->release(); + } } size_t MetalAllocator::set_cache_limit(size_t limit) { @@ -226,8 +248,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { } // Allocate new buffer if needed - size_t res_opt = MTL::ResourceStorageModeShared; - res_opt |= MTL::ResourceHazardTrackingModeUntracked; if (num_resources_ >= resource_limit_) { std::ostringstream msg; msg << "[metal::malloc] Resource limit (" << resource_limit_ @@ -235,7 +255,12 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { throw std::runtime_error(msg.str()); } lk.unlock(); - buf = device_->newBuffer(size, res_opt); + if (size < small_size_ && heap_) { + buf = heap_->newBuffer(size, resource_options); + } + if (!buf) { + buf = device_->newBuffer(size, resource_options); + } lk.lock(); if (buf) { num_resources_++; @@ -246,13 +271,15 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { peak_memory_ = std::max(peak_memory_, active_memory_); // Maintain the cache below the requested limit - if (get_cache_memory() >= max_pool_size_) { + if (get_cache_memory() > max_pool_size_) { auto pool = metal::new_scoped_memory_pool(); num_resources_ -= buffer_cache_.release_cached_buffers( get_cache_memory() - max_pool_size_); } - residency_set_.insert(buf); + if (!buf->heap()) { + residency_set_.insert(buf); + } return Buffer{static_cast(buf)}; } @@ -269,7 +296,9 @@ void MetalAllocator::free(Buffer buffer) { return; } std::unique_lock lk(mutex_); - residency_set_.erase(buf); + if (!buf->heap()) { + residency_set_.erase(buf); + } active_memory_ -= buf->length(); if (get_cache_memory() < max_pool_size_) { buffer_cache_.recycle_to_cache(buf); @@ -301,7 +330,7 @@ size_t set_memory_limit(size_t limit, bool relaxed /* = true */) { } size_t set_wired_limit(size_t limit) { if (limit > - std::get(device_info()["max_recommended_working_set_size"])) { + std::get(device_info().at("max_recommended_working_set_size"))) { throw std::invalid_argument( "[metal::set_wired_limit] Setting a wired limit larger than " "the maximum working set size is not allowed."); diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 4e662ab56..df301f55e 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -43,6 +43,7 @@ class BufferCache { void remove_from_list(BufferHolder* to_remove); MTL::Device* device_; + MTL::Heap* heap_{nullptr}; std::multimap buffer_pool_; BufferHolder* head_; @@ -78,7 +79,15 @@ class MetalAllocator : public allocator::Allocator { private: MTL::Device* device_; + + // The size of allocations which go on the heap until it is full. This size + // is chosen because it is the actual minimum size of a buffer allocated from + // the heap, a heap can have at most heap.size() / 256 buffers. + static constexpr int small_size_ = 256; + static constexpr int heap_size_ = 1 << 20; + MTL::Heap* heap_; MetalAllocator(); + ~MetalAllocator(); friend MetalAllocator& allocator(); // Caching allocator diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 7a82e2fb3..06681c458 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -692,12 +692,13 @@ void new_stream(Stream stream) { } } -std::unordered_map> +const std::unordered_map>& device_info() { auto init_device_info = []() -> std::unordered_map> { auto pool = new_scoped_memory_pool(); auto raw_device = device(default_device()).mtl_device(); + auto name = std::string(raw_device->name()->utf8String()); auto arch = std::string(raw_device->architecture()->name()->utf8String()); size_t memsize = 0; @@ -711,6 +712,7 @@ device_info() { } return { + {"device_name", name}, {"architecture", arch}, {"max_buffer_length", raw_device->maxBufferLength()}, {"max_recommended_working_set_size", diff --git a/mlx/backend/metal/kernels/layer_norm.metal b/mlx/backend/metal/kernels/layer_norm.metal index 462eb3b94..4674a4228 100644 --- a/mlx/backend/metal/kernels/layer_norm.metal +++ b/mlx/backend/metal/kernels/layer_norm.metal @@ -7,6 +7,8 @@ using namespace metal; +constant bool has_w [[function_constant(20)]]; + template [[kernel]] void layer_norm_single_row( const device T* x, @@ -327,7 +329,9 @@ template gx[i] = static_cast( normalizer * (thread_w[i] * thread_g[i] - meanwg) - thread_x[i] * meanwgxc * normalizer2); - gw[i] = static_cast(thread_g[i] * thread_x[i]); + if (has_w) { + gw[i] = static_cast(thread_g[i] * thread_x[i]); + } } } else { for (int i = 0; i < N_READS; i++) { @@ -336,7 +340,9 @@ template gx[i] = static_cast( normalizer * (thread_w[i] * thread_g[i] - meanwg) - thread_x[i] * meanwgxc * normalizer2); - gw[i] = static_cast(thread_g[i] * thread_x[i]); + if (has_w) { + gw[i] = static_cast(thread_g[i] * thread_x[i]); + } } } } @@ -465,7 +471,9 @@ template float gi = g[i + r]; gx[i + r] = static_cast( normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); - gw[i + r] = static_cast(gi * xi); + if (has_w) { + gw[i + r] = static_cast(gi * xi); + } } } else { for (int i = 0; i < N_READS; i++) { @@ -475,7 +483,9 @@ template float gi = g[i + r]; gx[i + r] = static_cast( normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); - gw[i + r] = static_cast(gi * xi); + if (has_w) { + gw[i + r] = static_cast(gi * xi); + } } } } diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 1652207e3..3af3c971f 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -2015,9 +2015,9 @@ template device T* biases [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - constexpr T eps = T(1e-7); + constexpr float eps = 1e-7; constexpr int simd_size = 32; - constexpr T n_bins = (1 << bits) - 1; + constexpr float n_bins = (1 << bits) - 1; constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int values_per_reduce = group_size / simd_size; constexpr int writes_per_reduce = packs_per_int / values_per_reduce; @@ -2036,13 +2036,13 @@ template ? offset * writes_per_pack : offset * bytes_per_pack / writes_per_reduce; - T w_thread[values_per_reduce]; - T w_min = Limits::max; - T w_max = 0; + float w_thread[values_per_reduce]; + float w_min = Limits::max; + float w_max = 0; #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { - T val = w[in_index + i]; + float val = w[in_index + i]; w_thread[i] = val; w_min = min(w_min, val); w_max = max(w_max, val); @@ -2051,20 +2051,20 @@ template w_min = simd_min(w_min); w_max = simd_max(w_max); - T scale = max((w_max - w_min) / n_bins, eps); + float scale = max((w_max - w_min) / n_bins, eps); bool side = abs(w_min) > abs(w_max); scale = side ? scale : -scale; - T edge = side ? w_min : w_max; - T q0 = round(edge / scale); + float edge = side ? w_min : w_max; + float q0 = round(edge / scale); bool at_zero = q0 == 0.0f; scale = at_zero ? scale : edge / q0; - T bias = at_zero ? T(0) : edge; + float bias = at_zero ? 0 : edge; // Write out the scales and biases size_t gindex = in_index / group_size; if (in_index % group_size == 0) { - scales[gindex] = scale; - biases[gindex] = bias; + scales[gindex] = static_cast(scale); + biases[gindex] = static_cast(bias); } // We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t diff --git a/mlx/backend/metal/kernels/rms_norm.metal b/mlx/backend/metal/kernels/rms_norm.metal index f8fb53dd5..f4c1536de 100644 --- a/mlx/backend/metal/kernels/rms_norm.metal +++ b/mlx/backend/metal/kernels/rms_norm.metal @@ -7,6 +7,8 @@ using namespace metal; +constant bool has_w [[function_constant(20)]]; + template [[kernel]] void rms_single_row( const device T* x, @@ -243,7 +245,9 @@ template gx[i] = static_cast( thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3); - gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); + if (has_w) { + gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); + } } } else { for (int i = 0; i < N_READS; i++) { @@ -251,7 +255,9 @@ template gx[i] = static_cast( thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3); - gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); + if (has_w) { + gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); + } } } } @@ -351,7 +357,9 @@ template gx[i + r] = static_cast(gi * wi * normalizer - xi * meangwx * normalizer3); - gw[i + r] = static_cast(gi * xi * normalizer); + if (has_w) { + gw[i + r] = static_cast(gi * xi * normalizer); + } } } else { for (int i = 0; i < N_READS; i++) { @@ -362,7 +370,9 @@ template gx[i + r] = static_cast(gi * wi * normalizer - xi * meangwx * normalizer3); - gw[i + r] = static_cast(gi * xi * normalizer); + if (has_w) { + gw[i + r] = static_cast(gi * xi * normalizer); + } } } } diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index f5fe88f4a..1c3d23fc4 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -5,6 +5,7 @@ using namespace metal; constant bool has_mask [[function_constant(20)]]; +constant bool query_transposed [[function_constant(21)]]; template [[kernel]] void sdpa_vector( @@ -18,9 +19,11 @@ template const constant size_t& v_stride, const constant float& scale, const device bool* mask [[function_constant(has_mask)]], - const constant int& mask_seq_stride [[function_constant(has_mask)]], + const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], + const constant int& mask_q_seq_stride [[function_constant(has_mask)]], const constant int& mask_head_stride [[function_constant(has_mask)]], uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 32; @@ -41,15 +44,21 @@ template threadgroup U sum_exp_scores[BN]; // Adjust positions - const int head_idx = tid.y; + const int head_idx = tid.x; + const int q_seq_idx = tid.y; const int kv_head_idx = head_idx / gqa_factor; - queries += head_idx * D + simd_lid * qk_per_thread; + const int o_offset = tpg.x * q_seq_idx + head_idx; + const int q_offset = + query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; + queries += q_offset * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * qk_per_thread; values += kv_head_idx * v_stride + simd_gid * V + simd_lid * v_per_thread; if (has_mask) { - mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride; + mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + + q_seq_idx * mask_q_seq_stride; } - out += head_idx * V + simd_gid * v_per_thread; + + out += o_offset * V + simd_gid * v_per_thread; // Read the query and 0 the output accumulator for (int i = 0; i < qk_per_thread; i++) { @@ -95,7 +104,7 @@ template keys += inner_k_stride; values += inner_v_stride; if (has_mask) { - mask += BN * mask_seq_stride; + mask += BN * mask_kv_seq_stride; } } @@ -142,9 +151,11 @@ template const constant size_t& v_stride, const constant float& scale, const device bool* mask [[function_constant(has_mask)]], - const constant int& mask_seq_stride [[function_constant(has_mask)]], + const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], + const constant int& mask_q_seq_stride [[function_constant(has_mask)]], const constant int& mask_head_stride [[function_constant(has_mask)]], uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 8; @@ -167,20 +178,26 @@ template // Adjust positions const int block_idx = tid.z; - const int head_idx = tid.y; + const int head_idx = tid.x; + const int q_seq_idx = tid.y; + const int o_offset = tpg.x * q_seq_idx + head_idx; + const int q_offset = + query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; const int kv_head_idx = head_idx / gqa_factor; - queries += head_idx * D + simd_lid * qk_per_thread; + + queries += q_offset * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D + simd_lid * qk_per_thread; values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * V + simd_lid * v_per_thread; - out += head_idx * blocks * V + block_idx * V + simd_lid * v_per_thread; + out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; if (has_mask) { mask += head_idx * mask_head_stride + - (block_idx * BN + simd_gid) * mask_seq_stride; + (block_idx * BN + simd_gid) * mask_kv_seq_stride + + q_seq_idx * mask_q_seq_stride; } - sums += head_idx * blocks + block_idx; - maxs += head_idx * blocks + block_idx; + sums += o_offset * blocks + block_idx; + maxs += o_offset * blocks + block_idx; // Read the query and 0 the output accumulator for (int i = 0; i < qk_per_thread; i++) { @@ -226,7 +243,7 @@ template keys += blocks * inner_k_stride; values += blocks * inner_v_stride; if (has_mask) { - mask += BN * blocks * mask_seq_stride; + mask += BN * blocks * mask_kv_seq_stride; } } @@ -275,6 +292,7 @@ template const device float* maxs [[buffer(2)]], device T* out [[buffer(3)]], uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 32; @@ -288,11 +306,14 @@ template threadgroup U outputs[BN * BD]; // Adjust positions - const int head_idx = tid.y; - partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread; - sums += head_idx * blocks; - maxs += head_idx * blocks; - out += head_idx * D + simd_gid * elem_per_thread; + const int head_idx = tid.x; + const int q_seq_idx = tid.y; + const int n_heads = tpg.x; + const int q_offset = n_heads * q_seq_idx + head_idx; + partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread; + sums += q_offset * blocks; + maxs += q_offset * blocks; + out += q_offset * D + simd_gid * elem_per_thread; // First everybody reads the max and sum_exp U max_score = maxs[simd_lid]; diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index e5cb65afd..d5c64f79d 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -82,7 +82,7 @@ void start_capture(std::string path = ""); void stop_capture(); /** Get information about the GPU and system settings. */ -std::unordered_map> +const std::unordered_map>& device_info(); } // namespace mlx::core::metal diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index 7fa7e8646..3f5216392 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -77,7 +77,7 @@ void RMSNorm::eval_gpu( group_dims = MTL::Size(threadgroup_size, 1, 1); } - uint32_t w_stride = w.strides()[0]; + uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array( x.data_shared_ptr() == nullptr ? out : x, 0); @@ -101,20 +101,16 @@ void RMSNormVJP::eval_gpu( // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. - std::vector copies; - auto check_input = [&copies, &s](const array& x) -> const array& { + auto check_input = [&d, &s](const array& x) -> array { if (x.flags().row_contiguous) { return x; } - // Make sure we 'll only ever allocate once. The point of that goes beyond - // the minor optimization. We need to ensure that there will be no - // reallocation such that the references won't change when we - // push_back(...). So tl;dr 3 possible copies x, g and gw_temp. - copies.reserve(3); - copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); - copy_gpu(x, copies.back(), CopyType::General, s); - return copies.back(); + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + + return x_copy; }; const array& x = check_input(inputs[0]); const array& w = inputs[1]; @@ -122,6 +118,9 @@ void RMSNormVJP::eval_gpu( array& gx = outputs[0]; array& gw = outputs[1]; + // Check whether we had a weight + bool has_w = w.ndim() != 0; + // Allocate space for the outputs bool x_in_gx = false; bool g_in_gx = false; @@ -140,15 +139,18 @@ void RMSNormVJP::eval_gpu( // Allocate the gradient accumulator gw and a temporary to store the // gradients before they are accumulated. - array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}); + array gw_temp = + (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; bool g_in_gw = false; - if (!g_in_gx && g.is_donatable()) { - gw_temp.move_shared_buffer(g); - g_in_gw = true; - } else { - gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); + if (has_w) { + if (!g_in_gx && g.is_donatable()) { + gw_temp.move_shared_buffer(g); + g_in_gw = true; + } else { + gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); + d.add_temporary(gw_temp, s.index); + } } - copies.push_back(gw_temp); gw.set_data(allocator::malloc_or_wait(gw.nbytes())); const int simd_size = 32; @@ -159,9 +161,15 @@ void RMSNormVJP::eval_gpu( op_name += "_looped"; } op_name += type_to_name(gx); + + std::string hash_name = op_name + ((has_w) ? "_w" : "_now"); + metal::MTLFCList func_consts = { + {&has_w, MTL::DataType::DataTypeBool, 20}, + }; + auto& compute_encoder = d.get_command_encoder(s.index); { - auto kernel = d.get_kernel(op_name); + auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts); MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { @@ -179,7 +187,7 @@ void RMSNormVJP::eval_gpu( group_dims = MTL::Size(threadgroup_size, 1, 1); } - uint32_t w_stride = w.strides()[0]; + uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(x_in_gx ? gx : x, 0); compute_encoder.set_input_array(w, 1); @@ -192,12 +200,12 @@ void RMSNormVJP::eval_gpu( compute_encoder.dispatch_threads(grid_dims, group_dims); } - ReductionPlan plan( - ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); - strided_reduce_general_dispatch( - gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s); - - d.add_temporaries(std::move(copies), s.index); + if (has_w) { + ReductionPlan plan( + ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); + strided_reduce_general_dispatch( + gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s); + } } void LayerNorm::eval_gpu( @@ -295,20 +303,16 @@ void LayerNormVJP::eval_gpu( // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. - std::vector copies; - auto check_input = [&copies, &s](const array& x) -> const array& { + auto check_input = [&d, &s](const array& x) -> array { if (x.flags().row_contiguous) { return x; } - // Make sure we 'll only ever allocate once. The point of that goes beyond - // the minor optimization. We need to ensure that there will be no - // reallocation such that the references won't change when we - // push_back(...). So tl;dr 3 possible copies x, g and gw_temp. - copies.reserve(3); - copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); - copy_gpu(x, copies.back(), CopyType::General, s); - return copies.back(); + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + + return x_copy; }; const array& x = check_input(inputs[0]); const array& w = inputs[1]; @@ -318,6 +322,9 @@ void LayerNormVJP::eval_gpu( array& gw = outputs[1]; array& gb = outputs[2]; + // Check whether we had a weight + bool has_w = w.ndim() != 0; + // Allocate space for the outputs bool x_in_gx = false; bool g_in_gx = false; @@ -336,15 +343,18 @@ void LayerNormVJP::eval_gpu( // Allocate a temporary to store the gradients for w and allocate the output // gradient accumulators. - array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}); + array gw_temp = + (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; bool g_in_gw = false; - if (!g_in_gx && g.is_donatable()) { - gw_temp.move_shared_buffer(g); - g_in_gw = true; - } else { - gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); + if (has_w) { + if (!g_in_gx && g.is_donatable()) { + gw_temp.move_shared_buffer(g); + g_in_gw = true; + } else { + gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); + } + d.add_temporary(gw_temp, s.index); } - copies.push_back(gw_temp); gw.set_data(allocator::malloc_or_wait(gw.nbytes())); gb.set_data(allocator::malloc_or_wait(gb.nbytes())); @@ -372,8 +382,14 @@ void LayerNormVJP::eval_gpu( op_name += "_looped"; } op_name += type_to_name(gx); + + std::string hash_name = op_name + ((has_w) ? "_w" : "_now"); + metal::MTLFCList func_consts = { + {&has_w, MTL::DataType::DataTypeBool, 20}, + }; + { - auto kernel = d.get_kernel(op_name); + auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts); MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { @@ -404,14 +420,12 @@ void LayerNormVJP::eval_gpu( compute_encoder.dispatch_threads(grid_dims, group_dims); } - if (gw.ndim() == 1 && gw.size() == axis_size) { + if (has_w) { ReductionPlan plan( ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); strided_reduce_general_dispatch( gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s); } - - d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core::fast diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index 1ca3597e5..c6da9278a 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -25,6 +25,10 @@ void RoPE::eval_gpu( size_t out_strides[3]; bool donated = false; int ndim = in.ndim(); + int dispatch_ndim = in.ndim(); + while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { + dispatch_ndim--; + } size_t mat_size = in.shape(-2) * in.shape(-1); if (dims_ < in.shape(-1)) { donated = true; @@ -44,12 +48,12 @@ void RoPE::eval_gpu( strides[0] = mat_size; strides[1] = in.strides()[ndim - 2]; strides[2] = in.strides()[ndim - 1]; - } else if (ndim == 3) { + } else if (dispatch_ndim == 3) { // Handle non-contiguous 3D inputs out.set_data(allocator::malloc_or_wait(out.nbytes())); - strides[0] = in.strides()[0]; - strides[1] = in.strides()[1]; - strides[2] = in.strides()[2]; + strides[0] = in.strides()[ndim - 3]; + strides[1] = in.strides()[ndim - 2]; + strides[2] = in.strides()[ndim - 1]; } else { // Copy non-contiguous > 3D inputs into the output and treat // input as donated diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 47bf7f22a..a349fd031 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -134,14 +134,17 @@ void sdpa_vector( size_t k_stride = k.strides()[1]; size_t v_stride = v.strides()[1]; MTL::Size group_dims(1024, 1, 1); - MTL::Size grid_dims(1, B, 1); + MTL::Size grid_dims(B, q.shape(2), 1); bool has_mask = mask.has_value(); + bool query_transposed = !q.flags().row_contiguous; metal::MTLFCList func_consts = { {&has_mask, MTL::DataType::DataTypeBool, 20}, + {&query_transposed, MTL::DataType::DataTypeBool, 21}, }; std::string hash_name = kname; hash_name += has_mask ? "_mask" : "_nomask"; + hash_name += query_transposed ? "_qt" : "_qnt"; // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -161,10 +164,14 @@ void sdpa_vector( if (has_mask) { auto& m = *mask; compute_encoder.set_input_array(m, 9); - int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0; - int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0; - compute_encoder.set_bytes(seq_stride, 10); - compute_encoder.set_bytes(head_stride, 11); + auto nd = m.ndim(); + int32_t kv_seq_stride = + nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; + int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; + int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; + compute_encoder.set_bytes(kv_seq_stride, 10); + compute_encoder.set_bytes(q_seq_stride, 11); + compute_encoder.set_bytes(head_stride, 12); } // Launch @@ -198,7 +205,7 @@ void sdpa_vector_2pass( auto k_stride = k.strides()[1]; auto v_stride = v.strides()[1]; MTL::Size group_dims(8 * 32, 1, 1); - MTL::Size grid_dims(1, B, blocks); + MTL::Size grid_dims(B, q.shape(2), blocks); // Allocate the intermediates Shape intermediate_shape; @@ -219,11 +226,14 @@ void sdpa_vector_2pass( d.add_temporary(maxs, s.index); bool has_mask = mask.has_value(); + bool query_transposed = !q.flags().row_contiguous; metal::MTLFCList func_consts = { {&has_mask, MTL::DataType::DataTypeBool, 20}, + {&query_transposed, MTL::DataType::DataTypeBool, 21}, }; std::string hash_name = kname; hash_name += has_mask ? "_mask" : "_nomask"; + hash_name += query_transposed ? "_qt" : "_qnt"; // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -246,10 +256,14 @@ void sdpa_vector_2pass( if (has_mask) { auto& m = *mask; compute_encoder.set_input_array(m, 11); - int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0; - int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0; - compute_encoder.set_bytes(seq_stride, 12); - compute_encoder.set_bytes(head_stride, 13); + auto nd = m.ndim(); + int32_t kv_seq_stride = + nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; + int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; + int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; + compute_encoder.set_bytes(kv_seq_stride, 12); + compute_encoder.set_bytes(q_seq_stride, 13); + compute_encoder.set_bytes(head_stride, 14); } // Launch @@ -274,7 +288,7 @@ void sdpa_vector_2pass( // Launch group_dims = MTL::Size(1024, 1, 1); - grid_dims = MTL::Size(1, B, 1); + grid_dims = MTL::Size(B, q.shape(2), 1); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -301,16 +315,23 @@ void ScaledDotProductAttention::eval_gpu( if (!predicate(arr)) { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); + copies.push_back(std::move(arr_copy)); return copies.back(); } else { return arr; } }; - // Checks if arr is fully row contiguous - auto is_contiguous = [](const array& arr) { - return arr.flags().row_contiguous; + // Checks if arr is row contiguous or the sequence and head dimension are + // transposed + auto is_contiguous_or_head_seq_transposed = [](const array& arr) { + if (arr.flags().row_contiguous) { + return true; + } + auto& strides = arr.strides(); + auto& shape = arr.shape(); + return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) && + (strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]); }; // Returns true if the array is row contiguous except the sequence length @@ -328,18 +349,30 @@ void ScaledDotProductAttention::eval_gpu( }; // We are in vector mode ie single query - if (q_pre.shape(2) == 1) { - const auto& q = copy_unless(is_contiguous, q_pre); - // 1, heads, seq_len, head_dim - // mask [1, query_heads, 1, seq_len] + if (q_pre.shape(2) <= 8) { + const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre); const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre); const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre); // Donate the query if possible - if (q.is_donatable() && q.size() == o.size()) { + if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) && + q.size() == o.size()) { o.move_shared_buffer(q); } else { - o.set_data(allocator::malloc_or_wait(o.nbytes())); + if (o.shape(2) == 1) { + o.set_data(allocator::malloc_or_wait(o.nbytes())); + } else { + auto strides = o.strides(); + strides[2] = o.shape(1) * o.shape(3); + strides[1] = o.shape(3); + auto flags = q.flags(); + flags.row_contiguous = q.shape(1) == 1; + o.set_data( + allocator::malloc_or_wait(o.nbytes()), + o.size(), + std::move(strides), + flags); + } } auto mask = diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index 46c2a9bea..b31fdf4f9 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -17,10 +17,10 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); - std::vector copies; + bool donate = inputs[0].is_donatable(); auto in = inputs[0]; if (in.flags().contiguous && in.strides()[axis_] != 0) { - if (in.is_donatable() && in.itemsize() == out.itemsize()) { + if (donate && in.itemsize() == out.itemsize()) { out.move_shared_buffer(in); } else { out.set_data( @@ -32,8 +32,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { } else { array arr_copy(in.shape(), in.dtype(), nullptr, {}); copy_gpu(in, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - in = arr_copy; + in = std::move(arr_copy); out.move_shared_buffer(in); } @@ -127,8 +126,6 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { MTL::Size group_dims(thread_group_size, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } - - d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index d23f8d33a..9ae9800a2 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -54,7 +54,7 @@ void start_capture(std::string) {} void stop_capture() {} void clear_cache() {} -std::unordered_map> +const std::unordered_map>& device_info() { throw std::runtime_error( "[metal::device_info] Cannot get device info without metal backend"); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 1967c018f..136c7796a 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -54,30 +54,34 @@ std::pair, std::vector> Custom::vmap( array rms_norm( const array& x, - const array& weight, + const std::optional& weight, float eps, StreamOrDevice s_ /* = {} */) { + bool has_weight = weight.has_value(); + if (x.ndim() == 0) { std::ostringstream msg; msg << "[rms_norm] Input must have at least 1 dimension but got input with " "0 dimensions."; throw std::invalid_argument(msg.str()); } - if (weight.ndim() != 1) { - std::ostringstream msg; - msg << "[rms_norm] weight must have 1 dimension but has " << weight.ndim() - << " dimensions."; - throw std::invalid_argument(msg.str()); - } - if (weight.size() != x.shape(-1)) { - std::ostringstream msg; - msg << "[rms_norm] weight must have the same size as the last dimension of" - " x but has " - << weight.size() << " elements."; - throw std::invalid_argument(msg.str()); + if (has_weight) { + if ((*weight).ndim() != 1) { + std::ostringstream msg; + msg << "[rms_norm] (*weight) must have 1 dimension but has " + << (*weight).ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if ((*weight).size() != x.shape(-1)) { + std::ostringstream msg; + msg << "[rms_norm] (*weight) must have the same size as the last dimension of" + " x but has " + << (*weight).size() << " elements."; + throw std::invalid_argument(msg.str()); + } } - auto out_type = result_type(x, weight); + auto out_type = (weight.has_value()) ? result_type(x, (*weight)) : x.dtype(); if (!issubdtype(out_type, floating)) { std::ostringstream msg; msg << "[rms_norm] Received unsupported type " << out_type << "."; @@ -85,27 +89,36 @@ array rms_norm( } auto s = to_stream(s_); - auto fallback = [eps, out_type, s](const std::vector& inputs) { - auto x = astype(inputs[0], float32, s); - x = multiply( - x, - rsqrt( - add(mean(square(x, s), -1, /* keepdims */ true, s), - array(eps, float32), + auto fallback = + [has_weight, eps, out_type, s](const std::vector& inputs) { + auto x = astype(inputs[0], float32, s); + x = multiply( + x, + rsqrt( + add(mean(square(x, s), -1, /* keepdims */ true, s), + array(eps, float32), + s), s), - s), - s); - x = astype(x, out_type, s); - return std::vector{multiply(inputs[1], x, s)}; - }; + s); + x = astype(x, out_type, s); + + if (has_weight) { + x = multiply(x, inputs[1], s); + } + + return std::vector{x}; + }; + + auto passed_weight = + (has_weight) ? astype(*weight, out_type, s) : array(1, out_type); if (s.device == Device::gpu) { return array( x.shape(), out_type, std::make_shared(s, fallback, eps), - {astype(x, out_type, s), astype(weight, out_type, s)}); + {astype(x, out_type, s), passed_weight}); } - return fallback({x, weight})[0]; + return fallback({x, passed_weight})[0]; } std::vector RMSNorm::vjp( @@ -141,8 +154,12 @@ std::vector RMSNorm::vjp( // df/dw std::vector axes(g.ndim() - 1); std::iota(axes.begin(), axes.end(), 0); - vjps.push_back( - sum(multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s)); + if (w.ndim() == 0) { + vjps.push_back(zeros_like(w, s)); + } else { + vjps.push_back(sum( + multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s)); + } return vjps; }; @@ -177,28 +194,30 @@ array layer_norm( const std::optional& bias, float eps, StreamOrDevice s_ /* = {} */) { + bool has_weight = weight.has_value(); + bool has_bias = bias.has_value(); + if (x.ndim() == 0) { std::ostringstream msg; msg << "[layer_norm] Input must have at least 1 dimension but got input with " "0 dimensions."; throw std::invalid_argument(msg.str()); } - if (weight.has_value() && (*weight).ndim() != 1) { + if (has_weight && (*weight).ndim() != 1) { std::ostringstream msg; msg << "[layer_norm] weight must have 1 dimension but has " << (*weight).ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } - if (bias.has_value() && (*bias).ndim() != 1) { + if (has_bias && (*bias).ndim() != 1) { std::ostringstream msg; msg << "[layer_norm] bias must have 1 dimension but has " << (*bias).ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } - auto out_type = (weight.has_value()) - ? ((bias.has_value()) ? result_type(x, *weight, *bias) - : result_type(x, *weight)) + auto out_type = (has_weight) + ? ((has_bias) ? result_type(x, *weight, *bias) : result_type(x, *weight)) : x.dtype(); if (!issubdtype(out_type, floating)) { std::ostringstream msg; @@ -207,8 +226,6 @@ array layer_norm( } auto s = to_stream(s_); - bool has_weight = weight.has_value(); - bool has_bias = bias.has_value(); auto fallback = [has_weight, has_bias, eps, out_type, s]( const std::vector& inputs) { auto x = astype(inputs[0], float32, s); @@ -234,9 +251,9 @@ array layer_norm( }; auto passed_weight = - astype((weight.has_value()) ? *weight : array(1, out_type), out_type); + (has_weight) ? astype(*weight, out_type, s) : array(1, out_type); auto passed_bias = - astype((bias.has_value()) ? *bias : array(0, out_type), out_type); + (has_bias) ? astype(*bias, out_type, s) : array(0, out_type); if (s.device == Device::gpu) { return array( @@ -698,7 +715,8 @@ array scaled_dot_product_attention( const bool supports_sdpa_full = query_sequence_length >= threshold && !mask && sdpa_full_supported_head_dim && stream.device == Device::gpu; - const bool supports_sdpa_vector = query_sequence_length == 1 && + const bool supports_sdpa_vector = (query_sequence_length <= 8) && + (query_sequence_length <= k.shape(-2)) && (!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim && stream.device == Device::gpu; @@ -809,14 +827,17 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { auto wshape = w.shape(); wshape.back() = -1; - array zero(0, w.dtype()); - array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1 - array eps(1e-7, w.dtype()); + array zero(0, float32); + array n_bins((1 << bits) - 1, float32); // 2**bits - 1 + array eps(1e-7, float32); array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s); array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); + w_max = astype(w_max, float32, s); + w_min = astype(w_min, float32, s); + array mask = greater(abs(w_min, s), abs(w_max, s), s); array scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); @@ -827,6 +848,9 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { array biases = where(equal(q0, zero, s), zero, edge, s); packed_w = pack_and_quantize(packed_w, scales, biases, bits, s); + + scales = astype(scales, w.dtype(), s); + biases = astype(biases, w.dtype(), s); return { reshape(packed_w, wshape, s), reshape(scales, wshape, s), diff --git a/mlx/fast.h b/mlx/fast.h index 9e6586cf6..fe93de85e 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -10,7 +10,7 @@ namespace mlx::core::fast { array rms_norm( const array& x, - const array& weight, + const std::optional& weight, float eps, StreamOrDevice s = {}); diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 356d39626..5b9b51ad3 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -102,8 +102,21 @@ inline array matrix_norm( dtype, s); } else if (ord == 2.0 || ord == -2.0) { - throw std::runtime_error( - "[linalg::norm] Singular value norms are not implemented."); + row_axis = (axis[0] < 0) ? axis[0] + a.ndim() : axis[0]; + col_axis = (axis[1] < 0) ? axis[1] + a.ndim() : axis[1]; + auto a_matrix = (row_axis > col_axis) + ? moveaxis(moveaxis(a, row_axis, -1, s), col_axis, -1, s) + : moveaxis(moveaxis(a, col_axis, -1, s), row_axis, -2, s); + a_matrix = svd(a_matrix, false, s).at(0); + a_matrix = (ord == 2.0) ? max(a_matrix, -1, false, s) + : min(a_matrix, -1, false, s); + if (keepdims) { + std::vector sorted_axes = (row_axis < col_axis) + ? std::vector{row_axis, col_axis} + : std::vector{col_axis, row_axis}; + a_matrix = expand_dims(a_matrix, sorted_axes, s); + } + return astype(a_matrix, dtype, s); } else { std::ostringstream msg; msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm."; @@ -120,8 +133,19 @@ inline array matrix_norm( if (ord == "f" || ord == "fro") { return l2_norm(a, axis, keepdims, s); } else if (ord == "nuc") { - throw std::runtime_error( - "[linalg::norm] Nuclear norm not yet implemented."); + int row_axis = (axis[0] < 0) ? axis[0] + a.ndim() : axis[0]; + int col_axis = (axis[1] < 0) ? axis[1] + a.ndim() : axis[1]; + auto a_matrix = (row_axis > col_axis) + ? moveaxis(moveaxis(a, row_axis, -1, s), col_axis, -1, s) + : moveaxis(moveaxis(a, col_axis, -1, s), row_axis, -2, s); + a_matrix = sum(svd(a_matrix, false, s).at(0), -1, false, s); + if (keepdims) { + std::vector sorted_axes = (row_axis < col_axis) + ? std::vector{row_axis, col_axis} + : std::vector{col_axis, row_axis}; + a_matrix = expand_dims(a_matrix, sorted_axes, s); + } + return a_matrix; } else { std::ostringstream msg; msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm."; @@ -214,7 +238,8 @@ std::pair qr(const array& a, StreamOrDevice s /* = {} */) { return std::make_pair(out[0], out[1]); } -std::vector svd(const array& a, StreamOrDevice s /* = {} */) { +std::vector +svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) { check_cpu_stream(s, "[linalg::svd]"); check_float(a.dtype(), "[linalg::svd]"); @@ -230,14 +255,22 @@ std::vector svd(const array& a, StreamOrDevice s /* = {} */) { const auto n = a.shape(-1); const auto rank = a.ndim(); - auto u_shape = a.shape(); - u_shape[rank - 2] = m; - u_shape[rank - 1] = m; - auto s_shape = a.shape(); s_shape.pop_back(); s_shape[rank - 2] = std::min(m, n); + if (!compute_uv) { + return {array( + std::move(s_shape), + std::move(a.dtype()), + std::make_shared(to_stream(s), compute_uv), + {a})}; + } + + auto u_shape = a.shape(); + u_shape[rank - 2] = m; + u_shape[rank - 1] = m; + auto vt_shape = a.shape(); vt_shape[rank - 2] = n; vt_shape[rank - 1] = n; @@ -245,7 +278,7 @@ std::vector svd(const array& a, StreamOrDevice s /* = {} */) { return array::make_arrays( {u_shape, s_shape, vt_shape}, {a.dtype(), a.dtype(), a.dtype()}, - std::make_shared(to_stream(s)), + std::make_shared(to_stream(s), compute_uv), {a}); } @@ -323,7 +356,7 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) { int m = a.shape(-2); int n = a.shape(-1); int k = std::min(m, n); - auto outs = linalg::svd(a, s); + auto outs = linalg::svd(a, true, s); array U = outs[0]; array S = outs[1]; array V = outs[2]; diff --git a/mlx/linalg.h b/mlx/linalg.h index 9fe4dbf60..8c3a2070a 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -62,7 +62,11 @@ norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) { std::pair qr(const array& a, StreamOrDevice s = {}); -std::vector svd(const array& a, StreamOrDevice s = {}); +std::vector +svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */); +inline std::vector svd(const array& a, StreamOrDevice s = {}) { + return svd(a, true, s); +} array inv(const array& a, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index ac4c17938..60c13b2c9 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -4940,7 +4940,8 @@ std::pair, std::vector> SVD::vmap( const std::vector& axes) { auto ax = axes[0] >= 0 ? 0 : -1; auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; - return {{linalg::svd(a, stream())}, {ax, ax, ax}}; + std::vector new_axes(compute_uv_ ? 3 : 1, ax); + return {linalg::svd(a, compute_uv_, stream()), std::move(new_axes)}; } std::pair, std::vector> Inverse::vmap( diff --git a/mlx/primitives.h b/mlx/primitives.h index a73ddef96..c2c0576aa 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2287,7 +2287,8 @@ class QRF : public Primitive { /* SVD primitive. */ class SVD : public Primitive { public: - explicit SVD(Stream stream) : Primitive(stream) {} + explicit SVD(Stream stream, bool compute_uv) + : Primitive(stream), compute_uv_(compute_uv) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; @@ -2296,6 +2297,12 @@ class SVD : public Primitive { DEFINE_VMAP() DEFINE_PRINT(SVD) + auto state() const { + return compute_uv_; + } + + private: + bool compute_uv_; }; /* Matrix inversion primitive. */ diff --git a/mlx/random.cpp b/mlx/random.cpp index a4755605c..d6ce5bb0e 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -244,7 +244,7 @@ array multivariate_normal( // Compute the square-root of the covariance matrix, using the SVD auto covariance = astype(cov, float32, stream); - auto SVD = linalg::svd(covariance, stream); + auto SVD = linalg::svd(covariance, true, stream); auto std = astype( matmul( multiply( diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 2c1d27c6e..1a749beed 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -14,6 +14,8 @@ import time from collections import Counter from dataclasses import dataclass from pathlib import Path +from queue import Empty as QueueEmpty +from queue import Queue from select import select from subprocess import PIPE, Popen, run from typing import Optional @@ -185,46 +187,54 @@ def parse_hostlist(parser, hostlist, repeats): def make_monitor_script(rank, hostfile, cwd, env, command, verbose): + # Imports that are used throughout script = "" + script += "import os\n" + script += "import sys\n" + script += "import tempfile\n" + script += "from pathlib import Path\n" # Write the PID to a file so we can kill the process if needed - script += "pidfile=$(mktemp)\n" - script += "echo $$ >$pidfile\n" - script += "echo $pidfile\n" + script += "_, pidfile = tempfile.mkstemp() \n" + script += "open(pidfile, 'w').write(str(os.getpid()))\n" + script += "print(pidfile, flush=True)\n" # Change the working directory if one was requested. Otherwise attempt to - # change to change to the current one but don't fail if it wasn't possible. + # change to the current one but don't fail if it wasn't possible. d = cwd or os.getcwd() - script += f"if [ -d {shlex.quote(d)} ]; then\n" - script += f" cd {shlex.quote(d)}\n" + script += f"if Path({repr(d)}).exists():\n" + script += f" os.chdir({repr(d)})\n" if cwd is not None: - script += "else\n" - script += f" echo Failed to change directory to {shlex.quote(d)} 1>&2\n" - script += f" exit 1\n" - script += "fi\n" + script += "else:\n" + script += ( + f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n" + ) + script += f" sys.exit(1)\n" # Add the environment variables that were given to us + script += "env = dict(os.environ)\n" for e in env: key, *value = e.split("=", maxsplit=1) value = shlex.quote(value[0]) if len(value) > 0 else "" if not all(c.isalnum() or c == "_" for c in key): log_warning(f"'{e}' is an invalid environment variable so it is ignored") continue - script += f"export {key}={value}\n" + script += f"env[{repr(key)}] = {repr(value)}\n" # Add the environment variables to enable the ring distributed backend if hostfile != "": - script += "tmpfile=$(mktemp)\n" - script += f"echo {shlex.quote(hostfile)} >$tmpfile\n" + script += "_, hostfile = tempfile.mkstemp()\n" + script += "with open(hostfile, 'w') as f:\n" + script += f" f.write({repr(hostfile)})\n" if verbose: - script += "export MLX_RING_VERBOSE=1\n" - script += "export MLX_HOSTFILE=$tmpfile\n" - script += f"export MLX_RANK={rank}\n" + script += "env['MLX_RING_VERBOSE'] = '1'\n" + script += "env['MLX_HOSTFILE'] = hostfile\n" + script += f"env['MLX_RANK'] = '{rank}'\n" script += "\n" # Replace the process with the script - script += shlex.join(["exec", *command]) - script += "\n" + script += f"command = [{','.join(map(repr, command))}]\n" + script += "os.execve(command[0], command, env)\n" return script @@ -233,28 +243,37 @@ def launch_ring(parser, hosts, args, command): stop = False exit_codes = [None] * len(hosts) - def node_thread(rank, host, hostfile): + def node_thread(rank, host, hostfile, input_queue): is_local = host == "127.0.0.1" script = make_monitor_script( rank, hostfile, args.cwd, args.env, command, args.verbose ) script_b64 = base64.b64encode(script.encode()).decode() - cmd = f'echo "{script_b64}" | base64 -d | /bin/bash' + cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"' if not is_local: cmd = f"ssh {host} '{cmd}'" p = Popen( cmd, shell=True, + stdin=PIPE, stdout=PIPE, stderr=PIPE, ) os.set_blocking(p.stdout.fileno(), False) os.set_blocking(p.stderr.fileno(), False) + os.set_blocking(p.stdin.fileno(), False) # Repeat the stdout and stderr to the local machine + to_read = [p.stdout.fileno(), p.stderr.fileno()] + to_write = [p.stdin.fileno()] pidfile = "" + stdin_buffer = b"" while p.poll() is None: - rlist, _, _ = select([p.stdout.fileno(), p.stderr.fileno()], [], [], 1.0) + try: + stdin_buffer += input_queue.get_nowait() + except QueueEmpty: + pass + rlist, wlist, _ = select(to_read, to_write, [], 1.0) for fd in rlist: is_stdout = fd == p.stdout.fileno() outfile = sys.stdout if is_stdout else sys.stderr @@ -266,6 +285,11 @@ def launch_ring(parser, hosts, args, command): msg = msg[0] if msg else "" outfile.write(msg) + outfile.flush() + for fd in wlist: + if len(stdin_buffer) > 0: + n = os.write(fd, stdin_buffer) + stdin_buffer = stdin_buffer[n:] if stop: p.terminate() break @@ -310,16 +334,25 @@ def launch_ring(parser, hosts, args, command): log(args.verbose, "Running", shlex.join(command)) + input_queues = [] threads = [] for i, h in enumerate(hosts): if i + 1 == len(hosts): time.sleep(1.0) - t = threading.Thread(target=node_thread, args=(i, h.ssh_hostname, hostfile)) + input_queues.append(Queue()) + t = threading.Thread( + target=node_thread, args=(i, h.ssh_hostname, hostfile, input_queues[-1]) + ) t.start() threads.append(t) + os.set_blocking(sys.stdin.fileno(), False) while not stop: - time.sleep(1.0) + rlist, _, _ = select([sys.stdin.fileno()], [], [], 1.0) + for fd in rlist: + stdin_buffer = os.read(fd, 8192) + for q in input_queues: + q.put(stdin_buffer) if any(t.is_alive() for t in threads): for i, t in enumerate(threads): if not t.is_alive(): @@ -730,6 +763,8 @@ def main(): if len(rest) == 0: parser.error("No script is provided") + if rest[0] == "--": + rest.pop(0) # Try to extract a list of hosts and corresponding ips if args.hostfile is not None: diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index b3b701a72..9d0772a57 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -5,7 +5,7 @@ from typing import Callable, List, Optional, Tuple, Union import mlx.core as mx from mlx.nn import Module -from mlx.utils import tree_map, tree_reduce +from mlx.utils import tree_flatten, tree_map, tree_merge, tree_reduce, tree_unflatten class Optimizer: @@ -154,6 +154,79 @@ class Optimizer: self.state[name] = param +class MultiOptimizer(Optimizer): + """Wraps a list of optimizers with corresponding weight predicates/filters + to make it easy to use different optimizers for different weights. + + The predicates take the full "path" of the weight and the weight itself and + return True if it should be considered for this optimizer. The last + optimizer in the list is a fallback optimizer and no predicate should be + given for it. + + Args: + optimizers (list[Optimizer]): A list of optimizers to delegate to + filters (list[Callable[[str, array], bool]): A list of predicates that + should be one less than the provided optimizers. + """ + + def __init__(self, optimizers, filters: list = []): + super().__init__() + self._state = {} + + if len(filters) != len(optimizers) - 1: + raise ValueError( + f"Given {len(filters)} filters but {len(optimizers)-1} needed." + ) + + self.optimizers = optimizers + self.filters = filters + [lambda *args, **kwargs: True] + + def _split_dictionary(self, gradients: dict): + if len(self.optimizers) == 1: + return [gradients] + + parts = [[] for _ in range(len(self.optimizers))] + flat_gradients = tree_flatten(gradients) + for k, g in flat_gradients: + for i, fn in enumerate(self.filters): + if fn(k, g): + parts[i].append((k, g)) + break + + return [tree_unflatten(p) for p in parts] + + def init(self, parameters: dict): + for o, p in zip(self.optimizers, self._split_dictionary(parameters)): + o.init(p) + + def apply_gradients(self, gradients: dict, parameters: dict): + tree = {} + for o, g in zip(self.optimizers, self._split_dictionary(gradients)): + tree = tree_merge(tree, o.apply_gradients(g, parameters)) + return tree + + @property + def state(self): + return {"states": [o.state for o in self.optimizers]} + + @state.setter + def state(self, state: dict): + if "states" not in state or len(state["states"]) != len(self.optimizers): + raise ValueError("Invalid state provided") + + for o, s in zip(self.optimizers, state["states"]): + o.state = s + + @property + def learning_rate(self): + return self.optimizers[0].learning_rate + + @learning_rate.setter + def learning_rate(self, learning_rate: Union[float, mx.array]): + for o in self.optimizers: + o.learning_rate = learning_rate + + class SGD(Optimizer): r"""The stochastic gradient descent optimizer. diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 6754232a6..2a3c1e660 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,5 +1,6 @@ # Copyright © 2023 Apple Inc. from collections import defaultdict +from itertools import zip_longest from typing import Any, Callable, List, Optional, Tuple @@ -244,3 +245,46 @@ def tree_reduce(fn, tree, initializer=None, is_leaf=None): return tree if accumulator is None else fn(accumulator, tree) return accumulator + + +def tree_merge(tree_a, tree_b, merge_fn=None): + """Merge two Python trees in one containing the values of both. It can be + thought of as a deep dict.update method. + + Args: + tree_a (Any): The first Python tree. + tree_b (Any): The second Python tree. + merge_fn (callable, optional): A function to merge leaves. + + Returns: + The Python tree containing the values of both ``tree_a`` and + ``tree_b``. + """ + if isinstance(tree_a, (dict, list, tuple)) and len(tree_a) == 0: + tree_a = None + if isinstance(tree_b, (dict, list, tuple)) and len(tree_b) == 0: + tree_b = None + if tree_a is None and tree_b is not None: + return tree_b + if tree_a is not None and tree_b is None: + return tree_a + + if isinstance(tree_a, (list, tuple)) and isinstance(tree_b, (list, tuple)): + TreeType = type(tree_a) + return TreeType( + tree_merge(a, b, merge_fn) for a, b in zip_longest(tree_a, tree_b) + ) + elif isinstance(tree_a, dict) and isinstance(tree_b, dict): + return { + k: tree_merge(tree_a.get(k, None), tree_b.get(k, None), merge_fn) + for k in set(tree_a.keys()) | set(tree_b.keys()) + } + else: + if merge_fn is None: + raise ValueError( + ( + "Trees contain elements at the same locations but no merge " + "function was provided" + ) + ) + return merge_fn(tree_a, tree_b) diff --git a/python/src/fast.cpp b/python/src/fast.cpp index d7ccc000b..fc2cbd41d 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -25,12 +25,12 @@ void init_fast(nb::module_& parent_module) { "rms_norm", &mx::fast::rms_norm, "x"_a, - "weight"_a, + "weight"_a.none(), "eps"_a, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def rms_norm(x: array, weight: array, eps: float, *, stream: Union[None, Stream, Device] = None) -> array"), + "def rms_norm(x: array, weight: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Root Mean Square normalization (RMS norm). @@ -38,9 +38,9 @@ void init_fast(nb::module_& parent_module) { Args: x (array): Input array. - weight (array): A multiplicative weight to scale the result by. + weight (array, optional): A multiplicative weight to scale the result by. The ``weight`` should be one-dimensional with the same size - as the last axis of ``x``. + as the last axis of ``x``. If set to ``None`` then no scaling happens. eps (float): A small additive constant for numerical stability. Returns: diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index a43cebbe7..3bc0e5b1b 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -92,6 +92,7 @@ void init_linalg(nb::module_& parent_module) { ===== ============================ ========================== None Frobenius norm 2-norm 'fro' Frobenius norm -- + 'nuc' nuclear norm -- inf max(sum(abs(x), axis=1)) max(abs(x)) -inf min(sum(abs(x), axis=1)) min(abs(x)) 0 -- sum(x != 0) @@ -102,9 +103,6 @@ void init_linalg(nb::module_& parent_module) { other -- sum(abs(x)**ord)**(1./ord) ===== ============================ ========================== - .. warning:: - Nuclear norm and norms based on singular values are not yet implemented. - The Frobenius norm is given by [1]_: :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}` @@ -206,15 +204,22 @@ void init_linalg(nb::module_& parent_module) { )pbdoc"); m.def( "svd", - [](const mx::array& a, mx::StreamOrDevice s /* = {} */) { - const auto result = mx::linalg::svd(a, s); - return nb::make_tuple(result.at(0), result.at(1), result.at(2)); + [](const mx::array& a, + bool compute_uv /* = true */, + mx::StreamOrDevice s /* = {} */) -> nb::object { + const auto result = mx::linalg::svd(a, compute_uv, s); + if (result.size() == 1) { + return nb::cast(result.at(0)); + } else { + return nb::make_tuple(result.at(0), result.at(1), result.at(2)); + } }, "a"_a, + "compute_uv"_a = true, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"), + "def svd(a: array, compute_uv: bool = True, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"), R"pbdoc( The Singular Value Decomposition (SVD) of the input matrix. @@ -224,12 +229,15 @@ void init_linalg(nb::module_& parent_module) { Args: a (array): Input array. + compute_uv (bool, optional): If ``True``, return the ``U``, ``S``, and ``Vt`` components. + If ``False``, return only the ``S`` array. Default: ``True``. stream (Stream, optional): Stream or device. Defaults to ``None`` in which case the default stream of the default device is used. Returns: - tuple(array, array, array): The ``U``, ``S``, and ``Vt`` matrices, such that - ``A = U @ diag(S) @ Vt`` + Union[tuple(array, ...), array]: + If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that + ``A = U @ diag(S) @ Vt``. If compute_uv is ``False`` returns singular values array ``S``. )pbdoc"); m.def( "inv", diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 2aa8b067c..2c90a3755 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -298,6 +298,9 @@ class TestFast(mlx_tests.MLXTestCase): rx = rms_norm(x, weight, eps) rx_fast = mx.fast.rms_norm(x, weight, eps) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = rms_norm(x, mx.ones_like(weight), eps) + rx_fast = mx.fast.rms_norm(x, None, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) for eps in epss: dtype, _, dims = defaults @@ -306,6 +309,9 @@ class TestFast(mlx_tests.MLXTestCase): rx = rms_norm(x, weight, eps) rx_fast = mx.fast.rms_norm(x, weight, eps) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = rms_norm(x, mx.ones_like(weight), eps) + rx_fast = mx.fast.rms_norm(x, None, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) for dims in dimss: dtype, eps, _ = defaults @@ -314,6 +320,9 @@ class TestFast(mlx_tests.MLXTestCase): rx = rms_norm(x, weight, eps) rx_fast = mx.fast.rms_norm(x, weight, eps) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = rms_norm(x, mx.ones_like(weight), eps) + rx_fast = mx.fast.rms_norm(x, None, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) # Test > 4096 dims, dtype, eps = 4099, mx.float32, 1e-5 @@ -333,6 +342,8 @@ class TestFast(mlx_tests.MLXTestCase): eps = 1e-5 f1 = lambda x, w, y: (rms_norm(x, w, eps) * y).sum() f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, eps) * y).sum() + f3 = lambda x, y: (rms_norm(x, mx.ones((x.shape[-1],)), eps) * y).sum() + f4 = lambda x, y: (mx.fast.rms_norm(x, None, eps) * y).sum() x = mx.random.uniform(shape=(8, 100, D)) w = mx.random.uniform(shape=(D,)) @@ -341,6 +352,9 @@ class TestFast(mlx_tests.MLXTestCase): gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y) self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5) self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5) + gx1 = mx.grad(f3, argnums=(0,))(x, y) + gx2 = mx.grad(f4, argnums=(0,))(x, y) + self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5) D = 8192 x = mx.random.uniform(shape=(2, 2, D)) @@ -350,6 +364,9 @@ class TestFast(mlx_tests.MLXTestCase): gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y) self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5) self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5) + gx1 = mx.grad(f3, argnums=(0,))(x, y) + gx2 = mx.grad(f4, argnums=(0,))(x, y) + self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5) def gf(f): def inner(x, w, y): diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 348ba4c88..9baee4fb1 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -262,6 +262,61 @@ class TestFastSDPA(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + def test_fast_sdpa_few_query(self): + D = 64 + L = 43 + Lq = 4 + Nq = 8 + Nkv = 1 + scale = 1.0 + mx.random.seed(0) + q = 5e-1 * mx.random.normal(shape=(1, Lq, Nq, D)) + q = q.swapaxes(1, 2) + k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + + masks = [ + mx.array(True), + mx.array([True] * (L - 10) + [False] * 10), + mx.random.uniform(shape=(Nq, 1, L)) > 0.2, + mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, + ] + for m in masks: + ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) + out = mx.fast.scaled_dot_product_attention( + q, + k, + v, + scale=scale, + mask=m, + ) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + return + L = 4096 + scale = 1.0 + mx.random.seed(0) + q = 5e-1 * mx.random.normal(shape=(1, Nq, Lq, D)) + k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + + masks = [ + mx.array(True), + mx.array([True] * (L - 10) + [False] * 10), + mx.random.uniform(shape=(Nq, 1, L)) > 0.2, + mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, + ] + for m in masks: + ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) + out = mx.fast.scaled_dot_product_attention( + q, + k, + v, + scale=scale, + mask=m, + ) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + @unittest.skip("Different head and value dims is not enabled") def test_fast_sdpa_vector_value_dims(self): D = 192 diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index adc365c62..ffa355c10 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -12,11 +12,11 @@ import numpy as np class TestLinalg(mlx_tests.MLXTestCase): def test_norm(self): vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")] - matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")] + matrix_ords = [None, "fro", "nuc", -1, 1, -2, 2, float("inf"), -float("inf")] for shape in [(3,), (2, 3), (2, 3, 3)]: - x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) - x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) + x_mx = mx.arange(1, math.prod(shape) + 1, dtype=mx.float32).reshape(shape) + x_np = np.arange(1, math.prod(shape) + 1, dtype=np.float32).reshape(shape) # Test when at least one axis is provided for num_axes in range(1, len(shape)): if num_axes == 1: @@ -26,11 +26,14 @@ class TestLinalg(mlx_tests.MLXTestCase): for axis in itertools.combinations(range(len(shape)), num_axes): for keepdims in [True, False]: for o in ords: + stream = ( + mx.cpu if o in ["nuc", -2, 2] else mx.default_device() + ) out_np = np.linalg.norm( x_np, ord=o, axis=axis, keepdims=keepdims ) out_mx = mx.linalg.norm( - x_mx, ord=o, axis=axis, keepdims=keepdims + x_mx, ord=o, axis=axis, keepdims=keepdims, stream=stream ) with self.subTest( shape=shape, ord=o, axis=axis, keepdims=keepdims @@ -133,20 +136,38 @@ class TestLinalg(mlx_tests.MLXTestCase): def test_svd_decomposition(self): A = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float32) - U, S, Vt = mx.linalg.svd(A, stream=mx.cpu) + U, S, Vt = mx.linalg.svd(A, compute_uv=True, stream=mx.cpu) self.assertTrue( mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A, rtol=1e-5, atol=1e-7) ) + S = mx.linalg.svd(A, compute_uv=False, stream=mx.cpu) + self.assertTrue( + mx.allclose( + mx.linalg.norm(S), mx.linalg.norm(A, ord="fro"), rtol=1e-5, atol=1e-7 + ) + ) + # Multiple matrices B = A + 10.0 AB = mx.stack([A, B]) - Us, Ss, Vts = mx.linalg.svd(AB, stream=mx.cpu) + Us, Ss, Vts = mx.linalg.svd(AB, compute_uv=True, stream=mx.cpu) for M, U, S, Vt in zip([A, B], Us, Ss, Vts): self.assertTrue( mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7) ) + Ss = mx.linalg.svd(AB, compute_uv=False, stream=mx.cpu) + for M, S in zip([A, B], Ss): + self.assertTrue( + mx.allclose( + mx.linalg.norm(S), + mx.linalg.norm(M, ord="fro"), + rtol=1e-5, + atol=1e-7, + ) + ) + def test_inverse(self): A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32) A_inv = mx.linalg.inv(A, stream=mx.cpu) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index f5078afc0..8e1cd8efd 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1894,6 +1894,22 @@ class TestOps(mlx_tests.MLXTestCase): expected = mx.repeat(expected[:, None], 2, axis=1) self.assertTrue(mx.array_equal(expected, out)) + # Test donation + def fn(its): + x = mx.ones((32,)) + for _ in range(its): + x = mx.cumsum(x) + return x + + mx.synchronize(mx.default_stream(mx.default_device())) + mx.eval(fn(2)) + mx.synchronize(mx.default_stream(mx.default_device())) + mem2 = mx.metal.get_peak_memory() + mx.eval(fn(4)) + mx.synchronize(mx.default_stream(mx.default_device())) + mem4 = mx.metal.get_peak_memory() + self.assertEqual(mem2, mem4) + def test_squeeze_expand(self): a = mx.zeros((2, 1, 2, 1)) self.assertEqual(mx.squeeze(a).shape, (2, 2)) @@ -2846,6 +2862,11 @@ class TestOps(mlx_tests.MLXTestCase): b[::2] = 0 self.assertTrue(mx.array_equal(b, mx.array([0, 3, 0, 1]))) + def test_slice_with_negative_stride(self): + a = mx.random.uniform(shape=(128, 4)) + out = a[::-1] + self.assertTrue(mx.array_equal(out[-1, :], a[0, :])) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index cf3a2b4fa..ebfe97d80 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -39,6 +39,7 @@ def tree_equal(fn, *args): optimizers_dict = get_all_optimizers() +del optimizers_dict["MultiOptimizer"] class TestOptimizers(mlx_tests.MLXTestCase): @@ -500,6 +501,30 @@ class TestSchedulers(unittest.TestCase): grads = model.trainable_parameters() optimizer.update(model, grads) + def test_multi_optimizer(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.l1 = nn.Linear(2, 2) + self.drop = nn.Dropout(p=0.5) + self.l2 = nn.Linear(2, 2) + self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()] + + model = Model() + optimizer = opt.MultiOptimizer( + [opt.Adam(learning_rate=0.001), opt.SGD(learning_rate=0.1)], + [lambda name, weight: weight.ndim > 1], + ) + optimizer.init(model.trainable_parameters()) + + self.assertEqual(len(optimizer.state["states"]), 2) + + adam_states = tree_flatten(optimizer.state["states"][0]) + sgd_states = tree_flatten(optimizer.state["states"][1]) + self.assertEqual((len(sgd_states) - 2) * 2, len(adam_states) - 2) + self.assertFalse(any("bias" in k for k, v in adam_states)) + self.assertFalse(any("weight" in k for k, v in sgd_states)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_tree.py b/python/tests/test_tree.py index cab137b78..63018fdae 100644 --- a/python/tests/test_tree.py +++ b/python/tests/test_tree.py @@ -3,6 +3,7 @@ import unittest import mlx.core as mx +import mlx.nn as nn import mlx.utils import mlx_tests @@ -22,6 +23,29 @@ class TestTreeUtils(mlx_tests.MLXTestCase): self.assertEqual(list(zip(*flat_tree))[1], vals) self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree) + def test_merge(self): + t1 = {"a": 0} + t2 = {"b": 1} + t = mlx.utils.tree_merge(t1, t2) + self.assertEqual({"a": 0, "b": 1}, t) + with self.assertRaises(ValueError): + mlx.utils.tree_merge(t1, t1) + with self.assertRaises(ValueError): + mlx.utils.tree_merge(t, t1) + + mod1 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) + mod2 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) + mod = nn.Sequential(mod1, mod2) + + params1 = {"layers": [mod1.parameters()]} + params2 = {"layers": [None, mod2.parameters()]} + params = mlx.utils.tree_merge(params1, params2) + for (k1, v1), (k2, v2) in zip( + mlx.utils.tree_flatten(params), mlx.utils.tree_flatten(mod.parameters()) + ): + self.assertEqual(k1, k2) + self.assertTrue(mx.array_equal(v1, v2)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 2d38bc457..2eee33b5c 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -316,35 +316,59 @@ class TestVmap(mlx_tests.MLXTestCase): def test_vmap_svd(self): a = mx.random.uniform(shape=(3, 4, 2)) - cpu_svd = lambda x: mx.linalg.svd(x, stream=mx.cpu) + cpu_svd_full = lambda x: mx.linalg.svd(x, compute_uv=True, stream=mx.cpu) + cpu_svd_singular = lambda x: mx.linalg.svd(x, compute_uv=False, stream=mx.cpu) # Vmap over the first axis (this is already supported natively by the primitive). - Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(0,))(a) + Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(0,))(a) self.assertEqual(Us.shape, (a.shape[0], a.shape[1], a.shape[1])) self.assertEqual(Ss.shape, (a.shape[0], a.shape[2])) self.assertEqual(Vts.shape, (a.shape[0], a.shape[2], a.shape[2])) + Sv = mx.vmap(cpu_svd_singular, in_axes=(0,))(a) + self.assertEqual(Sv.shape, (a.shape[0], a.shape[2])) + for i in range(a.shape[0]): M = a[i] U, S, Vt = Us[i], Ss[i], Vts[i] self.assertTrue( mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7) ) + self.assertTrue( + mx.allclose( + mx.linalg.norm(Sv[i]), + mx.linalg.norm(M, ord="fro"), + rtol=1e-5, + atol=1e-7, + ) + ) # Vmap over the second axis. - Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(1,))(a) + Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(1,))(a) self.assertEqual(Us.shape, (a.shape[1], a.shape[0], a.shape[0])) self.assertEqual(Ss.shape, (a.shape[1], a.shape[2])) self.assertEqual(Vts.shape, (a.shape[1], a.shape[2], a.shape[2])) + Sv = mx.vmap(cpu_svd_singular, in_axes=(1,))(a) + self.assertEqual(Sv.shape, (a.shape[1], a.shape[2])) + for i in range(a.shape[1]): M = a[:, i, :] U, S, Vt = Us[i], Ss[i], Vts[i] self.assertTrue( mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7) ) + self.assertTrue( + mx.allclose( + mx.linalg.norm(Sv[i]), + mx.linalg.norm(M, ord="fro"), + rtol=1e-5, + atol=1e-7, + ) + ) def test_vmap_inverse(self): + mx.random.seed(42) a = mx.random.uniform(shape=(3, 4, 4)) cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu) diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index b2465c29a..0660a69fe 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -100,7 +100,7 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") { norm(x, -std::numeric_limits::infinity()).item(), doctest::Approx(expected)); - x = reshape(arange(9), {3, 3}); + x = reshape(arange(9, float32), {3, 3}); CHECK(allclose( norm(x, 2.0, 0, false), @@ -129,10 +129,34 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") { CHECK_EQ( norm(x, -1.0, std::vector{1, 0}).item(), doctest::Approx(3.0)); + CHECK_EQ( + norm(x, 2.0, std::vector{0, 1}, false, Device::cpu).item(), + doctest::Approx(14.226707)); + CHECK_EQ( + norm(x, 2.0, std::vector{1, 0}, false, Device::cpu).item(), + doctest::Approx(14.226707)); + CHECK_EQ( + norm(x, -2.0, std::vector{0, 1}, false, Device::cpu).item(), + doctest::Approx(0.0)); + CHECK_EQ( + norm(x, -2.0, std::vector{1, 0}, false, Device::cpu).item(), + doctest::Approx(0.0)); CHECK_EQ(norm(x, 1.0, std::vector{0, 1}, true).shape(), Shape{1, 1}); CHECK_EQ(norm(x, 1.0, std::vector{1, 0}, true).shape(), Shape{1, 1}); CHECK_EQ(norm(x, -1.0, std::vector{0, 1}, true).shape(), Shape{1, 1}); CHECK_EQ(norm(x, -1.0, std::vector{1, 0}, true).shape(), Shape{1, 1}); + CHECK_EQ( + norm(x, 2.0, std::vector{0, 1}, true, Device::cpu).shape(), + Shape{1, 1}); + CHECK_EQ( + norm(x, 2.0, std::vector{1, 0}, true, Device::cpu).shape(), + Shape{1, 1}); + CHECK_EQ( + norm(x, -2.0, std::vector{0, 1}, true, Device::cpu).shape(), + Shape{1, 1}); + CHECK_EQ( + norm(x, -2.0, std::vector{1, 0}, true, Device::cpu).shape(), + Shape{1, 1}); CHECK_EQ( norm(x, -1.0, std::vector{-2, -1}, false).item(), @@ -140,8 +164,14 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") { CHECK_EQ( norm(x, 1.0, std::vector{-2, -1}, false).item(), doctest::Approx(15.0)); + CHECK_EQ( + norm(x, -2.0, std::vector{-2, -1}, false, Device::cpu).item(), + doctest::Approx(0.0)); + CHECK_EQ( + norm(x, 2.0, std::vector{-2, -1}, false, Device::cpu).item(), + doctest::Approx(14.226707)); - x = reshape(arange(18), {2, 3, 3}); + x = reshape(arange(18, float32), {2, 3, 3}); CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2})); CHECK(allclose( norm(x, 3.0, 0), @@ -199,13 +229,31 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") { .item()); CHECK(allclose(norm(x, -1.0, std::vector{1, 2}), array({9, 36})) .item()); + CHECK(allclose( + norm(x, 2.0, std::vector{0, 1}, false, Device::cpu), + array({22.045408, 24.155825, 26.318918})) + .item()); + CHECK(allclose( + norm(x, 2.0, std::vector{1, 2}, false, Device::cpu), + array({14.226707, 39.759212})) + .item()); + CHECK(allclose( + norm(x, -2.0, std::vector{0, 1}, false, Device::cpu), + array({3, 2.7378995, 2.5128777})) + .item()); + CHECK(allclose( + norm(x, -2.0, std::vector{1, 2}, false, Device::cpu), + array({4.979028e-16, 7.009628e-16}), + /* rtol = */ 1e-5, + /* atol = */ 1e-6) + .item()); } TEST_CASE("[mlx.core.linalg.norm] string ord") { array x({1, 2, 3}); CHECK_THROWS(norm(x, "fro")); - x = reshape(arange(9), {3, 3}); + x = reshape(arange(9, float32), {3, 3}); CHECK_THROWS(norm(x, "bad ord")); CHECK_EQ( @@ -214,8 +262,11 @@ TEST_CASE("[mlx.core.linalg.norm] string ord") { CHECK_EQ( norm(x, "fro", std::vector{0, 1}).item(), doctest::Approx(14.2828568570857)); + CHECK_EQ( + norm(x, "nuc", std::vector{0, 1}, false, Device::cpu).item(), + doctest::Approx(15.491934)); - x = reshape(arange(18), {2, 3, 3}); + x = reshape(arange(18, float32), {2, 3, 3}); CHECK(allclose( norm(x, "fro", std::vector{0, 1}), array({22.24859546, 24.31049156, 26.43860813})) @@ -240,6 +291,18 @@ TEST_CASE("[mlx.core.linalg.norm] string ord") { norm(x, "f", std::vector{2, 1}), array({14.28285686, 39.7617907})) .item()); + CHECK(allclose( + norm(x, "nuc", std::vector{0, 1}, false, Device::cpu), + array({25.045408, 26.893724, 28.831797})) + .item()); + CHECK(allclose( + norm(x, "nuc", std::vector{1, 2}, false, Device::cpu), + array({15.491934, 40.211937})) + .item()); + CHECK(allclose( + norm(x, "nuc", std::vector{-2, -1}, false, Device::cpu), + array({15.491934, 40.211937})) + .item()); } TEST_CASE("test QR factorization") { @@ -271,7 +334,7 @@ TEST_CASE("test SVD factorization") { const auto prng_key = random::key(42); const auto A = mlx::core::random::normal({5, 4}, prng_key); - const auto outs = linalg::svd(A, Device::cpu); + const auto outs = linalg::svd(A, true, Device::cpu); CHECK_EQ(outs.size(), 3); const auto& U = outs[0]; @@ -291,6 +354,15 @@ TEST_CASE("test SVD factorization") { CHECK_EQ(U.dtype(), float32); CHECK_EQ(S.dtype(), float32); CHECK_EQ(Vt.dtype(), float32); + + // Test singular values + const auto& outs_sv = linalg::svd(A, false, Device::cpu); + const auto SV = outs_sv[0]; + + CHECK_EQ(SV.shape(), Shape{4}); + CHECK_EQ(SV.dtype(), float32); + + CHECK(allclose(norm(SV), norm(A, "fro")).item()); } TEST_CASE("test matrix inversion") { diff --git a/tests/vmap_tests.cpp b/tests/vmap_tests.cpp index 38011b942..2a2a28571 100644 --- a/tests/vmap_tests.cpp +++ b/tests/vmap_tests.cpp @@ -466,15 +466,19 @@ TEST_CASE("test vmap scatter") { } TEST_CASE("test vmap SVD") { - auto fun = [](std::vector inputs) { - return linalg::svd(inputs.at(0), Device::cpu); + auto svd_full = [](std::vector inputs) { + return linalg::svd(inputs.at(0), true, Device::cpu); + }; + + auto svd_singular = [](std::vector inputs) { + return linalg::svd(inputs.at(0), false, Device::cpu); }; auto a = astype(reshape(arange(24), {3, 4, 2}), float32); // vmap over the second axis. { - auto out = vmap(fun, /* in_axes = */ {1})({a}); + auto out = vmap(svd_full, /* in_axes = */ {1})({a}); const auto& U = out.at(0); const auto& S = out.at(1); const auto& Vt = out.at(2); @@ -486,7 +490,7 @@ TEST_CASE("test vmap SVD") { // vmap over the third axis. { - auto out = vmap(fun, /* in_axes = */ {2})({a}); + auto out = vmap(svd_full, /* in_axes = */ {2})({a}); const auto& U = out.at(0); const auto& S = out.at(1); const auto& Vt = out.at(2); @@ -495,6 +499,21 @@ TEST_CASE("test vmap SVD") { CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)}); CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)}); } + + // test singular values + { + auto out = vmap(svd_singular, /* in_axes = */ {1})({a}); + const auto& S = out.at(0); + + CHECK_EQ(S.shape(), Shape{a.shape(1), a.shape(2)}); + } + + { + auto out = vmap(svd_singular, /* in_axes = */ {2})({a}); + const auto& S = out.at(0); + + CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)}); + } } TEST_CASE("test vmap dynamic slices") {