mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
Rename the copy util in cpu/copy.h to copy_cpu (#2378)
This commit is contained in:
parent
d7734edd9f
commit
30571e2326
@ -20,7 +20,7 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
|
|||||||
|
|
||||||
// The decomposition is computed in place, so just copy the input to the
|
// The decomposition is computed in place, so just copy the input to the
|
||||||
// output.
|
// output.
|
||||||
copy(
|
copy_cpu(
|
||||||
a,
|
a,
|
||||||
factor,
|
factor,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
@ -883,7 +883,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
std::vector<array> temps;
|
std::vector<array> temps;
|
||||||
temps.push_back(array(0, conv_dtype));
|
temps.push_back(array(0, conv_dtype));
|
||||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset = padding_lo[0] * in_padded.strides()[1];
|
size_t data_offset = padding_lo[0] * in_padded.strides()[1];
|
||||||
@ -895,7 +895,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
in_padded_slice.size(),
|
in_padded_slice.size(),
|
||||||
data_offset);
|
data_offset);
|
||||||
// Copy input values into the slice
|
// Copy input values into the slice
|
||||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
||||||
temps.push_back(in_padded_slice);
|
temps.push_back(in_padded_slice);
|
||||||
|
|
||||||
// Make strided view
|
// Make strided view
|
||||||
@ -920,7 +920,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
// Materialize strided view
|
// Materialize strided view
|
||||||
Shape strided_reshape = {N * oH, wH * C};
|
Shape strided_reshape = {N * oH, wH * C};
|
||||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
copy(in_strided_view, in_strided, CopyType::General, stream);
|
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
|
||||||
temps.push_back(in_strided);
|
temps.push_back(in_strided);
|
||||||
|
|
||||||
// Check wt dtype and prepare
|
// Check wt dtype and prepare
|
||||||
@ -938,13 +938,13 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
wt.size(),
|
wt.size(),
|
||||||
0);
|
0);
|
||||||
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
|
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
|
||||||
copy(wt_transpose, gemm_wt, CopyType::General, stream);
|
copy_cpu(wt_transpose, gemm_wt, CopyType::General, stream);
|
||||||
temps.push_back(gemm_wt);
|
temps.push_back(gemm_wt);
|
||||||
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||||
auto ctype =
|
auto ctype =
|
||||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||||
copy(wt, gemm_wt, ctype, stream);
|
copy_cpu(wt, gemm_wt, ctype, stream);
|
||||||
temps.push_back(gemm_wt);
|
temps.push_back(gemm_wt);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -991,7 +991,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
|
|
||||||
// Copy results if needed
|
// Copy results if needed
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
copy_inplace(gemm_out, out, CopyType::Vector, stream);
|
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
|
||||||
}
|
}
|
||||||
encoder.add_temporaries(std::move(temps));
|
encoder.add_temporaries(std::move(temps));
|
||||||
}
|
}
|
||||||
@ -1029,7 +1029,7 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
std::vector<array> temps;
|
std::vector<array> temps;
|
||||||
temps.push_back(array(0, conv_dtype));
|
temps.push_back(array(0, conv_dtype));
|
||||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
|
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
|
||||||
@ -1044,7 +1044,7 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
temps.push_back(in_padded_slice);
|
temps.push_back(in_padded_slice);
|
||||||
|
|
||||||
// Copy input values into the slice
|
// Copy input values into the slice
|
||||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
||||||
|
|
||||||
// Make strided view
|
// Make strided view
|
||||||
Shape strided_shape = {N, oH, oW, wH, wW, C};
|
Shape strided_shape = {N, oH, oW, wH, wW, C};
|
||||||
@ -1065,7 +1065,7 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
// Materialize strided view
|
// Materialize strided view
|
||||||
Shape strided_reshape = {N * oH * oW, wH * wW * C};
|
Shape strided_reshape = {N * oH * oW, wH * wW * C};
|
||||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
copy(in_strided_view, in_strided, CopyType::General, stream);
|
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
|
||||||
temps.push_back(in_strided);
|
temps.push_back(in_strided);
|
||||||
|
|
||||||
// Check wt dtype and prepare
|
// Check wt dtype and prepare
|
||||||
@ -1076,7 +1076,7 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
auto ctype =
|
auto ctype =
|
||||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||||
copy(wt, gemm_wt, ctype, stream);
|
copy_cpu(wt, gemm_wt, ctype, stream);
|
||||||
temps.push_back(gemm_wt);
|
temps.push_back(gemm_wt);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1116,7 +1116,7 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
|
|
||||||
// Copy results if needed
|
// Copy results if needed
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
copy_inplace(gemm_out, out, CopyType::Vector, stream);
|
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
|
||||||
}
|
}
|
||||||
encoder.add_temporaries(std::move(temps));
|
encoder.add_temporaries(std::move(temps));
|
||||||
}
|
}
|
||||||
@ -1156,7 +1156,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
|
|
||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
std::vector<array> temps = {array(0, conv_dtype)};
|
std::vector<array> temps = {array(0, conv_dtype)};
|
||||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset = 0;
|
size_t data_offset = 0;
|
||||||
@ -1173,7 +1173,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
data_offset);
|
data_offset);
|
||||||
|
|
||||||
// Copy input values into the slice
|
// Copy input values into the slice
|
||||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
||||||
temps.push_back(in_padded_slice);
|
temps.push_back(in_padded_slice);
|
||||||
|
|
||||||
// Make strided view
|
// Make strided view
|
||||||
@ -1212,7 +1212,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
copy(in_strided_view, in_strided, CopyType::General, stream);
|
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
|
||||||
temps.push_back(in_strided);
|
temps.push_back(in_strided);
|
||||||
|
|
||||||
// Check wt dtype and prepare
|
// Check wt dtype and prepare
|
||||||
@ -1223,13 +1223,13 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
auto ctype =
|
auto ctype =
|
||||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||||
copy(wt, gemm_wt, ctype, stream);
|
copy_cpu(wt, gemm_wt, ctype, stream);
|
||||||
temps.push_back(gemm_wt);
|
temps.push_back(gemm_wt);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (flip) {
|
if (flip) {
|
||||||
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
|
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
|
||||||
copy(gemm_wt, gemm_wt_, CopyType::Vector, stream);
|
copy_cpu(gemm_wt, gemm_wt_, CopyType::Vector, stream);
|
||||||
temps.push_back(gemm_wt_);
|
temps.push_back(gemm_wt_);
|
||||||
|
|
||||||
// Calculate the total size of the spatial dimensions
|
// Calculate the total size of the spatial dimensions
|
||||||
@ -1284,7 +1284,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
|
|
||||||
// Copy results if needed
|
// Copy results if needed
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
copy_inplace(gemm_out, out, CopyType::Vector, stream);
|
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
|
||||||
}
|
}
|
||||||
encoder.add_temporaries(std::move(temps));
|
encoder.add_temporaries(std::move(temps));
|
||||||
}
|
}
|
||||||
|
@ -295,7 +295,11 @@ inline void copy_inplace_dispatch(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
|
void copy_cpu_inplace(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
CopyType ctype,
|
||||||
|
Stream stream) {
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
encoder.set_input_array(src);
|
encoder.set_input_array(src);
|
||||||
encoder.set_output_array(dst);
|
encoder.set_output_array(dst);
|
||||||
@ -305,7 +309,7 @@ void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
|
|||||||
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
|
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
|
||||||
}
|
}
|
||||||
|
|
||||||
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
|
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||||
bool donated = set_copy_output_data(src, dst, ctype);
|
bool donated = set_copy_output_data(src, dst, ctype);
|
||||||
if (donated && src.dtype() == dst.dtype()) {
|
if (donated && src.dtype() == dst.dtype()) {
|
||||||
// If the output has the same type as the input then there is nothing to
|
// If the output has the same type as the input then there is nothing to
|
||||||
@ -315,10 +319,10 @@ void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
|
|||||||
if (ctype == CopyType::GeneralGeneral) {
|
if (ctype == CopyType::GeneralGeneral) {
|
||||||
ctype = CopyType::General;
|
ctype = CopyType::General;
|
||||||
}
|
}
|
||||||
copy_inplace(src, dst, ctype, stream);
|
copy_cpu_inplace(src, dst, ctype, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void copy_inplace(
|
void copy_cpu_inplace(
|
||||||
const array& src,
|
const array& src,
|
||||||
array& dst,
|
array& dst,
|
||||||
const Shape& data_shape,
|
const Shape& data_shape,
|
||||||
|
@ -10,10 +10,14 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void copy(const array& src, array& dst, CopyType ctype, Stream stream);
|
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||||
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
|
void copy_cpu_inplace(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
CopyType ctype,
|
||||||
|
Stream stream);
|
||||||
|
|
||||||
void copy_inplace(
|
void copy_cpu_inplace(
|
||||||
const array& src,
|
const array& src,
|
||||||
array& dst,
|
array& dst,
|
||||||
const Shape& data_shape,
|
const Shape& data_shape,
|
||||||
|
@ -14,7 +14,7 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
|
|||||||
return {arr, false};
|
return {arr, false};
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
copy(arr, arr_copy, CopyType::General, stream);
|
copy_cpu(arr, arr_copy, CopyType::General, stream);
|
||||||
return {arr_copy, true};
|
return {arr_copy, true};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -35,7 +35,7 @@ void AllReduce::eval_cpu(
|
|||||||
return in;
|
return in;
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||||
copy(in, arr_copy, CopyType::General, s);
|
copy_cpu(in, arr_copy, CopyType::General, s);
|
||||||
out.copy_shared_buffer(arr_copy);
|
out.copy_shared_buffer(arr_copy);
|
||||||
return arr_copy;
|
return arr_copy;
|
||||||
}
|
}
|
||||||
|
@ -135,7 +135,7 @@ void Eig::eval_cpu(
|
|||||||
: array(a.shape(), complex64, nullptr, {});
|
: array(a.shape(), complex64, nullptr, {});
|
||||||
|
|
||||||
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
|
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
|
||||||
copy(
|
copy_cpu(
|
||||||
a,
|
a,
|
||||||
a_copy,
|
a_copy,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
@ -196,7 +196,7 @@ void Eigh::eval_cpu(
|
|||||||
|
|
||||||
values.set_data(allocator::malloc(values.nbytes()));
|
values.set_data(allocator::malloc(values.nbytes()));
|
||||||
|
|
||||||
copy(
|
copy_cpu(
|
||||||
a,
|
a,
|
||||||
vectors,
|
vectors,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
@ -96,7 +96,7 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (in.flags().row_contiguous && in.is_donatable()) {
|
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
copy(
|
copy_cpu(
|
||||||
in,
|
in,
|
||||||
out,
|
out,
|
||||||
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
@ -517,7 +517,7 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Copy src into out (copy allocates memory for out)
|
// Copy src into out (copy allocates memory for out)
|
||||||
auto ctype =
|
auto ctype =
|
||||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
copy(src, out, ctype, stream());
|
copy_cpu(src, out, ctype, stream());
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
std::vector<array> inds;
|
std::vector<array> inds;
|
||||||
@ -686,7 +686,7 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Copy src into out (copy allocates memory for out)
|
// Copy src into out (copy allocates memory for out)
|
||||||
auto ctype =
|
auto ctype =
|
||||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
copy(src, out, ctype, stream());
|
copy_cpu(src, out, ctype, stream());
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_input_array(idx);
|
encoder.set_input_array(idx);
|
||||||
|
@ -115,7 +115,7 @@ void inverse_impl(
|
|||||||
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
||||||
|
|
||||||
// The inverse is computed in place, so just copy the input to the output.
|
// The inverse is computed in place, so just copy the input to the output.
|
||||||
copy(
|
copy_cpu(
|
||||||
a,
|
a,
|
||||||
inv,
|
inv,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
@ -88,7 +88,7 @@ void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||||
copy(x, x_copy, CopyType::General, s);
|
copy_cpu(x, x_copy, CopyType::General, s);
|
||||||
encoder.add_temporary(x_copy);
|
encoder.add_temporary(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
@ -31,7 +31,7 @@ void luf_impl(
|
|||||||
strides[ndim - 1] = M;
|
strides[ndim - 1] = M;
|
||||||
strides[ndim - 2] = 1;
|
strides[ndim - 2] = 1;
|
||||||
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
|
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||||
copy_inplace(
|
copy_cpu_inplace(
|
||||||
a,
|
a,
|
||||||
lu,
|
lu,
|
||||||
a.shape(),
|
a.shape(),
|
||||||
|
@ -124,20 +124,20 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
|
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
|
||||||
if (do_copy) {
|
if (do_copy) {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
copy(arr, arr_copy, CopyType::Vector, s);
|
copy_cpu(arr, arr_copy, CopyType::Vector, s);
|
||||||
return std::make_tuple(false, stx, arr_copy, true);
|
return std::make_tuple(false, stx, arr_copy, true);
|
||||||
}
|
}
|
||||||
return std::make_tuple(false, stx, arr, false);
|
return std::make_tuple(false, stx, arr, false);
|
||||||
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
|
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
|
||||||
if (do_copy) {
|
if (do_copy) {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
copy(arr, arr_copy, CopyType::Vector, s);
|
copy_cpu(arr, arr_copy, CopyType::Vector, s);
|
||||||
return std::make_tuple(true, sty, arr_copy, true);
|
return std::make_tuple(true, sty, arr_copy, true);
|
||||||
}
|
}
|
||||||
return std::make_tuple(true, sty, arr, false);
|
return std::make_tuple(true, sty, arr, false);
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
copy(arr, arr_copy, CopyType::General, s);
|
copy_cpu(arr, arr_copy, CopyType::General, s);
|
||||||
int64_t stx = arr.shape(-1);
|
int64_t stx = arr.shape(-1);
|
||||||
return std::make_tuple(false, stx, arr_copy, true);
|
return std::make_tuple(false, stx, arr_copy, true);
|
||||||
}
|
}
|
||||||
@ -386,7 +386,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return std::make_tuple(true, sty, arr);
|
return std::make_tuple(true, sty, arr);
|
||||||
} else {
|
} else {
|
||||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||||
copy(arr, temps.back(), CopyType::General, s);
|
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||||
int64_t stx = arr.shape(-1);
|
int64_t stx = arr.shape(-1);
|
||||||
return std::make_tuple(false, stx, temps.back());
|
return std::make_tuple(false, stx, temps.back());
|
||||||
}
|
}
|
||||||
@ -504,7 +504,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return std::make_tuple(true, sty, x);
|
return std::make_tuple(true, sty, x);
|
||||||
} else {
|
} else {
|
||||||
array xc(x.shape(), x.dtype(), nullptr, {});
|
array xc(x.shape(), x.dtype(), nullptr, {});
|
||||||
copy(x, xc, CopyType::General, s);
|
copy_cpu(x, xc, CopyType::General, s);
|
||||||
encoder.add_temporary(xc);
|
encoder.add_temporary(xc);
|
||||||
int64_t stx = x.shape(-1);
|
int64_t stx = x.shape(-1);
|
||||||
return std::make_tuple(false, stx, xc);
|
return std::make_tuple(false, stx, xc);
|
||||||
|
@ -81,7 +81,7 @@ void matmul_general(
|
|||||||
return std::make_tuple(true, sty, arr);
|
return std::make_tuple(true, sty, arr);
|
||||||
} else {
|
} else {
|
||||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||||
copy(arr, temps.back(), CopyType::General, stream);
|
copy_cpu(arr, temps.back(), CopyType::General, stream);
|
||||||
stx = arr.shape(-1);
|
stx = arr.shape(-1);
|
||||||
return std::make_tuple(false, stx, temps.back());
|
return std::make_tuple(false, stx, temps.back());
|
||||||
}
|
}
|
||||||
@ -142,7 +142,7 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
CopyType ctype = c.data_size() == 1
|
CopyType ctype = c.data_size() == 1
|
||||||
? CopyType::Scalar
|
? CopyType::Scalar
|
||||||
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||||
copy(c, out, ctype, stream());
|
copy_cpu(c, out, ctype, stream());
|
||||||
if (inputs[0].shape(-1) == 0) {
|
if (inputs[0].shape(-1) == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -22,7 +22,7 @@ void reshape(const array& in, array& out) {
|
|||||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||||
if (copy_necessary) {
|
if (copy_necessary) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
copy_inplace(in, out, CopyType::General, out.primitive().stream());
|
copy_cpu_inplace(in, out, CopyType::General, out.primitive().stream());
|
||||||
} else {
|
} else {
|
||||||
shared_buffer_reshape(in, out_strides, out);
|
shared_buffer_reshape(in, out_strides, out);
|
||||||
}
|
}
|
||||||
@ -175,7 +175,7 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||||
copy(in, out, ctype, stream());
|
copy_cpu(in, out, ctype, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@ -198,7 +198,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
size_t data_offset = strides[axis_] * sizes[i];
|
size_t data_offset = strides[axis_] * sizes[i];
|
||||||
out_slice.copy_shared_buffer(
|
out_slice.copy_shared_buffer(
|
||||||
out, strides, flags, out_slice.size(), data_offset);
|
out, strides, flags, out_slice.size(), data_offset);
|
||||||
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -211,7 +211,7 @@ void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
copy(in, out, CopyType::General, stream());
|
copy_cpu(in, out, CopyType::General, stream());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,7 +235,7 @@ void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
} else {
|
} else {
|
||||||
ctype = CopyType::General;
|
ctype = CopyType::General;
|
||||||
}
|
}
|
||||||
copy(in, out, ctype, stream());
|
copy_cpu(in, out, ctype, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@ -251,7 +251,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||||
|
|
||||||
// Fill output with val
|
// Fill output with val
|
||||||
copy(val, out, CopyType::Scalar, stream());
|
copy_cpu(val, out, CopyType::Scalar, stream());
|
||||||
|
|
||||||
// Find offset for start of input values
|
// Find offset for start of input values
|
||||||
size_t data_offset = 0;
|
size_t data_offset = 0;
|
||||||
@ -266,7 +266,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||||
|
|
||||||
// Copy input values into the slice
|
// Copy input values into the slice
|
||||||
copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
copy_cpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@ -340,7 +340,7 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
auto [in_offset, donated] =
|
auto [in_offset, donated] =
|
||||||
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
|
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
|
||||||
copy_inplace(
|
copy_cpu_inplace(
|
||||||
/* const array& src = */ in,
|
/* const array& src = */ in,
|
||||||
/* array& dst = */ out,
|
/* array& dst = */ out,
|
||||||
/* const Shape& data_shape = */ out.shape(),
|
/* const Shape& data_shape = */ out.shape(),
|
||||||
@ -372,11 +372,11 @@ void DynamicSliceUpdate::eval_cpu(
|
|||||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||||
? CopyType::Vector
|
? CopyType::Vector
|
||||||
: CopyType::General;
|
: CopyType::General;
|
||||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||||
|
|
||||||
auto [out_offset, donated] =
|
auto [out_offset, donated] =
|
||||||
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
|
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
|
||||||
copy_inplace(
|
copy_cpu_inplace(
|
||||||
/* const array& src = */ upd,
|
/* const array& src = */ upd,
|
||||||
/* array& dst = */ out,
|
/* array& dst = */ out,
|
||||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||||
@ -412,14 +412,14 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||||
? CopyType::Vector
|
? CopyType::Vector
|
||||||
: CopyType::General;
|
: CopyType::General;
|
||||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||||
|
|
||||||
// Calculate out strides, initial offset and if copy needs to be made
|
// Calculate out strides, initial offset and if copy needs to be made
|
||||||
auto [data_offset, out_strides] =
|
auto [data_offset, out_strides] =
|
||||||
prepare_slice(out, start_indices_, strides_);
|
prepare_slice(out, start_indices_, strides_);
|
||||||
|
|
||||||
// Do copy
|
// Do copy
|
||||||
copy_inplace(
|
copy_cpu_inplace(
|
||||||
/* const array& src = */ upd,
|
/* const array& src = */ upd,
|
||||||
/* array& dst = */ out,
|
/* array& dst = */ out,
|
||||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||||
@ -456,9 +456,9 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (in.dtype() == bool_) {
|
if (in.dtype() == bool_) {
|
||||||
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
||||||
in_tmp.copy_shared_buffer(in);
|
in_tmp.copy_shared_buffer(in);
|
||||||
copy_inplace(in_tmp, tmp, CopyType::General, stream());
|
copy_cpu_inplace(in_tmp, tmp, CopyType::General, stream());
|
||||||
} else {
|
} else {
|
||||||
copy_inplace(in, tmp, CopyType::General, stream());
|
copy_cpu_inplace(in, tmp, CopyType::General, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto flags = out.flags();
|
auto flags = out.flags();
|
||||||
|
@ -26,7 +26,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
|||||||
strides[in.ndim() - 2] = 1;
|
strides[in.ndim() - 2] = 1;
|
||||||
strides[in.ndim() - 1] = M;
|
strides[in.ndim() - 1] = M;
|
||||||
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
|
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
|
||||||
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
|
copy_cpu_inplace(a, in, CopyType::GeneralGeneral, stream);
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
q.set_data(allocator::malloc(q.nbytes()));
|
q.set_data(allocator::malloc(q.nbytes()));
|
||||||
r.set_data(allocator::malloc(r.nbytes()));
|
r.set_data(allocator::malloc(r.nbytes()));
|
||||||
|
@ -529,7 +529,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return arr;
|
return arr;
|
||||||
} else {
|
} else {
|
||||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||||
copy(arr, temps.back(), CopyType::General, s);
|
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||||
return temps.back();
|
return temps.back();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -579,7 +579,7 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return arr;
|
return arr;
|
||||||
} else {
|
} else {
|
||||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||||
copy(arr, temps.back(), CopyType::General, s);
|
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||||
return temps.back();
|
return temps.back();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -713,7 +713,7 @@ void fast::AffineQuantize::eval_cpu(
|
|||||||
return std::make_pair(arr, false);
|
return std::make_pair(arr, false);
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
copy(arr, arr_copy, CopyType::General, s);
|
copy_cpu(arr, arr_copy, CopyType::General, s);
|
||||||
return std::make_pair(arr_copy, true);
|
return std::make_pair(arr_copy, true);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -251,7 +251,7 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto in = inputs[0];
|
auto in = inputs[0];
|
||||||
if (!in.flags().row_contiguous) {
|
if (!in.flags().row_contiguous) {
|
||||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||||
copy(in, arr_copy, CopyType::General, stream());
|
copy_cpu(in, arr_copy, CopyType::General, stream());
|
||||||
in = arr_copy;
|
in = arr_copy;
|
||||||
encoder.add_temporary(arr_copy);
|
encoder.add_temporary(arr_copy);
|
||||||
}
|
}
|
||||||
|
@ -132,7 +132,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||||
copy(x, x_copy, CopyType::General, s);
|
copy_cpu(x, x_copy, CopyType::General, s);
|
||||||
out.copy_shared_buffer(x_copy);
|
out.copy_shared_buffer(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
@ -335,7 +335,7 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// Copy input to output
|
// Copy input to output
|
||||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||||
copy(in, out, ctype, stream());
|
copy_cpu(in, out, ctype, stream());
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
@ -427,7 +427,7 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// Copy input to output
|
// Copy input to output
|
||||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||||
copy(in, out, ctype, stream());
|
copy_cpu(in, out, ctype, stream());
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
@ -31,7 +31,7 @@ void svd_impl(
|
|||||||
|
|
||||||
// lapack clobbers the input, so we have to make a copy.
|
// lapack clobbers the input, so we have to make a copy.
|
||||||
array in(a.shape(), a.dtype(), nullptr, {});
|
array in(a.shape(), a.dtype(), nullptr, {});
|
||||||
copy(
|
copy_cpu(
|
||||||
a,
|
a,
|
||||||
in,
|
in,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
Loading…
Reference in New Issue
Block a user