mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Rename the copy util in cpu/copy.h to copy_cpu (#2378)
This commit is contained in:
@@ -883,7 +883,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
// Fill with zeros
|
||||
std::vector<array> temps;
|
||||
temps.push_back(array(0, conv_dtype));
|
||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = padding_lo[0] * in_padded.strides()[1];
|
||||
@@ -895,7 +895,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
// Copy input values into the slice
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
||||
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
||||
temps.push_back(in_padded_slice);
|
||||
|
||||
// Make strided view
|
||||
@@ -920,7 +920,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
// Materialize strided view
|
||||
Shape strided_reshape = {N * oH, wH * C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy(in_strided_view, in_strided, CopyType::General, stream);
|
||||
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
|
||||
temps.push_back(in_strided);
|
||||
|
||||
// Check wt dtype and prepare
|
||||
@@ -938,13 +938,13 @@ void explicit_gemm_conv_1D_cpu(
|
||||
wt.size(),
|
||||
0);
|
||||
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
|
||||
copy(wt_transpose, gemm_wt, CopyType::General, stream);
|
||||
copy_cpu(wt_transpose, gemm_wt, CopyType::General, stream);
|
||||
temps.push_back(gemm_wt);
|
||||
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||
auto ctype =
|
||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||
copy(wt, gemm_wt, ctype, stream);
|
||||
copy_cpu(wt, gemm_wt, ctype, stream);
|
||||
temps.push_back(gemm_wt);
|
||||
}
|
||||
|
||||
@@ -991,7 +991,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
|
||||
// Copy results if needed
|
||||
if (out.dtype() != float32) {
|
||||
copy_inplace(gemm_out, out, CopyType::Vector, stream);
|
||||
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
|
||||
}
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
}
|
||||
@@ -1029,7 +1029,7 @@ void explicit_gemm_conv_2D_cpu(
|
||||
// Fill with zeros
|
||||
std::vector<array> temps;
|
||||
temps.push_back(array(0, conv_dtype));
|
||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
|
||||
@@ -1044,7 +1044,7 @@ void explicit_gemm_conv_2D_cpu(
|
||||
temps.push_back(in_padded_slice);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
||||
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
||||
|
||||
// Make strided view
|
||||
Shape strided_shape = {N, oH, oW, wH, wW, C};
|
||||
@@ -1065,7 +1065,7 @@ void explicit_gemm_conv_2D_cpu(
|
||||
// Materialize strided view
|
||||
Shape strided_reshape = {N * oH * oW, wH * wW * C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy(in_strided_view, in_strided, CopyType::General, stream);
|
||||
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
|
||||
temps.push_back(in_strided);
|
||||
|
||||
// Check wt dtype and prepare
|
||||
@@ -1076,7 +1076,7 @@ void explicit_gemm_conv_2D_cpu(
|
||||
auto ctype =
|
||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||
copy(wt, gemm_wt, ctype, stream);
|
||||
copy_cpu(wt, gemm_wt, ctype, stream);
|
||||
temps.push_back(gemm_wt);
|
||||
}
|
||||
|
||||
@@ -1116,7 +1116,7 @@ void explicit_gemm_conv_2D_cpu(
|
||||
|
||||
// Copy results if needed
|
||||
if (out.dtype() != float32) {
|
||||
copy_inplace(gemm_out, out, CopyType::Vector, stream);
|
||||
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
|
||||
}
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
}
|
||||
@@ -1156,7 +1156,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
|
||||
// Fill with zeros
|
||||
std::vector<array> temps = {array(0, conv_dtype)};
|
||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = 0;
|
||||
@@ -1173,7 +1173,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
||||
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
||||
temps.push_back(in_padded_slice);
|
||||
|
||||
// Make strided view
|
||||
@@ -1212,7 +1212,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
}
|
||||
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy(in_strided_view, in_strided, CopyType::General, stream);
|
||||
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
|
||||
temps.push_back(in_strided);
|
||||
|
||||
// Check wt dtype and prepare
|
||||
@@ -1223,13 +1223,13 @@ void explicit_gemm_conv_ND_cpu(
|
||||
auto ctype =
|
||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||
copy(wt, gemm_wt, ctype, stream);
|
||||
copy_cpu(wt, gemm_wt, ctype, stream);
|
||||
temps.push_back(gemm_wt);
|
||||
}
|
||||
|
||||
if (flip) {
|
||||
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
|
||||
copy(gemm_wt, gemm_wt_, CopyType::Vector, stream);
|
||||
copy_cpu(gemm_wt, gemm_wt_, CopyType::Vector, stream);
|
||||
temps.push_back(gemm_wt_);
|
||||
|
||||
// Calculate the total size of the spatial dimensions
|
||||
@@ -1284,7 +1284,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
|
||||
// Copy results if needed
|
||||
if (out.dtype() != float32) {
|
||||
copy_inplace(gemm_out, out, CopyType::Vector, stream);
|
||||
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
|
||||
}
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user