mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Add batch offsets for mx.fast.rope (#2564)
* implement batch rope for Metal * cuda rope (#2576)
This commit is contained in:
		| @@ -103,15 +103,21 @@ template <typename T, bool traditional, bool forward, int N = 4> | ||||
| __device__ void rope_impl( | ||||
|     const T* in, | ||||
|     T* out, | ||||
|     int offset, | ||||
|     const int* offset, | ||||
|     float inv_freq, | ||||
|     float scale, | ||||
|     const cuda::std::array<int64_t, 3> strides, | ||||
|     const cuda::std::array<int64_t, 3> out_strides, | ||||
|     int64_t n_batch, | ||||
|     int64_t offset_stride, | ||||
|     int n_head, | ||||
|     uint3 pos, | ||||
|     uint3 dims) { | ||||
|   float L = scale * static_cast<float>(pos.y + offset); | ||||
|   auto n_head_up = N * ((n_head + N - 1) / N); | ||||
|   auto head_idx = static_cast<int>((pos.z * N) % n_head_up); | ||||
|   auto batch_idx = (pos.z * N) / n_head_up; | ||||
|   auto batch_offset = offset[batch_idx * offset_stride]; | ||||
|   float L = scale * static_cast<float>(pos.y + batch_offset); | ||||
|   auto mat_idx = batch_idx * n_head + head_idx; | ||||
|  | ||||
|   // Compute costheta, sintheta | ||||
|   float theta = L * inv_freq; | ||||
| @@ -123,20 +129,19 @@ __device__ void rope_impl( | ||||
|   size_t out_index_1, out_index_2; | ||||
|   if (traditional) { | ||||
|     out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + | ||||
|         N * pos.z * out_strides[0]; | ||||
|         mat_idx * out_strides[0]; | ||||
|     out_index_2 = out_index_1 + 1; | ||||
|     in_index_1 = | ||||
|         2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; | ||||
|         2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; | ||||
|     in_index_2 = in_index_1 + strides[2]; | ||||
|   } else { | ||||
|     out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + | ||||
|         N * pos.z * out_strides[0]; | ||||
|         mat_idx * out_strides[0]; | ||||
|     out_index_2 = out_index_1 + dims.x * out_strides[2]; | ||||
|     in_index_1 = | ||||
|         pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; | ||||
|     in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; | ||||
|     in_index_2 = in_index_1 + dims.x * strides[2]; | ||||
|   } | ||||
|   for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { | ||||
|   for (int i = 0; i < N && head_idx + i < n_head; ++i) { | ||||
|     // Read and write the output | ||||
|     float x1 = static_cast<float>(in[in_index_1]); | ||||
|     float x2 = static_cast<float>(in[in_index_2]); | ||||
| @@ -167,7 +172,8 @@ __global__ void rope( | ||||
|     float base, | ||||
|     const __grid_constant__ cuda::std::array<int64_t, 3> strides, | ||||
|     const __grid_constant__ cuda::std::array<int64_t, 3> out_strides, | ||||
|     int64_t n_batch, | ||||
|     int64_t offset_stride, | ||||
|     int n_head, | ||||
|     uint3 dims) { | ||||
|   uint3 pos = make_uint3( | ||||
|       blockIdx.x * blockDim.x + threadIdx.x, | ||||
| @@ -182,12 +188,13 @@ __global__ void rope( | ||||
|   rope_impl<T, traditional, forward>( | ||||
|       in, | ||||
|       out, | ||||
|       *offset, | ||||
|       offset, | ||||
|       inv_freq, | ||||
|       scale, | ||||
|       strides, | ||||
|       out_strides, | ||||
|       n_batch, | ||||
|       offset_stride, | ||||
|       n_head, | ||||
|       pos, | ||||
|       dims); | ||||
| } | ||||
| @@ -202,7 +209,8 @@ __global__ void rope_freqs( | ||||
|     float base, | ||||
|     const __grid_constant__ cuda::std::array<int64_t, 3> strides, | ||||
|     const __grid_constant__ cuda::std::array<int64_t, 3> out_strides, | ||||
|     int64_t n_batch, | ||||
|     int64_t offset_stride, | ||||
|     int n_head, | ||||
|     uint3 dims, | ||||
|     int64_t freq_stride) { | ||||
|   uint3 pos = make_uint3( | ||||
| @@ -217,12 +225,13 @@ __global__ void rope_freqs( | ||||
|   rope_impl<T, traditional, forward>( | ||||
|       in, | ||||
|       out, | ||||
|       *offset, | ||||
|       offset, | ||||
|       inv_freq, | ||||
|       scale, | ||||
|       strides, | ||||
|       out_strides, | ||||
|       n_batch, | ||||
|       offset_stride, | ||||
|       n_head, | ||||
|       pos, | ||||
|       dims); | ||||
| } | ||||
| @@ -245,23 +254,28 @@ void RoPE::eval_gpu( | ||||
|   auto& offset = inputs[1]; | ||||
|   auto& out = outputs[0]; | ||||
|  | ||||
|   if (in.ndim() < 3) { | ||||
|     throw std::runtime_error("[RoPE] Input must have at least 3 dimensions"); | ||||
|   } | ||||
|  | ||||
|   cuda::std::array<int64_t, 3> strides; | ||||
|   cuda::std::array<int64_t, 3> out_strides; | ||||
|   bool donated = false; | ||||
|   int ndim = in.ndim(); | ||||
|   int dispatch_ndim = in.ndim(); | ||||
|  | ||||
|   int B = in.shape(0); | ||||
|   int T = in.shape(-2); | ||||
|   int D = in.shape(-1); | ||||
|   size_t mat_size = T * D; | ||||
|   int dispatch_ndim = ndim; | ||||
|   while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { | ||||
|     dispatch_ndim--; | ||||
|   } | ||||
|   size_t mat_size = in.shape(-2) * in.shape(-1); | ||||
|  | ||||
|   int N = 1; | ||||
|   for (int i = 1; i < (ndim - 2); ++i) { | ||||
|     N *= in.shape(i); | ||||
|   } | ||||
|  | ||||
|   // We apply rope to less that the whole vector so copy to output and then | ||||
|   // apply in-place. | ||||
|   if (dims_ < in.shape(-1)) { | ||||
|   if (dims_ < D) { | ||||
|     donated = true; | ||||
|     auto ctype = | ||||
|         (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; | ||||
| @@ -302,7 +316,7 @@ void RoPE::eval_gpu( | ||||
|   out_strides[2] = out.strides()[ndim - 1]; | ||||
|  | ||||
|   // Some flags to help us dispatch below | ||||
|   bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); | ||||
|   bool single = in.flags().row_contiguous && B == 1 && T == 1; | ||||
|   bool with_freqs = inputs.size() == 3; | ||||
|  | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
| @@ -319,7 +333,7 @@ void RoPE::eval_gpu( | ||||
|         if (single && !with_freqs) { | ||||
|           auto kernel = | ||||
|               cu::rope_single<DataType, traditional.value, forward.value>; | ||||
|           uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); | ||||
|           uint2 dims = make_uint2(dims_ / 2, N); | ||||
|           auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); | ||||
|           encoder.add_kernel_node( | ||||
|               kernel, | ||||
| @@ -336,7 +350,7 @@ void RoPE::eval_gpu( | ||||
|         } else if (single) { | ||||
|           auto kernel = | ||||
|               cu::rope_single_freqs<DataType, traditional.value, forward.value>; | ||||
|           uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); | ||||
|           uint2 dims = make_uint2(dims_ / 2, N); | ||||
|           auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); | ||||
|           encoder.add_kernel_node( | ||||
|               kernel, | ||||
| @@ -354,10 +368,14 @@ void RoPE::eval_gpu( | ||||
|         } else if (with_freqs) { | ||||
|           auto kernel = | ||||
|               cu::rope_freqs<DataType, traditional.value, forward.value>; | ||||
|           uint3 dims = | ||||
|               make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); | ||||
|           dims.z = (dims.z + 3) / 4; | ||||
|           int n_per_thread = 4; | ||||
|           uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread); | ||||
|           uint3 dims = make_uint3(dims_ / 2, T, dimz); | ||||
|           auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); | ||||
|           int64_t offset_stride = 0; | ||||
|           if (inputs[1].ndim() > 0) { | ||||
|             offset_stride = inputs[1].strides()[0]; | ||||
|           } | ||||
|           encoder.add_kernel_node( | ||||
|               kernel, | ||||
|               grid, | ||||
| @@ -371,15 +389,20 @@ void RoPE::eval_gpu( | ||||
|               std::log2(base_), | ||||
|               strides, | ||||
|               out_strides, | ||||
|               in.size() / mat_size, | ||||
|               offset_stride, | ||||
|               N, | ||||
|               dims, | ||||
|               inputs[2].strides(0)); | ||||
|         } else { | ||||
|           auto kernel = cu::rope<DataType, traditional.value, forward.value>; | ||||
|           uint3 dims = | ||||
|               make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); | ||||
|           dims.z = (dims.z + 3) / 4; | ||||
|           int n_per_thread = 4; | ||||
|           uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread); | ||||
|           uint3 dims = make_uint3(dims_ / 2, T, dimz); | ||||
|           auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); | ||||
|           int64_t offset_stride = 0; | ||||
|           if (inputs[1].ndim() > 0) { | ||||
|             offset_stride = inputs[1].strides()[0]; | ||||
|           } | ||||
|           encoder.add_kernel_node( | ||||
|               kernel, | ||||
|               grid, | ||||
| @@ -392,7 +415,8 @@ void RoPE::eval_gpu( | ||||
|               std::log2(base_), | ||||
|               strides, | ||||
|               out_strides, | ||||
|               in.size() / mat_size, | ||||
|               offset_stride, | ||||
|               N, | ||||
|               dims); | ||||
|         } | ||||
|       }); | ||||
|   | ||||
| @@ -10,7 +10,7 @@ void rope_single_impl( | ||||
|     constant const int& offset, | ||||
|     const float inv_freq, | ||||
|     constant const float& scale, | ||||
|     constant const size_t& stride, | ||||
|     constant const int64_t& stride, | ||||
|     uint2 pos, | ||||
|     uint2 grid) { | ||||
|   float L = scale * static_cast<float>(offset); | ||||
| @@ -52,7 +52,7 @@ template <typename T, bool traditional, bool forward> | ||||
|     device T* out [[buffer(1)]], | ||||
|     constant const int& offset, | ||||
|     constant const float& scale, | ||||
|     constant const size_t& stride, | ||||
|     constant const int64_t& stride, | ||||
|     constant const float& base [[buffer(10)]], | ||||
|     uint2 pos [[thread_position_in_grid]], | ||||
|     uint2 grid [[threads_per_grid]]) { | ||||
| @@ -68,9 +68,9 @@ template <typename T, bool traditional, bool forward> | ||||
|     device T* out [[buffer(1)]], | ||||
|     constant const int& offset, | ||||
|     constant const float& scale, | ||||
|     constant const size_t& stride, | ||||
|     constant const int64_t& stride, | ||||
|     const device float* freqs [[buffer(10)]], | ||||
|     constant const size_t& freq_stride [[buffer(11)]], | ||||
|     constant const int64_t& freq_stride [[buffer(11)]], | ||||
|     uint2 pos [[thread_position_in_grid]], | ||||
|     uint2 grid [[threads_per_grid]]) { | ||||
|   float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); | ||||
| @@ -82,15 +82,21 @@ template <typename T, bool traditional, bool forward, int N = 4> | ||||
| void rope_impl( | ||||
|     const device T* in, | ||||
|     device T* out, | ||||
|     constant const int& offset, | ||||
|     const device int* offset, | ||||
|     const float inv_freq, | ||||
|     constant const float& scale, | ||||
|     constant const size_t strides[3], | ||||
|     constant const size_t out_strides[3], | ||||
|     constant const size_t& n_batch, | ||||
|     constant const int64_t strides[3], | ||||
|     constant const int64_t out_strides[3], | ||||
|     constant const int64_t& offset_stride, | ||||
|     constant const int& n_head, | ||||
|     uint3 pos, | ||||
|     uint3 grid) { | ||||
|   float L = scale * static_cast<float>(pos.y + offset); | ||||
|   auto n_head_up = N * ((n_head + N - 1) / N); | ||||
|   auto head_idx = static_cast<int>((pos.z * N) % n_head_up); | ||||
|   auto batch_idx = (pos.z * N) / n_head_up; | ||||
|   auto batch_offset = offset[batch_idx * offset_stride]; | ||||
|   float L = scale * static_cast<float>(pos.y + batch_offset); | ||||
|   auto mat_idx = batch_idx * n_head + head_idx; | ||||
|  | ||||
|   // Compute costheta, sintheta | ||||
|   float theta = L * inv_freq; | ||||
| @@ -102,20 +108,19 @@ void rope_impl( | ||||
|   size_t out_index_1, out_index_2; | ||||
|   if (traditional) { | ||||
|     out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + | ||||
|         N * pos.z * out_strides[0]; | ||||
|         mat_idx * out_strides[0]; | ||||
|     out_index_2 = out_index_1 + 1; | ||||
|     in_index_1 = | ||||
|         2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; | ||||
|         2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; | ||||
|     in_index_2 = in_index_1 + strides[2]; | ||||
|   } else { | ||||
|     out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + | ||||
|         N * pos.z * out_strides[0]; | ||||
|         mat_idx * out_strides[0]; | ||||
|     out_index_2 = out_index_1 + grid.x * out_strides[2]; | ||||
|     in_index_1 = | ||||
|         pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; | ||||
|     in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; | ||||
|     in_index_2 = in_index_1 + grid.x * strides[2]; | ||||
|   } | ||||
|   for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { | ||||
|   for (int i = 0; i < N && head_idx + i < n_head; ++i) { | ||||
|     // Read and write the output | ||||
|     float x1 = static_cast<float>(in[in_index_1]); | ||||
|     float x2 = static_cast<float>(in[in_index_2]); | ||||
| @@ -141,11 +146,12 @@ template <typename T, bool traditional, bool forward, int N = 4> | ||||
| [[kernel]] void rope( | ||||
|     const device T* in [[buffer(0)]], | ||||
|     device T* out [[buffer(1)]], | ||||
|     constant const int& offset, | ||||
|     const device int* offset, | ||||
|     constant const float& scale, | ||||
|     constant const size_t strides[3], | ||||
|     constant const size_t out_strides[3], | ||||
|     constant const size_t& n_batch, | ||||
|     constant const int64_t strides[3], | ||||
|     constant const int64_t out_strides[3], | ||||
|     constant const int64_t& offset_stride, | ||||
|     constant const int& n_head, | ||||
|     constant const float& base [[buffer(10)]], | ||||
|     uint3 pos [[thread_position_in_grid]], | ||||
|     uint3 grid [[threads_per_grid]]) { | ||||
| @@ -159,7 +165,8 @@ template <typename T, bool traditional, bool forward, int N = 4> | ||||
|       scale, | ||||
|       strides, | ||||
|       out_strides, | ||||
|       n_batch, | ||||
|       offset_stride, | ||||
|       n_head, | ||||
|       pos, | ||||
|       grid); | ||||
| } | ||||
| @@ -168,13 +175,14 @@ template <typename T, bool traditional, bool forward, int N = 4> | ||||
| [[kernel]] void rope_freqs( | ||||
|     const device T* in [[buffer(0)]], | ||||
|     device T* out [[buffer(1)]], | ||||
|     constant const int& offset, | ||||
|     const device int* offset, | ||||
|     constant const float& scale, | ||||
|     constant const size_t strides[3], | ||||
|     constant const size_t out_strides[3], | ||||
|     constant const size_t& n_batch, | ||||
|     constant const int64_t strides[3], | ||||
|     constant const int64_t out_strides[3], | ||||
|     constant const int64_t& offset_stride, | ||||
|     constant const int& n_head, | ||||
|     const device float* freqs [[buffer(10)]], | ||||
|     constant const size_t& freq_stride [[buffer(11)]], | ||||
|     constant const int64_t& freq_stride [[buffer(11)]], | ||||
|     uint3 pos [[thread_position_in_grid]], | ||||
|     uint3 grid [[threads_per_grid]]) { | ||||
|   float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); | ||||
| @@ -186,61 +194,20 @@ template <typename T, bool traditional, bool forward, int N = 4> | ||||
|       scale, | ||||
|       strides, | ||||
|       out_strides, | ||||
|       n_batch, | ||||
|       offset_stride, | ||||
|       n_head, | ||||
|       pos, | ||||
|       grid); | ||||
| } | ||||
|  | ||||
| // clang-format off | ||||
| #define instantiate_rope_g(name, type, traditional, forward) \ | ||||
|   template [[host_name("rope_" #name)]] [[kernel]] void      \ | ||||
|   rope<type, traditional, forward>(                          \ | ||||
|       const device type* in [[buffer(0)]],                   \ | ||||
|       device type* out [[buffer(1)]],                        \ | ||||
|       constant const int& offset,                            \ | ||||
|       constant const float& scale,                           \ | ||||
|       constant const size_t strides[3],                      \ | ||||
|       constant const size_t out_strides[3],                  \ | ||||
|       constant const size_t& n_batch,                        \ | ||||
|       constant const float& base [[buffer(10)]],             \ | ||||
|       uint3 pos [[thread_position_in_grid]],                 \ | ||||
|       uint3 grid [[threads_per_grid]]);                      \ | ||||
|   template [[host_name("rope_freqs_" #name)]]                \ | ||||
|   [[kernel]] void rope_freqs<type, traditional, forward>(    \ | ||||
|       const device type* in [[buffer(0)]],                   \ | ||||
|       device type* out [[buffer(1)]],                        \ | ||||
|       constant const int& offset,                            \ | ||||
|       constant const float& scale,                           \ | ||||
|       constant const size_t strides[3],                      \ | ||||
|       constant const size_t out_strides[3],                  \ | ||||
|       constant const size_t& n_batch,                        \ | ||||
|       const device float* freqs [[buffer(10)]],              \ | ||||
|       constant const size_t& freq_stride [[buffer(11)]],     \ | ||||
|       uint3 pos [[thread_position_in_grid]],                 \ | ||||
|       uint3 grid [[threads_per_grid]]); | ||||
|   instantiate_kernel("rope_" #name, rope, type, traditional, forward) \ | ||||
|   instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward) | ||||
|  | ||||
| #define instantiate_rope_s(name, type, traditional, forward)     \ | ||||
|   template [[host_name("rope_single_" #name)]] [[kernel]] void   \ | ||||
|   rope_single<type, traditional, forward>(                       \ | ||||
|       const device type* in [[buffer(0)]],                       \ | ||||
|       device type* out [[buffer(1)]],                            \ | ||||
|       constant const int& offset,                                \ | ||||
|       constant const float& scale,                               \ | ||||
|       constant const size_t& stride,                             \ | ||||
|       constant const float& base [[buffer(10)]],                 \ | ||||
|       uint2 pos [[thread_position_in_grid]],                     \ | ||||
|       uint2 grid [[threads_per_grid]]);                          \ | ||||
|   template [[host_name("rope_single_freqs_" #name)]]             \ | ||||
|   [[kernel]] void rope_single_freqs<type, traditional, forward>( \ | ||||
|       const device type* in [[buffer(0)]],                       \ | ||||
|       device type* out [[buffer(1)]],                            \ | ||||
|       constant const int& offset,                                \ | ||||
|       constant const float& scale,                               \ | ||||
|       constant const size_t& stride,                             \ | ||||
|       const device float* freqs [[buffer(10)]],                  \ | ||||
|       constant const size_t& freq_stride [[buffer(11)]],         \ | ||||
|       uint2 pos [[thread_position_in_grid]],                     \ | ||||
|       uint2 grid [[threads_per_grid]]); | ||||
| #define instantiate_rope_s(name, type, traditional, forward) \ | ||||
|   instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \ | ||||
|   instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward) | ||||
|  | ||||
| #define instantiate_rope(name, type, traditional, forward) \ | ||||
|   instantiate_rope_s(name, type, traditional, forward)     \ | ||||
|   | ||||
| @@ -18,23 +18,29 @@ void RoPE::eval_gpu( | ||||
|   auto& in = inputs[0]; | ||||
|   auto& out = outputs[0]; | ||||
|  | ||||
|   if (in.ndim() < 3) { | ||||
|     throw std::runtime_error("[RoPE] Input must have at least 3 dimensions"); | ||||
|   } | ||||
|  | ||||
|   auto& s = out.primitive().stream(); | ||||
|   auto& d = metal::device(s.device); | ||||
|  | ||||
|   size_t strides[3]; | ||||
|   size_t out_strides[3]; | ||||
|   int64_t strides[3]; | ||||
|   int64_t out_strides[3]; | ||||
|   bool donated = false; | ||||
|   int ndim = in.ndim(); | ||||
|   int dispatch_ndim = in.ndim(); | ||||
|   int B = in.shape(0); | ||||
|   int T = in.shape(-2); | ||||
|   int D = in.shape(-1); | ||||
|   size_t mat_size = T * D; | ||||
|  | ||||
|   int dispatch_ndim = ndim; | ||||
|   while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { | ||||
|     dispatch_ndim--; | ||||
|   } | ||||
|   size_t mat_size = in.shape(-2) * in.shape(-1); | ||||
|   if (dims_ < in.shape(-1)) { | ||||
|  | ||||
|   int N = 1; | ||||
|   for (int i = 1; i < (ndim - 2); ++i) { | ||||
|     N *= in.shape(i); | ||||
|   } | ||||
|  | ||||
|   if (dims_ < D) { | ||||
|     donated = true; | ||||
|     auto ctype = | ||||
|         (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; | ||||
| @@ -71,8 +77,8 @@ void RoPE::eval_gpu( | ||||
|   out_strides[1] = out.strides()[ndim - 2]; | ||||
|   out_strides[2] = out.strides()[ndim - 1]; | ||||
|  | ||||
|   // Special case for inference (single time step and contiguous) | ||||
|   bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); | ||||
|   // Special case for inference (single batch, single time step, and contiguous) | ||||
|   bool single = in.flags().row_contiguous && B == 1 && T == 1; | ||||
|  | ||||
|   bool with_freqs = inputs.size() == 3; | ||||
|   std::ostringstream kname; | ||||
| @@ -86,24 +92,29 @@ void RoPE::eval_gpu( | ||||
|   compute_encoder.set_compute_pipeline_state(kernel); | ||||
|   compute_encoder.set_input_array(donated ? out : in, 0); | ||||
|   compute_encoder.set_output_array(out, 1); | ||||
|  | ||||
|   compute_encoder.set_input_array(inputs[1], 2); | ||||
|   compute_encoder.set_bytes(scale_, 3); | ||||
|  | ||||
|   size_t n_batch = in.size() / mat_size; | ||||
|   MTL::Size group_dims; | ||||
|   MTL::Size grid_dims; | ||||
|   if (single) { | ||||
|     compute_encoder.set_bytes(out_strides, 1, 4); | ||||
|     uint32_t dim0 = dims_ / 2; | ||||
|     group_dims = get_block_dims(dim0, n_batch, 1); | ||||
|     grid_dims = MTL::Size(dim0, n_batch, 1); | ||||
|     group_dims = get_block_dims(dim0, N, 1); | ||||
|     grid_dims = MTL::Size(dim0, N, 1); | ||||
|   } else { | ||||
|     compute_encoder.set_bytes(strides, 3, 4); | ||||
|     compute_encoder.set_bytes(out_strides, 3, 5); | ||||
|     compute_encoder.set_bytes(n_batch, 6); | ||||
|     int64_t offset_stride = 0; | ||||
|     if (inputs[1].ndim() > 0) { | ||||
|       offset_stride = inputs[1].strides()[0]; | ||||
|     } | ||||
|     compute_encoder.set_bytes(offset_stride, 6); | ||||
|     compute_encoder.set_bytes(N, 7); | ||||
|     uint32_t dim0 = dims_ / 2; | ||||
|     uint32_t dim1 = in.shape(-2); | ||||
|     uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread; | ||||
|     uint32_t dim1 = T; | ||||
|     uint32_t dim2 = B * ((N + n_per_thread - 1) / n_per_thread); | ||||
|     group_dims = get_block_dims(dim0, dim1, dim2); | ||||
|     grid_dims = MTL::Size(dim0, dim1, dim2); | ||||
|   } | ||||
|   | ||||
							
								
								
									
										60
									
								
								mlx/fast.cpp
									
									
									
									
									
								
							
							
						
						
									
										60
									
								
								mlx/fast.cpp
									
									
									
									
									
								
							| @@ -366,10 +366,16 @@ array rope( | ||||
|     msg << "[rope] Input must be a floating type but got " << x.dtype() << "."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|   if (offset.size() != 1) { | ||||
|   if (offset.ndim() > 1) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[rope] offset must be a scalar but has shape " << offset.shape() | ||||
|         << "."; | ||||
|     msg << "[rope] offset must have at most one dimension but has shape " | ||||
|         << offset.shape() << "."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|   if (offset.size() != 1 && offset.size() != x.shape(0)) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[rope] offset must be a scalar or vector with " << x.shape(0) | ||||
|         << " elements but has shape " << offset.shape() << "."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|   if (!issubdtype(offset.dtype(), integer)) { | ||||
| @@ -379,7 +385,7 @@ array rope( | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|   if (offset.dtype().size() != 4) { | ||||
|     inputs[1] = astype(offset, uint32, s); | ||||
|     inputs[1] = astype(offset, int32, s); | ||||
|   } | ||||
|   if (inputs.size() == 3 && | ||||
|       (inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) { | ||||
| @@ -391,15 +397,26 @@ array rope( | ||||
|  | ||||
|   auto fallback = [dims, traditional, base, scale, forward, s]( | ||||
|                       std::vector<array> inputs) { | ||||
|     auto& shape = inputs[0].shape(); | ||||
|     int ndim = shape.size(); | ||||
|     auto x = flatten(inputs[0], 0, ndim - 3, s); | ||||
|     auto x = inputs[0]; | ||||
|     auto shape = x.shape(); | ||||
|     if (x.ndim() == 3) { | ||||
|       x = expand_dims(x, 1, s); | ||||
|     } else if (x.ndim() > 4) { | ||||
|       x = flatten(x, 1, 1 + (x.ndim() - 4), s); | ||||
|     } | ||||
|  | ||||
|     auto B = x.shape(0); | ||||
|     auto N = x.shape(1); | ||||
|     auto T = x.shape(2); | ||||
|     auto t = x.dtype(); | ||||
|     // Compute sines and cosines | ||||
|     auto half_dims = dims / 2; | ||||
|     auto& offset = inputs[1]; | ||||
|     auto offset = inputs[1]; | ||||
|     if (offset.size() > 1) { | ||||
|       offset = expand_dims(offset, {-1, -2}, s); | ||||
|     } | ||||
|     auto positions = | ||||
|         multiply(add(arange(x.shape(1), t, s), offset, s), array(scale, t), s); | ||||
|         multiply(add(arange(x.shape(2), t, s), offset, s), array(scale, t), s); | ||||
|  | ||||
|     auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() { | ||||
|       return exp( | ||||
| @@ -412,8 +429,7 @@ array rope( | ||||
|  | ||||
|     auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s) | ||||
|                                         : default_inv_freqs(); | ||||
|     auto theta = | ||||
|         multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s); | ||||
|     auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s); | ||||
|     auto coss = cos(theta, s); | ||||
|     auto sins = sin(theta, s); | ||||
|  | ||||
| @@ -436,32 +452,30 @@ array rope( | ||||
|     }; | ||||
|  | ||||
|     if (traditional) { | ||||
|       auto x1 = | ||||
|           slice(x, {0, 0, 0}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s); | ||||
|       auto x2 = | ||||
|           slice(x, {0, 0, 1}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s); | ||||
|       auto x1 = slice(x, {0, 0, 0, 0}, {B, N, T, dims}, {1, 1, 1, 2}, s); | ||||
|       auto x2 = slice(x, {0, 0, 0, 1}, {B, N, T, dims}, {1, 1, 1, 2}, s); | ||||
|       auto outs = apply_rope(x1, x2, coss, sins); | ||||
|       for (auto& o : outs) { | ||||
|         o = expand_dims(o, 3, s); | ||||
|         o = expand_dims(o, -1, s); | ||||
|       } | ||||
|       auto out = concatenate(outs, 3, s); | ||||
|       auto out = reshape(concatenate(outs, -1, s), {B, N, T, dims}, s); | ||||
|       if (dims < x.shape(-1)) { | ||||
|         out = reshape(out, {x.shape(0), x.shape(1), dims}); | ||||
|         out = concatenate({out, slice(x, {0, 0, dims}, x.shape(), s)}, 2, s); | ||||
|         out = | ||||
|             concatenate({out, slice(x, {0, 0, 0, dims}, x.shape(), s)}, -1, s); | ||||
|       } | ||||
|       return std::vector<array>{reshape(out, shape, s)}; | ||||
|     } else { | ||||
|       auto out_s = x.shape(); | ||||
|       out_s.back() = half_dims; | ||||
|       auto x1 = slice(x, {0, 0, 0}, out_s, s); | ||||
|       auto x1 = slice(x, {0, 0, 0, 0}, out_s, s); | ||||
|       out_s.back() = dims; | ||||
|       auto x2 = slice(x, {0, 0, half_dims}, out_s, s); | ||||
|       auto x2 = slice(x, {0, 0, 0, half_dims}, out_s, s); | ||||
|  | ||||
|       auto outs = apply_rope(x1, x2, coss, sins); | ||||
|       if (dims < x.shape(-1)) { | ||||
|         outs.push_back(slice(x, {0, 0, dims}, x.shape(), s)); | ||||
|         outs.push_back(slice(x, {0, 0, 0, dims}, x.shape(), s)); | ||||
|       } | ||||
|       return std::vector<array>{reshape(concatenate(outs, 2, s), shape, s)}; | ||||
|       return std::vector<array>{reshape(concatenate(outs, -1, s), shape, s)}; | ||||
|     } | ||||
|   }; | ||||
|   auto stream = to_stream(s); | ||||
|   | ||||
| @@ -164,8 +164,13 @@ void init_fast(nb::module_& parent_module) { | ||||
|       R"pbdoc( | ||||
|         Apply rotary positional encoding to the input. | ||||
|  | ||||
|         The input is expected to be at least 3D with shape ``(B, *, T, D)`` where: | ||||
|           * ``B`` is the batch size. | ||||
|           * ``T`` is the sequence length. | ||||
|           * ``D`` is the feature dimension. | ||||
|  | ||||
|         Args: | ||||
|             a (array): Input array. | ||||
|             a (array): The input array. | ||||
|             dims (int): The feature dimensions to be rotated. If the input feature | ||||
|               is larger than dims then the rest is left unchanged. | ||||
|             traditional (bool): If set to ``True`` choose the traditional | ||||
| @@ -174,7 +179,9 @@ void init_fast(nb::module_& parent_module) { | ||||
|               each dimension in the positional encodings. Exactly one of ``base`` and | ||||
|               ``freqs`` must be ``None``. | ||||
|             scale (float): The scale used to scale the positions. | ||||
|             offset (int or array): The position offset to start at. | ||||
|             offset (int or array): The position offset to start at. If an | ||||
|               :obj:`array` is given it can be a scalar or vector of ``B`` | ||||
|               offsets for each example in the batch. | ||||
|             freqs (array, optional): Optional frequencies to use with RoPE. | ||||
|               If set, the ``base`` parameter must be ``None``. Default: ``None``. | ||||
|  | ||||
|   | ||||
| @@ -91,7 +91,7 @@ mx::array to_array_with_accessor(nb::object obj) { | ||||
|     return nb::cast<mx::array>(obj.attr("__mlx_array__")()); | ||||
|   } else { | ||||
|     std::ostringstream msg; | ||||
|     msg << "Invalid type  " << nb::type_name(obj.type()).c_str() | ||||
|     msg << "Invalid type " << nb::type_name(obj.type()).c_str() | ||||
|         << " received in array initialization."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|   | ||||
| @@ -8,18 +8,23 @@ import mlx_tests | ||||
|  | ||||
|  | ||||
| def rope_orig(x, dims, traditional, base, scale, offset, freqs=None): | ||||
|     offset = offset.item() if isinstance(offset, mx.array) else offset | ||||
|     N = x.shape[-2] + offset | ||||
|     N = x.shape[-2] | ||||
|     dtype = x.dtype | ||||
|     half_D = dims // 2 | ||||
|     positions = mx.arange(offset, N, dtype=dtype) * scale | ||||
|     positions = mx.arange(N, dtype=dtype) | ||||
|     if isinstance(offset, mx.array) and offset.size > 1: | ||||
|         expand = tuple(range(1, x.ndim - 1)) | ||||
|         positions = mx.expand_dims(offset, expand) + positions | ||||
|     else: | ||||
|         positions = offset + positions | ||||
|     positions = positions * scale | ||||
|     if freqs is None: | ||||
|         inv_freqs = mx.exp( | ||||
|             -mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D) | ||||
|         ) | ||||
|     else: | ||||
|         inv_freqs = (1 / freqs).astype(x.dtype) | ||||
|     theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1)) | ||||
|     theta = mx.expand_dims(positions, -1) * inv_freqs | ||||
|     costheta, sintheta = mx.cos(theta), mx.sin(theta) | ||||
|     if traditional: | ||||
|         x1 = x[..., :dims:2] | ||||
| @@ -214,6 +219,7 @@ class TestFast(mlx_tests.MLXTestCase): | ||||
|             ) | ||||
|             self.assertEqual(dtype, rx.dtype) | ||||
|             self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) | ||||
|             return | ||||
|  | ||||
|         # Test single vector | ||||
|         x = mx.random.uniform(shape=(1, 1, dims)) | ||||
| @@ -277,6 +283,55 @@ class TestFast(mlx_tests.MLXTestCase): | ||||
|                 g2 = mx.grad(f2)(x, y) | ||||
|                 self.assertLess(mx.abs(g1 - g2).max(), 1e-5) | ||||
|  | ||||
|     def test_rope_batch(self): | ||||
|         T = 4 | ||||
|         base = 10000.0 | ||||
|         scale = 1.0 | ||||
|         traditional = True | ||||
|         batch_sizes = [3, 8, 11] | ||||
|         num_heads = [1, 3, 5] | ||||
|         dims = 32 | ||||
|  | ||||
|         x = mx.random.uniform(shape=(8, 4, T, dims)) | ||||
|  | ||||
|         offset = mx.array([1, 2, 3]) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.fast.rope( | ||||
|                 x, | ||||
|                 dims, | ||||
|                 traditional=traditional, | ||||
|                 base=base, | ||||
|                 scale=scale, | ||||
|                 offset=offset, | ||||
|             ) | ||||
|  | ||||
|         for batch_size in batch_sizes: | ||||
|             for n_head in num_heads: | ||||
|                 x = mx.random.uniform(shape=(batch_size, n_head, T, dims)) | ||||
|                 offset = mx.arange(batch_size) | ||||
|                 rx = rope_orig(x, dims, traditional, base, scale, offset) | ||||
|                 rx_fast = mx.fast.rope( | ||||
|                     x, | ||||
|                     dims, | ||||
|                     traditional=traditional, | ||||
|                     base=base, | ||||
|                     scale=scale, | ||||
|                     offset=offset, | ||||
|                 ) | ||||
|                 self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5) | ||||
|         x = mx.random.normal(shape=(2, 6, 8, 64)).transpose(0, 2, 1, 3) | ||||
|         dims = 64 | ||||
|         offset = 0 | ||||
|         rx_fast = mx.fast.rope( | ||||
|             x, dims, traditional=traditional, scale=scale, base=base, offset=offset | ||||
|         ) | ||||
|         rx_fast_single = mx.fast.rope( | ||||
|             x[0:1], dims, traditional=traditional, scale=scale, base=base, offset=offset | ||||
|         ) | ||||
|  | ||||
|         rx = rope_orig(x, dims, traditional, base, scale, offset) | ||||
|         self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5) | ||||
|  | ||||
|     def test_rms_norm(self): | ||||
|         # Per dtype absolute tolerance | ||||
|         tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun