mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
WIP (cpu)
This commit is contained in:
@@ -17,14 +17,14 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
|||||||
Strides strides = remove_index(in.strides(), axis);
|
Strides strides = remove_index(in.strides(), axis);
|
||||||
Shape shape = remove_index(in.shape(), axis);
|
Shape shape = remove_index(in.shape(), axis);
|
||||||
auto in_ptr = in.data<InT>();
|
auto in_ptr = in.data<InT>();
|
||||||
auto out_ptr = out.data<uint32_t>();
|
auto out_ptr = out.data<int64_t>();
|
||||||
|
|
||||||
for (uint32_t i = 0; i < out.size(); ++i) {
|
for (int64_t i = 0; i < out.size(); ++i) {
|
||||||
auto loc = elem_to_loc(i, shape, strides);
|
auto loc = elem_to_loc(i, shape, strides);
|
||||||
auto local_in_ptr = in_ptr + loc;
|
auto local_in_ptr = in_ptr + loc;
|
||||||
uint32_t ind_v = 0;
|
int64_t ind_v = 0;
|
||||||
InT v = (*local_in_ptr);
|
InT v = (*local_in_ptr);
|
||||||
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
|
for (int64_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
|
||||||
op(j, (*local_in_ptr), &ind_v, &v);
|
op(j, (*local_in_ptr), &ind_v, &v);
|
||||||
}
|
}
|
||||||
out_ptr[i] = ind_v;
|
out_ptr[i] = ind_v;
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ namespace mlx::core {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
|
void binary(const array& a, const array& b, array& out, Op /* op */, Stream stream) {
|
||||||
auto bopt = get_binary_op_type(a, b);
|
auto bopt = get_binary_op_type(a, b);
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
@@ -81,7 +81,7 @@ void comparison_op(
|
|||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
array& out,
|
array& out,
|
||||||
Op op,
|
Op /* op */,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
auto bopt = get_binary_op_type(a, b);
|
auto bopt = get_binary_op_type(a, b);
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
@@ -146,7 +146,7 @@ void binary_float(
|
|||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
array& out,
|
array& out,
|
||||||
Op op,
|
Op /* op */,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
auto bopt = get_binary_op_type(a, b);
|
auto bopt = get_binary_op_type(a, b);
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
@@ -187,7 +187,7 @@ void binary_int(
|
|||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
array& out,
|
array& out,
|
||||||
Op op,
|
Op /* op */,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
auto bopt = get_binary_op_type(a, b);
|
auto bopt = get_binary_op_type(a, b);
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ void binary_op_dispatch_dims(
|
|||||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||||
auto stride = out_strides[ndim - 3];
|
auto stride = out_strides[ndim - 3];
|
||||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
for (int64_t elem = 0; elem < std::ssize(a); elem += stride) {
|
||||||
binary_op_dims<T, U, Op, 2>(
|
binary_op_dims<T, U, Op, 2>(
|
||||||
a_ptr + a_it.loc,
|
a_ptr + a_it.loc,
|
||||||
b_ptr + b_it.loc,
|
b_ptr + b_it.loc,
|
||||||
@@ -137,21 +137,21 @@ void binary_op(
|
|||||||
if (bopt == BinaryOpType::ScalarScalar) {
|
if (bopt == BinaryOpType::ScalarScalar) {
|
||||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||||
} else if (bopt == BinaryOpType::ScalarVector) {
|
} else if (bopt == BinaryOpType::ScalarVector) {
|
||||||
for (size_t i = 0; i < b.data_size(); ++i) {
|
for (int64_t i = 0; i < b.data_size(); ++i) {
|
||||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||||
out_a_ptr++;
|
out_a_ptr++;
|
||||||
out_b_ptr++;
|
out_b_ptr++;
|
||||||
b_ptr++;
|
b_ptr++;
|
||||||
}
|
}
|
||||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
for (int64_t i = 0; i < a.data_size(); ++i) {
|
||||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||||
out_a_ptr++;
|
out_a_ptr++;
|
||||||
out_b_ptr++;
|
out_b_ptr++;
|
||||||
a_ptr++;
|
a_ptr++;
|
||||||
}
|
}
|
||||||
} else { // VectorVector
|
} else { // VectorVector
|
||||||
for (size_t i = 0; i < a.size(); ++i) {
|
for (int64_t i = 0; i < a.size(); ++i) {
|
||||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||||
out_a_ptr++;
|
out_a_ptr++;
|
||||||
out_b_ptr++;
|
out_b_ptr++;
|
||||||
|
|||||||
@@ -860,7 +860,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding_lo,
|
||||||
const std::vector<int>& padding_hi,
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& /* wt_dilation */,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||||
const int iH = in.shape(1); // Input spatial dim
|
const int iH = in.shape(1); // Input spatial dim
|
||||||
@@ -1003,7 +1003,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding_lo,
|
||||||
const std::vector<int>& padding_hi,
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& /* wt_dilation */,
|
||||||
const bool flip,
|
const bool flip,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||||
@@ -1023,7 +1023,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
// Pad input
|
// Pad input
|
||||||
Shape padded_shape(in.shape().size());
|
Shape padded_shape(in.shape().size());
|
||||||
padded_shape.front() = N;
|
padded_shape.front() = N;
|
||||||
for (size_t i = 0; i < iDim.size(); i++) {
|
for (int i = 0; i < iDim.size(); i++) {
|
||||||
padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
|
padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
|
||||||
}
|
}
|
||||||
padded_shape.back() = C;
|
padded_shape.back() = C;
|
||||||
@@ -1054,20 +1054,20 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
// Make strided view
|
// Make strided view
|
||||||
Shape strided_shape(oDim.size() + wDim.size() + 2);
|
Shape strided_shape(oDim.size() + wDim.size() + 2);
|
||||||
strided_shape.front() = N;
|
strided_shape.front() = N;
|
||||||
for (size_t i = 0; i < oDim.size(); i++) {
|
for (int i = 0; i < oDim.size(); i++) {
|
||||||
strided_shape[i + 1] = oDim[i];
|
strided_shape[i + 1] = oDim[i];
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < wDim.size(); i++) {
|
for (int i = 0; i < wDim.size(); i++) {
|
||||||
strided_shape[i + 1 + oDim.size()] = wDim[i];
|
strided_shape[i + 1 + oDim.size()] = wDim[i];
|
||||||
}
|
}
|
||||||
strided_shape.back() = C;
|
strided_shape.back() = C;
|
||||||
|
|
||||||
Strides strided_strides(in.shape().size() * 2 - 2);
|
Strides strided_strides(in.shape().size() * 2 - 2);
|
||||||
strided_strides[0] = in_padded.strides()[0];
|
strided_strides[0] = in_padded.strides()[0];
|
||||||
for (size_t i = 0; i < wt_strides.size(); i++) {
|
for (int i = 0; i < std::ssize(wt_strides); i++) {
|
||||||
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
|
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
|
||||||
}
|
}
|
||||||
for (size_t i = 1; i < in_padded.strides().size(); i++) {
|
for (int i = 1; i < std::ssize(in_padded.strides()); i++) {
|
||||||
strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
|
strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ void eig_impl(
|
|||||||
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
||||||
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
||||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
for (int64_t i = 0; i < size / (N * N); ++i) {
|
||||||
geev<T>(
|
geev<T>(
|
||||||
&jobl,
|
&jobl,
|
||||||
&jobr,
|
&jobr,
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ void eigh_impl(
|
|||||||
EighWork<T> work(jobz, uplo, N);
|
EighWork<T> work(jobz, uplo, N);
|
||||||
|
|
||||||
// Work loop
|
// Work loop
|
||||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
for (int64_t i = 0; i < size / (N * N); ++i) {
|
||||||
work.run(vec_ptr, eig_ptr);
|
work.run(vec_ptr, eig_ptr);
|
||||||
vec_ptr += N * N;
|
vec_ptr += N * N;
|
||||||
eig_ptr += N;
|
eig_ptr += N;
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ struct CommandEncoder {
|
|||||||
CommandEncoder(CommandEncoder&&) = delete;
|
CommandEncoder(CommandEncoder&&) = delete;
|
||||||
CommandEncoder& operator=(CommandEncoder&&) = delete;
|
CommandEncoder& operator=(CommandEncoder&&) = delete;
|
||||||
|
|
||||||
void set_input_array(const array& a) {}
|
void set_input_array(const array& /* a */) {}
|
||||||
void set_output_array(array& a) {}
|
void set_output_array(array& /* a */) {}
|
||||||
|
|
||||||
// Hold onto a temporary until any already scheduled tasks which use it as
|
// Hold onto a temporary until any already scheduled tasks which use it as
|
||||||
// an input are complete.
|
// an input are complete.
|
||||||
|
|||||||
Reference in New Issue
Block a user