mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
WIP (cpu)
This commit is contained in:
@@ -17,14 +17,14 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
Strides strides = remove_index(in.strides(), axis);
|
||||
Shape shape = remove_index(in.shape(), axis);
|
||||
auto in_ptr = in.data<InT>();
|
||||
auto out_ptr = out.data<uint32_t>();
|
||||
auto out_ptr = out.data<int64_t>();
|
||||
|
||||
for (uint32_t i = 0; i < out.size(); ++i) {
|
||||
for (int64_t i = 0; i < out.size(); ++i) {
|
||||
auto loc = elem_to_loc(i, shape, strides);
|
||||
auto local_in_ptr = in_ptr + loc;
|
||||
uint32_t ind_v = 0;
|
||||
int64_t ind_v = 0;
|
||||
InT v = (*local_in_ptr);
|
||||
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
|
||||
for (int64_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
|
||||
op(j, (*local_in_ptr), &ind_v, &v);
|
||||
}
|
||||
out_ptr[i] = ind_v;
|
||||
|
||||
@@ -17,7 +17,7 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename Op>
|
||||
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
|
||||
void binary(const array& a, const array& b, array& out, Op /* op */, Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
@@ -81,7 +81,7 @@ void comparison_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Op /* op */,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
@@ -146,7 +146,7 @@ void binary_float(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Op /* op */,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
@@ -187,7 +187,7 @@ void binary_int(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Op /* op */,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
@@ -99,7 +99,7 @@ void binary_op_dispatch_dims(
|
||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
for (int64_t elem = 0; elem < std::ssize(a); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
@@ -137,21 +137,21 @@ void binary_op(
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
} else if (bopt == BinaryOpType::ScalarVector) {
|
||||
for (size_t i = 0; i < b.data_size(); ++i) {
|
||||
for (int64_t i = 0; i < b.data_size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||
for (int64_t i = 0; i < a.data_size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
a_ptr++;
|
||||
}
|
||||
} else { // VectorVector
|
||||
for (size_t i = 0; i < a.size(); ++i) {
|
||||
for (int64_t i = 0; i < a.size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
|
||||
@@ -860,7 +860,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& /* wt_dilation */,
|
||||
Stream stream) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = in.shape(1); // Input spatial dim
|
||||
@@ -1003,7 +1003,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& /* wt_dilation */,
|
||||
const bool flip,
|
||||
Stream stream) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
@@ -1023,7 +1023,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
// Pad input
|
||||
Shape padded_shape(in.shape().size());
|
||||
padded_shape.front() = N;
|
||||
for (size_t i = 0; i < iDim.size(); i++) {
|
||||
for (int i = 0; i < iDim.size(); i++) {
|
||||
padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
|
||||
}
|
||||
padded_shape.back() = C;
|
||||
@@ -1054,20 +1054,20 @@ void explicit_gemm_conv_ND_cpu(
|
||||
// Make strided view
|
||||
Shape strided_shape(oDim.size() + wDim.size() + 2);
|
||||
strided_shape.front() = N;
|
||||
for (size_t i = 0; i < oDim.size(); i++) {
|
||||
for (int i = 0; i < oDim.size(); i++) {
|
||||
strided_shape[i + 1] = oDim[i];
|
||||
}
|
||||
for (size_t i = 0; i < wDim.size(); i++) {
|
||||
for (int i = 0; i < wDim.size(); i++) {
|
||||
strided_shape[i + 1 + oDim.size()] = wDim[i];
|
||||
}
|
||||
strided_shape.back() = C;
|
||||
|
||||
Strides strided_strides(in.shape().size() * 2 - 2);
|
||||
strided_strides[0] = in_padded.strides()[0];
|
||||
for (size_t i = 0; i < wt_strides.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(wt_strides); i++) {
|
||||
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
|
||||
}
|
||||
for (size_t i = 1; i < in_padded.strides().size(); i++) {
|
||||
for (int i = 1; i < std::ssize(in_padded.strides()); i++) {
|
||||
strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
|
||||
}
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ void eig_impl(
|
||||
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
||||
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
for (int64_t i = 0; i < size / (N * N); ++i) {
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
|
||||
@@ -165,7 +165,7 @@ void eigh_impl(
|
||||
EighWork<T> work(jobz, uplo, N);
|
||||
|
||||
// Work loop
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
for (int64_t i = 0; i < size / (N * N); ++i) {
|
||||
work.run(vec_ptr, eig_ptr);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
|
||||
@@ -20,8 +20,8 @@ struct CommandEncoder {
|
||||
CommandEncoder(CommandEncoder&&) = delete;
|
||||
CommandEncoder& operator=(CommandEncoder&&) = delete;
|
||||
|
||||
void set_input_array(const array& a) {}
|
||||
void set_output_array(array& a) {}
|
||||
void set_input_array(const array& /* a */) {}
|
||||
void set_output_array(array& /* a */) {}
|
||||
|
||||
// Hold onto a temporary until any already scheduled tasks which use it as
|
||||
// an input are complete.
|
||||
|
||||
Reference in New Issue
Block a user