mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Fixes for large arrays with a few ops (#1299)
* fixes for large arrays with a few ops * fix bug * fix all of copy
This commit is contained in:
		| @@ -21,10 +21,43 @@ namespace mlx::core { | ||||
|  | ||||
| constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5; | ||||
|  | ||||
| std::string get_kernel_name( | ||||
|     BinaryOpType bopt, | ||||
|     const std::string& op, | ||||
|     const array& a, | ||||
|     bool use_2d, | ||||
|     int ndim) { | ||||
|   std::ostringstream kname; | ||||
|   switch (bopt) { | ||||
|     case BinaryOpType::ScalarScalar: | ||||
|       kname << "ss"; | ||||
|       break; | ||||
|     case BinaryOpType::ScalarVector: | ||||
|       kname << (use_2d ? "sv2" : "sv"); | ||||
|       break; | ||||
|     case BinaryOpType::VectorScalar: | ||||
|       kname << (use_2d ? "vs2" : "vs"); | ||||
|       break; | ||||
|     case BinaryOpType::VectorVector: | ||||
|       kname << (use_2d ? "vv2" : "vv"); | ||||
|       break; | ||||
|     case BinaryOpType::General: | ||||
|       kname << "g"; | ||||
|       if (ndim <= MAX_BINARY_SPECIALIZED_DIMS) { | ||||
|         kname << ndim; | ||||
|       } else { | ||||
|         kname << "n"; | ||||
|       } | ||||
|       break; | ||||
|   } | ||||
|   kname << op << type_to_name(a); | ||||
|   return kname.str(); | ||||
| } | ||||
|  | ||||
| void binary_op_gpu_inplace( | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs, | ||||
|     const std::string op, | ||||
|     const std::string& op, | ||||
|     const Stream& s) { | ||||
|   auto& a = inputs[0]; | ||||
|   auto& b = inputs[1]; | ||||
| @@ -41,35 +74,8 @@ void binary_op_gpu_inplace( | ||||
|   auto& strides_b = strides[1]; | ||||
|   auto& strides_out = strides[2]; | ||||
|  | ||||
|   std::string kernel_name; | ||||
|   { | ||||
|     std::ostringstream kname; | ||||
|     switch (bopt) { | ||||
|       case BinaryOpType::ScalarScalar: | ||||
|         kname << "ss"; | ||||
|         break; | ||||
|       case BinaryOpType::ScalarVector: | ||||
|         kname << "sv"; | ||||
|         break; | ||||
|       case BinaryOpType::VectorScalar: | ||||
|         kname << "vs"; | ||||
|         break; | ||||
|       case BinaryOpType::VectorVector: | ||||
|         kname << "vv"; | ||||
|         break; | ||||
|       case BinaryOpType::General: | ||||
|         kname << "g"; | ||||
|         if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { | ||||
|           kname << shape.size(); | ||||
|         } else { | ||||
|           kname << "n"; | ||||
|         } | ||||
|         break; | ||||
|     } | ||||
|     kname << op << type_to_name(a); | ||||
|     kernel_name = kname.str(); | ||||
|   } | ||||
|  | ||||
|   bool use_2d = out.data_size() > UINT32_MAX; | ||||
|   std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size()); | ||||
|   auto& d = metal::device(s.device); | ||||
|  | ||||
|   auto kernel = | ||||
| @@ -117,9 +123,11 @@ void binary_op_gpu_inplace( | ||||
|     MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } else { | ||||
|     // Launch a 1D grid of threads | ||||
|     // Launch a 1D or 2D grid of threads | ||||
|     size_t nthreads = out.data_size(); | ||||
|     MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); | ||||
|     MTL::Size grid_dims = use_2d | ||||
|         ? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) | ||||
|         : MTL::Size(nthreads, 1, 1); | ||||
|     NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); | ||||
|     if (thread_group_size > nthreads) { | ||||
|       thread_group_size = nthreads; | ||||
| @@ -132,7 +140,7 @@ void binary_op_gpu_inplace( | ||||
| void binary_op_gpu( | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs, | ||||
|     const std::string op, | ||||
|     const std::string& op, | ||||
|     const Stream& s) { | ||||
|   assert(inputs.size() == 2); | ||||
|   auto& a = inputs[0]; | ||||
| @@ -146,7 +154,7 @@ void binary_op_gpu( | ||||
| void binary_op_gpu( | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs, | ||||
|     const std::string op) { | ||||
|     const std::string& op) { | ||||
|   auto& s = outputs[0].primitive().stream(); | ||||
|   binary_op_gpu(inputs, outputs, op, s); | ||||
| } | ||||
| @@ -154,7 +162,7 @@ void binary_op_gpu( | ||||
| void binary_op_gpu_inplace( | ||||
|     const std::vector<array>& inputs, | ||||
|     array& out, | ||||
|     const std::string op, | ||||
|     const std::string& op, | ||||
|     const Stream& s) { | ||||
|   auto& a = inputs[0]; | ||||
|   auto& b = inputs[1]; | ||||
| @@ -169,35 +177,8 @@ void binary_op_gpu_inplace( | ||||
|   auto& strides_b = strides[1]; | ||||
|   auto& strides_out = strides[2]; | ||||
|  | ||||
|   std::string kernel_name; | ||||
|   { | ||||
|     std::ostringstream kname; | ||||
|     switch (bopt) { | ||||
|       case BinaryOpType::ScalarScalar: | ||||
|         kname << "ss"; | ||||
|         break; | ||||
|       case BinaryOpType::ScalarVector: | ||||
|         kname << "sv"; | ||||
|         break; | ||||
|       case BinaryOpType::VectorScalar: | ||||
|         kname << "vs"; | ||||
|         break; | ||||
|       case BinaryOpType::VectorVector: | ||||
|         kname << "vv"; | ||||
|         break; | ||||
|       case BinaryOpType::General: | ||||
|         kname << "g"; | ||||
|         if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { | ||||
|           kname << shape.size(); | ||||
|         } else { | ||||
|           kname << "n"; | ||||
|         } | ||||
|         break; | ||||
|     } | ||||
|     kname << op << type_to_name(a); | ||||
|     kernel_name = kname.str(); | ||||
|   } | ||||
|  | ||||
|   bool use_2d = out.data_size() > UINT32_MAX; | ||||
|   std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size()); | ||||
|   auto& d = metal::device(s.device); | ||||
|  | ||||
|   auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op); | ||||
| @@ -237,10 +218,11 @@ void binary_op_gpu_inplace( | ||||
|     MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } else { | ||||
|     // Launch a 1D grid of threads | ||||
|     size_t nthreads = | ||||
|         bopt == BinaryOpType::General ? out.size() : out.data_size(); | ||||
|     MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); | ||||
|     // Launch a 1D or 2D grid of threads | ||||
|  | ||||
|     size_t nthreads = out.data_size(); | ||||
|     MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) | ||||
|                                  : MTL::Size(nthreads, 1, 1); | ||||
|     NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); | ||||
|     if (thread_group_size > nthreads) { | ||||
|       thread_group_size = nthreads; | ||||
| @@ -253,7 +235,7 @@ void binary_op_gpu_inplace( | ||||
| void binary_op_gpu( | ||||
|     const std::vector<array>& inputs, | ||||
|     array& out, | ||||
|     const std::string op, | ||||
|     const std::string& op, | ||||
|     const Stream& s) { | ||||
|   assert(inputs.size() == 2); | ||||
|   auto& a = inputs[0]; | ||||
| @@ -266,7 +248,7 @@ void binary_op_gpu( | ||||
| void binary_op_gpu( | ||||
|     const std::vector<array>& inputs, | ||||
|     array& out, | ||||
|     const std::string op) { | ||||
|     const std::string& op) { | ||||
|   auto& s = out.primitive().stream(); | ||||
|   binary_op_gpu(inputs, out, op, s); | ||||
| } | ||||
|   | ||||
| @@ -9,25 +9,25 @@ namespace mlx::core { | ||||
| void binary_op_gpu( | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs, | ||||
|     const std::string op, | ||||
|     const std::string& op, | ||||
|     const Stream& s); | ||||
|  | ||||
| void binary_op_gpu( | ||||
|     const std::vector<array>& inputs, | ||||
|     array& out, | ||||
|     const std::string op, | ||||
|     const std::string& op, | ||||
|     const Stream& s); | ||||
|  | ||||
| void binary_op_gpu_inplace( | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs, | ||||
|     const std::string op, | ||||
|     const std::string& op, | ||||
|     const Stream& s); | ||||
|  | ||||
| void binary_op_gpu_inplace( | ||||
|     const std::vector<array>& inputs, | ||||
|     array& out, | ||||
|     const std::string op, | ||||
|     const std::string& op, | ||||
|     const Stream& s); | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
| @@ -64,16 +64,17 @@ void copy_gpu_inplace( | ||||
|   auto& strides_in_ = strides[0]; | ||||
|   auto& strides_out_ = strides[1]; | ||||
|  | ||||
|   bool use_2d = out.data_size() > UINT32_MAX; | ||||
|   auto& d = metal::device(s.device); | ||||
|   std::string kernel_name; | ||||
|   { | ||||
|     std::ostringstream kname; | ||||
|     switch (ctype) { | ||||
|       case CopyType::Scalar: | ||||
|         kname << "s"; | ||||
|         kname << (use_2d ? "s2" : "s"); | ||||
|         break; | ||||
|       case CopyType::Vector: | ||||
|         kname << "v"; | ||||
|         kname << (use_2d ? "v2" : "v"); | ||||
|         break; | ||||
|       case CopyType::General: | ||||
|         kname << "g"; | ||||
| @@ -139,7 +140,8 @@ void copy_gpu_inplace( | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } else { | ||||
|     size_t nthreads = out.data_size(); | ||||
|     MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); | ||||
|     MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) | ||||
|                                  : MTL::Size(nthreads, 1, 1); | ||||
|     NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); | ||||
|     if (thread_group_size > nthreads) { | ||||
|       thread_group_size = nthreads; | ||||
|   | ||||
| @@ -36,6 +36,39 @@ template <typename T, typename U, typename Op> | ||||
|   c[index] = Op()(a[index], b[index]); | ||||
| } | ||||
|  | ||||
| template <typename T, typename U, typename Op> | ||||
| [[kernel]] void binary_sv2( | ||||
|     device const T* a, | ||||
|     device const T* b, | ||||
|     device U* c, | ||||
|     uint2 index [[thread_position_in_grid]], | ||||
|     uint2 grid_dim [[threads_per_grid]]) { | ||||
|   size_t offset = index.x + grid_dim.x * size_t(index.y); | ||||
|   c[offset] = Op()(a[0], b[offset]); | ||||
| } | ||||
|  | ||||
| template <typename T, typename U, typename Op> | ||||
| [[kernel]] void binary_vs2( | ||||
|     device const T* a, | ||||
|     device const T* b, | ||||
|     device U* c, | ||||
|     uint2 index [[thread_position_in_grid]], | ||||
|     uint2 grid_dim [[threads_per_grid]]) { | ||||
|   size_t offset = index.x + grid_dim.x * size_t(index.y); | ||||
|   c[offset] = Op()(a[offset], b[0]); | ||||
| } | ||||
|  | ||||
| template <typename T, typename U, typename Op> | ||||
| [[kernel]] void binary_vv2( | ||||
|     device const T* a, | ||||
|     device const T* b, | ||||
|     device U* c, | ||||
|     uint2 index [[thread_position_in_grid]], | ||||
|     uint2 grid_dim [[threads_per_grid]]) { | ||||
|   size_t offset = index.x + grid_dim.x * size_t(index.y); | ||||
|   c[offset] = Op()(a[offset], b[offset]); | ||||
| } | ||||
|  | ||||
| template <typename T, typename U, typename Op> | ||||
| [[kernel]] void binary_g_nd1( | ||||
|     device const T* a, | ||||
|   | ||||
| @@ -14,6 +14,9 @@ | ||||
|   instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op)      \ | ||||
|   instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op)      \ | ||||
|   instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op)      \ | ||||
|   instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op)    \ | ||||
|   instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op)    \ | ||||
|   instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op)    \ | ||||
|   instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op)       \ | ||||
|   instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op)   \ | ||||
|   instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op)   \ | ||||
|   | ||||
| @@ -48,6 +48,48 @@ template <typename T, typename U, typename Op> | ||||
|   d[index] = out[1]; | ||||
| } | ||||
|  | ||||
| template <typename T, typename U, typename Op> | ||||
| [[kernel]] void binary_sv2( | ||||
|     device const T* a, | ||||
|     device const T* b, | ||||
|     device U* c, | ||||
|     device U* d, | ||||
|     uint2 index [[thread_position_in_grid]], | ||||
|     uint2 grid_dim [[threads_per_grid]]) { | ||||
|   size_t offset = index.x + grid_dim.x * size_t(index.y); | ||||
|   auto out = Op()(a[0], b[offset]); | ||||
|   c[offset] = out[0]; | ||||
|   d[offset] = out[1]; | ||||
| } | ||||
|  | ||||
| template <typename T, typename U, typename Op> | ||||
| [[kernel]] void binary_vs2( | ||||
|     device const T* a, | ||||
|     device const T* b, | ||||
|     device U* c, | ||||
|     device U* d, | ||||
|     uint2 index [[thread_position_in_grid]], | ||||
|     uint2 grid_dim [[threads_per_grid]]) { | ||||
|   size_t offset = index.x + grid_dim.x * size_t(index.y); | ||||
|   auto out = Op()(a[offset], b[0]); | ||||
|   c[offset] = out[0]; | ||||
|   d[offset] = out[1]; | ||||
| } | ||||
|  | ||||
| template <typename T, typename U, typename Op> | ||||
| [[kernel]] void binary_vv2( | ||||
|     device const T* a, | ||||
|     device const T* b, | ||||
|     device U* c, | ||||
|     device U* d, | ||||
|     uint2 index [[thread_position_in_grid]], | ||||
|     uint2 grid_dim [[threads_per_grid]]) { | ||||
|   size_t offset = index.x + grid_dim.x * size_t(index.y); | ||||
|   auto out = Op()(a[offset], b[offset]); | ||||
|   c[offset] = out[0]; | ||||
|   d[offset] = out[1]; | ||||
| } | ||||
|  | ||||
| template <typename T, typename U, typename Op> | ||||
| [[kernel]] void binary_g_nd1( | ||||
|     device const T* a, | ||||
|   | ||||
| @@ -12,6 +12,9 @@ | ||||
|   instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op)      \ | ||||
|   instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op)      \ | ||||
|   instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op)      \ | ||||
|   instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op)    \ | ||||
|   instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op)    \ | ||||
|   instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op)    \ | ||||
|   instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op)       \ | ||||
|   instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op)   \ | ||||
|   instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op)   \ | ||||
|   | ||||
| @@ -16,6 +16,26 @@ template <typename T, typename U> | ||||
|   dst[index] = static_cast<U>(src[index]); | ||||
| } | ||||
|  | ||||
| template <typename T, typename U> | ||||
| [[kernel]] void copy_s2( | ||||
|     device const T* src [[buffer(0)]], | ||||
|     device U* dst [[buffer(1)]], | ||||
|     uint2 index [[thread_position_in_grid]], | ||||
|     uint2 grid_dim [[threads_per_grid]]) { | ||||
|   size_t offset = index.x + grid_dim.x * size_t(index.y); | ||||
|   dst[offset] = static_cast<U>(src[0]); | ||||
| } | ||||
|  | ||||
| template <typename T, typename U> | ||||
| [[kernel]] void copy_v2( | ||||
|     device const T* src [[buffer(0)]], | ||||
|     device U* dst [[buffer(1)]], | ||||
|     uint2 index [[thread_position_in_grid]], | ||||
|     uint2 grid_dim [[threads_per_grid]]) { | ||||
|   size_t offset = index.x + grid_dim.x * size_t(index.y); | ||||
|   dst[offset] = static_cast<U>(src[offset]); | ||||
| } | ||||
|  | ||||
| template <typename T, typename U> | ||||
| [[kernel]] void copy_g_nd1( | ||||
|     device const T* src [[buffer(0)]], | ||||
|   | ||||
| @@ -5,95 +5,23 @@ | ||||
| #include "mlx/backend/metal/kernels/bf16.h" | ||||
| #include "mlx/backend/metal/kernels/copy.h" | ||||
|  | ||||
| #define instantiate_copy(name, itype, otype, ctype)                        \ | ||||
|   template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \ | ||||
|       device const itype* src [[buffer(0)]],                               \ | ||||
|       device otype* dst [[buffer(1)]],                                     \ | ||||
|       uint index [[thread_position_in_grid]]); | ||||
|  | ||||
| #define instantiate_copy_g_dim(name, itype, otype, dims)      \ | ||||
|   template [[host_name("g" #dims "_" name)]] [[kernel]] void  \ | ||||
|   copy_g_nd<itype, otype, dims>(                              \ | ||||
|       device const itype* src [[buffer(0)]],                  \ | ||||
|       device otype* dst [[buffer(1)]],                        \ | ||||
|       constant const int* src_shape [[buffer(2)]],            \ | ||||
|       constant const int64_t* src_strides [[buffer(3)]],      \ | ||||
|       uint3 index [[thread_position_in_grid]],                \ | ||||
|       uint3 grid_dim [[threads_per_grid]]);                   \ | ||||
|   template [[host_name("gg" #dims "_" name)]] [[kernel]] void \ | ||||
|   copy_gg_nd<itype, otype, dims>(                             \ | ||||
|       device const itype* src [[buffer(0)]],                  \ | ||||
|       device otype* dst [[buffer(1)]],                        \ | ||||
|       constant const int* src_shape [[buffer(2)]],            \ | ||||
|       constant const int64_t* src_strides [[buffer(3)]],      \ | ||||
|       constant const int64_t* dst_strides [[buffer(4)]],      \ | ||||
|       uint3 index [[thread_position_in_grid]]); | ||||
|  | ||||
| #define instantiate_copy_g_nd(name, itype, otype)                              \ | ||||
|   template [[host_name("g1_" name)]] [[kernel]] void copy_g_nd1<itype, otype>( \ | ||||
|       device const itype* src [[buffer(0)]],                                   \ | ||||
|       device otype* dst [[buffer(1)]],                                         \ | ||||
|       constant const int64_t& src_stride [[buffer(3)]],                        \ | ||||
|       uint index [[thread_position_in_grid]]);                                 \ | ||||
|   template [[host_name("g2_" name)]] [[kernel]] void copy_g_nd2<itype, otype>( \ | ||||
|       device const itype* src [[buffer(0)]],                                   \ | ||||
|       device otype* dst [[buffer(1)]],                                         \ | ||||
|       constant const int64_t* src_strides [[buffer(3)]],                       \ | ||||
|       uint2 index [[thread_position_in_grid]],                                 \ | ||||
|       uint2 grid_dim [[threads_per_grid]]);                                    \ | ||||
|   template [[host_name("g3_" name)]] [[kernel]] void copy_g_nd3<itype, otype>( \ | ||||
|       device const itype* src [[buffer(0)]],                                   \ | ||||
|       device otype* dst [[buffer(1)]],                                         \ | ||||
|       constant const int64_t* src_strides [[buffer(3)]],                       \ | ||||
|       uint3 index [[thread_position_in_grid]],                                 \ | ||||
|       uint3 grid_dim [[threads_per_grid]]);                                    \ | ||||
|   template [[host_name("gg1_" name )]] [[kernel]] void                         \ | ||||
|   copy_gg_nd1<itype, otype>(                                                   \ | ||||
|       device const itype* src [[buffer(0)]],                                   \ | ||||
|       device otype* dst [[buffer(1)]],                                         \ | ||||
|       constant const int64_t& src_stride [[buffer(3)]],                        \ | ||||
|       constant const int64_t& dst_stride [[buffer(4)]],                        \ | ||||
|       uint index [[thread_position_in_grid]]);                                 \ | ||||
|   template [[host_name("gg2_" name)]] [[kernel]] void                          \ | ||||
|   copy_gg_nd2<itype, otype>(                                                   \ | ||||
|       device const itype* src [[buffer(0)]],                                   \ | ||||
|       device otype* dst [[buffer(1)]],                                         \ | ||||
|       constant const int64_t* src_strides [[buffer(3)]],                       \ | ||||
|       constant const int64_t* dst_strides [[buffer(4)]],                       \ | ||||
|       uint2 index [[thread_position_in_grid]]);                                \ | ||||
|   template [[host_name("gg3_" name)]] [[kernel]] void                          \ | ||||
|   copy_gg_nd3<itype, otype>(                                                   \ | ||||
|       device const itype* src [[buffer(0)]],                                   \ | ||||
|       device otype* dst [[buffer(1)]],                                         \ | ||||
|       constant const int64_t* src_strides [[buffer(3)]],                       \ | ||||
|       constant const int64_t* dst_strides [[buffer(4)]],                       \ | ||||
|       uint3 index [[thread_position_in_grid]]);                                \ | ||||
|   instantiate_copy_g_dim(name, itype, otype, 4)                                \ | ||||
|   instantiate_copy_g_dim(name, itype, otype, 5) | ||||
|  | ||||
| #define instantiate_copy_g(name, itype, otype)                              \ | ||||
|   template [[host_name("g_" name)]] [[kernel]] void copy_g<itype, otype>(   \ | ||||
|       device const itype* src [[buffer(0)]],                                \ | ||||
|       device otype* dst [[buffer(1)]],                                      \ | ||||
|       constant const int* src_shape [[buffer(2)]],                          \ | ||||
|       constant const int64_t* src_strides [[buffer(3)]],                    \ | ||||
|       constant const int& ndim [[buffer(5)]],                               \ | ||||
|       uint3 index [[thread_position_in_grid]],                              \ | ||||
|       uint3 grid_dim [[threads_per_grid]]);                                 \ | ||||
|   template [[host_name("gg_" name)]] [[kernel]] void copy_gg<itype, otype>( \ | ||||
|       device const itype* src [[buffer(0)]],                                \ | ||||
|       device otype* dst [[buffer(1)]],                                      \ | ||||
|       constant const int* src_shape [[buffer(2)]],                          \ | ||||
|       constant const int64_t* src_strides [[buffer(3)]],                    \ | ||||
|       constant const int64_t* dst_strides [[buffer(4)]],                    \ | ||||
|       constant const int& ndim [[buffer(5)]],                               \ | ||||
|       uint3 index [[thread_position_in_grid]]); | ||||
|  | ||||
| #define instantiate_copy_all(tname, itype, otype)    \ | ||||
|   instantiate_copy("s_copy" #tname, itype, otype, s) \ | ||||
|   instantiate_copy("v_copy" #tname, itype, otype, v) \ | ||||
|   instantiate_copy_g("copy" #tname, itype, otype)    \ | ||||
|   instantiate_copy_g_nd("copy" #tname, itype, otype) | ||||
|   instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \ | ||||
|   instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \ | ||||
|   instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \ | ||||
|   instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \ | ||||
|   instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \ | ||||
|   instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \ | ||||
|   instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \ | ||||
|   instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \ | ||||
|   instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \ | ||||
|   instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \ | ||||
|   instantiate_kernel("g4_copy" #tname, copy_g_nd, itype, otype, 4) \ | ||||
|   instantiate_kernel("g5_copy" #tname, copy_g_nd, itype, otype, 5) \ | ||||
|   instantiate_kernel("gg4_copy" #tname, copy_gg_nd, itype, otype, 4) \ | ||||
|   instantiate_kernel("gg5_copy" #tname, copy_gg_nd, itype, otype, 5) \ | ||||
|   instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \ | ||||
|   instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype) | ||||
|  | ||||
| #define instantiate_copy_itype(itname, itype)                \ | ||||
|   instantiate_copy_all(itname ##bool_, itype, bool)          \ | ||||
|   | ||||
| @@ -34,7 +34,7 @@ template <typename T, int N_READS = RMS_N_READS> | ||||
|   threadgroup float local_mean[1]; | ||||
|   threadgroup float local_normalizer[1]; | ||||
|  | ||||
|   x += gid * axis_size + lid * N_READS; | ||||
|   x += gid * size_t(axis_size) + lid * N_READS; | ||||
|   w += w_stride * lid * N_READS; | ||||
|   b += b_stride * lid * N_READS; | ||||
|  | ||||
| @@ -89,7 +89,7 @@ template <typename T, int N_READS = RMS_N_READS> | ||||
|   float normalizer = local_normalizer[0]; | ||||
|  | ||||
|   // Write the outputs | ||||
|   out += gid * axis_size + lid * N_READS; | ||||
|   out += gid * size_t(axis_size) + lid * N_READS; | ||||
|   if (lid * N_READS + N_READS <= axis_size) { | ||||
|     for (int i = 0; i < N_READS; i++) { | ||||
|       thread_x[i] = (thread_x[i] - mean) * normalizer; | ||||
| @@ -131,7 +131,7 @@ template <typename T, int N_READS = RMS_N_READS> | ||||
|   threadgroup float local_mean[1]; | ||||
|   threadgroup float local_normalizer[1]; | ||||
|  | ||||
|   x += gid * axis_size + lid * N_READS; | ||||
|   x += gid * size_t(axis_size) + lid * N_READS; | ||||
|   w += w_stride * lid * N_READS; | ||||
|   b += b_stride * lid * N_READS; | ||||
|  | ||||
| @@ -188,7 +188,7 @@ template <typename T, int N_READS = RMS_N_READS> | ||||
|   float normalizer = local_normalizer[0]; | ||||
|  | ||||
|   // Write the outputs | ||||
|   out += gid * axis_size + lid * N_READS; | ||||
|   out += gid * size_t(axis_size) + lid * N_READS; | ||||
|   for (uint r = 0; r < axis_size; r += lsize * N_READS) { | ||||
|     if (r + lid * N_READS + N_READS <= axis_size) { | ||||
|       for (int i = 0; i < N_READS; i++) { | ||||
| @@ -223,8 +223,8 @@ template <typename T, int N_READS = RMS_N_READS> | ||||
|     uint simd_lane_id [[thread_index_in_simdgroup]], | ||||
|     uint simd_group_id [[simdgroup_index_in_threadgroup]]) { | ||||
|   // Advance the input pointers | ||||
|   x += gid * axis_size + lid * N_READS; | ||||
|   g += gid * axis_size + lid * N_READS; | ||||
|   x += gid * size_t(axis_size) + lid * N_READS; | ||||
|   g += gid * size_t(axis_size) + lid * N_READS; | ||||
|   w += w_stride * lid * N_READS; | ||||
|  | ||||
|   // Allocate registers for the computation and accumulators | ||||
| @@ -321,8 +321,8 @@ template <typename T, int N_READS = RMS_N_READS> | ||||
|   float normalizer2 = normalizer * normalizer; | ||||
|  | ||||
|   // Write the outputs | ||||
|   gx += gid * axis_size + lid * N_READS; | ||||
|   gw += gid * axis_size + lid * N_READS; | ||||
|   gx += gid * size_t(axis_size) + lid * N_READS; | ||||
|   gw += gid * size_t(axis_size) + lid * N_READS; | ||||
|   if (lid * N_READS + N_READS <= axis_size) { | ||||
|     for (int i = 0; i < N_READS; i++) { | ||||
|       thread_x[i] = (thread_x[i] - mean) * normalizer; | ||||
| @@ -360,8 +360,8 @@ template <typename T, int N_READS = RMS_N_READS> | ||||
|     uint simd_lane_id [[thread_index_in_simdgroup]], | ||||
|     uint simd_group_id [[simdgroup_index_in_threadgroup]]) { | ||||
|   // Advance the input pointers | ||||
|   x += gid * axis_size + lid * N_READS; | ||||
|   g += gid * axis_size + lid * N_READS; | ||||
|   x += gid * size_t(axis_size) + lid * N_READS; | ||||
|   g += gid * size_t(axis_size) + lid * N_READS; | ||||
|   w += w_stride * lid * N_READS; | ||||
|  | ||||
|   // Allocate registers for the accumulators | ||||
| @@ -457,8 +457,8 @@ template <typename T, int N_READS = RMS_N_READS> | ||||
|   float normalizer2 = normalizer * normalizer; | ||||
|  | ||||
|   // Write the outputs | ||||
|   gx += gid * axis_size + lid * N_READS; | ||||
|   gw += gid * axis_size + lid * N_READS; | ||||
|   gx += gid * size_t(axis_size) + lid * N_READS; | ||||
|   gw += gid * size_t(axis_size) + lid * N_READS; | ||||
|   for (uint r = 0; r < axis_size; r += lsize * N_READS) { | ||||
|     if (r + lid * N_READS + N_READS <= axis_size) { | ||||
|       for (int i = 0; i < N_READS; i++) { | ||||
|   | ||||
| @@ -24,7 +24,7 @@ template <typename T, int N_READS = RMS_N_READS> | ||||
|     uint simd_lane_id [[thread_index_in_simdgroup]], | ||||
|     uint simd_group_id [[simdgroup_index_in_threadgroup]]) { | ||||
|   float acc = 0; | ||||
|   x += gid * axis_size + lid * N_READS; | ||||
|   x += gid * size_t(axis_size) + lid * N_READS; | ||||
|   w += w_stride * lid * N_READS; | ||||
|   if (lid * N_READS + N_READS <= axis_size) { | ||||
|     for (int i = 0; i < N_READS; i++) { | ||||
| @@ -62,7 +62,7 @@ template <typename T, int N_READS = RMS_N_READS> | ||||
|   threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|   // Write the outputs | ||||
|   out += gid * axis_size + lid * N_READS; | ||||
|   out += gid * size_t(axis_size) + lid * N_READS; | ||||
|   if (lid * N_READS + N_READS <= axis_size) { | ||||
|     for (int i = 0; i < N_READS; i++) { | ||||
|       out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]); | ||||
| @@ -92,7 +92,7 @@ template <typename T, int N_READS = RMS_N_READS> | ||||
|     uint simd_lane_id [[thread_index_in_simdgroup]], | ||||
|     uint simd_group_id [[simdgroup_index_in_threadgroup]]) { | ||||
|   float acc = 0; | ||||
|   x += gid * axis_size + lid * N_READS; | ||||
|   x += gid * size_t(axis_size) + lid * N_READS; | ||||
|   w += w_stride * lid * N_READS; | ||||
|   for (uint r = 0; r < axis_size; r += lsize * N_READS) { | ||||
|     if (r + lid * N_READS + N_READS <= axis_size) { | ||||
| @@ -132,7 +132,7 @@ template <typename T, int N_READS = RMS_N_READS> | ||||
|   threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|   // Write the outputs | ||||
|   out += gid * axis_size + lid * N_READS; | ||||
|   out += gid * size_t(axis_size) + lid * N_READS; | ||||
|   for (uint r = 0; r < axis_size; r += lsize * N_READS) { | ||||
|     if (r + lid * N_READS + N_READS <= axis_size) { | ||||
|       for (int i = 0; i < N_READS; i++) { | ||||
| @@ -165,8 +165,8 @@ template <typename T, int N_READS = RMS_N_READS> | ||||
|     uint simd_lane_id [[thread_index_in_simdgroup]], | ||||
|     uint simd_group_id [[simdgroup_index_in_threadgroup]]) { | ||||
|   // Advance the input pointers | ||||
|   x += gid * axis_size + lid * N_READS; | ||||
|   g += gid * axis_size + lid * N_READS; | ||||
|   x += gid * size_t(axis_size) + lid * N_READS; | ||||
|   g += gid * size_t(axis_size) + lid * N_READS; | ||||
|   w += w_stride * lid * N_READS; | ||||
|  | ||||
|   // Allocate registers for the computation and accumulators | ||||
| @@ -233,8 +233,8 @@ template <typename T, int N_READS = RMS_N_READS> | ||||
|   float normalizer3 = normalizer * normalizer * normalizer; | ||||
|  | ||||
|   // Write the outputs | ||||
|   gx += gid * axis_size + lid * N_READS; | ||||
|   gw += gid * axis_size + lid * N_READS; | ||||
|   gx += gid * size_t(axis_size) + lid * N_READS; | ||||
|   gw += gid * size_t(axis_size) + lid * N_READS; | ||||
|   if (lid * N_READS + N_READS <= axis_size) { | ||||
|     for (int i = 0; i < N_READS; i++) { | ||||
|       gx[i] = static_cast<T>( | ||||
| @@ -270,8 +270,8 @@ template <typename T, int N_READS = RMS_N_READS> | ||||
|     uint simd_lane_id [[thread_index_in_simdgroup]], | ||||
|     uint simd_group_id [[simdgroup_index_in_threadgroup]]) { | ||||
|   // Advance the input pointers | ||||
|   x += gid * axis_size + lid * N_READS; | ||||
|   g += gid * axis_size + lid * N_READS; | ||||
|   x += gid * size_t(axis_size) + lid * N_READS; | ||||
|   g += gid * size_t(axis_size) + lid * N_READS; | ||||
|   w += w_stride * lid * N_READS; | ||||
|  | ||||
|   // Allocate registers for the accumulators | ||||
| @@ -337,8 +337,8 @@ template <typename T, int N_READS = RMS_N_READS> | ||||
|   float normalizer3 = normalizer * normalizer * normalizer; | ||||
|  | ||||
|   // Write the outputs | ||||
|   gx += gid * axis_size + lid * N_READS; | ||||
|   gw += gid * axis_size + lid * N_READS; | ||||
|   gx += gid * size_t(axis_size) + lid * N_READS; | ||||
|   gw += gid * size_t(axis_size) + lid * N_READS; | ||||
|   for (uint r = 0; r < axis_size; r += lsize * N_READS) { | ||||
|     if (r + lid * N_READS + N_READS <= axis_size) { | ||||
|       for (int i = 0; i < N_READS; i++) { | ||||
|   | ||||
| @@ -25,7 +25,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS> | ||||
|  | ||||
|   AccT ld[N_READS]; | ||||
|  | ||||
|   in += gid * axis_size + lid * N_READS; | ||||
|   in += gid * size_t(axis_size) + lid * N_READS; | ||||
|   if (lid * N_READS + N_READS <= axis_size) { | ||||
|     for (int i = 0; i < N_READS; i++) { | ||||
|       ld[i] = AccT(in[i]); | ||||
| @@ -83,7 +83,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS> | ||||
|   normalizer = 1 / local_normalizer[0]; | ||||
|  | ||||
|   // Normalize and write to the output | ||||
|   out += gid * axis_size + lid * N_READS; | ||||
|   out += gid * size_t(axis_size) + lid * N_READS; | ||||
|   if (lid * N_READS + N_READS <= axis_size) { | ||||
|     for (int i = 0; i < N_READS; i++) { | ||||
|       out[i] = T(ld[i] * normalizer); | ||||
| @@ -107,7 +107,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS> | ||||
|     uint lsize [[threads_per_threadgroup]], | ||||
|     uint simd_lane_id [[thread_index_in_simdgroup]], | ||||
|     uint simd_group_id [[simdgroup_index_in_threadgroup]]) { | ||||
|   in += gid * axis_size; | ||||
|   in += gid * size_t(axis_size); | ||||
|  | ||||
|   constexpr int SIMD_SIZE = 32; | ||||
|  | ||||
| @@ -170,7 +170,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS> | ||||
|  | ||||
|   // Finally given the normalizer and max value we can directly write the | ||||
|   // softmax output | ||||
|   out += gid * axis_size; | ||||
|   out += gid * size_t(axis_size); | ||||
|   for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize)); | ||||
|        r++) { | ||||
|     int offset = r * lsize * N_READS + lid * N_READS; | ||||
|   | ||||
| @@ -10,6 +10,18 @@ template <typename T, typename Op> | ||||
|   d[index] = Op()(a[index], b[index], c[index]); | ||||
| } | ||||
|  | ||||
| template <typename T, typename Op> | ||||
| [[kernel]] void ternary_v2( | ||||
|     device const bool* a, | ||||
|     device const T* b, | ||||
|     device const T* c, | ||||
|     device T* d, | ||||
|     uint2 index [[thread_position_in_grid]], | ||||
|     uint2 grid_dim [[threads_per_grid]]) { | ||||
|   size_t offset = index.x + grid_dim.x * size_t(index.y); | ||||
|   d[offset] = Op()(a[offset], b[offset], c[offset]); | ||||
| } | ||||
|  | ||||
| template <typename T, typename Op> | ||||
| [[kernel]] void ternary_g_nd1( | ||||
|     device const bool* a, | ||||
|   | ||||
| @@ -11,6 +11,7 @@ | ||||
|  | ||||
| #define instantiate_ternary_all(op, tname, type)                  \ | ||||
|   instantiate_kernel("v_" #op #tname, ternary_v, type, op)        \ | ||||
|   instantiate_kernel("v2_" #op #tname, ternary_v2, type, op)      \ | ||||
|   instantiate_kernel("g_" #op #tname, ternary_g, type, op)        \ | ||||
|   instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op)   \ | ||||
|   instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op)   \ | ||||
|   | ||||
| @@ -8,6 +8,16 @@ template <typename T, typename Op> | ||||
|   out[index] = Op()(in[index]); | ||||
| } | ||||
|  | ||||
| template <typename T, typename Op> | ||||
| [[kernel]] void unary_v2( | ||||
|     device const T* in, | ||||
|     device T* out, | ||||
|     uint2 index [[thread_position_in_grid]], | ||||
|     uint2 grid_dim [[threads_per_grid]]) { | ||||
|   size_t offset = index.x + grid_dim.x * size_t(index.y); | ||||
|   out[offset] = Op()(in[offset]); | ||||
| } | ||||
|  | ||||
| template <typename T, typename Op> | ||||
| [[kernel]] void unary_g( | ||||
|     device const T* in, | ||||
|   | ||||
| @@ -5,8 +5,9 @@ | ||||
| #include "mlx/backend/metal/kernels/unary_ops.h" | ||||
| #include "mlx/backend/metal/kernels/unary.h" | ||||
|  | ||||
| #define instantiate_unary_all(op, tname, type)          \ | ||||
|   instantiate_kernel("v" #op #tname, unary_v, type, op) \ | ||||
| #define instantiate_unary_all(op, tname, type)            \ | ||||
|   instantiate_kernel("v" #op #tname, unary_v, type, op)   \ | ||||
|   instantiate_kernel("v2" #op #tname, unary_v2, type, op) \ | ||||
|   instantiate_kernel("g" #op #tname, unary_g, type, op) | ||||
|  | ||||
| #define instantiate_unary_float(op)               \ | ||||
|   | ||||
| @@ -1,5 +1,4 @@ | ||||
| // Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| #include <algorithm> | ||||
|  | ||||
| #include "mlx/backend/metal/copy.h" | ||||
|   | ||||
| @@ -32,6 +32,7 @@ void ternary_op_gpu_inplace( | ||||
|   auto& strides_c = strides[2]; | ||||
|   auto& strides_out = strides[3]; | ||||
|  | ||||
|   bool use_2d = out.data_size(); | ||||
|   std::string kernel_name; | ||||
|   { | ||||
|     std::ostringstream kname; | ||||
| @@ -40,6 +41,8 @@ void ternary_op_gpu_inplace( | ||||
|       if (shape.size() <= MAX_TERNARY_SPECIALIZED_DIMS) { | ||||
|         kname << shape.size(); | ||||
|       } | ||||
|     } else if (use_2d) { | ||||
|       kname << "v2"; | ||||
|     } else { | ||||
|       kname << "v"; | ||||
|     } | ||||
|   | ||||
| @@ -25,11 +25,14 @@ void unary_op_gpu_inplace( | ||||
|  | ||||
|   auto& d = metal::device(s.device); | ||||
|  | ||||
|   std::string kernel_name = (contig ? "v" : "g") + op + type_to_name(out); | ||||
|   size_t nthreads = contig ? in.data_size() : in.size(); | ||||
|   bool use_2d = nthreads > UINT32_MAX; | ||||
|   std::string kernel_name = | ||||
|       (contig ? (use_2d ? "v2" : "v") : "g") + op + type_to_name(out); | ||||
|   auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op); | ||||
|  | ||||
|   size_t nthreads = contig ? in.data_size() : in.size(); | ||||
|   MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); | ||||
|   MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides()) | ||||
|                                : MTL::Size(nthreads, 1, 1); | ||||
|   NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); | ||||
|   if (thread_group_size > nthreads) { | ||||
|     thread_group_size = nthreads; | ||||
|   | ||||
| @@ -104,6 +104,35 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) { | ||||
|   return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]}; | ||||
| } | ||||
|  | ||||
| // Computes a 2D grid where each element is < UINT_MAX | ||||
| // Assumes: | ||||
| // - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2 | ||||
| // - shape and strides correspond to a contiguous (no holes) but | ||||
| //   possibly broadcasted array | ||||
| MTL::Size get_2d_grid_dims( | ||||
|     const std::vector<int>& shape, | ||||
|     const std::vector<size_t>& strides) { | ||||
|   // Dims with strides of 0 are ignored as they | ||||
|   // correspond to broadcasted dimensions | ||||
|   size_t grid_x = 1; | ||||
|   size_t grid_y = 1; | ||||
|   for (int i = 0; i < shape.size(); ++i) { | ||||
|     if (strides[i] == 0) { | ||||
|       continue; | ||||
|     } | ||||
|     if (grid_x * shape[i] < UINT32_MAX) { | ||||
|       grid_x *= shape[i]; | ||||
|     } else { | ||||
|       grid_y *= shape[i]; | ||||
|     } | ||||
|   } | ||||
|   if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { | ||||
|     throw std::runtime_error("Unable to safely factor shape."); | ||||
|   } | ||||
|   return MTL::Size( | ||||
|       static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1); | ||||
| } | ||||
|  | ||||
| inline NS::String* make_string(std::ostringstream& os) { | ||||
|   std::string string = os.str(); | ||||
|   return NS::String::string(string.c_str(), NS::UTF8StringEncoding); | ||||
|   | ||||
| @@ -273,7 +273,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { | ||||
|   // Check for the number of indices passed | ||||
|   if (non_none_indices > src.ndim()) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "Too many indices for array with " << src.ndim() << "dimensions."; | ||||
|     msg << "Too many indices for array with " << src.ndim() << " dimensions."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|  | ||||
| @@ -585,7 +585,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd( | ||||
|  | ||||
|   if (non_none_indices > src.ndim()) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "Too many indices for array with " << src.ndim() << "dimensions."; | ||||
|     msg << "Too many indices for array with " << src.ndim() << " dimensions."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|  | ||||
| @@ -840,7 +840,7 @@ auto mlx_slice_update( | ||||
|   // Dimension check | ||||
|   if (non_none_indices > src.ndim()) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "Too many indices for array with " << src.ndim() << "dimensions."; | ||||
|     msg << "Too many indices for array with " << src.ndim() << " dimensions."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun