mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Add contiguous_copy_gpu util for copying array (#2379)
This commit is contained in:
		| @@ -237,8 +237,7 @@ void LayerNorm::eval_gpu( | ||||
|       } | ||||
|       return x; | ||||
|     } else { | ||||
|       auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); | ||||
|       copy_gpu(x, x_copy, CopyType::General, s); | ||||
|       array x_copy = contiguous_copy_gpu(x, s); | ||||
|       out.copy_shared_buffer(x_copy); | ||||
|       return x_copy; | ||||
|     } | ||||
| @@ -295,9 +294,7 @@ void LayerNormVJP::eval_gpu( | ||||
|       return x; | ||||
|     } | ||||
|     copied = true; | ||||
|     array x_copy(x.shape(), x.dtype(), nullptr, {}); | ||||
|     copy_gpu(x, x_copy, CopyType::General, s); | ||||
|     return x_copy; | ||||
|     return contiguous_copy_gpu(x, s); | ||||
|   }; | ||||
|   bool donate_x = inputs[0].is_donatable(); | ||||
|   bool donate_g = inputs[3].is_donatable(); | ||||
|   | ||||
| @@ -108,8 +108,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { | ||||
|       return x; | ||||
|     } else { | ||||
|       auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); | ||||
|       copy_gpu(x, x_copy, CopyType::General, s); | ||||
|       array x_copy = contiguous_copy_gpu(x, s); | ||||
|       encoder.add_temporary(x_copy); | ||||
|       return x_copy; | ||||
|     } | ||||
|   | ||||
| @@ -297,8 +297,7 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) { | ||||
|   } else if (stx == 1 && sty == arr.shape(-2)) { | ||||
|     return std::make_tuple(true, sty, arr); | ||||
|   } else { | ||||
|     array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); | ||||
|     copy_gpu(arr, arr_copy, CopyType::General, s); | ||||
|     array arr_copy = contiguous_copy_gpu(arr, s); | ||||
|     enc.add_temporary(arr_copy); | ||||
|     return std::make_tuple(false, arr.shape(-1), arr_copy); | ||||
|   } | ||||
|   | ||||
| @@ -247,8 +247,7 @@ inline array ensure_row_contiguous( | ||||
|     cu::CommandEncoder& enc, | ||||
|     const Stream& s) { | ||||
|   if (!x.flags().row_contiguous) { | ||||
|     array x_copy(x.shape(), x.dtype(), nullptr, {}); | ||||
|     copy_gpu(x, x_copy, CopyType::General, s); | ||||
|     array x_copy = contiguous_copy_gpu(x, s); | ||||
|     enc.add_temporary(x_copy); | ||||
|     return x_copy; | ||||
|   } else { | ||||
|   | ||||
| @@ -47,8 +47,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     } | ||||
|   } | ||||
|   if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) { | ||||
|     array in_copy(in.shape(), in.dtype(), nullptr, {}); | ||||
|     copy_gpu(in, in_copy, CopyType::General, s); | ||||
|     array in_copy = contiguous_copy_gpu(in, s); | ||||
|     encoder.add_temporary(in_copy); | ||||
|     in = in_copy; | ||||
|     plan = get_reduction_plan(in, axes_); | ||||
|   | ||||
| @@ -206,8 +206,7 @@ void RMSNorm::eval_gpu( | ||||
|       } | ||||
|       return x; | ||||
|     } else { | ||||
|       auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); | ||||
|       copy_gpu(x, x_copy, CopyType::General, s); | ||||
|       array x_copy = contiguous_copy_gpu(x, s); | ||||
|       out.copy_shared_buffer(x_copy); | ||||
|       return x_copy; | ||||
|     } | ||||
| @@ -259,9 +258,7 @@ void RMSNormVJP::eval_gpu( | ||||
|       return x; | ||||
|     } | ||||
|     copied = true; | ||||
|     array x_copy(x.shape(), x.dtype(), nullptr, {}); | ||||
|     copy_gpu(x, x_copy, CopyType::General, s); | ||||
|     return x_copy; | ||||
|     return contiguous_copy_gpu(x, s); | ||||
|   }; | ||||
|   bool donate_x = inputs[0].is_donatable(); | ||||
|   bool donate_g = inputs[2].is_donatable(); | ||||
|   | ||||
| @@ -379,9 +379,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|           in.flags()); | ||||
|     } | ||||
|   } else { | ||||
|     array arr_copy(in.shape(), in.dtype(), nullptr, {}); | ||||
|     copy_gpu(in, arr_copy, CopyType::General, s); | ||||
|     in = std::move(arr_copy); | ||||
|     in = contiguous_copy_gpu(in, s); | ||||
|     out.copy_shared_buffer(in); | ||||
|   } | ||||
|  | ||||
|   | ||||
| @@ -125,8 +125,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|       } | ||||
|       return x; | ||||
|     } else { | ||||
|       auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); | ||||
|       copy_gpu(x, x_copy, CopyType::General, s); | ||||
|       array x_copy = contiguous_copy_gpu(x, s); | ||||
|       out.copy_shared_buffer(x_copy); | ||||
|       return x_copy; | ||||
|     } | ||||
|   | ||||
| @@ -72,8 +72,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { | ||||
|   bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1; | ||||
|   if (!is_segmented_sort) { | ||||
|     array trans = swapaxes_in_eval(in, axis, last_dim); | ||||
|     in = array(trans.shape(), trans.dtype(), nullptr, {}); | ||||
|     copy_gpu(trans, in, CopyType::General, s); | ||||
|     in = contiguous_copy_gpu(trans, s); | ||||
|     encoder.add_temporary(in); | ||||
|     out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); | ||||
|     encoder.add_temporary(out); | ||||
|   | ||||
| @@ -46,4 +46,10 @@ void copy_gpu_inplace( | ||||
|       in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s); | ||||
| } | ||||
|  | ||||
| array contiguous_copy_gpu(const array& arr, const Stream& s) { | ||||
|   array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); | ||||
|   copy_gpu(arr, arr_copy, CopyType::General, s); | ||||
|   return arr_copy; | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
| @@ -43,4 +43,7 @@ void copy_gpu_inplace( | ||||
| // Fill the output with the scalar val | ||||
| void fill_gpu(const array& val, array& out, const Stream& s); | ||||
|  | ||||
| // Return a contiguous array with same shape that copies the data of |arr|. | ||||
| array contiguous_copy_gpu(const array& arr, const Stream& s); | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
| @@ -149,8 +149,7 @@ void explicit_gemm_conv_group_ND_gpu( | ||||
|       wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size()); | ||||
|  | ||||
|   // Materialize | ||||
|   auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {}); | ||||
|   copy_gpu(wt_view, wt_transpose, CopyType::General, s); | ||||
|   array wt_transpose = contiguous_copy_gpu(wt_view, s); | ||||
|  | ||||
|   // Perform gemm | ||||
|   std::vector<array> copies = {in_unfolded, wt_transpose}; | ||||
| @@ -961,16 +960,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   auto in = inputs[0]; | ||||
|   auto wt = inputs[1]; | ||||
|   if (!in.flags().row_contiguous) { | ||||
|     array arr_copy(in.shape(), in.dtype(), nullptr, {}); | ||||
|     copy_gpu(in, arr_copy, CopyType::General, s); | ||||
|     copies.push_back(arr_copy); | ||||
|     in = arr_copy; | ||||
|     in = contiguous_copy_gpu(in, s); | ||||
|     copies.push_back(in); | ||||
|   } | ||||
|   if (!wt.flags().row_contiguous) { | ||||
|     array arr_copy(wt.shape(), wt.dtype(), nullptr, {}); | ||||
|     copy_gpu(wt, arr_copy, CopyType::General, s); | ||||
|     copies.push_back(arr_copy); | ||||
|     wt = arr_copy; | ||||
|     wt = contiguous_copy_gpu(wt, s); | ||||
|     copies.push_back(wt); | ||||
|   } | ||||
|  | ||||
|   // 3D conv | ||||
|   | ||||
| @@ -25,8 +25,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { | ||||
|       return x; | ||||
|     } else { | ||||
|       auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); | ||||
|       copy_gpu(x, x_copy, CopyType::General, s); | ||||
|       array x_copy = contiguous_copy_gpu(x, s); | ||||
|       d.add_temporary(x_copy, s.index); | ||||
|       return x_copy; | ||||
|     } | ||||
|   | ||||
| @@ -33,8 +33,7 @@ std::tuple<bool, int64_t, array> check_transpose( | ||||
|   } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) { | ||||
|     return std::make_tuple(true, sty, arr); | ||||
|   } else { | ||||
|     array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); | ||||
|     copy_gpu(arr, arr_copy, CopyType::General, s); | ||||
|     array arr_copy = contiguous_copy_gpu(arr, s); | ||||
|     copies.push_back(arr_copy); | ||||
|     return std::make_tuple(false, arr.shape(-1), arr_copy); | ||||
|   } | ||||
| @@ -43,8 +42,7 @@ std::tuple<bool, int64_t, array> check_transpose( | ||||
| inline array | ||||
| ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { | ||||
|   if (!x.flags().row_contiguous) { | ||||
|     array x_copy(x.shape(), x.dtype(), nullptr, {}); | ||||
|     copy_gpu(x, x_copy, CopyType::General, s); | ||||
|     array x_copy = contiguous_copy_gpu(x, s); | ||||
|     d.add_temporary(x_copy, s.index); | ||||
|     return x_copy; | ||||
|   } else { | ||||
| @@ -75,8 +73,7 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   array x_copy(x.shape(), x.dtype(), nullptr, {}); | ||||
|   copy_gpu(x, x_copy, CopyType::General, s); | ||||
|   array x_copy = contiguous_copy_gpu(x, s); | ||||
|   d.add_temporary(x_copy, s.index); | ||||
|   return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy); | ||||
| } | ||||
| @@ -1894,8 +1891,7 @@ void segmented_mm( | ||||
|       return std::make_tuple(false, x); | ||||
|     } | ||||
|  | ||||
|     array x_copy(x.shape(), x.dtype(), nullptr, {}); | ||||
|     copy_gpu(x, x_copy, CopyType::General, s); | ||||
|     array x_copy = contiguous_copy_gpu(x, s); | ||||
|     d.add_temporary(x_copy, s.index); | ||||
|     return std::make_tuple(true, x_copy); | ||||
|   }; | ||||
|   | ||||
| @@ -40,8 +40,7 @@ void RMSNorm::eval_gpu( | ||||
|       } | ||||
|       return x; | ||||
|     } else { | ||||
|       auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); | ||||
|       copy_gpu(x, x_copy, CopyType::General, s); | ||||
|       array x_copy = contiguous_copy_gpu(x, s); | ||||
|       out.copy_shared_buffer(x_copy); | ||||
|       return x_copy; | ||||
|     } | ||||
| @@ -107,9 +106,7 @@ void RMSNormVJP::eval_gpu( | ||||
|     if (x.flags().row_contiguous) { | ||||
|       return {x, false}; | ||||
|     } | ||||
|  | ||||
|     array x_copy(x.shape(), x.dtype(), nullptr, {}); | ||||
|     copy_gpu(x, x_copy, CopyType::General, s); | ||||
|     array x_copy = contiguous_copy_gpu(x, s); | ||||
|     return {x_copy, true}; | ||||
|   }; | ||||
|   bool donate_x = inputs[0].is_donatable(); | ||||
| @@ -241,8 +238,7 @@ void LayerNorm::eval_gpu( | ||||
|       } | ||||
|       return x; | ||||
|     } else { | ||||
|       auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); | ||||
|       copy_gpu(x, x_copy, CopyType::General, s); | ||||
|       array x_copy = contiguous_copy_gpu(x, s); | ||||
|       out.copy_shared_buffer(x_copy); | ||||
|       return x_copy; | ||||
|     } | ||||
| @@ -319,8 +315,7 @@ void LayerNormVJP::eval_gpu( | ||||
|     if (x.flags().row_contiguous) { | ||||
|       return {x, false}; | ||||
|     } | ||||
|     array x_copy(x.shape(), x.dtype(), nullptr, {}); | ||||
|     copy_gpu(x, x_copy, CopyType::General, s); | ||||
|     array x_copy = contiguous_copy_gpu(x, s); | ||||
|     return {x_copy, true}; | ||||
|   }; | ||||
|   bool donate_x = inputs[0].is_donatable(); | ||||
|   | ||||
| @@ -20,8 +20,7 @@ namespace { | ||||
| inline array | ||||
| ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { | ||||
|   if (!x.flags().row_contiguous) { | ||||
|     array x_copy(x.shape(), x.dtype(), nullptr, {}); | ||||
|     copy_gpu(x, x_copy, CopyType::General, s); | ||||
|     array x_copy = contiguous_copy_gpu(x, s); | ||||
|     d.add_temporary(x_copy, s.index); | ||||
|     return x_copy; | ||||
|   } else { | ||||
| @@ -38,8 +37,7 @@ inline array ensure_row_contiguous_matrix( | ||||
|   if (stride_0 == x.shape(-1) && stride_1 == 1) { | ||||
|     return x; | ||||
|   } else { | ||||
|     array x_copy(x.shape(), x.dtype(), nullptr, {}); | ||||
|     copy_gpu(x, x_copy, CopyType::General, s); | ||||
|     array x_copy = contiguous_copy_gpu(x, s); | ||||
|     d.add_temporary(x_copy, s.index); | ||||
|     return x_copy; | ||||
|   } | ||||
|   | ||||
| @@ -989,8 +989,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     //       input for the axes with stride smaller than the minimum reduction | ||||
|     //       stride. | ||||
|     if (plan.type == GeneralReduce) { | ||||
|       array in_copy(in.shape(), in.dtype(), nullptr, {}); | ||||
|       copy_gpu(in, in_copy, CopyType::General, s); | ||||
|       array in_copy = contiguous_copy_gpu(in, s); | ||||
|       d.add_temporary(in_copy, s.index); | ||||
|       in = in_copy; | ||||
|       plan = get_reduction_plan(in, axes_); | ||||
|   | ||||
| @@ -398,8 +398,7 @@ void ScaledDotProductAttention::eval_gpu( | ||||
|   auto copy_unless = [&copies, &s]( | ||||
|                          auto predicate, const array& arr) -> const array& { | ||||
|     if (!predicate(arr)) { | ||||
|       array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); | ||||
|       copy_gpu(arr, arr_copy, CopyType::General, s); | ||||
|       array arr_copy = contiguous_copy_gpu(arr, s); | ||||
|       copies.push_back(std::move(arr_copy)); | ||||
|       return copies.back(); | ||||
|     } else { | ||||
|   | ||||
| @@ -30,9 +30,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|           in.flags()); | ||||
|     } | ||||
|   } else { | ||||
|     array arr_copy(in.shape(), in.dtype(), nullptr, {}); | ||||
|     copy_gpu(in, arr_copy, CopyType::General, s); | ||||
|     in = std::move(arr_copy); | ||||
|     in = contiguous_copy_gpu(in, s); | ||||
|     out.copy_shared_buffer(in); | ||||
|   } | ||||
|  | ||||
|   | ||||
| @@ -35,8 +35,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|       } | ||||
|       return x; | ||||
|     } else { | ||||
|       auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); | ||||
|       copy_gpu(x, x_copy, CopyType::General, s); | ||||
|       array x_copy = contiguous_copy_gpu(x, s); | ||||
|       out.copy_shared_buffer(x_copy); | ||||
|       return x_copy; | ||||
|     } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Cheng
					Cheng