diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 35bba9c635..08df53a8ed 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -1,9 +1,16 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/common/utils.h" +#include "mlx/primitives.h" namespace mlx::core { +std::string get_primitive_string(Primitive* primitive) { + std::ostringstream op_t; + primitive->print(op_t); + return op_t.str(); +} + std::tuple> collapse_contiguous_dims( const Shape& shape, const std::vector& strides, @@ -101,4 +108,105 @@ std::pair collapse_contiguous_dims( return collapse_contiguous_dims(a.shape(), a.strides(), size_cap); } +Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) { + int pows[3] = {0, 0, 0}; + int sum = 0; + while (true) { + int presum = sum; + // Check all the pows + if (dim0 >= (1 << (pows[0] + 1))) { + pows[0]++; + sum++; + } + if (sum == 10) { + break; + } + if (dim1 >= (1 << (pows[1] + 1))) { + pows[1]++; + sum++; + } + if (sum == 10) { + break; + } + if (dim2 >= (1 << (pows[2] + 1))) { + pows[2]++; + sum++; + } + if (sum == presum || sum == pow2) { + break; + } + } + return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]); +} + +Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) { + // Dims with strides of 0 are ignored as they + // correspond to broadcasted dimensions + size_t grid_x = 1; + size_t grid_y = 1; + for (int i = 0; i < shape.size(); ++i) { + if (strides[i] == 0) { + continue; + } + if (grid_x * shape[i] < UINT32_MAX) { + grid_x *= shape[i]; + } else { + grid_y *= shape[i]; + } + } + if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { + throw std::runtime_error("Unable to safely factor shape."); + } + if (grid_y > grid_x) { + std::swap(grid_x, grid_y); + } + return std::make_tuple( + static_cast(grid_x), static_cast(grid_y), 1); +} + +Dims get_2d_grid_dims_common( + const Shape& shape, + const Strides& strides, + size_t divisor) { + // Compute the 2d grid dimensions such that the total size of the grid is + // divided by divisor. + size_t grid_x = 1; + size_t grid_y = 1; + for (int i = 0; i < shape.size(); ++i) { + if (strides[i] == 0) { + continue; + } + + // No need to add this shape we can just remove it from the divisor. + if (divisor % shape[i] == 0) { + divisor /= shape[i]; + continue; + } + + if (grid_x * shape[i] < UINT32_MAX) { + grid_x *= shape[i]; + } else { + grid_y *= shape[i]; + } + + if (divisor > 1) { + if (grid_x % divisor == 0) { + grid_x /= divisor; + divisor = 1; + } else if (grid_y % divisor == 0) { + grid_y /= divisor; + divisor = 1; + } + } + } + if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) { + throw std::runtime_error("Unable to safely factor shape."); + } + if (grid_y > grid_x) { + std::swap(grid_x, grid_y); + } + return std::make_tuple( + static_cast(grid_x), static_cast(grid_y), 1); +} + } // namespace mlx::core diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index a4bdaa5ca1..40bc3efe4a 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -2,12 +2,15 @@ #pragma once +#include #include #include "mlx/array.h" namespace mlx::core { +std::string get_primitive_string(Primitive* primitive); + inline int64_t elem_to_loc(int elem, const Shape& shape, const Strides& strides) { int64_t loc = 0; @@ -70,6 +73,28 @@ std::pair collapse_contiguous_dims( const array& a, int64_t size_cap = std::numeric_limits::max()); +// Compute the thread block dimensions which fit the given +// input dimensions. +// - The thread block dimensions will be powers of two +// - The thread block size will be less than 2^pow2 +using Dims = std::tuple; +Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10); + +// Computes a 2D grid where each element is < UINT_MAX +// Assumes: +// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2 +// - shape and strides correspond to a contiguous (no holes) but +// possibly broadcasted array +Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides); + +// Same as above but we do an implicit division with divisor. +// Basically, equivalent to factorizing +// Prod(s \forall s in shape if strides[s] > 0) / divisor. +Dims get_2d_grid_dims_common( + const Shape& shape, + const Strides& strides, + size_t divisor); + struct ContiguousIterator { inline void step() { int dims = shape_.size(); diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 7ebe68324c..2a8ef99635 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -11,6 +11,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cu + ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp diff --git a/mlx/backend/cuda/kernel_utils.cu b/mlx/backend/cuda/kernel_utils.cu new file mode 100644 index 0000000000..575af7cf65 --- /dev/null +++ b/mlx/backend/cuda/kernel_utils.cu @@ -0,0 +1,26 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/kernel_utils.cuh" + +namespace mlx::core { + +dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2) { + Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2); + return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); +} + +dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides) { + Dims dims = get_2d_grid_dims_common(shape, strides); + return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); +} + +dim3 get_2d_grid_dims( + const Shape& shape, + const Strides& strides, + size_t divisor) { + Dims dims = get_2d_grid_dims_common(shape, strides, divisor); + return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/dtype_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh similarity index 53% rename from mlx/backend/cuda/dtype_utils.cuh rename to mlx/backend/cuda/kernel_utils.cuh index 9b7f8ba652..67ac47449e 100644 --- a/mlx/backend/cuda/dtype_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -1,7 +1,13 @@ // Copyright © 2025 Apple Inc. +// This file includes host-only utilies for writing CUDA kernels, the difference +// from backend/cuda/kernels/utils.cuh is that the latter file only include +// device-only code. + #pragma once +#include "mlx/array.h" + #include #include #include @@ -32,4 +38,12 @@ struct CTypeToCudaType { template using cuda_type_t = typename CTypeToCudaType::type; +// Compute the grid and block dimensions, check backend/common/utils.h for docs. +dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); +dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides); +dim3 get_2d_grid_dims( + const Shape& shape, + const Strides& strides, + size_t divisor); + } // namespace mlx::core diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index defdc746a1..d105a242b7 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -1,7 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/dtype_utils.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernels/arange.cuh" #include "mlx/backend/cuda/kernels/fp16_math.cuh" #include "mlx/distributed/primitives.h" diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index 58d5087657..6eaec8984f 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -1,5 +1,7 @@ // Copyright © 2025 Apple Inc. +// This file include utilies that are used by C++ code (i.e. .cpp files). + #pragma once #include diff --git a/mlx/backend/metal/utils.cpp b/mlx/backend/metal/utils.cpp index 329d250dfc..9785018357 100644 --- a/mlx/backend/metal/utils.cpp +++ b/mlx/backend/metal/utils.cpp @@ -1,8 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/utils.h" - -using namespace mlx; +#include "mlx/backend/common/utils.h" namespace mlx::core { @@ -59,109 +58,20 @@ std::string type_to_name(const array& a) { return type_to_name(a.dtype()); } -MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) { - int pows[3] = {0, 0, 0}; - int sum = 0; - while (true) { - int presum = sum; - // Check all the pows - if (dim0 >= (1 << (pows[0] + 1))) { - pows[0]++; - sum++; - } - if (sum == 10) { - break; - } - if (dim1 >= (1 << (pows[1] + 1))) { - pows[1]++; - sum++; - } - if (sum == 10) { - break; - } - if (dim2 >= (1 << (pows[2] + 1))) { - pows[2]++; - sum++; - } - if (sum == presum || sum == pow2) { - break; - } - } - return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]}; +MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2) { + Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2); + return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); } MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides) { - // Dims with strides of 0 are ignored as they - // correspond to broadcasted dimensions - size_t grid_x = 1; - size_t grid_y = 1; - for (int i = 0; i < shape.size(); ++i) { - if (strides[i] == 0) { - continue; - } - if (grid_x * shape[i] < UINT32_MAX) { - grid_x *= shape[i]; - } else { - grid_y *= shape[i]; - } - } - if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { - throw std::runtime_error("Unable to safely factor shape."); - } - if (grid_y > grid_x) { - std::swap(grid_x, grid_y); - } - return MTL::Size( - static_cast(grid_x), static_cast(grid_y), 1); + Dims dims = get_2d_grid_dims_common(shape, strides); + return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); } MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor) { - // Compute the 2d grid dimensions such that the total size of the grid is - // divided by divisor. - size_t grid_x = 1; - size_t grid_y = 1; - for (int i = 0; i < shape.size(); ++i) { - if (strides[i] == 0) { - continue; - } - - // No need to add this shape we can just remove it from the divisor. - if (divisor % shape[i] == 0) { - divisor /= shape[i]; - continue; - } - - if (grid_x * shape[i] < UINT32_MAX) { - grid_x *= shape[i]; - } else { - grid_y *= shape[i]; - } - - if (divisor > 1) { - if (grid_x % divisor == 0) { - grid_x /= divisor; - divisor = 1; - } else if (grid_y % divisor == 0) { - grid_y /= divisor; - divisor = 1; - } - } - } - if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) { - throw std::runtime_error("Unable to safely factor shape."); - } - if (grid_y > grid_x) { - std::swap(grid_x, grid_y); - } - return MTL::Size( - static_cast(grid_x), static_cast(grid_y), 1); -} - -std::string get_primitive_string(Primitive* primitive) { - std::ostringstream op_t; - primitive->print(op_t); - return op_t.str(); + Dims dims = get_2d_grid_dims_common(shape, strides, divisor); + return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); } } // namespace mlx::core diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index f9245a6d60..576fb91078 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -13,22 +13,9 @@ namespace mlx::core { std::string type_to_name(const Dtype& t); std::string type_to_name(const array& a); -// Compute the thread block dimensions which fit the given -// input dimensions. -// - The thread block dimensions will be powers of two -// - The thread block size will be less than 2^pow2 +// Compute the grid and block dimensions, check backend/common/utils.h for docs. MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); - -// Computes a 2D grid where each element is < UINT_MAX -// Assumes: -// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2 -// - shape and strides correspond to a contiguous (no holes) but -// possibly broadcasted array MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides); - -// Same as above but we do an implicit division with divisor. -// Basically, equivalent to factorizing -// Prod(s \forall s in shape if strides[s] > 0) / divisor. MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor); @@ -58,8 +45,6 @@ inline void debug_set_primitive_buffer_label( #endif } -std::string get_primitive_string(Primitive* primitive); - template constexpr bool is_numeric_except_char = std::is_arithmetic_v && !std::is_same_v && !std::is_same_v &&