diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index cf2e1c83f..1c0dac55e 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -21,10 +21,43 @@ namespace mlx::core { constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5; +std::string get_kernel_name( + BinaryOpType bopt, + const std::string& op, + const array& a, + bool use_2d, + int ndim) { + std::ostringstream kname; + switch (bopt) { + case BinaryOpType::ScalarScalar: + kname << "ss"; + break; + case BinaryOpType::ScalarVector: + kname << (use_2d ? "sv2" : "sv"); + break; + case BinaryOpType::VectorScalar: + kname << (use_2d ? "vs2" : "vs"); + break; + case BinaryOpType::VectorVector: + kname << (use_2d ? "vv2" : "vv"); + break; + case BinaryOpType::General: + kname << "g"; + if (ndim <= MAX_BINARY_SPECIALIZED_DIMS) { + kname << ndim; + } else { + kname << "n"; + } + break; + } + kname << op << type_to_name(a); + return kname.str(); +} + void binary_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, - const std::string op, + const std::string& op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; @@ -41,35 +74,8 @@ void binary_op_gpu_inplace( auto& strides_b = strides[1]; auto& strides_out = strides[2]; - std::string kernel_name; - { - std::ostringstream kname; - switch (bopt) { - case BinaryOpType::ScalarScalar: - kname << "ss"; - break; - case BinaryOpType::ScalarVector: - kname << "sv"; - break; - case BinaryOpType::VectorScalar: - kname << "vs"; - break; - case BinaryOpType::VectorVector: - kname << "vv"; - break; - case BinaryOpType::General: - kname << "g"; - if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { - kname << shape.size(); - } else { - kname << "n"; - } - break; - } - kname << op << type_to_name(a); - kernel_name = kname.str(); - } - + bool use_2d = out.data_size() > UINT32_MAX; + std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size()); auto& d = metal::device(s.device); auto kernel = @@ -117,9 +123,11 @@ void binary_op_gpu_inplace( MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatchThreads(grid_dims, group_dims); } else { - // Launch a 1D grid of threads + // Launch a 1D or 2D grid of threads size_t nthreads = out.data_size(); - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims = use_2d + ? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) + : MTL::Size(nthreads, 1, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size > nthreads) { thread_group_size = nthreads; @@ -132,7 +140,7 @@ void binary_op_gpu_inplace( void binary_op_gpu( const std::vector& inputs, std::vector& outputs, - const std::string op, + const std::string& op, const Stream& s) { assert(inputs.size() == 2); auto& a = inputs[0]; @@ -146,7 +154,7 @@ void binary_op_gpu( void binary_op_gpu( const std::vector& inputs, std::vector& outputs, - const std::string op) { + const std::string& op) { auto& s = outputs[0].primitive().stream(); binary_op_gpu(inputs, outputs, op, s); } @@ -154,7 +162,7 @@ void binary_op_gpu( void binary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string op, + const std::string& op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; @@ -169,35 +177,8 @@ void binary_op_gpu_inplace( auto& strides_b = strides[1]; auto& strides_out = strides[2]; - std::string kernel_name; - { - std::ostringstream kname; - switch (bopt) { - case BinaryOpType::ScalarScalar: - kname << "ss"; - break; - case BinaryOpType::ScalarVector: - kname << "sv"; - break; - case BinaryOpType::VectorScalar: - kname << "vs"; - break; - case BinaryOpType::VectorVector: - kname << "vv"; - break; - case BinaryOpType::General: - kname << "g"; - if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { - kname << shape.size(); - } else { - kname << "n"; - } - break; - } - kname << op << type_to_name(a); - kernel_name = kname.str(); - } - + bool use_2d = out.data_size() > UINT32_MAX; + std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size()); auto& d = metal::device(s.device); auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op); @@ -237,10 +218,11 @@ void binary_op_gpu_inplace( MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatchThreads(grid_dims, group_dims); } else { - // Launch a 1D grid of threads - size_t nthreads = - bopt == BinaryOpType::General ? out.size() : out.data_size(); - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + // Launch a 1D or 2D grid of threads + + size_t nthreads = out.data_size(); + MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) + : MTL::Size(nthreads, 1, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size > nthreads) { thread_group_size = nthreads; @@ -253,7 +235,7 @@ void binary_op_gpu_inplace( void binary_op_gpu( const std::vector& inputs, array& out, - const std::string op, + const std::string& op, const Stream& s) { assert(inputs.size() == 2); auto& a = inputs[0]; @@ -266,7 +248,7 @@ void binary_op_gpu( void binary_op_gpu( const std::vector& inputs, array& out, - const std::string op) { + const std::string& op) { auto& s = out.primitive().stream(); binary_op_gpu(inputs, out, op, s); } diff --git a/mlx/backend/metal/binary.h b/mlx/backend/metal/binary.h index 40d414291..8552c1e07 100644 --- a/mlx/backend/metal/binary.h +++ b/mlx/backend/metal/binary.h @@ -9,25 +9,25 @@ namespace mlx::core { void binary_op_gpu( const std::vector& inputs, std::vector& outputs, - const std::string op, + const std::string& op, const Stream& s); void binary_op_gpu( const std::vector& inputs, array& out, - const std::string op, + const std::string& op, const Stream& s); void binary_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, - const std::string op, + const std::string& op, const Stream& s); void binary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string op, + const std::string& op, const Stream& s); } // namespace mlx::core diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index f2f98da06..b57d254e8 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -64,16 +64,17 @@ void copy_gpu_inplace( auto& strides_in_ = strides[0]; auto& strides_out_ = strides[1]; + bool use_2d = out.data_size() > UINT32_MAX; auto& d = metal::device(s.device); std::string kernel_name; { std::ostringstream kname; switch (ctype) { case CopyType::Scalar: - kname << "s"; + kname << (use_2d ? "s2" : "s"); break; case CopyType::Vector: - kname << "v"; + kname << (use_2d ? "v2" : "v"); break; case CopyType::General: kname << "g"; @@ -139,7 +140,8 @@ void copy_gpu_inplace( compute_encoder.dispatchThreads(grid_dims, group_dims); } else { size_t nthreads = out.data_size(); - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) + : MTL::Size(nthreads, 1, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size > nthreads) { thread_group_size = nthreads; diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index ca55bdebf..2e668621b 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -36,6 +36,39 @@ template c[index] = Op()(a[index], b[index]); } +template +[[kernel]] void binary_sv2( + device const T* a, + device const T* b, + device U* c, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + c[offset] = Op()(a[0], b[offset]); +} + +template +[[kernel]] void binary_vs2( + device const T* a, + device const T* b, + device U* c, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + c[offset] = Op()(a[offset], b[0]); +} + +template +[[kernel]] void binary_vv2( + device const T* a, + device const T* b, + device U* c, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + c[offset] = Op()(a[offset], b[offset]); +} + template [[kernel]] void binary_g_nd1( device const T* a, diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 11e4f7a6d..2c302c20b 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -14,6 +14,9 @@ instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \ + instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \ + instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \ instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \ diff --git a/mlx/backend/metal/kernels/binary_two.h b/mlx/backend/metal/kernels/binary_two.h index 3890adbce..08ff876ca 100644 --- a/mlx/backend/metal/kernels/binary_two.h +++ b/mlx/backend/metal/kernels/binary_two.h @@ -48,6 +48,48 @@ template d[index] = out[1]; } +template +[[kernel]] void binary_sv2( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + auto out = Op()(a[0], b[offset]); + c[offset] = out[0]; + d[offset] = out[1]; +} + +template +[[kernel]] void binary_vs2( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + auto out = Op()(a[offset], b[0]); + c[offset] = out[0]; + d[offset] = out[1]; +} + +template +[[kernel]] void binary_vv2( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + auto out = Op()(a[offset], b[offset]); + c[offset] = out[0]; + d[offset] = out[1]; +} + template [[kernel]] void binary_g_nd1( device const T* a, diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index 275f6a0d8..fb1bd785b 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -12,6 +12,9 @@ instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \ + instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \ + instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \ instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \ diff --git a/mlx/backend/metal/kernels/copy.h b/mlx/backend/metal/kernels/copy.h index 451b7bb4c..6ba5ed741 100644 --- a/mlx/backend/metal/kernels/copy.h +++ b/mlx/backend/metal/kernels/copy.h @@ -16,6 +16,26 @@ template dst[index] = static_cast(src[index]); } +template +[[kernel]] void copy_s2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + dst[offset] = static_cast(src[0]); +} + +template +[[kernel]] void copy_v2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + dst[offset] = static_cast(src[offset]); +} + template [[kernel]] void copy_g_nd1( device const T* src [[buffer(0)]], diff --git a/mlx/backend/metal/kernels/copy.metal b/mlx/backend/metal/kernels/copy.metal index df21c75e0..a121197e5 100644 --- a/mlx/backend/metal/kernels/copy.metal +++ b/mlx/backend/metal/kernels/copy.metal @@ -5,95 +5,23 @@ #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/copy.h" -#define instantiate_copy(name, itype, otype, ctype) \ - template [[host_name(name)]] [[kernel]] void copy_##ctype( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - uint index [[thread_position_in_grid]]); - -#define instantiate_copy_g_dim(name, itype, otype, dims) \ - template [[host_name("g" #dims "_" name)]] [[kernel]] void \ - copy_g_nd( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int* src_shape [[buffer(2)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); \ - template [[host_name("gg" #dims "_" name)]] [[kernel]] void \ - copy_gg_nd( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int* src_shape [[buffer(2)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - constant const int64_t* dst_strides [[buffer(4)]], \ - uint3 index [[thread_position_in_grid]]); - -#define instantiate_copy_g_nd(name, itype, otype) \ - template [[host_name("g1_" name)]] [[kernel]] void copy_g_nd1( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t& src_stride [[buffer(3)]], \ - uint index [[thread_position_in_grid]]); \ - template [[host_name("g2_" name)]] [[kernel]] void copy_g_nd2( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); \ - template [[host_name("g3_" name)]] [[kernel]] void copy_g_nd3( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); \ - template [[host_name("gg1_" name )]] [[kernel]] void \ - copy_gg_nd1( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t& src_stride [[buffer(3)]], \ - constant const int64_t& dst_stride [[buffer(4)]], \ - uint index [[thread_position_in_grid]]); \ - template [[host_name("gg2_" name)]] [[kernel]] void \ - copy_gg_nd2( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - constant const int64_t* dst_strides [[buffer(4)]], \ - uint2 index [[thread_position_in_grid]]); \ - template [[host_name("gg3_" name)]] [[kernel]] void \ - copy_gg_nd3( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - constant const int64_t* dst_strides [[buffer(4)]], \ - uint3 index [[thread_position_in_grid]]); \ - instantiate_copy_g_dim(name, itype, otype, 4) \ - instantiate_copy_g_dim(name, itype, otype, 5) - -#define instantiate_copy_g(name, itype, otype) \ - template [[host_name("g_" name)]] [[kernel]] void copy_g( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int* src_shape [[buffer(2)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - constant const int& ndim [[buffer(5)]], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); \ - template [[host_name("gg_" name)]] [[kernel]] void copy_gg( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int* src_shape [[buffer(2)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - constant const int64_t* dst_strides [[buffer(4)]], \ - constant const int& ndim [[buffer(5)]], \ - uint3 index [[thread_position_in_grid]]); - #define instantiate_copy_all(tname, itype, otype) \ - instantiate_copy("s_copy" #tname, itype, otype, s) \ - instantiate_copy("v_copy" #tname, itype, otype, v) \ - instantiate_copy_g("copy" #tname, itype, otype) \ - instantiate_copy_g_nd("copy" #tname, itype, otype) + instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \ + instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \ + instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \ + instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \ + instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \ + instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \ + instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \ + instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \ + instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \ + instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \ + instantiate_kernel("g4_copy" #tname, copy_g_nd, itype, otype, 4) \ + instantiate_kernel("g5_copy" #tname, copy_g_nd, itype, otype, 5) \ + instantiate_kernel("gg4_copy" #tname, copy_gg_nd, itype, otype, 4) \ + instantiate_kernel("gg5_copy" #tname, copy_gg_nd, itype, otype, 5) \ + instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \ + instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype) #define instantiate_copy_itype(itname, itype) \ instantiate_copy_all(itname ##bool_, itype, bool) \ diff --git a/mlx/backend/metal/kernels/layer_norm.metal b/mlx/backend/metal/kernels/layer_norm.metal index 4c0dc7346..53dc89eb0 100644 --- a/mlx/backend/metal/kernels/layer_norm.metal +++ b/mlx/backend/metal/kernels/layer_norm.metal @@ -34,7 +34,7 @@ template threadgroup float local_mean[1]; threadgroup float local_normalizer[1]; - x += gid * axis_size + lid * N_READS; + x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; b += b_stride * lid * N_READS; @@ -89,7 +89,7 @@ template float normalizer = local_normalizer[0]; // Write the outputs - out += gid * axis_size + lid * N_READS; + out += gid * size_t(axis_size) + lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { thread_x[i] = (thread_x[i] - mean) * normalizer; @@ -131,7 +131,7 @@ template threadgroup float local_mean[1]; threadgroup float local_normalizer[1]; - x += gid * axis_size + lid * N_READS; + x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; b += b_stride * lid * N_READS; @@ -188,7 +188,7 @@ template float normalizer = local_normalizer[0]; // Write the outputs - out += gid * axis_size + lid * N_READS; + out += gid * size_t(axis_size) + lid * N_READS; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { @@ -223,8 +223,8 @@ template uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { // Advance the input pointers - x += gid * axis_size + lid * N_READS; - g += gid * axis_size + lid * N_READS; + x += gid * size_t(axis_size) + lid * N_READS; + g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; // Allocate registers for the computation and accumulators @@ -321,8 +321,8 @@ template float normalizer2 = normalizer * normalizer; // Write the outputs - gx += gid * axis_size + lid * N_READS; - gw += gid * axis_size + lid * N_READS; + gx += gid * size_t(axis_size) + lid * N_READS; + gw += gid * size_t(axis_size) + lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { thread_x[i] = (thread_x[i] - mean) * normalizer; @@ -360,8 +360,8 @@ template uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { // Advance the input pointers - x += gid * axis_size + lid * N_READS; - g += gid * axis_size + lid * N_READS; + x += gid * size_t(axis_size) + lid * N_READS; + g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; // Allocate registers for the accumulators @@ -457,8 +457,8 @@ template float normalizer2 = normalizer * normalizer; // Write the outputs - gx += gid * axis_size + lid * N_READS; - gw += gid * axis_size + lid * N_READS; + gx += gid * size_t(axis_size) + lid * N_READS; + gw += gid * size_t(axis_size) + lid * N_READS; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { diff --git a/mlx/backend/metal/kernels/rms_norm.metal b/mlx/backend/metal/kernels/rms_norm.metal index b38aaa4da..c79bbae10 100644 --- a/mlx/backend/metal/kernels/rms_norm.metal +++ b/mlx/backend/metal/kernels/rms_norm.metal @@ -24,7 +24,7 @@ template uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { float acc = 0; - x += gid * axis_size + lid * N_READS; + x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { @@ -62,7 +62,7 @@ template threadgroup_barrier(mem_flags::mem_threadgroup); // Write the outputs - out += gid * axis_size + lid * N_READS; + out += gid * size_t(axis_size) + lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { out[i] = w[w_stride * i] * static_cast(x[i] * local_inv_mean[0]); @@ -92,7 +92,7 @@ template uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { float acc = 0; - x += gid * axis_size + lid * N_READS; + x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { @@ -132,7 +132,7 @@ template threadgroup_barrier(mem_flags::mem_threadgroup); // Write the outputs - out += gid * axis_size + lid * N_READS; + out += gid * size_t(axis_size) + lid * N_READS; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { @@ -165,8 +165,8 @@ template uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { // Advance the input pointers - x += gid * axis_size + lid * N_READS; - g += gid * axis_size + lid * N_READS; + x += gid * size_t(axis_size) + lid * N_READS; + g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; // Allocate registers for the computation and accumulators @@ -233,8 +233,8 @@ template float normalizer3 = normalizer * normalizer * normalizer; // Write the outputs - gx += gid * axis_size + lid * N_READS; - gw += gid * axis_size + lid * N_READS; + gx += gid * size_t(axis_size) + lid * N_READS; + gw += gid * size_t(axis_size) + lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { gx[i] = static_cast( @@ -270,8 +270,8 @@ template uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { // Advance the input pointers - x += gid * axis_size + lid * N_READS; - g += gid * axis_size + lid * N_READS; + x += gid * size_t(axis_size) + lid * N_READS; + g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; // Allocate registers for the accumulators @@ -337,8 +337,8 @@ template float normalizer3 = normalizer * normalizer * normalizer; // Write the outputs - gx += gid * axis_size + lid * N_READS; - gw += gid * axis_size + lid * N_READS; + gx += gid * size_t(axis_size) + lid * N_READS; + gw += gid * size_t(axis_size) + lid * N_READS; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { diff --git a/mlx/backend/metal/kernels/softmax.h b/mlx/backend/metal/kernels/softmax.h index 031ec128e..a455dabd2 100644 --- a/mlx/backend/metal/kernels/softmax.h +++ b/mlx/backend/metal/kernels/softmax.h @@ -25,7 +25,7 @@ template AccT ld[N_READS]; - in += gid * axis_size + lid * N_READS; + in += gid * size_t(axis_size) + lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { ld[i] = AccT(in[i]); @@ -83,7 +83,7 @@ template normalizer = 1 / local_normalizer[0]; // Normalize and write to the output - out += gid * axis_size + lid * N_READS; + out += gid * size_t(axis_size) + lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { out[i] = T(ld[i] * normalizer); @@ -107,7 +107,7 @@ template uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - in += gid * axis_size; + in += gid * size_t(axis_size); constexpr int SIMD_SIZE = 32; @@ -170,7 +170,7 @@ template // Finally given the normalizer and max value we can directly write the // softmax output - out += gid * axis_size; + out += gid * size_t(axis_size); for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); r++) { int offset = r * lsize * N_READS + lid * N_READS; diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index 312a73207..370480d35 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -10,6 +10,18 @@ template d[index] = Op()(a[index], b[index], c[index]); } +template +[[kernel]] void ternary_v2( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + d[offset] = Op()(a[offset], b[offset], c[offset]); +} + template [[kernel]] void ternary_g_nd1( device const bool* a, diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index 97bfdf81c..4101229b9 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -11,6 +11,7 @@ #define instantiate_ternary_all(op, tname, type) \ instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ + instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ instantiate_kernel("g_" #op #tname, ternary_g, type, op) \ instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \ instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \ diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index 80b17121c..e904e1629 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -8,6 +8,16 @@ template out[index] = Op()(in[index]); } +template +[[kernel]] void unary_v2( + device const T* in, + device T* out, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + out[offset] = Op()(in[offset]); +} + template [[kernel]] void unary_g( device const T* in, diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 5ad57a15d..ea5c05177 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -5,8 +5,9 @@ #include "mlx/backend/metal/kernels/unary_ops.h" #include "mlx/backend/metal/kernels/unary.h" -#define instantiate_unary_all(op, tname, type) \ - instantiate_kernel("v" #op #tname, unary_v, type, op) \ +#define instantiate_unary_all(op, tname, type) \ + instantiate_kernel("v" #op #tname, unary_v, type, op) \ + instantiate_kernel("v2" #op #tname, unary_v2, type, op) \ instantiate_kernel("g" #op #tname, unary_g, type, op) #define instantiate_unary_float(op) \ diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 22ab3cf4b..cc25ac7f3 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -1,5 +1,4 @@ // Copyright © 2023-2024 Apple Inc. - #include #include "mlx/backend/metal/copy.h" diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 1f12ecf39..66364e7de 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -32,6 +32,7 @@ void ternary_op_gpu_inplace( auto& strides_c = strides[2]; auto& strides_out = strides[3]; + bool use_2d = out.data_size(); std::string kernel_name; { std::ostringstream kname; @@ -40,6 +41,8 @@ void ternary_op_gpu_inplace( if (shape.size() <= MAX_TERNARY_SPECIALIZED_DIMS) { kname << shape.size(); } + } else if (use_2d) { + kname << "v2"; } else { kname << "v"; } diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 2ac01c490..17ff1f7b3 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -25,11 +25,14 @@ void unary_op_gpu_inplace( auto& d = metal::device(s.device); - std::string kernel_name = (contig ? "v" : "g") + op + type_to_name(out); + size_t nthreads = contig ? in.data_size() : in.size(); + bool use_2d = nthreads > UINT32_MAX; + std::string kernel_name = + (contig ? (use_2d ? "v2" : "v") : "g") + op + type_to_name(out); auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op); - size_t nthreads = contig ? in.data_size() : in.size(); - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides()) + : MTL::Size(nthreads, 1, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size > nthreads) { thread_group_size = nthreads; diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index fc43079b4..0a24e431a 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -104,6 +104,35 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) { return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]}; } +// Computes a 2D grid where each element is < UINT_MAX +// Assumes: +// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2 +// - shape and strides correspond to a contiguous (no holes) but +// possibly broadcasted array +MTL::Size get_2d_grid_dims( + const std::vector& shape, + const std::vector& strides) { + // Dims with strides of 0 are ignored as they + // correspond to broadcasted dimensions + size_t grid_x = 1; + size_t grid_y = 1; + for (int i = 0; i < shape.size(); ++i) { + if (strides[i] == 0) { + continue; + } + if (grid_x * shape[i] < UINT32_MAX) { + grid_x *= shape[i]; + } else { + grid_y *= shape[i]; + } + } + if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { + throw std::runtime_error("Unable to safely factor shape."); + } + return MTL::Size( + static_cast(grid_x), static_cast(grid_y), 1); +} + inline NS::String* make_string(std::ostringstream& os) { std::string string = os.str(); return NS::String::string(string.c_str(), NS::UTF8StringEncoding); diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 7e8130502..91c4ed9e3 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -273,7 +273,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { // Check for the number of indices passed if (non_none_indices > src.ndim()) { std::ostringstream msg; - msg << "Too many indices for array with " << src.ndim() << "dimensions."; + msg << "Too many indices for array with " << src.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } @@ -585,7 +585,7 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( if (non_none_indices > src.ndim()) { std::ostringstream msg; - msg << "Too many indices for array with " << src.ndim() << "dimensions."; + msg << "Too many indices for array with " << src.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } @@ -840,7 +840,7 @@ auto mlx_slice_update( // Dimension check if (non_none_indices > src.ndim()) { std::ostringstream msg; - msg << "Too many indices for array with " << src.ndim() << "dimensions."; + msg << "Too many indices for array with " << src.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); }