diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index c9969f8d6..9e3b27532 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -8,7 +8,7 @@ with a short description of your contribution(s) below. For example: MLX was developed with contributions from the following individuals: - Juarez Bochi: Fixed bug in cross attention. -- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, linear and logistic regression python example. +- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. # Third-Party Software diff --git a/CMakeLists.txt b/CMakeLists.txt index 2ea908981..70293ebba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) if(NOT MLX_VERSION) - set(MLX_VERSION 0.0.3) + set(MLX_VERSION 0.0.6) endif() # --------------------- Processor tests ------------------------- @@ -221,4 +221,4 @@ install( install( DIRECTORY ${CMAKE_MODULE_PATH}/ DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} -) \ No newline at end of file +) diff --git a/README.md b/README.md index 72022b0a1..0276e5006 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ variety of examples, including: - [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training. - Large-scale text generation with - [LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llama) and + [LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama) and finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora). - Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion). - Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper). diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index c54af3a46..4adde50bc 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -125,6 +125,14 @@ if __name__ == "__main__": compare_filtered("sum_axis --size 16x128x1024 --axis 1") compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu") compare_filtered("sum_axis --size 16x128x1024 --axis 0") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --cpu") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,1") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --cpu") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,2") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1") compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu") compare_filtered("argmax --size 10x1024x128 --axis 1") compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu") diff --git a/docs/src/conf.py b/docs/src/conf.py index a5fbf5b16..d38d3424f 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -10,8 +10,8 @@ import subprocess project = "MLX" copyright = "2023, MLX Contributors" author = "MLX Contributors" -version = "0.0.5" -release = "0.0.5" +version = "0.0.6" +release = "0.0.6" # -- General configuration --------------------------------------------------- diff --git a/docs/src/index.rst b/docs/src/index.rst index ac4932f10..207238f37 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -57,6 +57,7 @@ are the CPU and GPU. python/random python/transforms python/fft + python/linalg python/nn python/optimizers python/tree_utils diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst new file mode 100644 index 000000000..27746441e --- /dev/null +++ b/docs/src/python/linalg.rst @@ -0,0 +1,11 @@ +.. _linalg: + +Linear Algebra +============== + +.. currentmodule:: mlx.core.linalg + +.. autosummary:: + :toctree: _autosummary + + norm diff --git a/docs/src/python/nn.rst b/docs/src/python/nn.rst index bc19a8162..4c9868171 100644 --- a/docs/src/python/nn.rst +++ b/docs/src/python/nn.rst @@ -123,7 +123,7 @@ To get more detailed information on the arrays in a :class:`Module` you can use all the parameters in a :class:`Module` do: .. code-block:: python - + from mlx.utils import tree_map shapes = tree_map(lambda p: p.shape, mlp.parameters()) @@ -131,7 +131,7 @@ As another example, you can count the number of parameters in a :class:`Module` with: .. code-block:: python - + from mlx.utils import tree_flatten num_params = sum(v.size for _, v in tree_flatten(mlp.parameters())) diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index fab3ff785..5ef45d60d 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -20,6 +20,7 @@ Layers Linear Conv1d Conv2d + BatchNorm LayerNorm RMSNorm GroupNorm @@ -27,3 +28,6 @@ Layers MultiHeadAttention Sequential QuantizedLinear + Dropout + Dropout2d + diff --git a/docs/src/python/nn/losses.rst b/docs/src/python/nn/losses.rst index b6a202d4a..3fb7589f8 100644 --- a/docs/src/python/nn/losses.rst +++ b/docs/src/python/nn/losses.rst @@ -16,4 +16,7 @@ Loss Functions mse_loss nll_loss smooth_l1_loss - triplet_loss \ No newline at end of file + triplet_loss + hinge_loss + huber_loss + log_cosh_loss \ No newline at end of file diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index bd28537f1..e004fc3d9 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -14,6 +14,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h ) diff --git a/mlx/backend/common/reduce.h b/mlx/backend/common/reduce.h index 740f54a48..da1d1658a 100644 --- a/mlx/backend/common/reduce.h +++ b/mlx/backend/common/reduce.h @@ -126,7 +126,7 @@ struct ReductionPlan { ReductionPlan get_reduction_plan(const array& x, const std::vector axes) { // The data is all there and we are reducing over everything if (x.size() == x.data_size() && axes.size() == x.ndim() && - (x.flags().row_contiguous || x.flags().col_contiguous)) { + x.flags().contiguous) { return ContiguousAllReduce; } diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 68662763a..c48f2908f 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -19,6 +19,9 @@ namespace mlx::core::metal { namespace { +// Catch things related to the main-thread static variables +static std::shared_ptr global_memory_pool = new_scoped_memory_pool(); + // TODO nicer way to set this or possibly expose as an environment variable static constexpr int MAX_BUFFERS_PER_QUEUE = 12; @@ -110,15 +113,22 @@ MTL::Library* load_library( } // namespace -Device::Device() - : pool_(NS::AutoreleasePool::alloc()->init()), - device_(load_device()), - library_map_({{"mlx", load_library(device_)}}) {} +Device::Device() { + auto pool = new_scoped_memory_pool(); + device_ = load_device(); + library_map_ = {{"mlx", load_library(device_)}}; +} Device::~Device() { for (auto& q : queue_map_) { q.second->release(); } + for (auto& b : buffer_map_) { + b.second.second->release(); + } + for (auto& e : encoder_map_) { + e.second->release(); + } for (auto& k : kernel_map_) { k.second->release(); } @@ -126,7 +136,6 @@ Device::~Device() { l.second->release(); } device_->release(); - pool_->release(); } void Device::new_queue(int index) { @@ -235,6 +244,7 @@ void Device::register_library( MTL::ComputePipelineState* Device::get_kernel( const std::string& name, const std::string& lib_name /* = "mlx" */) { + auto pool = new_scoped_memory_pool(); // Look for cached kernel if (auto it = kernel_map_.find(name); it != kernel_map_.end()) { return it->second; @@ -277,18 +287,18 @@ MTL::ComputePipelineState* Device::get_kernel( } Device& device(mlx::core::Device) { - static Device metal_device_; - return metal_device_; + static Device metal_device; + return metal_device; } -NS::AutoreleasePool*& thread_autorelease_pool() { - static thread_local NS::AutoreleasePool* p = - NS::AutoreleasePool::alloc()->init(); - return p; +std::shared_ptr new_scoped_memory_pool() { + auto dtor = [](void* ptr) { + static_cast(ptr)->release(); + }; + return std::shared_ptr(NS::AutoreleasePool::alloc()->init(), dtor); } void new_stream(Stream stream) { - thread_autorelease_pool(); if (stream.device == mlx::core::Device::gpu) { device(stream.device).new_queue(stream.index); } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 62675d430..45449a332 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -67,7 +67,6 @@ class Device { const std::vector& arg_descs) const; private: - NS::AutoreleasePool* pool_; MTL::Device* device_; std::unordered_map queue_map_; std::unordered_map> buffer_map_; @@ -78,6 +77,5 @@ class Device { }; Device& device(mlx::core::Device); -NS::AutoreleasePool*& thread_autorelease_pool(); } // namespace mlx::core::metal diff --git a/mlx/backend/metal/kernels/reduce.metal b/mlx/backend/metal/kernels/reduce.metal index 25bf1ee1f..85ff41f44 100644 --- a/mlx/backend/metal/kernels/reduce.metal +++ b/mlx/backend/metal/kernels/reduce.metal @@ -112,88 +112,33 @@ template uint simd_group_id [[simdgroup_index_in_threadgroup]]); -/////////////////////////////////////////////////////////////////////////////// -// General reduce -/////////////////////////////////////////////////////////////////////////////// - -template -[[kernel]] void general_reduce( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], - const device int *in_shape [[buffer(2)]], - const device size_t *in_strides [[buffer(3)]], - const device size_t *out_strides [[buffer(4)]], - const device size_t& ndim [[buffer(5)]], - uint gid [[thread_position_in_grid]]) { - Op op; - auto in_idx = elem_to_loc(gid, in_shape, in_strides, ndim); - auto out_idx = elem_to_loc(gid, in_shape, out_strides, ndim); - op.atomic_update(out, static_cast(in[in_idx]), out_idx); -} - -template -[[kernel]] void general_reduce( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], - const device int *in_shape [[buffer(2)]], - const device size_t *in_strides [[buffer(3)]], - const device size_t *out_strides [[buffer(4)]], - uint gid [[thread_position_in_grid]]) { - Op op; - auto in_idx = elem_to_loc_nd(gid, in_shape, in_strides); - auto out_idx = elem_to_loc_nd(gid, in_shape, out_strides); - op.atomic_update(out, static_cast(in[in_idx]), out_idx); -} - -#define instantiate_general_reduce_helper(name, itype, otype, op) \ - template [[host_name("general_reduce_" #name)]] \ - [[kernel]] void general_reduce( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const device int *in_shape [[buffer(2)]], \ - const device size_t *in_strides [[buffer(3)]], \ - const device size_t *out_strides [[buffer(4)]], \ - const device size_t& ndim [[buffer(5)]], \ - uint gid [[thread_position_in_grid]]); - -#define instantiate_general_reduce_helper_nd(name, itype, otype, op, n) \ - template [[host_name("general_reduce_" #name "_dim_" #n)]] \ - [[kernel]] void general_reduce( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const device int *in_shape [[buffer(2)]], \ - const device size_t *in_strides [[buffer(3)]], \ - const device size_t *out_strides [[buffer(4)]], \ - uint gid [[thread_position_in_grid]]); - -#define instantiate_general_reduce(name, itype, otype, op) \ - instantiate_general_reduce_helper(name, itype, otype, op) \ - instantiate_general_reduce_helper_nd(name, itype, otype, op, 1) \ - instantiate_general_reduce_helper_nd(name, itype, otype, op, 2) \ - instantiate_general_reduce_helper_nd(name, itype, otype, op, 3) \ - instantiate_general_reduce_helper_nd(name, itype, otype, op, 4) - - /////////////////////////////////////////////////////////////////////////////// // Row atomics /////////////////////////////////////////////////////////////////////////////// template -[[kernel]] void row_reduce( +[[kernel]] void row_reduce_general( const device T *in [[buffer(0)]], - device U *out [[buffer(1)]], - const device size_t& reduction_size [[buffer(2)]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], - uint tid [[threadgroup_position_in_grid]], + device mlx_atomic *out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& out_size [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { Op op; - // Each threadgroup handles 1 reduction - in += tid * reduction_size + lid * N_READS; + // Each threadgroup handles 1 reduction + // TODO: Specializing elem_to_loc would be slightly faster + int idx = tid.y * out_size + tid.x; + int extra_offset = elem_to_loc(idx, shape, strides, ndim); + in += extra_offset + lid.x * N_READS; // The reduction is accumulated here U total_val = Op::init; @@ -201,7 +146,7 @@ template // Loop over the reduction size within thread group int r = 0; - for (; r < (int)ceildiv(reduction_size, N_READS*lsize) - 1; r++) { + for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) { T vals[N_READS]; for(int i = 0; i < N_READS; i++) { vals[i] = in[i]; @@ -210,11 +155,11 @@ template total_val = op(static_cast(vals[i]), total_val); } - in += lsize * N_READS; + in += lsize.x * N_READS; } - // Sepate case for the last set as we close the reduction size - size_t reduction_index = (lid + (size_t)lsize * r) * N_READS; + // Separate case for the last set as we close the reduction size + size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS; if(reduction_index < reduction_size) { int max_reads = reduction_size - reduction_index; @@ -240,26 +185,30 @@ template // Reduction within thread group // Only needed if multiple simd groups if(reduction_size > simd_size) { - total_val = lid < simd_per_group ? local_vals[lid] : op.init; + total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init; total_val = op.simd_reduce(total_val); } // Update output - if (lid == 0) { - out[tid] = total_val; + if (lid.x == 0) { + op.atomic_update(out, total_val, tid.x); } } -#define instantiate_row_reduce(name, itype, otype, op) \ - template [[host_name("row_reduce_" #name)]] \ - [[kernel]] void row_reduce( \ - const device itype *in [[buffer(0)]], \ - device otype *out [[buffer(1)]], \ - const device size_t& reduction_size [[buffer(2)]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint tid [[threadgroup_position_in_grid]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ +#define instantiate_row_reduce_general(name, itype, otype, op) \ + template [[host_name("row_reduce_general_" #name)]] \ + [[kernel]] void row_reduce_general( \ + const device itype *in [[buffer(0)]], \ + device mlx_atomic *out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& out_size [[buffer(3)]], \ + const constant int* shape [[buffer(4)]], \ + const constant size_t* strides [[buffer(5)]], \ + const constant int& ndim [[buffer(6)]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); @@ -311,148 +260,57 @@ inline void _contiguous_strided_reduce( } template -[[kernel]] void col_reduce( +[[kernel]] void col_reduce_general( const device T *in [[buffer(0)]], device mlx_atomic *out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_stride [[buffer(3)]], const constant size_t& out_size [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], threadgroup U *local_data [[threadgroup(0)]], - uint2 tid [[threadgroup_position_in_grid]], - uint2 lid [[thread_position_in_threadgroup]], - uint2 lsize [[threads_per_threadgroup]]) { - auto out_idx = tid.x * lsize.x + lid.x; - + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]]) { + auto out_idx = tid.x * lsize.x + lid.x; + auto in_idx = elem_to_loc( + out_idx + tid.z * out_size, + shape, + strides, + ndim + ); + if(out_idx < out_size) { _contiguous_strided_reduce( in, out, local_data, - out_idx, + in_idx, out_idx, reduction_size, reduction_stride, - tid, - lid, - lsize); + tid.xy, + lid.xy, + lsize.xy); } } -#define instantiate_col_reduce(name, itype, otype, op) \ - template [[host_name("col_reduce_" #name)]] \ - [[kernel]] void col_reduce( \ +#define instantiate_col_reduce_general(name, itype, otype, op) \ + template [[host_name("col_reduce_general_" #name)]] \ + [[kernel]] void col_reduce_general( \ const device itype *in [[buffer(0)]], \ device mlx_atomic *out [[buffer(1)]], \ const constant size_t& reduction_size [[buffer(2)]], \ const constant size_t& reduction_stride [[buffer(3)]], \ const constant size_t& out_size [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ threadgroup otype *local_data [[threadgroup(0)]], \ - uint2 tid [[threadgroup_position_in_grid]], \ - uint2 lid [[thread_position_in_threadgroup]], \ - uint2 lsize [[threads_per_threadgroup]]); - -template -[[kernel]] void contiguous_strided_reduce( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const device int* in_shape [[buffer(5)]], - const device size_t* in_strides [[buffer(6)]], - threadgroup U *local_data [[threadgroup(0)]], - uint2 tid [[threadgroup_position_in_grid]], - uint2 lid [[thread_position_in_threadgroup]], - uint2 lsize [[threads_per_threadgroup]]) { - - auto out_idx = tid.x * lsize.x + lid.x; - auto in_idx = elem_to_loc_nd(out_idx, in_shape, in_strides); - - if(out_idx < out_size) { - _contiguous_strided_reduce( - in, - out, - local_data, - in_idx, - out_idx, - reduction_size, - reduction_stride, - tid, - lid, - lsize); - } -} - -template -[[kernel]] void contiguous_strided_reduce( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const device int* in_shape [[buffer(5)]], - const device size_t* in_strides [[buffer(6)]], - const device size_t& in_dim [[buffer(7)]], - threadgroup U *local_data [[threadgroup(0)]], - uint2 tid [[threadgroup_position_in_grid]], - uint2 lid [[thread_position_in_threadgroup]], - uint2 lsize [[threads_per_threadgroup]]) { - - auto out_idx = tid.x * lsize.x + lid.x; - auto in_idx = elem_to_loc(out_idx, in_shape, in_strides, in_dim); - - if(out_idx < out_size) { - _contiguous_strided_reduce( - in, - out, - local_data, - in_idx, - out_idx, - reduction_size, - reduction_stride, - tid, - lid, - lsize); - } -} - -#define instantiate_contiguous_strided_helper(name, itype, otype, op) \ - template [[host_name("contiguous_strided_reduce_" #name)]] \ - [[kernel]] void contiguous_strided_reduce( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant size_t& out_size [[buffer(4)]], \ - const device int* in_shape [[buffer(5)]], \ - const device size_t* in_strides [[buffer(6)]], \ - const device size_t& in_dim [[buffer(7)]], \ - threadgroup otype *local_data [[threadgroup(0)]], \ - uint2 tid [[threadgroup_position_in_grid]], \ - uint2 lid [[thread_position_in_threadgroup]], \ - uint2 lsize [[threads_per_threadgroup]]); - -#define instantiate_contiguous_strided_helper_nd(name, itype, otype, op, n) \ - template [[host_name("contiguous_strided_reduce_" #name "_dim_" #n)]] \ - [[kernel]] void contiguous_strided_reduce( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant size_t& out_size [[buffer(4)]], \ - const device int* in_shape [[buffer(5)]], \ - const device size_t* in_strides [[buffer(6)]], \ - threadgroup otype *local_data [[threadgroup(0)]], \ - uint2 tid [[threadgroup_position_in_grid]], \ - uint2 lid [[thread_position_in_threadgroup]], \ - uint2 lsize [[threads_per_threadgroup]]); - -#define instantiate_contiguous_strided(name, itype, otype, op) \ - instantiate_contiguous_strided_helper(name, itype, otype, op) \ - instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 1) \ - instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 2) \ - instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 3) \ - instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 4) + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]]); /////////////////////////////////////////////////////////////////////////////// @@ -461,10 +319,8 @@ template #define instantiate_reduce(name, itype, otype, op) \ instantiate_all_reduce(name, itype, otype, op) \ - instantiate_row_reduce(name, itype, otype, op) \ - instantiate_col_reduce(name, itype, otype, op) \ - instantiate_contiguous_strided(name, itype, otype, op) \ - instantiate_general_reduce(name, itype, otype, op) + instantiate_row_reduce_general(name, itype, otype, op) \ + instantiate_col_reduce_general(name, itype, otype, op) #define instantiate_same_reduce(name, tname, type, op) \ instantiate_init_reduce(name ##tname, type, op) \ @@ -535,4 +391,4 @@ instantiate_same_reduce(max_, float16, half, Max) instantiate_same_reduce(max_, float32, float, Max) instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min) -instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max) \ No newline at end of file +instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max) diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index f63ad55a3..478e57c73 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -50,6 +50,7 @@ std::function make_task( bool retain_graph) { auto task = [retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable { + auto pool = new_scoped_memory_pool(); for (auto& d : deps) { d.wait(); } @@ -66,12 +67,6 @@ std::function make_task( arr.detach(); } p->set_value(); - // Signal this thread to clear the pool on a synchroniztion. - scheduler::enqueue(s, []() { - thread_autorelease_pool()->release(); - thread_autorelease_pool() = - NS::AutoreleasePool::alloc()->init(); - }); scheduler::notify_task_completion(s); }); metal::device(s.device).commit_command_buffer(s.index); diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index f1f7ede44..99f400956 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -20,6 +20,7 @@ constexpr bool is_available() { } void new_stream(Stream stream); +std::shared_ptr new_scoped_memory_pool(); std::function make_task( array& arr, diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 532f18353..6a2ce084b 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -2,9 +2,11 @@ #include #include +#include #include #include "mlx/backend/common/reduce.h" +#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/utils.h" @@ -61,22 +63,47 @@ void all_reduce_dispatch( compute_encoder->dispatchThreads(grid_dims, group_dims); } -void row_reduce_dispatch( +void row_reduce_general_dispatch( const array& in, array& out, const std::string& op_name, - const std::vector& axes_, + const ReductionPlan& plan, + const std::vector& axes, MTL::ComputeCommandEncoder* compute_encoder, metal::Device& d) { - auto kernel = d.get_kernel("row_reduce_" + op_name + type_to_name(in)); + auto kernel = + d.get_kernel("row_reduce_general_" + op_name + type_to_name(in)); + // Prepare the arguments for the kernel int n_reads = REDUCE_N_READS; - size_t reduction_size = in.size() / out.size(); + size_t reduction_size = plan.shape.back(); + size_t out_size = out.size(); + auto shape = plan.shape; + auto strides = plan.strides; + shape.pop_back(); + strides.pop_back(); + size_t non_row_reductions = 1; + for (auto s : shape) { + non_row_reductions *= static_cast(s); + } + auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes); + for (auto s : rem_shape) { + shape.push_back(s); + } + for (auto s : rem_strides) { + strides.push_back(s); + } + int ndim = shape.size(); + // Set the arguments for the kernel compute_encoder->setComputePipelineState(kernel); set_array_buffer(compute_encoder, in, 0); set_array_buffer(compute_encoder, out, 1); compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); + compute_encoder->setBytes(&out_size, sizeof(size_t), 3); + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); + compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 5); + compute_encoder->setBytes(&ndim, sizeof(int), 6); // Each thread group is responsible for 1 output NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); @@ -91,92 +118,54 @@ void row_reduce_dispatch( // Launch enough thread groups for each output size_t n_threads = out.size() * thread_group_size; - MTL::Size grid_dims = MTL::Size(n_threads, 1, 1); + MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); compute_encoder->dispatchThreads(grid_dims, group_dims); } -void col_reduce_dispatch( +void strided_reduce_general_dispatch( const array& in, array& out, const std::string& op_name, - const std::vector& axes_, + const ReductionPlan& plan, + const std::vector& axes, MTL::ComputeCommandEncoder* compute_encoder, metal::Device& d) { - std::ostringstream kernel_name; + auto kernel = + d.get_kernel("col_reduce_general_" + op_name + type_to_name(in)); - bool encode_in_shape = false; - bool encode_ndim = false; - - // If the slowest moving axis can be merged into the reductions, - // we call the column reduce kernel - // In this case, a linear index in the output corresponds to the - // linear index in the input where the reduction starts - if (axes_[axes_.size() - 1] == (axes_.size() - 1)) { - kernel_name << "col_reduce_" << op_name << type_to_name(in); - } - // Otherwise, while all the reduction axes can be merged, the mapping between - // indices in the output and input require resolving using shapes and strides - else { - kernel_name << "contiguous_strided_reduce_" << op_name << type_to_name(in); - encode_in_shape = true; - - // We check for a viable template with the required number of dimensions - // we only care about encoding non-reduced shapes and strides in the input - size_t non_reducing_dims = in.ndim() - axes_.size(); - if (non_reducing_dims >= 1 && - non_reducing_dims <= MAX_REDUCE_SPECIALIZED_DIMS) { - kernel_name << "_dim_" << non_reducing_dims; - } else { - encode_ndim = true; - } - } - - auto kernel = d.get_kernel(kernel_name.str()); - size_t in_size = in.size(); + // Prepare the arguments for the kernel + size_t reduction_size = plan.shape.back(); + size_t reduction_stride = plan.strides.back(); size_t out_size = out.size(); + auto shape = plan.shape; + auto strides = plan.strides; + shape.pop_back(); + strides.pop_back(); + size_t non_col_reductions = 1; + for (auto s : shape) { + non_col_reductions *= static_cast(s); + } + auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes); + for (auto s : rem_shape) { + shape.push_back(s); + } + for (auto s : rem_strides) { + strides.push_back(s); + } + int ndim = shape.size(); + // Set the arguments for the kernel compute_encoder->setComputePipelineState(kernel); set_array_buffer(compute_encoder, in, 0); set_array_buffer(compute_encoder, out, 1); - - // Calculate the number of inputs to reduce and the stride b/w them - size_t reduction_size = 1; - size_t in_ndim = in.ndim(); - size_t reduction_stride = in_size; - - for (int i : axes_) { - reduction_size *= in.shape(i); - reduction_stride = std::min(reduction_stride, in.strides()[i]); - } - compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); compute_encoder->setBytes(&out_size, sizeof(size_t), 4); - if (encode_in_shape) { - // Obtain the non-reducing shape and strides of the input to encode - std::vector inp_shape_mod; - std::vector inp_strides_mod; - - for (size_t i = 0, j = 0; i < in.ndim(); i++) { - if (j < axes_.size() && axes_[j] == i) { - j++; - } else { - inp_shape_mod.push_back(in.shape(i)); - inp_strides_mod.push_back(in.strides()[i]); - } - } - - size_t ndim = inp_shape_mod.size(); - - compute_encoder->setBytes(inp_shape_mod.data(), ndim * sizeof(int), 5); - compute_encoder->setBytes(inp_strides_mod.data(), ndim * sizeof(size_t), 6); - - if (encode_ndim) { - compute_encoder->setBytes(&ndim, sizeof(size_t), 7); - } - } + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); + compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 6); + compute_encoder->setBytes(&ndim, sizeof(int), 7); // Select block dimensions @@ -200,7 +189,8 @@ void col_reduce_dispatch( (n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y; // Launch enough thread groups for each output - MTL::Size grid_dims = MTL::Size(n_threadgroups_x, n_threadgroups_y, 1); + MTL::Size grid_dims = + MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions); MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1); // We set shared memory to be exploited here for reductions within a @@ -216,60 +206,6 @@ void col_reduce_dispatch( compute_encoder->dispatchThreadgroups(grid_dims, group_dims); } -void general_reduce_dispatch( - const array& in, - array& out, - const std::string& op_name, - const std::vector& axes_, - MTL::ComputeCommandEncoder* compute_encoder, - metal::Device& d) { - bool encode_ndim = true; - std::ostringstream kernel_name; - kernel_name << "general_reduce_" << op_name << type_to_name(in); - - // Check for specialzed kernels for input ndim - if (in.ndim() >= 1 && in.ndim() <= MAX_REDUCE_SPECIALIZED_DIMS) { - kernel_name << "_dim_" << in.ndim(); - encode_ndim = false; - } - auto kernel = d.get_kernel(kernel_name.str()); - size_t in_size = in.size(); - size_t ndim = in.ndim(); - - // We set the reducing strides to 0 to induce collisions for the reduction - std::vector out_strides(ndim); - size_t stride = 1; - for (int i = ndim - 1, j = axes_.size() - 1; i >= 0; --i) { - if (j >= 0 && axes_[j] == i) { - out_strides[i] = 0; - --j; - } else { - out_strides[i] = stride; - stride *= in.shape(i); - } - } - - compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, out, 1); - compute_encoder->setBytes(in.shape().data(), ndim * sizeof(int), 2); - compute_encoder->setBytes(in.strides().data(), ndim * sizeof(size_t), 3); - compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4); - if (encode_ndim) { - compute_encoder->setBytes(&ndim, sizeof(size_t), 5); - } - - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size > in_size) { - thread_group_size = in_size; - } - size_t nthreads = in_size; - - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); -} - } // namespace ////////////////////////////////////////////////////////////////////// @@ -278,7 +214,7 @@ void general_reduce_dispatch( void Reduce::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - auto& in = inputs[0]; + array in = inputs[0]; // TODO: Allow specific row and column reductions with types disabled // due to atomics ? @@ -335,36 +271,46 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // Reduce { - // Check for contiguous data - if (in.size() == in.data_size() && - (in.flags().row_contiguous || in.flags().col_contiguous)) { - // Go to all reduce if reducing over all axes - if (axes_.size() == in.ndim()) { - all_reduce_dispatch(in, out, op_name, compute_encoder, d); - return; - } - // Use specialized kernels if the input is row contiguous and - // the reducing axes can be merged into one - else if ( - in.flags().row_contiguous && in.strides().back() == 1 && - (axes_.back() - axes_.front()) == axes_.size() - 1) { - // If the fastest moving axis is being reduced, go to row reduce - if (axes_[0] == (in.ndim() - axes_.size())) { - row_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d); - return; - } - // Otherwise go to to generalized strided reduce - // Note: bool isn't support here yet due to the use of atomics - // once that is updated, this should be the else condition of this - // branch - else if (in.dtype() != bool_) { - col_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d); - return; - } - } + std::vector copies; + ReductionPlan plan = get_reduction_plan(in, axes_); + + // If it is a general reduce then copy the input to a contiguous array and + // recompute the plan. + if (plan.type == GeneralReduce) { + array in_copy(in.shape(), in.dtype(), nullptr, {}); + copy_gpu(in, in_copy, CopyType::General, s); + copies.push_back(in_copy); + in = in_copy; + plan = get_reduction_plan(in, axes_); + } + + // Reducing over everything and the data is all there no broadcasting or + // slicing etc. + if (plan.type == ContiguousAllReduce) { + all_reduce_dispatch(in, out, op_name, compute_encoder, d); + } + + // At least the last dimension is row contiguous and we are reducing over + // the last dim. + else if ( + plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { + row_reduce_general_dispatch( + in, out, op_name, plan, axes_, compute_encoder, d); + } + + // At least the last two dimensions are contiguous and we are doing a + // strided reduce over these. + else if ( + plan.type == ContiguousStridedReduce || + plan.type == GeneralStridedReduce) { + strided_reduce_general_dispatch( + in, out, op_name, plan, axes_, compute_encoder, d); + } + + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); } - // Fall back to the general case - general_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d); } } diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index accfc4c8a..212ca2839 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -7,6 +7,9 @@ namespace mlx::core::metal { void new_stream(Stream) {} +std::shared_ptr new_scoped_memory_pool() { + return nullptr; +} std::function make_task( array& arr, diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp new file mode 100644 index 000000000..7e7264e3f --- /dev/null +++ b/mlx/linalg.cpp @@ -0,0 +1,175 @@ +// Copyright © 2023 Apple Inc. + +#include +#include +#include + +#include "mlx/dtype.h" +#include "mlx/linalg.h" + +namespace mlx::core::linalg { + +Dtype at_least_float(const Dtype& d) { + return is_floating_point(d) ? d : promote_types(d, float32); +} + +inline array l2_norm( + const array& a, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + if (is_complex(a.dtype())) { + return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s); + } else { + return sqrt(sum(square(a, s), axis, keepdims, s), s); + } +} + +inline array vector_norm( + const array& a, + const double ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + auto dtype = at_least_float(a.dtype()); + if (ord == 0.0) { + return astype(sum(not_equal(a, array(0), s), axis, keepdims, s), dtype, s); + } else if (ord == 1.0) { + return astype(sum(abs(a, s), axis, keepdims, s), dtype, s); + } else if (ord == 2.0) { + return l2_norm(a, axis, keepdims, s); + } else if (ord == std::numeric_limits::infinity()) { + return astype(max(abs(a, s), axis, keepdims, s), dtype, s); + } else if (ord == -std::numeric_limits::infinity()) { + return astype(min(abs(a, s), axis, keepdims, s), dtype, s); + } else { + return power( + sum(power(abs(a, s), array(ord, dtype), s), axis, keepdims, s), + array(1.0 / ord, dtype), + s); + } +} + +inline array matrix_norm( + const array& a, + const double ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + auto dtype = at_least_float(a.dtype()); + auto row_axis = axis[0]; + auto col_axis = axis[1]; + if (ord == -1.0) { + col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0); + return astype( + min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s), + dtype, + s); + } else if (ord == 1.0) { + col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0); + return astype( + max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s), + dtype, + s); + } else if (ord == std::numeric_limits::infinity()) { + row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0); + return astype( + max(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s), + dtype, + s); + } else if (ord == -std::numeric_limits::infinity()) { + row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0); + return astype( + min(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s), + dtype, + s); + } else if (ord == 2.0 || ord == -2.0) { + throw std::runtime_error( + "[linalg::norm] Singular value norms are not implemented."); + } else { + std::ostringstream msg; + msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm."; + throw std::invalid_argument(msg.str()); + } +} + +inline array matrix_norm( + const array& a, + const std::string& ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + if (ord == "f" || ord == "fro") { + return l2_norm(a, axis, keepdims, s); + } else if (ord == "nuc") { + throw std::runtime_error( + "[linalg::norm] Nuclear norm not yet implemented."); + } else { + std::ostringstream msg; + msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm."; + throw std::invalid_argument(msg.str()); + } +} + +array norm( + const array& a, + const std::optional>& axis /* = std::nullopt */, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + if (!axis) { + return norm(flatten(a, s), std::vector{0}, keepdims, s); + } + + if (axis.value().size() > 2) { + throw std::invalid_argument( + "[linalg::norm] Received too many axes for norm."); + } + return l2_norm(a, axis.value(), keepdims, s); +} + +array norm( + const array& a, + const double ord, + const std::optional>& axis /* = std::nullopt */, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + std::vector ax; + if (!axis) { + ax.resize(a.ndim()); + std::iota(ax.begin(), ax.end(), 0); + } else { + ax = axis.value(); + } + if (ax.size() == 1) { + return vector_norm(a, ord, ax, keepdims, s); + } else if (ax.size() == 2) { + return matrix_norm(a, ord, ax, keepdims, s); + } else { + throw std::invalid_argument( + "[linalg::norm] Received too many axes for norm."); + } +} + +array norm( + const array& a, + const std::string& ord, + const std::optional>& axis /* = std::nullopt */, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + std::vector ax; + if (!axis) { + ax.resize(a.ndim()); + std::iota(ax.begin(), ax.end(), 0); + } else { + ax = axis.value(); + } + if (ax.size() != 2) { + std::ostringstream msg; + msg << "[linalg::norm] Norm '" << ord << "' only supported for matrices," + << " but received " << ax.size() << " axis/axes."; + throw std::invalid_argument(msg.str()); + } + return matrix_norm(a, ord, ax, keepdims, s); +} + +} // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h new file mode 100644 index 000000000..80e484eb5 --- /dev/null +++ b/mlx/linalg.h @@ -0,0 +1,63 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/ops.h" +#include "mlx/stream.h" + +namespace mlx::core::linalg { + +/** + * Compute vector or matrix norms. + * + * - If axis and ord are both unspecified, computes the 2-norm of flatten(x). + * - If axis is not provided but ord is, then x must be either 1D or 2D. + * - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm + * for matrices) is computed along the given axes. At most 2 axes can be + * specified. + * - If both axis and ord are provided, then the corresponding matrix or vector + * norm is computed. At most 2 axes can be specified. + */ +array norm( + const array& a, + const double ord, + const std::optional>& axis = std::nullopt, + bool keepdims = false, + StreamOrDevice s = {}); +inline array norm( + const array& a, + const double ord, + int axis, + bool keepdims = false, + StreamOrDevice s = {}) { + return norm(a, ord, std::vector{axis}, keepdims, s); +} +array norm( + const array& a, + const std::string& ord, + const std::optional>& axis = std::nullopt, + bool keepdims = false, + StreamOrDevice s = {}); +inline array norm( + const array& a, + const std::string& ord, + int axis, + bool keepdims = false, + StreamOrDevice s = {}) { + return norm(a, ord, std::vector{axis}, keepdims, s); +} +array norm( + const array& a, + const std::optional>& axis = std::nullopt, + bool keepdims = false, + StreamOrDevice s = {}); +inline array +norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) { + return norm(a, std::vector{axis}, keepdims, s); +} + +} // namespace mlx::core::linalg diff --git a/mlx/mlx.h b/mlx/mlx.h index 102d2dde9..8d785c39f 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -6,6 +6,7 @@ #include "mlx/backend/metal/metal.h" #include "mlx/device.h" #include "mlx/fft.h" +#include "mlx/linalg.h" #include "mlx/ops.h" #include "mlx/random.h" #include "mlx/stream.h" diff --git a/mlx/random.cpp b/mlx/random.cpp index 232c458f9..ef11f8c65 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -103,7 +103,9 @@ array uniform( } auto stream = to_stream(s); - auto range = subtract(high, low, stream); + auto lo = astype(low, dtype, stream); + auto hi = astype(high, dtype, stream); + auto range = subtract(hi, lo, stream); auto out_shape = broadcast_shapes(shape, range.shape()); if (out_shape != shape) { std::ostringstream msg; @@ -136,7 +138,7 @@ array uniform( auto out = bits(shape, size_of(dtype), key, stream); out = astype(divide(out, maxval, stream), dtype, stream); out = minimum(out, upper, stream); - return add(multiply(range, out, stream), low, stream); + return add(multiply(range, out, stream), lo, stream); } array uniform( diff --git a/mlx/scheduler.h b/mlx/scheduler.h index 6506b20ab..150cc96db 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -35,6 +35,7 @@ struct StreamThread { } void thread_fn() { + auto thread_pool = metal::new_scoped_memory_pool(); metal::new_stream(stream); while (true) { std::function task; diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 5aaedd348..98261d9d0 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -33,10 +33,16 @@ from mlx.nn.layers.activations import ( from mlx.nn.layers.base import Module from mlx.nn.layers.containers import Sequential from mlx.nn.layers.convolution import Conv1d, Conv2d -from mlx.nn.layers.dropout import Dropout +from mlx.nn.layers.dropout import Dropout, Dropout2d from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Linear -from mlx.nn.layers.normalization import GroupNorm, InstanceNorm, LayerNorm, RMSNorm +from mlx.nn.layers.normalization import ( + BatchNorm, + GroupNorm, + InstanceNorm, + LayerNorm, + RMSNorm, +) from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.transformer import ( diff --git a/python/mlx/nn/layers/dropout.py b/python/mlx/nn/layers/dropout.py index 3193cdbd7..caa7a6452 100644 --- a/python/mlx/nn/layers/dropout.py +++ b/python/mlx/nn/layers/dropout.py @@ -5,7 +5,7 @@ from mlx.nn.layers.base import Module class Dropout(Module): - """Randomly zero a portion of the elements during training. + r"""Randomly zero a portion of the elements during training. The remaining elements are multiplied with :math:`\frac{1}{1-p}` where :math:`p` is the probability of zeroing an element. This is done so the @@ -32,4 +32,57 @@ class Dropout(Module): mask = mx.random.bernoulli(self._p_1, x.shape) - return (1 / self._p_1) * mask.astype(x.dtype) * x + return (1 / self._p_1) * mask * x + + +class Dropout2d(Module): + r"""Apply 2D channel-wise dropout during training. + + Randomly zero out entire channels independently with probability :math:`p`. + This layer expects the channels to be last, i.e. the input shape should be + ``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input + image height,``W`` is the input image width, and``C`` is the number of + input channels + + The remaining channels are scaled by :math:`\frac{1}{1-p}` to + maintain the expected value of each element. Unlike traditional dropout, + which zeros individual entries, this layer zeros entire channels. This is + beneficial for early convolution layers where adjacent pixels are + correlated. In such case, traditional dropout may not effectively + regularize activations. For more details, see [1]. + + [1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015. + Efficient Object Localization Using Convolutional Networks. CVPR 2015. + + Args: + p (float): Probability of zeroing a channel during training. + """ + + def __init__(self, p: float = 0.5): + super().__init__() + + if p < 0 or p >= 1: + raise ValueError("The dropout probability should be in [0, 1)") + + self._p_1 = 1 - p + + def _extra_repr(self): + return f"p={1-self._p_1}" + + def __call__(self, x): + if x.ndim not in (3, 4): + raise ValueError( + f"Received input with {x.ndim} dimensions. Expected 3 or 4 dimensions." + ) + + if self._p_1 == 1 or not self.training: + return x + + # Dropout is applied on the whole channel + # 3D input: (1, 1, C) + # 4D input: (B, 1, 1, C) + mask_shape = x.shape + mask_shape[-2] = mask_shape[-3] = 1 + + mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) + return (1 / self._p_1) * mask * x diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index f87eddbc6..92d2ab3ce 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -1,5 +1,7 @@ # Copyright © 2023 Apple Inc. +from typing import Tuple + import mlx.core as mx from mlx.nn.layers.base import Module @@ -252,3 +254,121 @@ class GroupNorm(Module): ) x = group_norm(x) return (self.weight * x + self.bias) if "weight" in self else x + + +class BatchNorm(Module): + r"""Applies Batch Normalization over a 2D or 3D input. + + Computes + + .. math:: + + y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta, + + where :math:`\gamma` and :math:`\beta` are learned per feature dimension + parameters initialized at 1 and 0 respectively. + + The input shape is specified as ``NC`` or ``NLC``, where ``N`` is the + batch, ``C`` is the number of features or channels, and ``L`` is the + sequence length. The output has the same shape as the input. For + four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are + the height and width respecitvely. + + For more information on Batch Normalization, see the original paper `Batch + Normalization: Accelerating Deep Network Training by Reducing Internal + Covariate Shift `_. + + Args: + num_features (int): The feature dimension to normalize over. + eps (float, optional): A small additive constant for numerical + stability. Default: ``1e-5``. + momentum (float, optional): The momentum for updating the running + mean and variance. Default: ``0.1``. + affine (bool, optional): If ``True``, apply a learned affine + transformation after the normalization. Default: ``True``. + track_running_stats (bool, optional): If ``True``, track the + running mean and variance. Default: ``True``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn as nn + >>> x = mx.random.normal((5, 4)) + >>> bn = nn.BatchNorm(num_features=4, affine=True) + >>> output = bn(x) + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + ): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.track_running_stats = track_running_stats + + if affine: + self.weight = mx.ones((num_features,)) + self.bias = mx.zeros((num_features,)) + + if self.track_running_stats: + self._running_mean = mx.zeros((num_features,)) + self._running_var = mx.ones((num_features,)) + + def _extra_repr(self): + return ( + f"{self.num_features}, eps={self.eps}, " + f"momentum={self.momentum}, affine={'weight' in self}, " + f"track_running_stats={self.track_running_stats}" + ) + + def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]: + """ + Calculate the mean and variance of the input tensor. + + Args: + x (mx.array): Input tensor. + + Returns: + tuple: Tuple containing mean and variance. + """ + reduction_axes = tuple(range(0, x.ndim - 1)) + means = mx.mean(x, axis=reduction_axes, keepdims=True) + var = mx.var(x, axis=reduction_axes, keepdims=True) + + if self.track_running_stats and self.training: + self._running_mean = ( + 1 - self.momentum + ) * self._running_mean + self.momentum * means + self._running_var = ( + 1 - self.momentum + ) * self._running_var + self.momentum * var + return means, var + + def __call__(self, x: mx.array) -> mx.array: + """ + Forward pass of BatchNorm. + + Args: + x (mx.array): Input tensor. + + Returns: + mx.array: Output tensor. + """ + + if x.ndim < 2 or x.ndim > 4: + raise ValueError( + f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}" + ) + + if self.training or not self.track_running_stats: + means, var = self._calc_stats(x) + else: + means, var = self._running_mean, self._running_var + x = (x - means) * mx.rsqrt(var + self.eps) + return (self.weight * x + self.bias) if "weight" in self else x diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 755656e4f..91316fd04 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -1,5 +1,7 @@ # Copyright © 2023 Apple Inc. +import math + import mlx.core as mx from mlx.nn.layers.base import Module @@ -131,10 +133,6 @@ def mse_loss( f"targets shape {targets.shape}." ) - assert ( - predictions.shape == targets.shape - ), f"Shape of predictions {predictions.shape} and targets {targets.shape} must match" - loss = mx.square(predictions - targets) return _reduce(loss, reduction) @@ -283,3 +281,94 @@ def _reduce(loss: mx.array, reduction: str = "none"): return loss else: raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") + + +def hinge_loss( + inputs: mx.array, targets: mx.array, reduction: str = "none" +) -> mx.array: + r""" + Computes the hinge loss between inputs and targets. + + .. math:: + + \text{hinge}(y, y_{\text{pred}}) = \max(0, 1 - y \cdot y_{\text{pred}}) + + + Args: + inputs (array): The predicted values. + targets (array): The target values. They should be -1 or 1. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + array: The computed hinge loss. + """ + loss = mx.maximum(1 - inputs * targets, 0) + + return _reduce(loss, reduction) + + +def huber_loss( + inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none" +) -> mx.array: + r""" + Computes the Huber loss between inputs and targets. + + .. math:: + + L_{\delta}(a) = + \left\{ \begin{array}{ll} + \frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\ + \delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.} + \end{array} \right. + + Args: + inputs (array): The predicted values. + targets (array): The target values. + delta (float, optional): The threshold at which to change between L1 and L2 loss. + Default: ``1.0``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + array: The computed Huber loss. + """ + errors = inputs - targets + abs_errors = mx.abs(errors) + quadratic = mx.minimum(abs_errors, delta) + linear = abs_errors - quadratic + loss = 0.5 * quadratic**2 + delta * linear + + return _reduce(loss, reduction) + + +def log_cosh_loss( + inputs: mx.array, targets: mx.array, reduction: str = "none" +) -> mx.array: + r""" + Computes the log cosh loss between inputs and targets. + + Logcosh acts like L2 loss for small errors, ensuring stable gradients, + and like the L1 loss for large errors, reducing sensitivity to outliers. This + dual behavior offers a balanced, robust approach for regression tasks. + + .. math:: + + \text{logcosh}(y_{\text{true}}, y_{\text{pred}}) = + \frac{1}{n} \sum_{i=1}^{n} + \log(\cosh(y_{\text{pred}}^{(i)} - y_{\text{true}}^{(i)})) + + + Args: + inputs (array): The predicted values. + targets (array): The target values. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + array: The computed log cosh loss. + """ + errors = inputs - targets + loss = mx.logaddexp(errors, -errors) - math.log(2) + + return _reduce(loss, reduction) diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 5ab8a50bf..1ad9d207d 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -11,6 +11,7 @@ pybind11_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ) if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) diff --git a/python/src/array.cpp b/python/src/array.cpp index 74223cdda..2580de0ea 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -510,6 +510,14 @@ void init_array(py::module_& m) { "size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc") .def_property_readonly( "ndim", &array::ndim, R"pbdoc(The array's dimension.)pbdoc") + .def_property_readonly( + "itemsize", + &array::itemsize, + R"pbdoc(The size of the array's datatype in bytes.)pbdoc") + .def_property_readonly( + "nbytes", + &array::nbytes, + R"pbdoc(The number of bytes in the array.)pbdoc") // TODO, this makes a deep copy of the shape // implement alternatives to use reference // https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp new file mode 100644 index 000000000..ea5474a70 --- /dev/null +++ b/python/src/linalg.cpp @@ -0,0 +1,180 @@ +// Copyright © 2023 Apple Inc. + +#include + +#include +#include + +#include "mlx/linalg.h" + +#include "python/src/load.h" +#include "python/src/utils.h" + +namespace py = pybind11; +using namespace py::literals; + +using namespace mlx::core; +using namespace mlx::core::linalg; + +void init_linalg(py::module_& parent_module) { + py::options options; + options.disable_function_signatures(); + + auto m = parent_module.def_submodule( + "linalg", "mlx.core.linalg: linear algebra routines."); + + m.def( + "norm", + [](const array& a, + const std::variant& ord_, + const std::variant>& axis_, + const bool keepdims, + const StreamOrDevice stream) { + std::optional> axis = std::nullopt; + if (auto pv = std::get_if(&axis_); pv) { + axis = std::vector{*pv}; + } else if (auto pv = std::get_if>(&axis_); pv) { + axis = *pv; + } + + if (std::holds_alternative(ord_)) { + return norm(a, axis, keepdims, stream); + } else { + if (auto pv = std::get_if(&ord_); pv) { + return norm(a, *pv, axis, keepdims, stream); + } + double ord; + if (auto pv = std::get_if(&ord_); pv) { + ord = *pv; + } else { + ord = std::get(ord_); + } + return norm(a, ord, axis, keepdims, stream); + } + }, + "a"_a, + py::pos_only(), + "ord"_a = none, + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array + + Matrix or vector norm. + + This function computes vector or matrix norms depending on the value of + the ``ord`` and ``axis`` parameters. + + Args: + a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D, + unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the + 2-norm of ``a.flatten`` will be returned. + ord (scalar or str, optional): Order of the norm (see table under ``Notes``). + If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed + along the given ``axis``. Default: ``None``. + axis (int or list(int), optional): If ``axis`` is an integer, it specifies the + axis of ``a`` along which to compute the vector norms. If ``axis`` is a + 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix + norms of these matrices are computed. If `axis` is ``None`` then + either a vector norm (when ``a`` is 1-D) or a matrix norm (when ``a`` is + 2-D) is returned. Default: ``None``. + keepdims (bool, optional): If ``True``, the axes which are normed over are + left in the result as dimensions with size one. Default ``False``. + + Returns: + array: The output containing the norm(s). + + Notes: + For values of ``ord < 1``, the result is, strictly speaking, not a + mathematical norm, but it may still be useful for various numerical + purposes. + + The following norms can be calculated: + + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- + inf max(sum(abs(x), axis=1)) max(abs(x)) + -inf min(sum(abs(x), axis=1)) min(abs(x)) + 0 -- sum(x != 0) + 1 max(sum(abs(x), axis=0)) as below + -1 min(sum(abs(x), axis=0)) as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + + .. warning:: + Nuclear norm and norms based on singular values are not yet implemented. + + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + The nuclear norm is the sum of the singular values. + + Both the Frobenius and nuclear norm orders are only defined for + matrices and raise a ``ValueError`` when ``a.ndim != 2``. + + References: + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + + Examples: + >>> import mlx.core as mx + >>> from mlx.core import linalg as la + >>> a = mx.arange(9) - 4 + >>> a + array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) + >>> b = a.reshape((3,3)) + >>> b + array([[-4, -3, -2], + [-1, 0, 1], + [ 2, 3, 4]], dtype=int32) + >>> la.norm(a) + array(7.74597, dtype=float32) + >>> la.norm(b) + array(7.74597, dtype=float32) + >>> la.norm(b, 'fro') + array(7.74597, dtype=float32) + >>> la.norm(a, float("inf")) + array(4, dtype=float32) + >>> la.norm(b, float("inf")) + array(9, dtype=float32) + >>> la.norm(a, -float("inf")) + array(0, dtype=float32) + >>> la.norm(b, -float("inf")) + array(2, dtype=float32) + >>> la.norm(a, 1) + array(20, dtype=float32) + >>> la.norm(b, 1) + array(7, dtype=float32) + >>> la.norm(a, -1) + array(0, dtype=float32) + >>> la.norm(b, -1) + array(6, dtype=float32) + >>> la.norm(a, 2) + array(7.74597, dtype=float32) + >>> la.norm(a, 3) + array(5.84804, dtype=float32) + >>> la.norm(a, -3) + array(0, dtype=float32) + >>> c = mx.array([[ 1, 2, 3], + ... [-1, 1, 4]]) + >>> la.norm(c, axis=0) + array([1.41421, 2.23607, 5], dtype=float32) + >>> la.norm(c, axis=1) + array([3.74166, 4.24264], dtype=float32) + >>> la.norm(c, ord=1, axis=1) + array([6, 6], dtype=float32) + >>> m = mx.arange(8).reshape(2,2,2) + >>> la.norm(m, axis=(1,2)) + array([3.74166, 11.225], dtype=float32) + >>> la.norm(m[0, :, :]), LA.norm(m[1, :, :]) + (array(3.74166, dtype=float32), array(11.225, dtype=float32)) + )pbdoc"); +} diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index ebadf767d..d7cf15751 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -15,6 +15,7 @@ void init_ops(py::module_&); void init_transforms(py::module_&); void init_random(py::module_&); void init_fft(py::module_&); +void init_linalg(py::module_&); PYBIND11_MODULE(core, m) { m.doc() = "mlx: A framework for machine learning on Apple silicon."; @@ -29,5 +30,6 @@ PYBIND11_MODULE(core, m) { init_transforms(m); init_random(m); init_fft(m); + init_linalg(m); m.attr("__version__") = TOSTRING(_VERSION_); } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 23a6ec2c6..277ef596b 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2129,7 +2129,7 @@ void init_ops(py::module_& m) { singleton dimensions, defaults to `False`. Returns: - array: The output array with the indices of the minimum values. + array: The output array with the indices of the maximum values. )pbdoc"); m.def( "sort", diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index a592b4458..096d5a486 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -569,7 +569,7 @@ void init_transforms(py::module_& m) { return lvalue # Returns lvalue, dlvalue/dparams - lvalue, grads = mx.value_and_grad(mse) + lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets) def lasso(params, inputs, targets, a=1.0, b=1.0): outputs = forward(params, inputs) @@ -580,7 +580,7 @@ void init_transforms(py::module_& m) { return loss, mse, l1 - (loss, mse, l1), grads = mx.value_and_grad(lasso) + (loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets) Args: fun (function): A function which takes a variable number of diff --git a/python/tests/test_array.py b/python/tests/test_array.py index fb6a24cbc..847d6d142 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -84,6 +84,8 @@ class TestArray(mlx_tests.MLXTestCase): x = mx.array(1) self.assertEqual(x.size, 1) self.assertEqual(x.ndim, 0) + self.assertEqual(x.itemsize, 4) + self.assertEqual(x.nbytes, 4) self.assertEqual(x.shape, []) self.assertEqual(x.dtype, mx.int32) self.assertEqual(x.item(), 1) diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py new file mode 100644 index 000000000..ac86c1e11 --- /dev/null +++ b/python/tests/test_linalg.py @@ -0,0 +1,94 @@ +# Copyright © 2023 Apple Inc. + +import itertools +import math +import unittest + +import mlx.core as mx +import mlx_tests +import numpy as np + + +class TestLinalg(mlx_tests.MLXTestCase): + def test_norm(self): + vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")] + matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")] + + for shape in [(3,), (2, 3), (2, 3, 3)]: + x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) + x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) + # Test when at least one axis is provided + for num_axes in range(1, len(shape)): + if num_axes == 1: + ords = vector_ords + else: + ords = matrix_ords + for axis in itertools.combinations(range(len(shape)), num_axes): + for keepdims in [True, False]: + for o in ords: + out_np = np.linalg.norm( + x_np, ord=o, axis=axis, keepdims=keepdims + ) + out_mx = mx.linalg.norm( + x_mx, ord=o, axis=axis, keepdims=keepdims + ) + with self.subTest( + shape=shape, ord=o, axis=axis, keepdims=keepdims + ): + self.assertTrue( + np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + ) + + # Test only ord provided + for shape in [(3,), (2, 3)]: + x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) + x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) + for o in [None, 1, -1, float("inf"), -float("inf")]: + for keepdims in [True, False]: + out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims) + with self.subTest(shape=shape, ord=o, keepdims=keepdims): + self.assertTrue( + np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + ) + + # Test no ord and no axis provided + for shape in [(3,), (2, 3), (2, 3, 3)]: + x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) + x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) + for keepdims in [True, False]: + out_np = np.linalg.norm(x_np, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, keepdims=keepdims) + with self.subTest(shape=shape, keepdims=keepdims): + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) + + def test_complex_norm(self): + for shape in [(3,), (2, 3), (2, 3, 3)]: + x_np = np.random.uniform(size=shape).astype( + np.float32 + ) + 1j * np.random.uniform(size=shape).astype(np.float32) + x_mx = mx.array(x_np) + out_np = np.linalg.norm(x_np) + out_mx = mx.linalg.norm(x_mx) + with self.subTest(shape=shape): + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) + for num_axes in range(1, len(shape)): + for axis in itertools.combinations(range(len(shape)), num_axes): + out_np = np.linalg.norm(x_np, axis=axis) + out_mx = mx.linalg.norm(x_mx, axis=axis) + with self.subTest(shape=shape, axis=axis): + self.assertTrue( + np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + ) + + x_np = np.random.uniform(size=(4, 4)).astype( + np.float32 + ) + 1j * np.random.uniform(size=(4, 4)).astype(np.float32) + x_mx = mx.array(x_np) + out_np = np.linalg.norm(x_np, ord="fro") + out_mx = mx.linalg.norm(x_mx, ord="fro") + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index ae0d0fc56..070b15523 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -511,6 +511,143 @@ class TestNN(mlx_tests.MLXTestCase): self.assertTrue(x.shape == y.shape) self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) + def test_batch_norm(self): + mx.random.seed(42) + x = mx.random.normal((5, 4), dtype=mx.float32) + + # Batch norm + bn = nn.BatchNorm(num_features=4, affine=True) + self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean))) + self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var))) + y = bn(x) + expected_y = mx.array( + [ + [-0.439520, 1.647328, -0.955515, 1.966031], + [-1.726690, -1.449826, -0.234026, -0.723364], + [0.938414, -0.349603, -0.354470, -0.175369], + [0.305006, 0.234914, -0.393017, -0.459385], + [0.922789, -0.082813, 1.937028, -0.607913], + ], + ) + expected_mean = mx.array([0.008929, 0.005680, -0.016092, 0.027778]) + expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258]) + self.assertTrue(x.shape == y.shape) + self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) + self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5)) + self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5)) + + # test eval mode + bn.eval() + y = bn(x) + expected_y = mx.array( + [ + [-0.15984, 1.73159, -1.25456, 1.57891], + [-0.872193, -1.4281, -0.414439, -0.228678], + [0.602743, -0.30566, -0.554687, 0.139639], + [0.252199, 0.29066, -0.599572, -0.0512532], + [0.594096, -0.0334829, 2.11359, -0.151081], + ] + ) + + self.assertTrue(x.shape == y.shape) + self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) + + # test_no_affine + bn = nn.BatchNorm(num_features=4, affine=False) + y = bn(x) + expected_y = mx.array( + [ + [-0.439520, 1.647328, -0.955515, 1.966031], + [-1.726690, -1.449826, -0.234026, -0.723364], + [0.938414, -0.349603, -0.354470, -0.175369], + [0.305006, 0.234914, -0.393017, -0.459385], + [0.922789, -0.082813, 1.937028, -0.607913], + ] + ) + self.assertTrue(x.shape == y.shape) + self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) + + # test with 3D input + mx.random.seed(42) + N = 2 + L = 4 + C = 5 + x = mx.random.normal((N, L, C), dtype=mx.float32) + + # Batch norm + bn = nn.BatchNorm(num_features=C, affine=True) + self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean))) + self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var))) + y = bn(x) + self.assertTrue(x.shape == y.shape) + expected_y = mx.array( + [ + [ + [-0.335754, 0.342054, 1.02653, 0.628588, -1.63899], + [1.92092, 0.432319, 0.343043, 1.95489, 1.0696], + [-0.853748, 1.3661, 0.868569, 0.0199196, -0.887284], + [0.459206, -0.684822, -0.706354, -0.271531, 0.566341], + ], + [ + [-0.921179, 0.684951, -0.77466, -0.490372, -0.247032], + [1.10839, -2.13179, 0.628924, -1.62639, -0.539708], + [-0.348943, 0.412194, -2.03818, 0.524972, 1.64568], + [-1.02889, -0.421, 0.652127, -0.740079, 0.0313996], + ], + ] + ) + self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) + expected_mean = mx.array( + [[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]] + ) + expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]]) + self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5)) + self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5)) + + x = mx.random.normal((N, L, C, L, C), dtype=mx.float32) + with self.assertRaises(ValueError): + y = bn(x) + + def test_batch_norm_stats(self): + batch_size = 2 + num_features = 4 + h = 3 + w = 3 + momentum = 0.1 + + batch_norm = nn.BatchNorm(num_features) + + batch_norm.train() + running_mean = np.array(batch_norm._running_mean) + running_var = np.array(batch_norm._running_var) + + data = mx.random.normal((batch_size, num_features)) + + normalized_data = batch_norm(data) + np_data = np.array(data) + means = np.mean(np_data, axis=0) + variances = np.var(np_data, axis=0) + running_mean = (1 - momentum) * running_mean + momentum * means + running_var = (1 - momentum) * running_var + momentum * variances + self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)) + self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5)) + + batch_norm = nn.BatchNorm(num_features) + + batch_norm.train() + running_mean = np.array(batch_norm._running_mean) + running_var = np.array(batch_norm._running_var) + data = mx.random.normal((batch_size, h, w, num_features)) + + normalized_data = batch_norm(data) + np_data = np.array(data) + means = np.mean(np_data, axis=(0, 1, 2)) + variances = np.var(np_data, axis=(0, 1, 2)) + running_mean = (1 - momentum) * running_mean + momentum * means + running_var = (1 - momentum) * running_var + momentum * variances + self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)) + self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5)) + def test_conv1d(self): N = 5 L = 12 @@ -772,6 +909,24 @@ class TestNN(mlx_tests.MLXTestCase): y = alibi(x.astype(mx.float16)) self.assertTrue(y.dtype, mx.float16) + def test_hinge_loss(self): + inputs = mx.ones((2, 4)) + targets = mx.zeros((2, 4)) + loss = nn.losses.hinge_loss(inputs, targets, reduction="mean") + self.assertEqual(loss, 1.0) + + def test_huber_loss(self): + inputs = mx.ones((2, 4)) + targets = mx.zeros((2, 4)) + loss = nn.losses.huber_loss(inputs, targets, reduction="mean") + self.assertEqual(loss, 0.5) + + def test_log_cosh_loss(self): + inputs = mx.ones((2, 4)) + targets = mx.zeros((2, 4)) + loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") + self.assertAlmostEqual(loss.item(), 0.433781, places=6) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 5fcc882a5..049f92fdb 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -13,7 +13,8 @@ class TestQuantized(mlx_tests.MLXTestCase): w_q, scales, biases = mx.quantize(w, 64, b) w_hat = mx.dequantize(w_q, scales, biases, 64, b) errors = (w - w_hat).abs().reshape(*scales.shape, -1) - self.assertTrue((errors <= scales[..., None] / 2).all()) + eps = 1e-6 + self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all()) def test_qmm(self): key = mx.random.key(0) diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 1603371b3..aa01339f4 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -58,6 +58,9 @@ class TestRandom(mlx_tests.MLXTestCase): a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5) self.assertTrue(mx.all((a > -1) < 5).item()) + a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16) + self.assertEqual(a.dtype, mx.bfloat16) + def test_normal(self): key = mx.random.key(0) a = mx.random.normal(key=key) diff --git a/setup.py b/setup.py index 4711a87c9..f5d04e959 100644 --- a/setup.py +++ b/setup.py @@ -165,7 +165,7 @@ if __name__ == "__main__": setup( name="mlx", - version=get_version("0.0.5"), + version=get_version("0.0.6"), author="MLX Contributors", author_email="mlx@group.apple.com", description="A framework for machine learning on Apple silicon.", diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0879aa0f6..dbc499205 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -31,6 +31,7 @@ target_sources(tests PRIVATE scheduler_tests.cpp utils_tests.cpp vmap_tests.cpp + linalg_tests.cpp ${METAL_TEST_SOURCES} ) diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp new file mode 100644 index 000000000..1d8ee43d9 --- /dev/null +++ b/tests/linalg_tests.cpp @@ -0,0 +1,250 @@ +// Copyright © 2023 Apple Inc. + +#include "doctest/doctest.h" + +#include + +#include "mlx/mlx.h" + +using namespace mlx::core; +using namespace mlx::core::linalg; + +TEST_CASE("[mlx.core.linalg.norm] no ord") { + // Zero dimensions + array x(2.0); + CHECK_EQ(norm(x).item(), 2.0f); + CHECK_THROWS(norm(x, 0)); + + x = array({1, 2, 3}); + float expected = std::sqrt(1 + 4 + 9); + CHECK_EQ(norm(x).item(), doctest::Approx(expected)); + CHECK_EQ(norm(x, 0, false).item(), doctest::Approx(expected)); + CHECK_EQ(norm(x, -1, false).item(), doctest::Approx(expected)); + CHECK_EQ(norm(x, -1, true).ndim(), 1); + CHECK_THROWS(norm(x, 1)); + + x = reshape(arange(9), {3, 3}); + expected = + std::sqrt(0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8); + + CHECK_EQ(norm(x).item(), doctest::Approx(expected)); + CHECK_EQ( + norm(x, std::vector{0, 1}).item(), doctest::Approx(expected)); + CHECK(array_equal( + norm(x, 0, false), + array( + {std::sqrt(0 + 3 * 3 + 6 * 6), + std::sqrt(1 + 4 * 4 + 7 * 7), + std::sqrt(2 * 2 + 5 * 5 + 8 * 8)})) + .item()); + CHECK(allclose( + norm(x, 1, false), + array( + {std::sqrt(0 + 1 + 2 * 2), + std::sqrt(3 * 3 + 4 * 4 + 5 * 5), + std::sqrt(6 * 6 + 7 * 7 + 8 * 8)})) + .item()); + + x = reshape(arange(18), {2, 3, 3}); + CHECK(allclose( + norm(x, 2, false), + array( + { + std::sqrt(0 + 1 + 2 * 2), + std::sqrt(3 * 3 + 4 * 4 + 5 * 5), + std::sqrt(6 * 6 + 7 * 7 + 8 * 8), + std::sqrt(9 * 9 + 10 * 10 + 11 * 11), + std::sqrt(12 * 12 + 13 * 13 + 14 * 14), + std::sqrt(15 * 15 + 16 * 16 + 17 * 17), + }, + {2, 3})) + .item()); + CHECK(allclose( + norm(x, std::vector{1, 2}, false), + array( + {std::sqrt( + 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + + 8 * 8), + std::sqrt( + 9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 + + 15 * 15 + 16 * 16 + 17 * 17)}, + {2})) + .item()); + CHECK_THROWS(norm(x, std::vector{0, 1, 2})); +} + +TEST_CASE("[mlx.core.linalg.norm] double ord") { + CHECK_THROWS(norm(array(0), 2.0)); + + array x({1, 2, 3}); + + float expected = std::sqrt(1 + 4 + 9); + CHECK_EQ(norm(x, 2.0).item(), doctest::Approx(expected)); + CHECK_EQ(norm(x, 2.0, 0).item(), doctest::Approx(expected)); + CHECK_THROWS(norm(x, 2.0, 1)); + + expected = 1 + 2 + 3; + CHECK_EQ(norm(x, 1.0).item(), doctest::Approx(expected)); + + expected = 3; + CHECK_EQ(norm(x, 0.0).item(), doctest::Approx(expected)); + + expected = 3; + CHECK_EQ( + norm(x, std::numeric_limits::infinity()).item(), + doctest::Approx(expected)); + + expected = 1; + CHECK_EQ( + norm(x, -std::numeric_limits::infinity()).item(), + doctest::Approx(expected)); + + x = reshape(arange(9), {3, 3}); + + CHECK(allclose( + norm(x, 2.0, 0, false), + array( + {std::sqrt(0 + 3 * 3 + 6 * 6), + std::sqrt(1 + 4 * 4 + 7 * 7), + std::sqrt(2 * 2 + 5 * 5 + 8 * 8)})) + .item()); + CHECK(allclose( + norm(x, 2.0, 1, false), + array( + {sqrt(0 + 1 + 2 * 2), + sqrt(3 * 3 + 4 * 4 + 5 * 5), + sqrt(6 * 6 + 7 * 7 + 8 * 8)})) + .item()); + + CHECK_EQ( + norm(x, 1.0, std::vector{0, 1}).item(), + doctest::Approx(15.0)); + CHECK_EQ( + norm(x, 1.0, std::vector{1, 0}).item(), + doctest::Approx(21.0)); + CHECK_EQ( + norm(x, -1.0, std::vector{0, 1}).item(), + doctest::Approx(9.0)); + CHECK_EQ( + norm(x, -1.0, std::vector{1, 0}).item(), + doctest::Approx(3.0)); + CHECK_EQ( + norm(x, 1.0, std::vector{0, 1}, true).shape(), + std::vector{1, 1}); + CHECK_EQ( + norm(x, 1.0, std::vector{1, 0}, true).shape(), + std::vector{1, 1}); + CHECK_EQ( + norm(x, -1.0, std::vector{0, 1}, true).shape(), + std::vector{1, 1}); + CHECK_EQ( + norm(x, -1.0, std::vector{1, 0}, true).shape(), + std::vector{1, 1}); + + CHECK_EQ( + norm(x, -1.0, std::vector{-2, -1}, false).item(), + doctest::Approx(9.0)); + CHECK_EQ( + norm(x, 1.0, std::vector{-2, -1}, false).item(), + doctest::Approx(15.0)); + + x = reshape(arange(18), {2, 3, 3}); + CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2})); + CHECK(allclose( + norm(x, 3.0, 0), + array( + {9., + 10.00333222, + 11.02199456, + 12.06217728, + 13.12502645, + 14.2094363, + 15.31340617, + 16.43469751, + 17.57113899}, + {3, 3})) + .item()); + CHECK(allclose( + norm(x, 3.0, 2), + array( + {2.08008382, + 6., + 10.23127655, + 14.5180117, + 18.82291607, + 23.13593104}, + {2, 3})) + .item()); + CHECK( + allclose( + norm(x, 0.0, 0), array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3})) + .item()); + CHECK(allclose(norm(x, 0.0, 1), array({2., 3., 3., 3., 3., 3.}, {2, 3})) + .item()); + CHECK(allclose(norm(x, 0.0, 2), array({2., 3., 3., 3., 3., 3.}, {2, 3})) + .item()); + CHECK(allclose( + norm(x, 1.0, 0), + array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3})) + .item()); + CHECK(allclose(norm(x, 1.0, 1), array({9., 12., 15., 36., 39., 42.}, {2, 3})) + .item()); + CHECK(allclose(norm(x, 1.0, 2), array({3., 12., 21., 30., 39., 48.}, {2, 3})) + .item()); + + CHECK(allclose(norm(x, 1.0, std::vector{0, 1}), array({21., 23., 25.})) + .item()); + CHECK(allclose(norm(x, 1.0, std::vector{1, 2}), array({15., 42.})) + .item()); + CHECK(allclose(norm(x, -1.0, std::vector{0, 1}), array({9., 11., 13.})) + .item()); + CHECK(allclose(norm(x, -1.0, std::vector{1, 2}), array({9., 36.})) + .item()); + CHECK(allclose(norm(x, -1.0, std::vector{1, 0}), array({9., 12., 15.})) + .item()); + CHECK(allclose(norm(x, -1.0, std::vector{2, 1}), array({3, 30})) + .item()); + CHECK(allclose(norm(x, -1.0, std::vector{1, 2}), array({9, 36})) + .item()); +} + +TEST_CASE("[mlx.core.linalg.norm] string ord") { + array x({1, 2, 3}); + CHECK_THROWS(norm(x, "fro")); + + x = reshape(arange(9), {3, 3}); + CHECK_THROWS(norm(x, "bad ord")); + + CHECK_EQ( + norm(x, "f", std::vector{0, 1}).item(), + doctest::Approx(14.2828568570857)); + CHECK_EQ( + norm(x, "fro", std::vector{0, 1}).item(), + doctest::Approx(14.2828568570857)); + + x = reshape(arange(18), {2, 3, 3}); + CHECK(allclose( + norm(x, "fro", std::vector{0, 1}), + array({22.24859546, 24.31049156, 26.43860813})) + .item()); + CHECK(allclose( + norm(x, "fro", std::vector{1, 2}), + array({14.28285686, 39.7617907})) + .item()); + CHECK(allclose( + norm(x, "f", std::vector{0, 1}), + array({22.24859546, 24.31049156, 26.43860813})) + .item()); + CHECK(allclose( + norm(x, "f", std::vector{1, 0}), + array({22.24859546, 24.31049156, 26.43860813})) + .item()); + CHECK(allclose( + norm(x, "f", std::vector{1, 2}), + array({14.28285686, 39.7617907})) + .item()); + CHECK(allclose( + norm(x, "f", std::vector{2, 1}), + array({14.28285686, 39.7617907})) + .item()); +} diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index 1a387febc..b7793e41c 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -260,6 +260,10 @@ TEST_CASE("test random uniform") { // Non float type throws CHECK_THROWS_AS(random::uniform({}, int32), std::invalid_argument); + // dtype respected + x = random::uniform(-.1, .1, {0}, bfloat16); + CHECK_EQ(x.dtype(), bfloat16); + // Check broadcasting x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3}); CHECK_EQ(x.shape(), std::vector{3, 3});