redesign for faster cpu/gpu synch (#1869)

* redesign for faster cpu/gpu synch

* load + more async CPU

* use command encoder API and move more ops to use it

* make fence back-end generic + CPU only fence

* faster build

* fix async eval

* fixes + handle temporaries

* fix / improve cpu conv

* remove unused status, fix siblings

* fix extensions

* fix

* fix no cpu build

* format

* comments

* fix perf regression, remove unecessary abort

* fix events, task limit cpu

* fix waiting

* fix donation / temporaries in normalization
This commit is contained in:
Awni Hannun
2025-03-06 19:23:38 -08:00
committed by GitHub
parent 5245f12a46
commit c4230747a1
103 changed files with 5013 additions and 3873 deletions

View File

@@ -44,7 +44,9 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
@@ -65,6 +67,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
if(MLX_BUILD_ACCELERATE)

View File

@@ -2,76 +2,27 @@
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/cpu/encoder.h"
namespace mlx::core {
namespace {
template <typename T>
void arange(T start, T next, array& out, size_t size) {
void arange(T start, T next, array& out, size_t size, Stream stream) {
auto ptr = out.data<T>();
auto step_size = next - start;
for (int i = 0; i < size; ++i) {
ptr[i] = start;
start += step_size;
}
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
encoder.dispatch([ptr, start, step_size, size]() mutable {
for (int i = 0; i < size; ++i) {
ptr[i] = start;
start += step_size;
}
});
}
} // namespace
void arange(
const std::vector<array>& inputs,
array& out,
double start,
double step) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
switch (out.dtype()) {
case bool_:
throw std::runtime_error("Bool type unsupported for arange.");
break;
case uint8:
arange<uint8_t>(start, start + step, out, out.size());
break;
case uint16:
arange<uint16_t>(start, start + step, out, out.size());
break;
case uint32:
arange<uint32_t>(start, start + step, out, out.size());
break;
case uint64:
arange<uint64_t>(start, start + step, out, out.size());
break;
case int8:
arange<int8_t>(start, start + step, out, out.size());
break;
case int16:
arange<int16_t>(start, start + step, out, out.size());
break;
case int32:
arange<int32_t>(start, start + step, out, out.size());
break;
case int64:
arange<int64_t>(start, start + step, out, out.size());
break;
case float16:
arange<float16_t>(start, start + step, out, out.size());
break;
case float32:
arange<float>(start, start + step, out, out.size());
break;
case float64:
arange<double>(start, start + step, out, out.size());
break;
case bfloat16:
arange<bfloat16_t>(start, start + step, out, out.size());
break;
case complex64:
arange<complex64_t>(start, start + step, out, out.size());
break;
}
}
} // namespace mlx::core

View File

@@ -3,6 +3,7 @@
#include <cassert>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -10,23 +11,43 @@ namespace mlx::core {
namespace {
template <typename InT, typename OpT>
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
void arg_reduce(
const array& in,
array& out,
const OpT& op,
int axis,
Stream stream) {
auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis];
Strides strides = in.strides();
Shape shape = in.shape();
strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis);
for (uint32_t i = 0; i < out.size(); ++i) {
auto loc = elem_to_loc(i, shape, strides);
auto in_ptr = in.data<InT>() + loc;
uint32_t ind_v = 0;
InT v = (*in_ptr);
for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) {
op(j, (*in_ptr), &ind_v, &v);
auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in_ptr,
out_ptr,
axis_size,
axis_stride,
op = std::move(op),
shape = std::move(shape),
strides = std::move(strides),
size = out.size()]() {
for (uint32_t i = 0; i < size; ++i) {
auto loc = elem_to_loc(i, shape, strides);
auto local_in_ptr = in_ptr + loc;
uint32_t ind_v = 0;
InT v = (*local_in_ptr);
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
op(j, (*local_in_ptr), &ind_v, &v);
}
out_ptr[i] = ind_v;
}
out.data<uint32_t>()[i] = ind_v;
}
});
}
template <typename InT>
@@ -34,7 +55,8 @@ void arg_reduce_dispatch(
const array& in,
array& out,
ArgReduce::ReduceType rtype,
int axis) {
int axis,
Stream stream) {
switch (rtype) {
case ArgReduce::ArgMin: {
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
@@ -43,7 +65,7 @@ void arg_reduce_dispatch(
(*ind_y) = ind_x;
}
};
arg_reduce<InT>(in, out, op, axis);
arg_reduce<InT>(in, out, op, axis, stream);
break;
}
case ArgReduce::ArgMax: {
@@ -53,7 +75,7 @@ void arg_reduce_dispatch(
(*ind_y) = ind_x;
}
};
arg_reduce<InT>(in, out, op, axis);
arg_reduce<InT>(in, out, op, axis, stream);
break;
}
}
@@ -68,46 +90,46 @@ void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) {
case bool_:
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_, stream());
break;
case uint8:
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_, stream());
break;
case uint16:
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_, stream());
break;
case uint32:
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_, stream());
break;
case uint64:
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_, stream());
break;
case int8:
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_, stream());
break;
case int16:
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_, stream());
break;
case int32:
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_, stream());
break;
case int64:
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_, stream());
break;
case float16:
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_, stream());
break;
case float32:
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_, stream());
break;
case bfloat16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_, stream());
break;
case float64:
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_, stream());
break;
case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_, stream());
break;
}
}

View File

@@ -16,49 +16,49 @@ namespace mlx::core {
namespace {
template <typename Op>
void comparison_op(const array& a, const array& b, array& out, Op op) {
void comparison_op(const array& a, const array& b, array& out) {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool>(a, b, out, op);
binary_op<bool, bool, Op>(a, b, out);
break;
case uint8:
binary_op<uint8_t, bool>(a, b, out, op);
binary_op<uint8_t, bool, Op>(a, b, out);
break;
case uint16:
binary_op<uint16_t, bool>(a, b, out, op);
binary_op<uint16_t, bool, Op>(a, b, out);
break;
case uint32:
binary_op<uint32_t, bool>(a, b, out, op);
binary_op<uint32_t, bool, Op>(a, b, out);
break;
case uint64:
binary_op<uint64_t, bool>(a, b, out, op);
binary_op<uint64_t, bool, Op>(a, b, out);
break;
case int8:
binary_op<int8_t, bool>(a, b, out, op);
binary_op<int8_t, bool, Op>(a, b, out);
break;
case int16:
binary_op<int16_t, bool>(a, b, out, op);
binary_op<int16_t, bool, Op>(a, b, out);
break;
case int32:
binary_op<int32_t, bool>(a, b, out, op);
binary_op<int32_t, bool, Op>(a, b, out);
break;
case int64:
binary_op<int64_t, bool>(a, b, out, op);
binary_op<int64_t, bool, Op>(a, b, out);
break;
case float16:
binary_op<float16_t, bool>(a, b, out, op);
binary_op<float16_t, bool, Op>(a, b, out);
break;
case float32:
binary_op<float, bool>(a, b, out, op);
binary_op<float, bool, Op>(a, b, out);
break;
case float64:
binary_op<double, bool>(a, b, out, op);
binary_op<double, bool, Op>(a, b, out);
break;
case bfloat16:
binary_op<bfloat16_t, bool>(a, b, out, op);
binary_op<bfloat16_t, bool, Op>(a, b, out);
break;
case complex64:
binary_op<complex64_t, bool>(a, b, out, op);
binary_op<complex64_t, bool, Op>(a, b, out);
break;
}
}
@@ -151,47 +151,47 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
if (equal_nan_) {
switch (a.dtype()) {
case float16:
binary_op<float16_t, bool>(a, b, out, detail::NaNEqual());
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out);
break;
case float32:
binary_op<float, bool>(a, b, out, detail::NaNEqual());
binary_op<float, bool, detail::NaNEqual>(a, b, out);
break;
case float64:
binary_op<double, bool>(a, b, out, detail::NaNEqual());
binary_op<double, bool, detail::NaNEqual>(a, b, out);
break;
case bfloat16:
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual());
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out);
break;
case complex64:
binary_op<complex64_t, bool>(a, b, out, detail::NaNEqual());
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out);
break;
default:
throw std::runtime_error(
"[NanEqual::eval_cpu] Only for floating point types.");
}
} else {
comparison_op(a, b, out, detail::Equal());
comparison_op<detail::Equal>(a, b, out);
}
}
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Greater());
comparison_op<detail::Greater>(inputs[0], inputs[1], out);
}
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
comparison_op<detail::GreaterEqual>(inputs[0], inputs[1], out);
}
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Less());
comparison_op<detail::Less>(inputs[0], inputs[1], out);
}
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
comparison_op<detail::LessEqual>(inputs[0], inputs[1], out);
}
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -200,16 +200,16 @@ void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
switch (out.dtype()) {
case float16:
binary_op<float16_t>(a, b, out, detail::LogAddExp());
binary_op<float16_t, detail::LogAddExp>(a, b, out);
break;
case float32:
binary_op<float>(a, b, out, detail::LogAddExp());
binary_op<float, detail::LogAddExp>(a, b, out);
break;
case float64:
binary_op<double>(a, b, out, detail::LogAddExp());
binary_op<double, detail::LogAddExp>(a, b, out);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
binary_op<bfloat16_t, detail::LogAddExp>(a, b, out);
break;
default:
throw std::runtime_error(
@@ -254,7 +254,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
comparison_op<detail::NotEqual>(inputs[0], inputs[1], out);
}
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {

View File

@@ -7,6 +7,8 @@
#include "mlx/array.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
#include "mlx/backend/cpu/simd/simd.h"
@@ -14,22 +16,18 @@ namespace mlx::core {
template <typename Op>
struct VectorScalar {
Op op;
VectorScalar(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *b;
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
simd::store(dst, Op{}(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
dst += N;
a += N;
size -= N;
}
while (size-- > 0) {
*dst = op(*a, scalar);
*dst = Op{}(*a, scalar);
dst++;
a++;
}
@@ -38,22 +36,18 @@ struct VectorScalar {
template <typename Op>
struct ScalarVector {
Op op;
ScalarVector(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *a;
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
simd::store(dst, Op{}(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
dst += N;
b += N;
size -= N;
}
while (size-- > 0) {
*dst = op(scalar, *b);
*dst = Op{}(scalar, *b);
dst++;
b++;
}
@@ -62,22 +56,18 @@ struct ScalarVector {
template <typename Op>
struct VectorVector {
Op op;
VectorVector(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a), simd::load<T, N>(b)));
simd::store(dst, Op{}(simd::load<T, N>(a), simd::load<T, N>(b)));
dst += N;
a += N;
b += N;
size -= N;
}
while (size-- > 0) {
*dst = op(*a, *b);
*dst = Op{}(*a, *b);
dst++;
a++;
b++;
@@ -90,7 +80,6 @@ void binary_op_dims(
const T* a,
const T* b,
U* out,
Op op,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
@@ -104,12 +93,12 @@ void binary_op_dims(
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
binary_op_dims<T, U, Op, D - 1, Strided>(
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
a, b, out, shape, a_strides, b_strides, out_strides, axis + 1);
} else {
if constexpr (Strided) {
op(a, b, out, stride_out);
Op{}(a, b, out, stride_out);
} else {
*out = op(*a, *b);
*out = Op{}(*a, *b);
}
}
out += stride_out;
@@ -120,66 +109,38 @@ void binary_op_dims(
template <typename T, typename U, bool Strided, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out,
Op op,
const T* a,
const T* b,
U* out,
int dim,
int size,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_ptr = out.data<U>();
switch (dim) {
case 1:
binary_op_dims<T, U, Op, 1, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
a, b, out, shape, a_strides, b_strides, out_strides, 0);
return;
case 2:
binary_op_dims<T, U, Op, 2, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
a, b, out, shape, a_strides, b_strides, out_strides, 0);
return;
case 3:
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
a, b, out, shape, a_strides, b_strides, out_strides, 0);
return;
}
ContiguousIterator a_it(shape, a_strides, dim - 3);
ContiguousIterator b_it(shape, b_strides, dim - 3);
auto stride = out_strides[dim - 4];
for (int64_t elem = 0; elem < a.size(); elem += stride) {
for (int64_t elem = 0; elem < size; elem += stride) {
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
out_ptr + elem,
op,
a + a_it.loc,
b + b_it.loc,
out + elem,
shape,
a_strides,
b_strides,
@@ -191,181 +152,216 @@ void binary_op_dispatch_dims(
}
template <typename T, typename U, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
void binary_op(const array& a, const array& b, array& out) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::ScalarScalar) {
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
return;
}
auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>();
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
ScalarVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
VectorScalar{op}(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
VectorVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
return;
}
// General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out.strides()});
const auto& a_strides = new_strides[0];
const auto& b_strides = new_strides[1];
const auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
auto out_ptr = out.data<U>();
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([bopt,
a_ptr,
b_ptr,
out_ptr,
a_data_size = a.data_size(),
b_data_size = b.data_size(),
size = a.size(),
shape = a.shape(),
a_strides = a.strides(),
b_strides = b.strides(),
strides = out.strides()]() mutable {
if (bopt == BinaryOpType::ScalarScalar) {
*out_ptr = Op{}(*a_ptr, *b_ptr);
return;
}
return d + 1;
};
auto a_rc_dim = leftmost_rc_dim(a_strides);
auto b_rc_dim = leftmost_rc_dim(b_strides);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == 0; d--) {
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b_data_size);
return;
}
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a_strides);
auto b_s_dim = leftmost_s_dim(b_strides);
auto ndim = new_shape.size();
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a_data_size);
return;
}
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, size);
return;
}
// General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims(
shape,
{std::move(a_strides), std::move(b_strides), std::move(strides)});
a_strides = new_strides[0];
b_strides = new_strides[1];
strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
}
return d + 1;
};
auto a_rc_dim = leftmost_rc_dim(a_strides);
auto b_rc_dim = leftmost_rc_dim(b_strides);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == 0; d--) {
}
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a_strides);
auto b_s_dim = leftmost_s_dim(b_strides);
auto ndim = new_shape.size();
// Case 1: LxM and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
dim = d;
}
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
dim = d;
}
// Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity.
if (dim == 0 || strides[dim - 1] < 16) {
bopt = BinaryOpType::General;
dim = ndim;
}
// Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity.
if (dim == 0 || strides[dim - 1] < 16) {
bopt = BinaryOpType::General;
dim = ndim;
}
switch (bopt) {
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U, true>(
a,
b,
out,
VectorVector{op},
dim,
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U, true>(
a,
b,
out,
VectorScalar{op},
dim,
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U, true>(
a,
b,
out,
ScalarVector{op},
dim,
new_shape,
a_strides,
b_strides,
strides);
break;
default:
binary_op_dispatch_dims<T, U, false>(
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
break;
}
switch (bopt) {
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U, true, VectorVector<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
size,
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U, true, VectorScalar<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
size,
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U, true, ScalarVector<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
size,
new_shape,
a_strides,
b_strides,
strides);
break;
default:
binary_op_dispatch_dims<T, U, false, Op>(
a_ptr,
b_ptr,
out_ptr,
dim,
size,
new_shape,
a_strides,
b_strides,
strides);
break;
}
});
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out) {
binary_op<T, T, Op>(a, b, out);
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
binary_op<T, T>(a, b, out, op);
binary_op<T, T, Op>(a, b, out);
}
template <typename Op>
void binary(const array& a, const array& b, array& out, Op op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
binary_op<bool, Op>(a, b, out);
break;
case uint8:
binary_op<uint8_t>(a, b, out, op);
binary_op<uint8_t, Op>(a, b, out);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
binary_op<uint16_t, Op>(a, b, out);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
binary_op<uint32_t, Op>(a, b, out);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
binary_op<uint64_t, Op>(a, b, out);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
binary_op<int8_t, Op>(a, b, out);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
binary_op<int16_t, Op>(a, b, out);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
binary_op<int32_t, Op>(a, b, out);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
binary_op<int64_t, Op>(a, b, out);
break;
case float16:
binary_op<float16_t>(a, b, out, op);
binary_op<float16_t, Op>(a, b, out);
break;
case float32:
binary_op<float>(a, b, out, op);
binary_op<float, Op>(a, b, out);
break;
case float64:
binary_op<double>(a, b, out, op);
binary_op<double, Op>(a, b, out);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, op);
binary_op<bfloat16_t, Op>(a, b, out);
break;
case complex64:
binary_op<complex64_t>(a, b, out, op);
binary_op<complex64_t, Op>(a, b, out);
break;
}
}

View File

@@ -4,6 +4,8 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -55,65 +57,81 @@ void binary_op_dispatch_dims(
const array& b,
array& out_a,
array& out_b,
Stream stream,
Op op) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out_a.strides()});
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_a_ptr = out_a.data<U>();
U* out_b_ptr = out_b.data<U>();
int ndim = shape.size();
switch (ndim) {
case 1:
binary_op_dims<T, U, Op, 1>(
a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 2:
binary_op_dims<T, U, Op, 2>(
a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
}
encoder.dispatch([a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
size = a.size(),
shape = std::move(shape),
strides = std::move(strides),
op = std::move(op)]() {
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
int ndim = shape.size();
switch (ndim) {
case 1:
binary_op_dims<T, U, Op, 1>(
a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 2:
binary_op_dims<T, U, Op, 2>(
a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
}
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
out_a_ptr + elem,
out_b_ptr + elem,
op,
shape,
a_strides,
b_strides,
out_strides,
ndim - 2);
a_it.step();
b_it.step();
}
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < size; elem += stride) {
binary_op_dims<T, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
out_a_ptr + elem,
out_b_ptr + elem,
op,
shape,
a_strides,
b_strides,
out_strides,
ndim - 2);
a_it.step();
b_it.step();
}
});
}
template <typename T, typename U = T, typename Op>
@@ -128,40 +146,71 @@ void binary_op(
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
auto stream = out_a.primitive().stream();
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::General) {
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, stream, op);
return;
}
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>();
auto out_a_ptr = out_a.data<U>();
auto out_b_ptr = out_b.data<U>();
if (bopt == BinaryOpType::ScalarScalar) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
encoder.dispatch(
[a_ptr, b_ptr, out_a_ptr, out_b_ptr, op = std::move(op)]() mutable {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
});
} else if (bopt == BinaryOpType::ScalarVector) {
for (size_t i = 0; i < b.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
b_ptr++;
}
encoder.dispatch([a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
size = b.size(),
op = std::move(op)]() mutable {
for (size_t i = 0; i < size; ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
b_ptr++;
}
});
} else if (bopt == BinaryOpType::VectorScalar) {
for (size_t i = 0; i < a.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
}
encoder.dispatch([a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
size = a.size(),
op = std::move(op)]() mutable {
for (size_t i = 0; i < size; ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
}
});
} else { // VectorVector
for (size_t i = 0; i < a.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
b_ptr++;
}
encoder.dispatch([a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
size = a.size(),
op = std::move(op)]() mutable {
for (size_t i = 0; i < size; ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
b_ptr++;
}
});
}
}

View File

@@ -2,6 +2,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
@@ -9,7 +10,7 @@
namespace mlx::core {
template <typename T>
void cholesky_impl(const array& a, array& factor, bool upper) {
void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
// Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric:
// (A)ᵀ = A
@@ -17,60 +18,63 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
// triangular matrix, so uplo is the opposite of what we would expect from
// upper
char uplo = (upper) ? 'L' : 'U';
// The decomposition is computed in place, so just copy the input to the
// output.
copy(
a,
factor,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream);
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(factor);
encoder.dispatch([matrix = factor.data<T>(),
upper,
N = a.shape(-1),
size = a.size()]() mutable {
char uplo = (upper) ? 'L' : 'U';
size_t num_matrices = size / (N * N);
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info;
potrf<T>(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
T* matrix = factor.data<T>();
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info;
potrf<T>(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
// TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how
// to catch errors from the implementation we should throw.
if (info < 0) {
std::stringstream msg;
msg << "[cholesky] Cholesky decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
// Zero out the upper/lower triangle while advancing the pointer to the
// next matrix at the same time.
for (int row = 0; row < N; row++) {
if (upper) {
std::fill(matrix, matrix + row, 0);
} else {
std::fill(matrix + row + 1, matrix + N, 0);
// TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how
// to catch errors from the implementation we should throw.
if (info < 0) {
std::stringstream msg;
msg << "[Cholesky::eval_cpu] Cholesky decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
// Zero out the upper/lower triangle while advancing the pointer to the
// next matrix at the same time.
for (int row = 0; row < N; row++) {
if (upper) {
std::fill(matrix, matrix + row, 0);
} else {
std::fill(matrix + row + 1, matrix + N, 0);
}
matrix += N;
}
matrix += N;
}
}
});
}
void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {
switch (inputs[0].dtype()) {
case float32:
cholesky_impl<float>(inputs[0], output, upper_);
cholesky_impl<float>(inputs[0], output, upper_, stream());
break;
case float64:
cholesky_impl<double>(inputs[0], output, upper_);
cholesky_impl<double>(inputs[0], output, upper_, stream());
break;
default:
throw std::runtime_error(

View File

@@ -11,6 +11,7 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/cpu/compiled_preamble.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"
@@ -288,6 +289,7 @@ void Compiled::eval_cpu(
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, shape);
auto& encoder = cpu::get_command_encoder(stream());
// Handle all broadcasting and collect function input arguments
std::vector<void*> args;
@@ -298,6 +300,7 @@ void Compiled::eval_cpu(
continue;
}
auto& x = inputs[i];
encoder.set_input_array(x);
args.push_back((void*)x.data<void>());
if (contiguous || is_scalar(x)) {
@@ -356,18 +359,25 @@ void Compiled::eval_cpu(
});
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false);
inputs, outputs, inputs_, constant_ids_, contiguous);
for (auto& x : outputs) {
args.push_back(x.data<void>());
encoder.set_output_array(x);
}
Shape out_shape;
if (!contiguous) {
args.push_back((void*)outputs[0].shape().data());
out_shape = outputs[0].shape();
args.push_back((void*)out_shape.data());
} else {
args.push_back((void*)outputs[0].data_size());
}
auto fun = (void (*)(void**))fn_ptr;
fun(args.data());
encoder.dispatch(
[fun,
args = std::move(args),
strides = std::move(strides),
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
}
} // namespace mlx::core

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core {
@@ -12,20 +13,29 @@ namespace mlx::core {
namespace {
template <typename SrcT, typename DstT>
void copy_single(const array& src, array& dst) {
auto val = static_cast<DstT>(src.data<SrcT>()[0]);
void copy_single(const array& src, array& dst, Stream stream) {
auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>();
for (int i = 0; i < dst.size(); ++i) {
dst_ptr[i] = val;
}
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch([src_ptr, dst_ptr, size = dst.size()]() {
auto val = static_cast<DstT>(src_ptr[0]);
std::fill_n(dst_ptr, size, val);
});
}
template <typename SrcT, typename DstT>
void copy_vector(const array& src, array& dst) {
void copy_vector(const array& src, array& dst, Stream stream) {
auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>();
size_t size = src.data_size();
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch([src_ptr, dst_ptr, size = src.data_size()]() {
std::copy(src_ptr, src_ptr + size, dst_ptr);
});
}
template <typename SrcT, typename DstT, int D>
@@ -56,151 +66,220 @@ template <typename SrcT, typename DstT>
void copy_general_general(
const array& src,
array& dst,
Stream stream,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset) {
if (data_shape.empty()) {
auto val = static_cast<DstT>(*(src.data<SrcT>() + i_offset));
auto dst_ptr = dst.data<DstT>() + o_offset;
*dst_ptr = val;
return;
}
auto [shape, strides] =
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
int64_t o_offset,
const std::optional<array>& dynamic_i_offset,
const std::optional<array>& dynamic_o_offset) {
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>() + o_offset;
int ndim = shape.size();
if (ndim == 1) {
copy_dims<SrcT, DstT, 1>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 2) {
copy_dims<SrcT, DstT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 3) {
copy_dims<SrcT, DstT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
}
ContiguousIterator in(shape, strides[0], ndim - 3);
ContiguousIterator out(shape, strides[1], ndim - 3);
auto stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
for (int64_t elem = 0; elem < src.size(); elem += stride) {
copy_dims<SrcT, DstT, 3>(
src_ptr + in.loc,
dst_ptr + out.loc,
shape,
strides[0],
strides[1],
ndim - 3);
in.step();
out.step();
}
auto i_offset_ptr =
dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr;
auto o_offset_ptr =
dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr;
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch([src_ptr,
dst_ptr,
size = src.size(),
data_shape = data_shape,
i_strides = i_strides,
o_strides = o_strides,
i_offset_ptr,
o_offset_ptr]() mutable {
if (data_shape.empty()) {
auto val = static_cast<DstT>(*src_ptr);
*dst_ptr = val;
return;
}
auto [shape, strides] =
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
int ndim = shape.size();
if (ndim < 3) {
if (i_offset_ptr) {
src_ptr += i_offset_ptr[0];
}
if (o_offset_ptr) {
dst_ptr += o_offset_ptr[0];
}
if (ndim == 1) {
copy_dims<SrcT, DstT, 1>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} else if (ndim == 2) {
copy_dims<SrcT, DstT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} else if (ndim == 3) {
copy_dims<SrcT, DstT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
}
return;
}
if (i_offset_ptr) {
src_ptr += i_offset_ptr[0];
}
if (o_offset_ptr) {
dst_ptr += o_offset_ptr[0];
}
ContiguousIterator in(shape, strides[0], ndim - 3);
ContiguousIterator out(shape, strides[1], ndim - 3);
auto stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
for (int64_t elem = 0; elem < size; elem += stride) {
copy_dims<SrcT, DstT, 3>(
src_ptr + in.loc,
dst_ptr + out.loc,
shape,
strides[0],
strides[1],
ndim - 3);
in.step();
out.step();
}
});
}
template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst) {
inline void copy_general_general(const array& src, array& dst, Stream stream) {
copy_general_general<SrcT, DstT>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
src,
dst,
stream,
src.shape(),
src.strides(),
dst.strides(),
0,
0,
std::nullopt,
std::nullopt);
}
template <typename SrcT, typename DstT>
void copy_general(
const array& src,
array& dst,
Stream stream,
const Shape& data_shape,
const Strides& i_strides,
const Strides&,
int64_t i_offset,
int64_t o_offset) {
int64_t o_offset,
const std::optional<array>& dynamic_i_offset,
const std::optional<array>& dynamic_o_offset) {
copy_general_general<SrcT, DstT>(
src,
dst,
stream,
data_shape,
i_strides,
make_contiguous_strides(data_shape),
i_offset,
o_offset);
o_offset,
dynamic_i_offset,
dynamic_o_offset);
}
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
inline void copy_general(const array& src, array& dst, Stream stream) {
copy_general_general<SrcT, DstT>(
src,
dst,
stream,
src.shape(),
src.strides(),
make_contiguous_strides(src.shape()),
0,
0);
0,
std::nullopt,
std::nullopt);
}
template <typename SrcT, typename DstT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
void copy(
const array& src,
array& dst,
CopyType ctype,
Stream stream,
Args&&... args) {
switch (ctype) {
case CopyType::Scalar:
copy_single<SrcT, DstT>(src, dst);
copy_single<SrcT, DstT>(src, dst, stream);
return;
case CopyType::Vector:
copy_vector<SrcT, DstT>(src, dst);
copy_vector<SrcT, DstT>(src, dst, stream);
return;
case CopyType::General:
copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
copy_general<SrcT, DstT>(src, dst, stream, std::forward<Args>(args)...);
return;
case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
copy_general_general<SrcT, DstT>(
src, dst, stream, std::forward<Args>(args)...);
return;
}
}
template <typename SrcT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
void copy(
const array& src,
array& dst,
CopyType ctype,
Stream stream,
Args&&... args) {
switch (dst.dtype()) {
case bool_:
copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, bool>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case uint8:
copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, uint8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case uint16:
copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, uint16_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case uint32:
copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, uint32_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case uint64:
copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, uint64_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case int8:
copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, int8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case int16:
copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, int16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case int32:
copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, int32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case int64:
copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, int64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case float16:
copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, float16_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case float32:
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, float>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case float64:
copy<SrcT, double>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, double>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case bfloat16:
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, bfloat16_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case complex64:
copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, complex64_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
break;
}
}
@@ -210,84 +289,71 @@ inline void copy_inplace_dispatch(
const array& src,
array& dst,
CopyType ctype,
Stream stream,
Args&&... args) {
switch (src.dtype()) {
case bool_:
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
copy<bool>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case uint8:
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<uint8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case uint16:
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<uint16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case uint32:
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<uint32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case uint64:
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<uint64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case int8:
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<int8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case int16:
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<int16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case int32:
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<int32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case int64:
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<int64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case float16:
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<float16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case float32:
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
copy<float>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case float64:
copy<double>(src, dst, ctype, std::forward<Args>(args)...);
copy<double>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<bfloat16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case complex64:
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<complex64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
}
}
} // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype) {
copy_inplace_dispatch(src, dst, ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
copy_inplace_dispatch(src, dst, ctype, stream);
}
void copy(const array& src, array& dst, CopyType ctype) {
// Allocate the output
switch (ctype) {
case CopyType::Vector:
if (src.is_donatable() && src.itemsize() == dst.itemsize()) {
dst.copy_shared_buffer(src);
} else {
auto size = src.data_size();
dst.set_data(
allocator::malloc_or_wait(size * dst.itemsize()),
size,
src.strides(),
src.flags());
}
break;
case CopyType::Scalar:
case CopyType::General:
case CopyType::GeneralGeneral:
dst.set_data(allocator::malloc_or_wait(dst.nbytes()));
break;
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
bool donated = set_copy_output_data(src, dst, ctype);
if (donated && src.dtype() == dst.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_inplace(src, dst, ctype);
copy_inplace(src, dst, ctype, stream);
}
void copy_inplace(
@@ -298,7 +364,10 @@ void copy_inplace(
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
CopyType ctype,
Stream stream,
const std::optional<array>& dynamic_i_offset, /* = std::nullopt */
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
@@ -306,15 +375,18 @@ void copy_inplace(
src,
dst,
ctype,
stream,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
o_offset,
dynamic_i_offset,
dynamic_o_offset);
break;
case CopyType::Scalar:
case CopyType::Vector:
copy_inplace_dispatch(src, dst, ctype);
copy_inplace_dispatch(src, dst, ctype, stream);
}
}

View File

@@ -2,14 +2,16 @@
#pragma once
#include <optional>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
void copy(const array& src, array& dst, CopyType ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype);
void copy(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_inplace(
const array& src,
@@ -19,6 +21,9 @@ void copy_inplace(
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
CopyType ctype,
Stream stream,
const std::optional<array>& dynamic_i_offset = std::nullopt,
const std::optional<array>& dynamic_o_offset = std::nullopt);
} // namespace mlx::core

View File

@@ -0,0 +1,94 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/primitives.h"
namespace mlx::core::distributed {
std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
if (arr.flags().row_contiguous) {
return {arr, false};
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General, stream);
return {arr_copy, true};
}
};
void AllReduce::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto donate_or_copy = [s = stream()](const array& in, array& out) {
if (in.flags().row_contiguous) {
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
return in;
} else {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General, s);
out.copy_shared_buffer(arr_copy);
return arr_copy;
}
};
auto in = donate_or_copy(inputs[0], outputs[0]);
switch (reduce_type_) {
case Sum:
distributed::detail::all_sum(group(), in, outputs[0], stream());
break;
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
}
void AllGather::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
distributed::detail::all_gather(group(), in, outputs[0], stream());
if (copied) {
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporary(in);
}
}
void Send::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
distributed::detail::send(group(), in, dst_, stream());
outputs[0].copy_shared_buffer(inputs[0]);
if (copied) {
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporary(in);
}
}
void Recv::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 0);
assert(outputs.size() == 1);
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
distributed::detail::recv(group(), outputs[0], src_, stream());
}
} // namespace mlx::core::distributed

View File

@@ -3,6 +3,7 @@
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
@@ -16,59 +17,72 @@ void eigh_impl(
array& vectors,
array& values,
const std::string& uplo,
bool compute_eigenvectors) {
bool compute_eigenvectors,
Stream stream) {
auto vec_ptr = vectors.data<T>();
auto eig_ptr = values.data<T>();
char jobz = compute_eigenvectors ? 'V' : 'N';
auto N = vectors.shape(-1);
// Work query
int lwork = -1;
int liwork = -1;
int info;
{
T work;
int iwork;
syevd<T>(
&jobz,
uplo.c_str(),
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work);
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < vectors.size() / (N * N); ++i) {
syevd<T>(
&jobz,
uplo.c_str(),
&N,
vec_ptr,
&N,
eig_ptr,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
&liwork,
&info);
vec_ptr += N * N;
eig_ptr += N;
if (info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(vectors);
encoder.set_output_array(values);
encoder.dispatch([vec_ptr,
eig_ptr,
jobz,
uplo = uplo[0],
N = vectors.shape(-1),
size = vectors.size()]() mutable {
// Work query
int lwork = -1;
int liwork = -1;
int info;
{
T work;
int iwork;
syevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work);
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
auto iwork_buf =
array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < size / (N * N); ++i) {
syevd<T>(
&jobz,
&uplo,
&N,
vec_ptr,
&N,
eig_ptr,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
&liwork,
&info);
vec_ptr += N * N;
eig_ptr += N;
if (info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
}
});
if (!compute_eigenvectors) {
encoder.add_temporary(vectors);
}
}
@@ -89,7 +103,8 @@ void Eigh::eval_cpu(
copy(
a,
vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream());
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
@@ -107,14 +122,15 @@ void Eigh::eval_cpu(
flags.col_contiguous = true;
}
}
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
vectors.copy_shared_buffer(vectors, strides, flags, vectors.data_size());
}
switch (a.dtype()) {
case float32:
eigh_impl<float>(vectors, values, uplo_, compute_eigenvectors_);
eigh_impl<float>(vectors, values, uplo_, compute_eigenvectors_, stream());
break;
case float64:
eigh_impl<double>(vectors, values, uplo_, compute_eigenvectors_);
eigh_impl<double>(
vectors, values, uplo_, compute_eigenvectors_, stream());
break;
default:
throw std::runtime_error(

View File

@@ -0,0 +1,16 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/encoder.h"
namespace mlx::core::cpu {
CommandEncoder& get_command_encoder(Stream stream) {
static std::unordered_map<int, CommandEncoder> encoder_map;
auto it = encoder_map.find(stream.index);
if (it == encoder_map.end()) {
it = encoder_map.emplace(stream.index, stream).first;
}
return it->second;
}
} // namespace mlx::core::cpu

53
mlx/backend/cpu/encoder.h Normal file
View File

@@ -0,0 +1,53 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <unordered_map>
#include "mlx/array.h"
#include "mlx/scheduler.h"
namespace mlx::core::cpu {
struct CommandEncoder {
CommandEncoder(Stream stream) : stream_(stream) {}
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
CommandEncoder(CommandEncoder&&) = delete;
CommandEncoder& operator=(CommandEncoder&&) = delete;
void set_input_array(const array& a) {}
void set_output_array(array& a) {}
// Hold onto a temporary until any already scheduled tasks which use it as
// an input are complete.
void add_temporary(array arr) {
temporaries_.push_back(std::move(arr));
}
void add_temporaries(std::vector<array> arrays) {
temporaries_.insert(
temporaries_.end(),
std::make_move_iterator(arrays.begin()),
std::make_move_iterator(arrays.end()));
}
std::vector<array>& temporaries() {
return temporaries_;
}
template <class F, class... Args>
void dispatch(F&& f, Args&&... args) {
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
scheduler::enqueue(stream_, std::move(task));
}
private:
Stream stream_;
std::vector<array> temporaries_;
};
CommandEncoder& get_command_encoder(Stream stream);
} // namespace mlx::core::cpu

44
mlx/backend/cpu/eval.cpp Normal file
View File

@@ -0,0 +1,44 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/eval.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"
namespace mlx::core::cpu {
void eval(array& arr) {
auto s = arr.primitive().stream();
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
arr.primitive().eval_cpu(arr.inputs(), outputs);
}
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
auto& encoder = cpu::get_command_encoder(s);
scheduler::notify_new_task(s);
encoder.dispatch([s,
buffers = std::move(buffers),
temps = std::move(encoder.temporaries())]() {
scheduler::notify_task_completion(s);
});
}
} // namespace mlx::core::cpu

12
mlx/backend/cpu/eval.h Normal file
View File

@@ -0,0 +1,12 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/stream.h"
namespace mlx::core::cpu {
void eval(array& arr);
} // namespace mlx::core::cpu

View File

@@ -4,6 +4,7 @@
#include "mlx/3rdparty/pocketfft.h"
#include "mlx/allocator.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -38,46 +39,78 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
});
scale /= nelem;
}
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
if (in.dtype() == complex64 && out.dtype() == complex64) {
auto in_ptr =
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
auto out_ptr =
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
pocketfft::c2c(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
encoder.dispatch([shape = std::move(shape),
strides_in = std::move(strides_in),
strides_out = std::move(strides_out),
axes = axes_,
inverse = inverse_,
in_ptr,
out_ptr,
scale]() {
pocketfft::c2c(
shape,
strides_in,
strides_out,
axes,
!inverse,
in_ptr,
out_ptr,
scale);
});
} else if (in.dtype() == float32 && out.dtype() == complex64) {
auto in_ptr = in.data<float>();
auto out_ptr =
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
pocketfft::r2c(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
encoder.dispatch([shape = std::move(shape),
strides_in = std::move(strides_in),
strides_out = std::move(strides_out),
axes = axes_,
inverse = inverse_,
in_ptr,
out_ptr,
scale]() {
pocketfft::r2c(
shape,
strides_in,
strides_out,
axes,
!inverse,
in_ptr,
out_ptr,
scale);
});
} else if (in.dtype() == complex64 && out.dtype() == float32) {
auto in_ptr =
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
auto out_ptr = out.data<float>();
pocketfft::c2r(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
encoder.dispatch([shape = std::move(shape),
strides_in = std::move(strides_in),
strides_out = std::move(strides_out),
axes = axes_,
inverse = inverse_,
in_ptr,
out_ptr,
scale]() {
pocketfft::c2r(
shape,
strides_in,
strides_out,
axes,
!inverse,
in_ptr,
out_ptr,
scale);
});
} else {
throw std::runtime_error(
"[FFT] Received unexpected input and output type combination.");

View File

@@ -7,14 +7,20 @@ namespace mlx::core {
template <typename T>
void matmul(
const array& a,
const array& b,
array& out,
const T* a,
const T* b,
T* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides);
} // namespace mlx::core

View File

@@ -9,39 +9,46 @@
namespace mlx::core {
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
uint32_t size_bits = size_of(mlx_dtype) * 8;
switch (kindof(mlx_dtype)) {
case Dtype::Kind::b:
return BNNSDataTypeBoolean;
case Dtype::Kind::u:
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
case Dtype::Kind::i:
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
case Dtype::Kind::f:
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
case Dtype::Kind::V:
return BNNSDataTypeBFloat16;
case Dtype::Kind::c:
throw std::invalid_argument("BNNS does not support complex types");
}
template <typename T>
constexpr BNNSDataType to_bnns_dtype();
template <>
constexpr BNNSDataType to_bnns_dtype<float>() {
return BNNSDataType(BNNSDataTypeFloatBit | 32);
}
template <>
constexpr BNNSDataType to_bnns_dtype<float16_t>() {
return BNNSDataType(BNNSDataTypeFloatBit | 16);
}
template <>
constexpr BNNSDataType to_bnns_dtype<bfloat16_t>() {
return BNNSDataTypeBFloat16;
}
template <typename T>
void matmul_bnns(
const array& a,
const array& b,
array& out,
const T* a,
const T* b,
T* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta) {
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
@@ -115,14 +122,14 @@ void matmul_bnns(
auto bnns_filter =
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
for (int i = 0; i < batch_size; ++i) {
BNNSFilterApplyTwoInput(
bnns_filter,
a.data<uint8_t>() +
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
b.data<uint8_t>() +
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
out.data<uint8_t>() + M * N * i * out.itemsize());
reinterpret_cast<const uint8_t*>(
a + elem_to_loc(M * K * i, a_shape, a_strides)),
reinterpret_cast<const uint8_t*>(
b + elem_to_loc(K * N * i, b_shape, b_strides)),
reinterpret_cast<uint8_t*>(out + M * N * i));
}
BNNSFilterDestroy(bnns_filter);
@@ -131,30 +138,72 @@ void matmul_bnns(
template <>
void matmul<float16_t>(
const array& a,
const array& b,
array& out,
const float16_t* a,
const float16_t* b,
float16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta) {
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
matmul_bnns(
a,
b,
out,
a_transposed,
b_transposed,
lda,
ldb,
ldc,
alpha,
beta,
batch_size,
a_shape,
a_strides,
b_shape,
b_strides);
}
template <>
void matmul<bfloat16_t>(
const array& a,
const array& b,
array& out,
const bfloat16_t* a,
const bfloat16_t* b,
bfloat16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta) {
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
matmul_bnns(
a,
b,
out,
a_transposed,
b_transposed,
lda,
ldb,
ldc,
alpha,
beta,
batch_size,
a_shape,
a_strides,
b_shape,
b_strides);
}
} // namespace mlx::core

View File

@@ -8,20 +8,27 @@ namespace mlx::core {
template <>
void matmul<float>(
const array& a,
const array& b,
array& out,
const float* a,
const float* b,
float* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta) {
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < (a.size() / (M * K)); ++i) {
for (int i = 0; i < batch_size; ++i) {
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
@@ -29,34 +36,40 @@ void matmul<float>(
M,
N,
K,
alpha, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
alpha,
a + elem_to_loc(M * K * i, a_shape, a_strides),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb,
beta, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
beta,
out + M * N * i,
ldc);
}
}
template <>
void matmul<double>(
const array& a,
const array& b,
array& out,
const double* a,
const double* b,
double* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta) {
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < (a.size() / (M * K)); ++i) {
for (int i = 0; i < batch_size; ++i) {
cblas_dgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
@@ -64,15 +77,14 @@ void matmul<double>(
M,
N,
K,
alpha, // alpha
a.data<double>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
alpha,
a + elem_to_loc(M * K * i, a_shape, a_strides),
lda,
b.data<double>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb,
beta, // beta
out.data<double>() + M * N * i,
out.shape(-1) // ldc
);
beta,
out + M * N * i,
ldc);
}
}

View File

@@ -6,15 +6,21 @@ namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const array&,
const array&,
array&,
const bfloat16_t*,
const bfloat16_t*,
bfloat16_t*,
bool,
bool,
size_t,
size_t,
size_t,
float,
float) {
float,
size_t,
const Shape&,
const Strides&,
const Shape&,
const Strides&) {
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
}

View File

@@ -6,15 +6,21 @@ namespace mlx::core {
template <>
void matmul<float16_t>(
const array&,
const array&,
array&,
const float16_t*,
const float16_t*,
float16_t*,
bool,
bool,
size_t,
size_t,
size_t,
float,
float) {
float,
size_t,
const Shape&,
const Strides&,
const Shape&,
const Strides&) {
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
}

View File

@@ -4,16 +4,17 @@
#include "mlx/backend/common/hadamard.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core {
// n = 2^k component
template <typename T>
void hadamard_n(array& out, int n, int m, float scale) {
for (int b = 0; b < out.size() / n; b++) {
void hadamard_n(T* out, int n, int m, float scale, size_t size) {
for (int b = 0; b < size / n; b++) {
size_t loc = b * n;
T* data_ptr = out.data<T>() + loc;
T* data_ptr = out + loc;
int h = 1;
int n_over_2 = n / 2;
while (h < n) {
@@ -36,7 +37,7 @@ void hadamard_n(array& out, int n, int m, float scale) {
// m component
template <typename T>
void hadamard_m(array& out, int n, int m, float scale) {
void hadamard_m(T* out, int n, int m, float scale, size_t size) {
auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m];
auto start = 1;
@@ -51,9 +52,9 @@ void hadamard_m(array& out, int n, int m, float scale) {
end = matrix.find('\n', start);
}
for (int b = 0; b < out.size() / m / n; b++) {
for (int b = 0; b < size / m / n; b++) {
size_t loc = b * n * m;
T* data_ptr = out.data<T>() + loc;
T* data_ptr = out + loc;
for (int i = 0; i < n; i++) {
std::vector<float> out(m);
for (int j = 0; j < m; j++) {
@@ -74,12 +75,17 @@ void hadamard_m(array& out, int n, int m, float scale) {
}
template <typename T>
void hadamard(array& out, int n, int m, float scale) {
float n_scale = m > 1 ? 1.0 : scale;
hadamard_n<T>(out, n, m, n_scale);
if (m > 1) {
hadamard_m<T>(out, n, m, scale);
}
void hadamard(array& out, int n, int m, float scale, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
auto out_ptr = out.data<T>();
encoder.dispatch([out_ptr, size = out.size(), n, m, scale]() {
float n_scale = m > 1 ? 1.0 : scale;
hadamard_n<T>(out_ptr, n, m, n_scale, size);
if (m > 1) {
hadamard_m<T>(out_ptr, n, m, scale, size);
}
});
}
void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -87,18 +93,26 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Copy input to output
copy(in, out, CopyType::General);
if (in.flags().row_contiguous && in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
copy(
in,
out,
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream());
}
int axis = out.ndim() - 1;
auto [n, m] = decompose_hadamard(out.shape(axis));
switch (in.dtype()) {
case float32:
return hadamard<float>(out, n, m, scale_);
return hadamard<float>(out, n, m, scale_, stream());
case float16:
return hadamard<float16_t>(out, n, m, scale_);
return hadamard<float16_t>(out, n, m, scale_, stream());
case bfloat16:
return hadamard<bfloat16_t>(out, n, m, scale_);
return hadamard<bfloat16_t>(out, n, m, scale_, stream());
default:
throw std::invalid_argument("[hadamard] Unsupported type.");
}

View File

@@ -8,6 +8,7 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
namespace mlx::core {
@@ -27,7 +28,8 @@ void gather(
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const Shape& slice_sizes) {
const Shape& slice_sizes,
Stream stream) {
// If the array is row contiguous then we can do a contiguous copy given
// two conditions on the slice size:
// - Any number of leading ones in the slice sizes are allowed
@@ -73,38 +75,60 @@ void gather(
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>();
size_t out_idx = 0;
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
ContiguousIterator src_it;
if (!can_copy && src.ndim() > 0) {
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
}
for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0;
for (int ii = 0; ii < inds.size(); ++ii) {
auto ax = axes[ii];
auto idx_loc = its[ii].loc;
its[ii].step();
auto idx_val =
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
src_idx += (idx_val * src.strides()[ax]);
}
if (slice_size == 1) {
dst_ptr[out_idx++] = src_ptr[src_idx];
} else if (can_copy) {
std::copy(
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
out_idx += slice_size;
} else {
for (int jj = 0; jj < slice_size; jj++) {
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
src_it.step();
}
src_it.reset();
}
std::vector<const IdxT*> ind_ptrs;
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
for (auto& idx : inds) {
ind_ptrs.push_back(idx.data<IdxT>());
encoder.set_input_array(idx);
}
encoder.set_output_array(out);
encoder.dispatch([src_ptr,
dst_ptr,
ind_ptrs = std::move(ind_ptrs),
axes,
ind_size,
slice_size,
src_shape = src.shape(),
src_strides = src.strides(),
src_it = std::move(src_it),
its = std::move(its),
can_copy]() mutable {
size_t out_idx = 0;
for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0;
for (int ii = 0; ii < ind_ptrs.size(); ++ii) {
auto ax = axes[ii];
auto idx_loc = its[ii].loc;
its[ii].step();
auto idx_val = offset_neg_idx(ind_ptrs[ii][idx_loc], src_shape[ax]);
src_idx += (idx_val * src_strides[ax]);
}
if (slice_size == 1) {
dst_ptr[out_idx++] = src_ptr[src_idx];
} else if (can_copy) {
std::copy(
src_ptr + src_idx,
src_ptr + src_idx + slice_size,
dst_ptr + out_idx);
out_idx += slice_size;
} else {
for (int jj = 0; jj < slice_size; jj++) {
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
src_it.step();
}
src_it.reset();
}
}
});
}
template <typename IdxT>
@@ -113,49 +137,50 @@ void dispatch_gather(
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const Shape& size) {
const Shape& size,
Stream stream) {
switch (out.dtype()) {
case bool_:
gather<bool, IdxT>(src, inds, out, axes, size);
gather<bool, IdxT>(src, inds, out, axes, size, stream);
break;
case uint8:
gather<uint8_t, IdxT>(src, inds, out, axes, size);
gather<uint8_t, IdxT>(src, inds, out, axes, size, stream);
break;
case uint16:
gather<uint16_t, IdxT>(src, inds, out, axes, size);
gather<uint16_t, IdxT>(src, inds, out, axes, size, stream);
break;
case uint32:
gather<uint32_t, IdxT>(src, inds, out, axes, size);
gather<uint32_t, IdxT>(src, inds, out, axes, size, stream);
break;
case uint64:
gather<uint64_t, IdxT>(src, inds, out, axes, size);
gather<uint64_t, IdxT>(src, inds, out, axes, size, stream);
break;
case int8:
gather<int8_t, IdxT>(src, inds, out, axes, size);
gather<int8_t, IdxT>(src, inds, out, axes, size, stream);
break;
case int16:
gather<int16_t, IdxT>(src, inds, out, axes, size);
gather<int16_t, IdxT>(src, inds, out, axes, size, stream);
break;
case int32:
gather<int32_t, IdxT>(src, inds, out, axes, size);
gather<int32_t, IdxT>(src, inds, out, axes, size, stream);
break;
case int64:
gather<int64_t, IdxT>(src, inds, out, axes, size);
gather<int64_t, IdxT>(src, inds, out, axes, size, stream);
break;
case float16:
gather<float16_t, IdxT>(src, inds, out, axes, size);
gather<float16_t, IdxT>(src, inds, out, axes, size, stream);
break;
case float32:
gather<float, IdxT>(src, inds, out, axes, size);
gather<float, IdxT>(src, inds, out, axes, size, stream);
break;
case float64:
gather<double, IdxT>(src, inds, out, axes, size);
gather<double, IdxT>(src, inds, out, axes, size, stream);
break;
case bfloat16:
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
gather<bfloat16_t, IdxT>(src, inds, out, axes, size, stream);
break;
case complex64:
gather<complex64_t, IdxT>(src, inds, out, axes, size);
gather<complex64_t, IdxT>(src, inds, out, axes, size, stream);
break;
}
}
@@ -167,34 +192,34 @@ void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
std::vector<array> inds(inputs.begin() + 1, inputs.end());
if (inds.empty()) {
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_, stream());
return;
}
switch (inds[0].dtype()) {
case uint8:
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_, stream());
break;
case uint16:
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_, stream());
break;
case uint32:
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_, stream());
break;
case uint64:
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_, stream());
break;
case int8:
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_, stream());
break;
case int16:
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_, stream());
break;
case int32:
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_, stream());
break;
case int64:
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_, stream());
break;
default:
throw std::runtime_error(
@@ -207,7 +232,8 @@ void gather_axis(
const array& src,
const array& ind,
array& out,
const int axis) {
const int axis,
Stream stream) {
auto strides = ind.strides();
strides.erase(strides.begin() + axis);
auto shape = ind.shape();
@@ -235,20 +261,39 @@ void gather_axis(
for (int i = axis + 1; i < ind.ndim(); ++i) {
size_post *= ind.shape(i);
}
size_t stride_pre = size_post * ind_ax_size;
for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) {
for (int j = 0; j < ind_ax_size; ++j) {
auto ind_val = offset_neg_idx(
ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size);
dst_ptr[k + j * dst_ax_stride] =
src_ptr[src_it.loc + ind_val * src_ax_stride];
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_input_array(ind);
encoder.set_output_array(out);
encoder.dispatch([ind_ptr,
src_ptr,
dst_ptr,
size_pre,
size_post,
ind_ax_size,
src_ax_size,
ind_ax_stride,
src_ax_stride,
dst_ax_stride,
ind_it = std::move(ind_it),
src_it = std::move(src_it)]() mutable {
size_t stride_pre = size_post * ind_ax_size;
for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) {
for (int j = 0; j < ind_ax_size; ++j) {
auto ind_val = offset_neg_idx(
ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size);
dst_ptr[k + j * dst_ax_stride] =
src_ptr[src_it.loc + ind_val * src_ax_stride];
}
ind_it.step();
src_it.step();
}
ind_it.step();
src_it.step();
dst_ptr += stride_pre;
}
dst_ptr += stride_pre;
}
});
}
template <typename IdxT>
@@ -256,49 +301,50 @@ void dispatch_gather_axis(
const array& src,
const array& inds,
array& out,
const int axis) {
const int axis,
Stream stream) {
switch (out.dtype()) {
case bool_:
gather_axis<bool, IdxT>(src, inds, out, axis);
gather_axis<bool, IdxT>(src, inds, out, axis, stream);
break;
case uint8:
gather_axis<uint8_t, IdxT>(src, inds, out, axis);
gather_axis<uint8_t, IdxT>(src, inds, out, axis, stream);
break;
case uint16:
gather_axis<uint16_t, IdxT>(src, inds, out, axis);
gather_axis<uint16_t, IdxT>(src, inds, out, axis, stream);
break;
case uint32:
gather_axis<uint32_t, IdxT>(src, inds, out, axis);
gather_axis<uint32_t, IdxT>(src, inds, out, axis, stream);
break;
case uint64:
gather_axis<uint64_t, IdxT>(src, inds, out, axis);
gather_axis<uint64_t, IdxT>(src, inds, out, axis, stream);
break;
case int8:
gather_axis<int8_t, IdxT>(src, inds, out, axis);
gather_axis<int8_t, IdxT>(src, inds, out, axis, stream);
break;
case int16:
gather_axis<int16_t, IdxT>(src, inds, out, axis);
gather_axis<int16_t, IdxT>(src, inds, out, axis, stream);
break;
case int32:
gather_axis<int32_t, IdxT>(src, inds, out, axis);
gather_axis<int32_t, IdxT>(src, inds, out, axis, stream);
break;
case int64:
gather_axis<int64_t, IdxT>(src, inds, out, axis);
gather_axis<int64_t, IdxT>(src, inds, out, axis, stream);
break;
case float16:
gather_axis<float16_t, IdxT>(src, inds, out, axis);
gather_axis<float16_t, IdxT>(src, inds, out, axis, stream);
break;
case float32:
gather_axis<float, IdxT>(src, inds, out, axis);
gather_axis<float, IdxT>(src, inds, out, axis, stream);
break;
case float64:
gather_axis<double, IdxT>(src, inds, out, axis);
gather_axis<double, IdxT>(src, inds, out, axis, stream);
break;
case bfloat16:
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis, stream);
break;
case complex64:
gather_axis<complex64_t, IdxT>(src, inds, out, axis);
gather_axis<complex64_t, IdxT>(src, inds, out, axis, stream);
break;
}
}
@@ -309,28 +355,28 @@ void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& inds = inputs[1];
switch (inds.dtype()) {
case uint8:
dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
dispatch_gather_axis<uint8_t>(src, inds, out, axis_, stream());
break;
case uint16:
dispatch_gather_axis<uint16_t>(src, inds, out, axis_);
dispatch_gather_axis<uint16_t>(src, inds, out, axis_, stream());
break;
case uint32:
dispatch_gather_axis<uint32_t>(src, inds, out, axis_);
dispatch_gather_axis<uint32_t>(src, inds, out, axis_, stream());
break;
case uint64:
dispatch_gather_axis<uint64_t>(src, inds, out, axis_);
dispatch_gather_axis<uint64_t>(src, inds, out, axis_, stream());
break;
case int8:
dispatch_gather_axis<int8_t>(src, inds, out, axis_);
dispatch_gather_axis<int8_t>(src, inds, out, axis_, stream());
break;
case int16:
dispatch_gather_axis<int16_t>(src, inds, out, axis_);
dispatch_gather_axis<int16_t>(src, inds, out, axis_, stream());
break;
case int32:
dispatch_gather_axis<int32_t>(src, inds, out, axis_);
dispatch_gather_axis<int32_t>(src, inds, out, axis_, stream());
break;
case int64:
dispatch_gather_axis<int64_t>(src, inds, out, axis_);
dispatch_gather_axis<int64_t>(src, inds, out, axis_, stream());
break;
default:
throw std::runtime_error(
@@ -345,7 +391,8 @@ void scatter(
array& out,
const std::vector<array>& inds,
const std::vector<int>& axes,
const OpT& op) {
const OpT& op,
Stream stream) {
int nind = inds.size();
auto inds_ndim = updates.ndim() - out.ndim();
size_t n_updates = nind ? inds[0].size() : 1;
@@ -361,26 +408,45 @@ void scatter(
ContiguousIterator update_it(updates);
ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0;
for (int j = 0; j < nind; ++j) {
auto ax = axes[j];
auto idx_loc = its[j].loc;
its[j].step();
auto idx_val =
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
out_offset += (idx_val * out.strides()[ax]);
}
update_it.seek(i * update_size);
for (int j = 0; j < update_size; ++j) {
op(updates.data<InT>()[update_it.loc],
out.data<InT>() + out_offset + out_it.loc);
update_it.step();
out_it.step();
}
out_it.reset();
update_it.reset();
std::vector<const IdxT*> ind_ptrs;
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(updates);
for (auto& idx : inds) {
ind_ptrs.push_back(idx.data<IdxT>());
encoder.set_input_array(idx);
}
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<InT>(),
upd_ptr = updates.data<InT>(),
ind_ptrs = std::move(ind_ptrs),
axes,
n_updates,
update_size,
op = std::move(op),
out_shape = out.shape(),
out_strides = out.strides(),
out_it = std::move(out_it),
update_it = std::move(update_it),
its = std::move(its)]() mutable {
for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0;
for (int j = 0; j < ind_ptrs.size(); ++j) {
auto ax = axes[j];
auto idx_loc = its[j].loc;
its[j].step();
auto idx_val = offset_neg_idx(ind_ptrs[j][idx_loc], out_shape[ax]);
out_offset += (idx_val * out_strides[ax]);
}
update_it.seek(i * update_size);
for (int j = 0; j < update_size; ++j) {
op(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
update_it.step();
out_it.step();
}
out_it.reset();
update_it.reset();
}
});
}
template <typename InT, typename IdxT>
@@ -389,29 +455,53 @@ void dispatch_scatter_inds(
const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes,
Scatter::ReduceType rtype) {
Scatter::ReduceType rtype,
Stream stream) {
switch (rtype) {
case Scatter::None:
scatter<InT, IdxT>(
updates, out, indices, axes, [](auto x, auto* y) { (*y) = x; });
updates,
out,
indices,
axes,
[](auto x, auto* y) { (*y) = x; },
stream);
break;
case Scatter::Sum:
scatter<InT, IdxT>(
updates, out, indices, axes, [](auto x, auto* y) { (*y) += x; });
updates,
out,
indices,
axes,
[](auto x, auto* y) { (*y) += x; },
stream);
break;
case Scatter::Prod:
scatter<InT, IdxT>(
updates, out, indices, axes, [](auto x, auto* y) { (*y) *= x; });
updates,
out,
indices,
axes,
[](auto x, auto* y) { (*y) *= x; },
stream);
break;
case Scatter::Max:
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
(*y) = (*y > x) ? *y : x;
});
scatter<InT, IdxT>(
updates,
out,
indices,
axes,
[](auto x, auto* y) { (*y) = (*y > x) ? *y : x; },
stream);
break;
case Scatter::Min:
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
(*y) = (*y < x) ? *y : x;
});
scatter<InT, IdxT>(
updates,
out,
indices,
axes,
[](auto x, auto* y) { (*y) = (*y < x) ? *y : x; },
stream);
break;
}
}
@@ -422,36 +512,46 @@ void dispatch_scatter(
const std::vector<array>& inds,
const array& updates,
const std::vector<int>& axes,
Scatter::ReduceType rtype) {
Scatter::ReduceType rtype,
Stream stream) {
if (inds.empty()) {
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
dispatch_scatter_inds<InT, uint8_t>(
out, inds, updates, axes, rtype, stream);
return;
}
switch (inds[0].dtype()) {
case uint8:
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
dispatch_scatter_inds<InT, uint8_t>(
out, inds, updates, axes, rtype, stream);
break;
case uint16:
dispatch_scatter_inds<InT, uint16_t>(out, inds, updates, axes, rtype);
dispatch_scatter_inds<InT, uint16_t>(
out, inds, updates, axes, rtype, stream);
break;
case uint32:
dispatch_scatter_inds<InT, uint32_t>(out, inds, updates, axes, rtype);
dispatch_scatter_inds<InT, uint32_t>(
out, inds, updates, axes, rtype, stream);
break;
case uint64:
dispatch_scatter_inds<InT, uint64_t>(out, inds, updates, axes, rtype);
dispatch_scatter_inds<InT, uint64_t>(
out, inds, updates, axes, rtype, stream);
break;
case int8:
dispatch_scatter_inds<InT, int8_t>(out, inds, updates, axes, rtype);
dispatch_scatter_inds<InT, int8_t>(
out, inds, updates, axes, rtype, stream);
break;
case int16:
dispatch_scatter_inds<InT, int16_t>(out, inds, updates, axes, rtype);
dispatch_scatter_inds<InT, int16_t>(
out, inds, updates, axes, rtype, stream);
break;
case int32:
dispatch_scatter_inds<InT, int32_t>(out, inds, updates, axes, rtype);
dispatch_scatter_inds<InT, int32_t>(
out, inds, updates, axes, rtype, stream);
break;
case int64:
dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
dispatch_scatter_inds<InT, int64_t>(
out, inds, updates, axes, rtype, stream);
break;
default:
throw std::runtime_error(
@@ -469,50 +569,63 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype);
copy(src, out, ctype, stream());
switch (src.dtype()) {
case bool_:
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_, stream());
break;
case uint8:
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
dispatch_scatter<uint8_t>(
out, inds, updates, axes_, reduce_type_, stream());
break;
case uint16:
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
dispatch_scatter<uint16_t>(
out, inds, updates, axes_, reduce_type_, stream());
break;
case uint32:
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
dispatch_scatter<uint32_t>(
out, inds, updates, axes_, reduce_type_, stream());
break;
case uint64:
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
dispatch_scatter<uint64_t>(
out, inds, updates, axes_, reduce_type_, stream());
break;
case int8:
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
dispatch_scatter<int8_t>(
out, inds, updates, axes_, reduce_type_, stream());
break;
case int16:
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
dispatch_scatter<int16_t>(
out, inds, updates, axes_, reduce_type_, stream());
break;
case int32:
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
dispatch_scatter<int32_t>(
out, inds, updates, axes_, reduce_type_, stream());
break;
case int64:
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
dispatch_scatter<int64_t>(
out, inds, updates, axes_, reduce_type_, stream());
break;
case float16:
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
dispatch_scatter<float16_t>(
out, inds, updates, axes_, reduce_type_, stream());
break;
case float32:
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
dispatch_scatter<float>(
out, inds, updates, axes_, reduce_type_, stream());
break;
case float64:
dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
dispatch_scatter<double>(
out, inds, updates, axes_, reduce_type_, stream());
break;
case bfloat16:
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
dispatch_scatter<bfloat16_t>(
out, inds, updates, axes_, reduce_type_, stream());
break;
case complex64:
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
dispatch_scatter<complex64_t>(
out, inds, updates, axes_, reduce_type_, stream());
break;
}
}
@@ -523,7 +636,8 @@ void scatter_axis(
const array idx,
const array& upd,
int axis,
const OpT& op) {
const OpT& op,
Stream stream) {
auto strides = idx.strides();
strides.erase(strides.begin() + axis);
auto shape = idx.shape();
@@ -543,6 +657,11 @@ void scatter_axis(
auto idx_ax_size = idx.shape(axis);
auto dst_ax_size = out.shape(axis);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(idx);
encoder.set_input_array(upd);
encoder.set_output_array(out);
size_t size_pre = 1;
size_t size_post = 1;
for (int i = 0; i < axis; ++i) {
@@ -551,20 +670,34 @@ void scatter_axis(
for (int i = axis + 1; i < idx.ndim(); ++i) {
size_post *= idx.shape(i);
}
size_t stride_pre = size_post * dst_ax_size;
for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) {
for (int j = 0; j < idx_ax_size; ++j) {
auto ind_val = offset_neg_idx(
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
op(upd_ptr[upd_it.loc + j * upd_ax_stride],
dst_ptr + k + ind_val * dst_ax_stride);
encoder.dispatch([idx_ptr,
upd_ptr,
dst_ptr,
size_pre,
size_post,
idx_ax_size,
dst_ax_size,
idx_ax_stride,
upd_ax_stride,
dst_ax_stride,
idx_it = std::move(idx_it),
upd_it = std::move(upd_it),
op = std::move(op)]() mutable {
size_t stride_pre = size_post * dst_ax_size;
for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) {
for (int j = 0; j < idx_ax_size; ++j) {
auto ind_val = offset_neg_idx(
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
op(upd_ptr[upd_it.loc + j * upd_ax_stride],
dst_ptr + k + ind_val * dst_ax_stride);
}
idx_it.step();
upd_it.step();
}
idx_it.step();
upd_it.step();
dst_ptr += stride_pre;
}
dst_ptr += stride_pre;
}
});
}
template <typename InT, typename IdxT>
@@ -573,15 +706,16 @@ void dispatch_scatter_axis_op(
const array& idx,
const array& updates,
int axis,
ScatterAxis::ReduceType rtype) {
ScatterAxis::ReduceType rtype,
Stream stream) {
switch (rtype) {
case ScatterAxis::None:
scatter_axis<InT, IdxT>(
out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; });
out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; }, stream);
break;
case ScatterAxis::Sum:
scatter_axis<InT, IdxT>(
out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; });
out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; }, stream);
break;
}
}
@@ -592,31 +726,40 @@ void dispatch_scatter_axis(
const array& idx,
const array& updates,
int axis,
ScatterAxis::ReduceType rtype) {
ScatterAxis::ReduceType rtype,
Stream stream) {
switch (idx.dtype()) {
case uint8:
dispatch_scatter_axis_op<InT, uint8_t>(out, idx, updates, axis, rtype);
dispatch_scatter_axis_op<InT, uint8_t>(
out, idx, updates, axis, rtype, stream);
break;
case uint16:
dispatch_scatter_axis_op<InT, uint16_t>(out, idx, updates, axis, rtype);
dispatch_scatter_axis_op<InT, uint16_t>(
out, idx, updates, axis, rtype, stream);
break;
case uint32:
dispatch_scatter_axis_op<InT, uint32_t>(out, idx, updates, axis, rtype);
dispatch_scatter_axis_op<InT, uint32_t>(
out, idx, updates, axis, rtype, stream);
break;
case uint64:
dispatch_scatter_axis_op<InT, uint64_t>(out, idx, updates, axis, rtype);
dispatch_scatter_axis_op<InT, uint64_t>(
out, idx, updates, axis, rtype, stream);
break;
case int8:
dispatch_scatter_axis_op<InT, int8_t>(out, idx, updates, axis, rtype);
dispatch_scatter_axis_op<InT, int8_t>(
out, idx, updates, axis, rtype, stream);
break;
case int16:
dispatch_scatter_axis_op<InT, int16_t>(out, idx, updates, axis, rtype);
dispatch_scatter_axis_op<InT, int16_t>(
out, idx, updates, axis, rtype, stream);
break;
case int32:
dispatch_scatter_axis_op<InT, int32_t>(out, idx, updates, axis, rtype);
dispatch_scatter_axis_op<InT, int32_t>(
out, idx, updates, axis, rtype, stream);
break;
case int64:
dispatch_scatter_axis_op<InT, int64_t>(out, idx, updates, axis, rtype);
dispatch_scatter_axis_op<InT, int64_t>(
out, idx, updates, axis, rtype, stream);
break;
default:
throw std::runtime_error(
@@ -634,51 +777,64 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype);
copy(src, out, ctype, stream());
switch (src.dtype()) {
case bool_:
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
dispatch_scatter_axis<bool>(
out, idx, updates, axis_, reduce_type_, stream());
break;
case uint8:
dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);
dispatch_scatter_axis<uint8_t>(
out, idx, updates, axis_, reduce_type_, stream());
break;
case uint16:
dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);
dispatch_scatter_axis<uint16_t>(
out, idx, updates, axis_, reduce_type_, stream());
break;
case uint32:
dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);
dispatch_scatter_axis<uint32_t>(
out, idx, updates, axis_, reduce_type_, stream());
break;
case uint64:
dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);
dispatch_scatter_axis<uint64_t>(
out, idx, updates, axis_, reduce_type_, stream());
break;
case int8:
dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);
dispatch_scatter_axis<int8_t>(
out, idx, updates, axis_, reduce_type_, stream());
break;
case int16:
dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);
dispatch_scatter_axis<int16_t>(
out, idx, updates, axis_, reduce_type_, stream());
break;
case int32:
dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);
dispatch_scatter_axis<int32_t>(
out, idx, updates, axis_, reduce_type_, stream());
break;
case int64:
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
dispatch_scatter_axis<int64_t>(
out, idx, updates, axis_, reduce_type_, stream());
break;
case float16:
dispatch_scatter_axis<float16_t>(out, idx, updates, axis_, reduce_type_);
dispatch_scatter_axis<float16_t>(
out, idx, updates, axis_, reduce_type_, stream());
break;
case float32:
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
dispatch_scatter_axis<float>(
out, idx, updates, axis_, reduce_type_, stream());
break;
case float64:
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
dispatch_scatter_axis<double>(
out, idx, updates, axis_, reduce_type_, stream());
break;
case bfloat16:
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_);
dispatch_scatter_axis<bfloat16_t>(
out, idx, updates, axis_, reduce_type_, stream());
break;
case complex64:
dispatch_scatter_axis<complex64_t>(
out, idx, updates, axis_, reduce_type_);
out, idx, updates, axis_, reduce_type_, stream());
break;
}
}

View File

@@ -2,20 +2,21 @@
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {
template <typename T>
void general_inv(array& inv, int N, int i) {
void general_inv(T* inv, int N) {
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
// Compute LU factorization.
getrf<T>(
/* m = */ &N,
/* n = */ &N,
/* a = */ inv.data<T>() + N * N * i,
/* a = */ inv,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
@@ -53,7 +54,7 @@ void general_inv(array& inv, int N, int i) {
// Compute inverse.
getri<T>(
/* m = */ &N,
/* a = */ inv.data<T>() + N * N * i,
/* a = */ inv,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
@@ -68,29 +69,28 @@ void general_inv(array& inv, int N, int i) {
}
template <typename T>
void tri_inv(array& inv, int N, int i, bool upper) {
void tri_inv(T* inv, int N, bool upper) {
const char uplo = upper ? 'L' : 'U';
const char diag = 'N';
T* data = inv.data<T>() + N * N * i;
int info;
trtri<T>(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ data,
/* a = */ inv,
/* lda = */ &N,
/* info = */ &info);
// zero out the other triangle
if (upper) {
for (int i = 0; i < N; i++) {
std::fill(data, data + i, 0.0f);
data += N;
std::fill(inv, inv + i, 0.0f);
inv += N;
}
} else {
for (int i = 0; i < N; i++) {
std::fill(data + i + 1, data + N, 0.0f);
data += N;
std::fill(inv + i + 1, inv + N, 0.0f);
inv += N;
}
}
@@ -103,34 +103,53 @@ void tri_inv(array& inv, int N, int i, bool upper) {
}
template <typename T>
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
void inverse_impl(
const array& a,
array& inv,
bool tri,
bool upper,
Stream stream) {
// Lapack uses the column-major convention. We take advantage of the following
// identity to avoid transposing (see
// https://math.stackexchange.com/a/340234):
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
// The inverse is computed in place, so just copy the input to the output.
copy(a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy(
a,
inv,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream);
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
for (int i = 0; i < num_matrices; i++) {
if (tri) {
tri_inv<T>(inv, N, i, upper);
} else {
general_inv<T>(inv, N, i);
}
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(inv);
auto inv_ptr = inv.data<T>();
if (tri) {
encoder.dispatch([inv_ptr, N, num_matrices, upper]() {
for (int i = 0; i < num_matrices; i++) {
tri_inv<T>(inv_ptr + N * N * i, N, upper);
}
});
} else {
encoder.dispatch([inv_ptr, N, num_matrices]() {
for (int i = 0; i < num_matrices; i++) {
general_inv<T>(inv_ptr + N * N * i, N);
}
});
}
}
void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) {
switch (inputs[0].dtype()) {
case float32:
inverse_impl<float>(inputs[0], output, tri_, upper_);
inverse_impl<float>(inputs[0], output, tri_, upper_, stream());
break;
case float64:
inverse_impl<double>(inputs[0], output, tri_, upper_);
inverse_impl<double>(inputs[0], output, tri_, upper_, stream());
break;
default:
throw std::runtime_error(

View File

@@ -4,15 +4,22 @@
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {
template <typename T>
void luf_impl(const array& a, array& lu, array& pivots, array& row_indices) {
void luf_impl(
const array& a,
array& lu,
array& pivots,
array& row_indices,
Stream stream) {
int M = a.shape(-2);
int N = a.shape(-1);
int K = std::min(M, N);
// Copy a into lu and make it col contiguous
auto ndim = lu.ndim();
@@ -26,57 +33,72 @@ void luf_impl(const array& a, array& lu, array& pivots, array& row_indices) {
lu.set_data(
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
copy_inplace(
a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral);
a,
lu,
a.shape(),
a.strides(),
strides,
0,
0,
CopyType::GeneralGeneral,
stream);
auto a_ptr = lu.data<T>();
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
auto pivots_ptr = pivots.data<uint32_t>();
auto row_indices_ptr = row_indices.data<uint32_t>();
int info;
size_t num_matrices = a.size() / (M * N);
for (size_t i = 0; i < num_matrices; ++i) {
// Compute LU factorization of A
getrf<T>(
/* m */ &M,
/* n */ &N,
/* a */ a_ptr,
/* lda */ &M,
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
/* info */ &info);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(lu);
encoder.set_output_array(pivots);
encoder.set_output_array(row_indices);
if (info != 0) {
std::stringstream ss;
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
<< ((info > 0) ? " because matrix is singular"
: " because argument had an illegal value");
throw std::runtime_error(ss.str());
}
encoder.dispatch(
[a_ptr, pivots_ptr, row_indices_ptr, num_matrices, M, N, K]() mutable {
int info;
for (size_t i = 0; i < num_matrices; ++i) {
// Compute LU factorization of A
getrf<T>(
/* m */ &M,
/* n */ &N,
/* a */ a_ptr,
/* lda */ &M,
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
/* info */ &info);
// Subtract 1 to get 0-based index
int j = 0;
for (; j < pivots.shape(-1); ++j) {
pivots_ptr[j]--;
row_indices_ptr[j] = j;
}
for (; j < row_indices.shape(-1); ++j) {
row_indices_ptr[j] = j;
}
for (int j = pivots.shape(-1) - 1; j >= 0; --j) {
auto piv = pivots_ptr[j];
auto t1 = row_indices_ptr[piv];
auto t2 = row_indices_ptr[j];
row_indices_ptr[j] = t1;
row_indices_ptr[piv] = t2;
}
if (info != 0) {
std::stringstream ss;
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
<< ((info > 0) ? " because matrix is singular"
: " because argument had an illegal value");
throw std::runtime_error(ss.str());
}
// Advance pointers to the next matrix
a_ptr += M * N;
pivots_ptr += pivots.shape(-1);
row_indices_ptr += pivots.shape(-1);
}
// Subtract 1 to get 0-based index
int j = 0;
for (; j < K; ++j) {
pivots_ptr[j]--;
row_indices_ptr[j] = j;
}
for (; j < M; ++j) {
row_indices_ptr[j] = j;
}
for (int j = K - 1; j >= 0; --j) {
auto piv = pivots_ptr[j];
auto t1 = row_indices_ptr[piv];
auto t2 = row_indices_ptr[j];
row_indices_ptr[j] = t1;
row_indices_ptr[piv] = t2;
}
// Advance pointers to the next matrix
a_ptr += M * N;
pivots_ptr += K;
row_indices_ptr += M;
}
});
}
void LUF::eval_cpu(
@@ -85,10 +107,10 @@ void LUF::eval_cpu(
assert(inputs.size() == 1);
switch (inputs[0].dtype()) {
case float32:
luf_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2]);
luf_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2], stream());
break;
case float64:
luf_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2]);
luf_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2], stream());
break;
default:
throw std::runtime_error(

View File

@@ -5,6 +5,7 @@
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"
@@ -64,36 +65,36 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b_pre = inputs[1];
auto check_transpose =
[](const array& arr, bool do_copy, bool expand_all = false) {
[s = stream()](const array& arr, bool do_copy, bool expand_all = false) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(false, stx, arr_copy);
copy(arr, arr_copy, CopyType::Vector, s);
return std::make_tuple(false, stx, arr_copy, true);
}
return std::make_tuple(false, stx, arr);
return std::make_tuple(false, stx, arr, false);
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(true, sty, arr_copy);
copy(arr, arr_copy, CopyType::Vector, s);
return std::make_tuple(true, sty, arr_copy, true);
}
return std::make_tuple(true, sty, arr);
return std::make_tuple(true, sty, arr, false);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
copy(arr, arr_copy, CopyType::General, s);
int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
return std::make_tuple(false, stx, arr_copy, true);
}
};
bool has_op_mask = inputs.size() > 3;
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
auto [a_transposed, lda, a] =
auto [a_transposed, lda, a, a_copied] =
check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_);
auto [b_transposed, ldb, b] =
auto [b_transposed, ldb, b, b_copied] =
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
size_t M = a.shape(-2);
@@ -104,31 +105,39 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return;
}
auto& encoder = cpu::get_command_encoder(stream());
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<void>(), nbytes = out.nbytes()]() {
std::memset(out_ptr, 0, nbytes);
});
return;
}
auto mask_array = [](const array& mask,
auto mask_array = [](const void* mask,
float* data,
int block_size,
int batch_idx,
int X,
int Y,
size_t X_data_str,
size_t Y_data_str) {
size_t Y_data_str,
const Shape& mask_shape,
const Strides& mask_strides,
bool is_bool) {
auto ndim = mask_shape.size();
auto mask_offset = elem_to_loc(
mask.shape(-1) * mask.shape(-2) * batch_idx,
mask.shape(),
mask.strides());
mask_shape[ndim - 1] * mask_shape[ndim - 2] * batch_idx,
mask_shape,
mask_strides);
auto X_mask_str = mask.strides()[mask.ndim() - 2];
auto Y_mask_str = mask.strides()[mask.ndim() - 1];
auto X_mask_str = mask_strides[ndim - 2];
auto Y_mask_str = mask_strides[ndim - 1];
if (mask.dtype() == bool_) {
if (is_bool) {
return mask_matrix(
data,
mask.data<bool>(),
static_cast<const bool*>(mask),
block_size,
X,
Y,
@@ -140,7 +149,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
} else {
return mask_matrix(
data,
mask.data<float>(),
static_cast<const float*>(mask),
block_size,
X,
Y,
@@ -152,61 +161,155 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
}
};
for (int i = 0; i < (out.size() / (M * size_t(N))); ++i) {
// Adjust pointer
float* ai =
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
float* bi =
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides());
float* ci = out.data<float>() + M * N * i;
encoder.set_input_array(a);
encoder.set_input_array(b);
const void* a_mask_ptr;
const void* b_mask_ptr;
const void* out_mask_ptr;
Shape a_mask_shape;
Shape b_mask_shape;
Shape out_mask_shape;
Strides a_mask_strides;
Strides b_mask_strides;
Strides out_mask_strides;
bool a_mask_bool;
bool b_mask_bool;
bool out_mask_bool;
if (has_op_mask) {
auto& a_mask = inputs[inputs.size() - 2];
auto& b_mask = inputs[inputs.size() - 1];
a_mask_ptr = a_mask.data<void>();
b_mask_ptr = b_mask.data<void>();
a_mask_shape = a_mask.shape();
b_mask_shape = b_mask.shape();
a_mask_strides = a_mask.strides();
b_mask_strides = b_mask.strides();
a_mask_bool = (a_mask.dtype() == bool_);
b_mask_bool = (b_mask.dtype() == bool_);
encoder.set_input_array(a_mask);
encoder.set_input_array(b_mask);
}
if (has_out_mask) {
auto& out_mask = inputs[2];
out_mask_ptr = out_mask.data<void>();
out_mask_bool = (out_mask.dtype() == bool_);
encoder.set_input_array(out_mask);
out_mask_shape = out_mask.shape();
out_mask_strides = out_mask.strides();
}
encoder.set_output_array(out);
auto a_ptr = a.data<float>();
auto b_ptr = b.data<float>();
auto out_ptr = out.data<float>();
size_t num_matrices = out.size() / (M * size_t(N));
auto ldc = out.shape(-1);
// Zero out blocks in a and b if needed
if (has_op_mask) {
auto& a_mask = inputs[inputs.size() - 2];
mask_array(
a_mask,
ai,
block_size_,
i,
encoder.dispatch([a_ptr,
b_ptr,
out_ptr,
a_mask_ptr,
b_mask_ptr,
out_mask_ptr,
has_op_mask,
has_out_mask,
block_size = block_size_,
num_matrices,
M,
N,
K,
a_transposed = a_transposed,
b_transposed = b_transposed,
lda = lda,
ldb = ldb,
ldc,
a_shape = a.shape(),
a_strides = a.strides(),
b_shape = b.shape(),
b_strides = b.strides(),
a_mask_shape = std::move(a_mask_shape),
b_mask_shape = std::move(b_mask_shape),
out_mask_shape = std::move(out_mask_shape),
a_mask_strides = std::move(a_mask_strides),
b_mask_strides = std::move(b_mask_strides),
out_mask_strides = std::move(out_mask_strides),
mask_array,
a_mask_bool,
b_mask_bool,
out_mask_bool]() {
for (int i = 0; i < num_matrices; ++i) {
// Adjust pointer
float* ai = a_ptr + elem_to_loc(M * K * i, a_shape, a_strides);
float* bi = b_ptr + elem_to_loc(K * N * i, b_shape, b_strides);
float* ci = out_ptr + M * N * i;
// Zero out blocks in a and b if needed
if (has_op_mask) {
mask_array(
a_mask_ptr,
ai,
block_size,
i,
M,
K,
a_transposed ? 1 : lda,
a_transposed ? lda : 1,
a_mask_shape,
a_mask_strides,
a_mask_bool);
mask_array(
b_mask_ptr,
bi,
block_size,
i,
K,
N,
b_transposed ? 1 : ldb,
b_transposed ? ldb : 1,
b_mask_shape,
b_mask_strides,
b_mask_bool);
}
// Do matmul
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
K,
a_transposed ? 1 : lda,
a_transposed ? lda : 1);
auto& b_mask = inputs[inputs.size() - 1];
mask_array(
b_mask,
bi,
block_size_,
i,
K,
N,
b_transposed ? 1 : ldb,
b_transposed ? ldb : 1);
}
K,
1.0, // alpha
ai,
lda,
bi,
ldb,
0.0, // beta
ci,
ldc);
// Do matmul
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0, // alpha
ai,
lda,
bi,
ldb,
0.0, // beta
ci,
out.shape(-1) // ldc
);
// Zero out blocks in out
if (has_out_mask) {
mask_array(inputs[2], ci, block_size_, i, M, N, N, 1);
// Zero out blocks in out
if (has_out_mask) {
mask_array(
out_mask_ptr,
ci,
block_size,
i,
M,
N,
N,
1,
out_mask_shape,
out_mask_strides,
out_mask_bool);
}
}
});
if (a_copied) {
encoder.add_temporary(a);
}
if (b_copied) {
encoder.add_temporary(b);
}
}
@@ -220,7 +323,8 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto check_transpose = [](const array& arr) {
std::vector<array> temps;
auto check_transpose = [s = stream(), &temps](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
@@ -228,10 +332,10 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, temps.back(), CopyType::General, s);
int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
return std::make_tuple(false, stx, temps.back());
}
};
@@ -246,8 +350,12 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return;
}
auto& encoder = cpu::get_command_encoder(stream());
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<float>(), size = out.size()]() {
std::fill_n(out_ptr, size, 0);
});
return;
}
@@ -272,29 +380,61 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
auto ldc = out.shape(-1);
for (int i = 0; i < batch_size_out; i++) {
// Get index
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(i, lhs_indices)];
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(i, rhs_indices)];
encoder.dispatch([a_ptr = a.data<float>(),
b_ptr = b.data<float>(),
out_ptr = out.data<float>(),
M,
N,
K,
lda = lda,
ldb = ldb,
a_transposed = a_transposed,
b_transposed = b_transposed,
ldc,
lhs_indices_ptr,
rhs_indices_ptr,
lhs_indices_shape = lhs_indices.shape(),
lhs_indices_strides = lhs_indices.strides(),
rhs_indices_shape = rhs_indices.shape(),
rhs_indices_strides = rhs_indices.strides(),
batch_size_out,
matrix_stride_out,
batch_shape_A = std::move(batch_shape_A),
batch_shape_B = std::move(batch_shape_B),
batch_strides_A = std::move(batch_strides_A),
batch_strides_B = std::move(batch_strides_B)]() {
for (int i = 0; i < batch_size_out; i++) {
// Get index
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(
i, lhs_indices_shape, lhs_indices_strides)];
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(
i, rhs_indices_shape, rhs_indices_strides)];
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0f, // alpha
a.data<float>() + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
lda,
b.data<float>() + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
ldb,
0.0f, // beta
out.data<float>() + matrix_stride_out * i,
out.shape(-1) // ldc
);
}
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0f, // alpha
a_ptr + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
lda,
b_ptr + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
ldb,
0.0f, // beta
out_ptr + matrix_stride_out * i,
ldc);
}
});
encoder.add_temporaries(std::move(temps));
}
} // namespace mlx::core

View File

@@ -3,18 +3,76 @@
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/primitives.h"
namespace mlx::core {
template <typename T>
void matmul_dispatch(
const array& a,
const array& b,
array& out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
float alpha,
float beta,
Stream stream) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
T* out_ptr = out.data<T>();
size_t ldc = out.shape(-1);
size_t batch_size = a.size() / (a.shape(-2) * a.shape(-1));
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a_ptr,
b_ptr,
out_ptr,
a_transposed,
b_transposed,
lda,
ldb,
ldc,
alpha,
beta,
batch_size,
a_shape = a.shape(),
a_strides = a.strides(),
b_shape = b.shape(),
b_strides = b.strides()]() {
matmul<T>(
a_ptr,
b_ptr,
out_ptr,
a_transposed,
b_transposed,
lda,
ldb,
ldc,
alpha,
beta,
batch_size,
a_shape,
a_strides,
b_shape,
b_strides);
});
}
void matmul_general(
const array& a_pre,
const array& b_pre,
array& out,
Stream stream,
float alpha = 1.0f,
float beta = 0.0f) {
auto check_transpose = [](const array& arr) {
std::vector<array> temps;
auto check_transpose = [stream, &temps](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
@@ -22,10 +80,10 @@ void matmul_general(
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, temps.back(), CopyType::General, stream);
stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
return std::make_tuple(false, stx, temps.back());
}
};
@@ -39,28 +97,34 @@ void matmul_general(
}
if (out.dtype() == float32) {
matmul<float>(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
matmul_dispatch<float>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == float16) {
matmul<float16_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
matmul_dispatch<float16_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == bfloat16) {
matmul<bfloat16_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
matmul_dispatch<bfloat16_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == float64) {
matmul<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
matmul_dispatch<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else {
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
}
cpu::get_command_encoder(stream).add_temporaries(std::move(temps));
}
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
if (inputs[0].shape(-1) == 0) {
std::memset(out.data<void>(), 0, out.nbytes());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<void>(), nbytes = out.nbytes()]() {
std::memset(out_ptr, 0, nbytes);
});
return;
}
return matmul_general(inputs[0], inputs[1], out);
matmul_general(inputs[0], inputs[1], out, stream());
}
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -74,9 +138,9 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
CopyType ctype = c.data_size() == 1
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy(c, out, ctype);
copy(c, out, ctype, stream());
return matmul_general(inputs[0], inputs[1], out, alpha_, beta_);
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
}
} // namespace mlx::core

View File

@@ -7,11 +7,11 @@
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/backend/common/load.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/arange.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/threefry.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -22,39 +22,58 @@ void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_inplace(in, out, CopyType::General);
copy_inplace(in, out, CopyType::General, out.primitive().stream());
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
int64_t compute_dynamic_offset(
static std::pair<array, bool> compute_dynamic_offset(
const array& indices,
const Strides& strides,
const std::vector<int>& axes) {
auto compute_offset = [&strides, &axes](const auto* indices) {
int64_t offset = 0;
for (int i = 0; i < axes.size(); ++i) {
offset += indices[i] * strides[axes[i]];
}
return offset;
};
const std::vector<int>& axes,
Stream stream) {
array offset({1}, int64, nullptr, {});
bool donate = indices.is_donatable() &&
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
if (donate) {
offset.copy_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
}
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(indices);
encoder.set_output_array(offset);
auto compute_offset =
[strides, axes, offset = offset.data<int64_t>()](const auto* indices) {
int64_t offset_ = 0;
for (int i = 0; i < axes.size(); ++i) {
offset_ += indices[i] * strides[axes[i]];
}
offset[0] = offset_;
};
switch (indices.dtype()) {
case int8:
case uint8:
return compute_offset(indices.data<uint8_t>());
encoder.dispatch(compute_offset, indices.data<uint8_t>());
break;
case int16:
case uint16:
return compute_offset(indices.data<uint16_t>());
encoder.dispatch(compute_offset, indices.data<uint16_t>());
break;
case int32:
case uint32:
return compute_offset(indices.data<uint32_t>());
encoder.dispatch(compute_offset, indices.data<uint32_t>());
break;
case int64:
case uint64:
return compute_offset(indices.data<uint64_t>());
encoder.dispatch(compute_offset, indices.data<uint64_t>());
break;
default:
throw std::runtime_error("Invalid indices type.");
}
return {offset, donate};
}
void AsStrided::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -104,14 +123,59 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
}
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
arange(inputs, out, start_, step_);
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
switch (out.dtype()) {
case bool_:
throw std::runtime_error("Bool type unsupported for arange.");
break;
case uint8:
arange<uint8_t>(start_, start_ + step_, out, out.size(), stream());
break;
case uint16:
arange<uint16_t>(start_, start_ + step_, out, out.size(), stream());
break;
case uint32:
arange<uint32_t>(start_, start_ + step_, out, out.size(), stream());
break;
case uint64:
arange<uint64_t>(start_, start_ + step_, out, out.size(), stream());
break;
case int8:
arange<int8_t>(start_, start_ + step_, out, out.size(), stream());
break;
case int16:
arange<int16_t>(start_, start_ + step_, out, out.size(), stream());
break;
case int32:
arange<int32_t>(start_, start_ + step_, out, out.size(), stream());
break;
case int64:
arange<int64_t>(start_, start_ + step_, out, out.size(), stream());
break;
case float16:
arange<float16_t>(start_, start_ + step_, out, out.size(), stream());
break;
case float32:
arange<float>(start_, start_ + step_, out, out.size(), stream());
break;
case float64:
arange<double>(start_, start_ + step_, out, out.size(), stream());
break;
case bfloat16:
arange<bfloat16_t>(start_, start_ + step_, out, out.size(), stream());
break;
case complex64:
arange<complex64_t>(start_, start_ + step_, out, out.size(), stream());
break;
}
}
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
copy(in, out, ctype, stream());
}
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -134,7 +198,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral);
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
}
}
@@ -145,7 +209,7 @@ void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
(allow_col_major_ && in.flags().col_contiguous)) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General);
copy(in, out, CopyType::General, stream());
}
}
@@ -169,14 +233,7 @@ void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
} else {
ctype = CopyType::General;
}
copy(in, out, ctype);
}
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
load(out, offset_, reader_, swap_endianness_);
copy(in, out, ctype, stream());
}
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -192,7 +249,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
// Fill output with val
copy(val, out, CopyType::Scalar);
copy(val, out, CopyType::Scalar, stream());
// Find offset for start of input values
size_t data_offset = 0;
@@ -207,7 +264,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_inplace(in, out_slice, CopyType::GeneralGeneral);
copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
}
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -223,39 +280,49 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
auto kptr = inputs[0].data<uint32_t>();
auto cptr = out.data<char>();
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0;
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
auto ptr = reinterpret_cast<uint32_t*>(cptr);
// Get ith key
auto kidx = 2 * i;
auto k1_elem = elem_to_loc(kidx, keys.shape(), keys.strides());
auto k2_elem = elem_to_loc(kidx + 1, keys.shape(), keys.strides());
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(inputs[0]);
encoder.set_output_array(out);
encoder.dispatch([kptr,
cptr,
bytes_per_key,
num_keys,
kshape = keys.shape(),
kstrides = keys.strides()]() mutable {
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0;
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
auto ptr = reinterpret_cast<uint32_t*>(cptr);
// Get ith key
auto kidx = 2 * i;
auto k1_elem = elem_to_loc(kidx, kshape, kstrides);
auto k2_elem = elem_to_loc(kidx + 1, kshape, kstrides);
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
for (; count.first + 1 < half_size; count.first++, count.second++) {
std::tie(ptr[count.first], ptr[count.second]) =
random::threefry2x32_hash(key, count);
}
if (count.first < half_size) {
auto rb = random::threefry2x32_hash(key, count);
ptr[count.first++] = rb.first;
if (bytes_per_key % 4 > 0) {
std::copy(
reinterpret_cast<char*>(&rb.second),
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
cptr + 4 * count.second);
} else {
ptr[count.second] = rb.second;
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
for (; count.first + 1 < half_size; count.first++, count.second++) {
std::tie(ptr[count.first], ptr[count.second]) =
random::threefry2x32_hash(key, count);
}
if (count.first < half_size) {
auto rb = random::threefry2x32_hash(key, count);
ptr[count.first++] = rb.first;
if (bytes_per_key % 4 > 0) {
std::copy(
reinterpret_cast<char*>(&rb.second),
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
cptr + 4 * count.second);
} else {
ptr[count.second] = rb.second;
}
}
if (!even) {
count.second = 0;
ptr[half_size] = random::threefry2x32_hash(key, count).first;
}
}
if (!even) {
count.second = 0;
ptr[half_size] = random::threefry2x32_hash(key, count).first;
}
}
});
}
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -269,16 +336,23 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
}
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto i_offset = compute_dynamic_offset(inputs[1], in.strides(), axes_);
auto [in_offset, donated] =
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
copy_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(),
/* const Strides& i_strides = */ in.strides(),
/* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ i_offset,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral);
/* CopyType ctype = */ CopyType::GeneralGeneral,
stream(),
/* const std::optional<array>& dynamic_i_offset = */ in_offset,
/* const std::optional<array>& dynamic_o_offset = */ std::nullopt);
if (!donated) {
cpu::get_command_encoder(stream()).add_temporary(std::move(in_offset));
}
}
void DynamicSliceUpdate::eval_cpu(
@@ -296,9 +370,10 @@ void DynamicSliceUpdate::eval_cpu(
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
auto o_offset = compute_dynamic_offset(inputs[2], out.strides(), axes_);
auto [out_offset, donated] =
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
copy_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
@@ -306,8 +381,14 @@ void DynamicSliceUpdate::eval_cpu(
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
/* const std::vector<stride_t>& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ o_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral);
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral,
stream(),
/* const std::optional<array>& dynamic_i_offset = */ std::nullopt,
/* const std::optional<array>& dynamic_o_offset = */ out_offset);
if (!donated) {
cpu::get_command_encoder(stream()).add_temporary(std::move(out_offset));
}
}
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -329,7 +410,7 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] =
@@ -344,7 +425,8 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
/* const std::vector<stride_t>& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral);
/* CopyType ctype = */ CopyType::GeneralGeneral,
stream());
}
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -372,9 +454,9 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in);
copy_inplace(in_tmp, tmp, CopyType::General);
copy_inplace(in_tmp, tmp, CopyType::General, stream());
} else {
copy_inplace(in, tmp, CopyType::General);
copy_inplace(in, tmp, CopyType::General, stream());
}
auto flags = out.flags();
@@ -382,7 +464,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
}
}

View File

@@ -2,20 +2,18 @@
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {
template <typename T>
void qrf_impl(const array& a, array& q, array& r) {
void qrf_impl(const array& a, array& q, array& r, Stream stream) {
const int M = a.shape(-2);
const int N = a.shape(-1);
const int lda = M;
size_t num_matrices = a.size() / (M * N);
int num_reflectors = std::min(M, N);
auto tau =
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
// Copy A to inplace input and make it col-contiguous
array in(a.shape(), a.dtype(), nullptr, {});
@@ -29,93 +27,107 @@ void qrf_impl(const array& a, array& q, array& r) {
strides[in.ndim() - 1] = M;
in.set_data(
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
copy_inplace(a, in, CopyType::GeneralGeneral);
T optimal_work;
int lwork = -1;
int info;
// Compute workspace size
geqrf<T>(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
// Update workspace size
lwork = optimal_work;
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Solve
geqrf<T>(
&M,
&N,
in.data<T>() + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
allocator::free(work);
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
auto& encoder = cpu::get_command_encoder(stream);
q.set_data(allocator::malloc_or_wait(q.nbytes()));
r.set_data(allocator::malloc_or_wait(r.nbytes()));
for (int i = 0; i < num_matrices; ++i) {
/// num_reflectors x N
for (int j = 0; j < r.shape(-2); ++j) {
for (int k = 0; k < j; ++k) {
r.data<T>()[i * N * num_reflectors + j * N + k] = 0;
}
for (int k = j; k < r.shape(-1); ++k) {
r.data<T>()[i * N * num_reflectors + j * N + k] =
in.data<T>()[i * N * M + j + k * M];
auto in_ptr = in.data<T>();
auto r_ptr = r.data<T>();
auto q_ptr = q.data<T>();
encoder.set_input_array(in);
encoder.set_output_array(q);
encoder.set_output_array(r);
encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {
int num_reflectors = std::min(M, N);
auto tau =
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
T optimal_work;
int lwork = -1;
int info;
// Compute workspace size
geqrf<T>(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
// Update workspace size
lwork = optimal_work;
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Solve
geqrf<T>(
&M,
&N,
in_ptr + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
allocator::free(work);
for (int i = 0; i < num_matrices; ++i) {
/// num_reflectors x N
for (int j = 0; j < num_reflectors; ++j) {
for (int k = 0; k < j; ++k) {
r_ptr[i * N * num_reflectors + j * N + k] = 0;
}
for (int k = j; k < N; ++k) {
r_ptr[i * N * num_reflectors + j * N + k] =
in_ptr[i * N * M + j + k * M];
}
}
}
}
// Get work size
lwork = -1;
orgqr<T>(
&M,
&num_reflectors,
&num_reflectors,
nullptr,
&lda,
nullptr,
&optimal_work,
&lwork,
&info);
lwork = optimal_work;
work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Compute Q
// Get work size
lwork = -1;
orgqr<T>(
&M,
&num_reflectors,
&num_reflectors,
in.data<T>() + M * N * i,
nullptr,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
nullptr,
&optimal_work,
&lwork,
&info);
}
lwork = optimal_work;
work = allocator::malloc_or_wait(sizeof(T) * lwork);
q.set_data(allocator::malloc_or_wait(q.nbytes()));
for (int i = 0; i < num_matrices; ++i) {
// M x num_reflectors
for (int j = 0; j < q.shape(-2); ++j) {
for (int k = 0; k < q.shape(-1); ++k) {
q.data<T>()[i * M * num_reflectors + j * num_reflectors + k] =
in.data<T>()[i * N * M + j + k * M];
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Compute Q
orgqr<T>(
&M,
&num_reflectors,
&num_reflectors,
in_ptr + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
for (int i = 0; i < num_matrices; ++i) {
// M x num_reflectors
for (int j = 0; j < M; ++j) {
for (int k = 0; k < num_reflectors; ++k) {
q_ptr[i * M * num_reflectors + j * num_reflectors + k] =
in_ptr[i * N * M + j + k * M];
}
}
}
}
// Cleanup
allocator::free(work);
allocator::free(tau);
// Cleanup
allocator::free(work);
allocator::free(tau);
});
encoder.add_temporary(in);
}
void QRF::eval_cpu(
@@ -123,10 +135,10 @@ void QRF::eval_cpu(
std::vector<array>& outputs) {
switch (inputs[0].dtype()) {
case float32:
qrf_impl<float>(inputs[0], outputs[0], outputs[1]);
qrf_impl<float>(inputs[0], outputs[0], outputs[1], stream());
break;
case float64:
qrf_impl<double>(inputs[0], outputs[0], outputs[1]);
qrf_impl<double>(inputs[0], outputs[0], outputs[1], stream());
break;
default:
throw std::runtime_error(

View File

@@ -3,6 +3,7 @@
#include <cassert>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
@@ -316,6 +317,76 @@ void _qmm_dispatch_typed(
}
}
template <typename T>
void _qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
const array& biases,
int bits,
int group_size,
bool transposed_w,
Stream stream) {
int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>();
encoder.dispatch([out_ptr,
x_ptr,
w_ptr,
scales_ptr,
biases_ptr,
x_shape = x.shape(),
x_strides = x.strides(),
w_shape = w.shape(),
w_strides = w.strides(),
scales_shape = scales.shape(),
scales_strides = scales.strides(),
biases_shape = biases.shape(),
biases_strides = biases.strides(),
w_els,
g_els,
batch_size,
M,
N,
K,
bits,
group_size,
transposed_w] {
for (int i = 0; i < batch_size; i++) {
_qmm_dispatch_typed<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x_shape, x_strides),
w_ptr + elem_to_loc(i * w_els, w_shape, w_strides),
scales_ptr + elem_to_loc(i * g_els, scales_shape, scales_strides),
biases_ptr + elem_to_loc(i * g_els, biases_shape, biases_strides),
M,
N,
K,
bits,
group_size,
transposed_w);
}
});
}
void _qmm_dispatch(
array& out,
const array& x,
@@ -324,64 +395,111 @@ void _qmm_dispatch(
const array& biases,
int bits,
int group_size,
bool transposed_w) {
bool transposed_w,
Stream stream) {
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
template <typename T>
void _bs_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
const array& biases,
const array& lhs_indices,
const array& rhs_indices,
int bits,
int group_size,
bool transposed_w,
Stream stream) {
int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
int batch_size = x.size() / (K * M);
for (int i = 0; i < batch_size; i++) {
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>() + i * M * N,
x.data<float>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<float>() + elem_to_loc(i * g_els, scales),
biases.data<float>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>() + i * M * N,
x.data<float16_t>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<float16_t>() + elem_to_loc(i * g_els, scales),
biases.data<float16_t>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>() + i * M * N,
x.data<bfloat16_t>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<bfloat16_t>() + elem_to_loc(i * g_els, scales),
biases.data<bfloat16_t>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>();
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
encoder.dispatch([out_ptr,
x_ptr,
w_ptr,
scales_ptr,
biases_ptr,
lhs_indices_ptr,
rhs_indices_ptr,
x_shape = x.shape(),
x_strides = x.strides(),
w_shape = w.shape(),
w_strides = w.strides(),
scales_shape = scales.shape(),
scales_strides = scales.strides(),
biases_shape = biases.shape(),
biases_strides = biases.strides(),
lhs_indices_shape = lhs_indices.shape(),
lhs_indices_strides = lhs_indices.strides(),
rhs_indices_shape = rhs_indices.shape(),
rhs_indices_strides = rhs_indices.strides(),
w_els,
g_els,
indices_size = lhs_indices.size(),
M,
N,
K,
bits,
group_size,
transposed_w]() {
for (int i = 0; i < indices_size; i++) {
int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices_shape, lhs_indices_strides)];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices_shape, rhs_indices_strides)];
_qmm_dispatch_typed<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x_shape, x_strides),
w_ptr + elem_to_loc(w_idx * w_els, w_shape, w_strides),
scales_ptr + elem_to_loc(w_idx * g_els, scales_shape, scales_strides),
biases_ptr + elem_to_loc(w_idx * g_els, biases_shape, biases_strides),
M,
N,
K,
bits,
group_size,
transposed_w);
}
}
});
}
void _bs_qmm_dispatch(
@@ -394,68 +512,54 @@ void _bs_qmm_dispatch(
const array& rhs_indices,
int bits,
int group_size,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
const uint32_t* lhs_indices_data = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_data = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)];
int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)];
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>() + i * M * N,
x.data<float>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<float>() + elem_to_loc(w_idx * g_els, scales),
biases.data<float>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>() + i * M * N,
x.data<float16_t>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<float16_t>() + elem_to_loc(w_idx * g_els, scales),
biases.data<float16_t>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>() + i * M * N,
x.data<bfloat16_t>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, scales),
biases.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
bool transposed_w,
Stream stream) {
switch (x.dtype()) {
case float32:
_bs_qmm_dispatch_typed<float>(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
bits,
group_size,
transposed_w,
stream);
break;
case float16:
_bs_qmm_dispatch_typed<float16_t>(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
bits,
group_size,
transposed_w,
stream);
break;
case bfloat16:
_bs_qmm_dispatch_typed<bfloat16_t>(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
bits,
group_size,
transposed_w,
stream);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
@@ -469,13 +573,14 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto ensure_row_contiguous = [](const array& arr) {
std::vector<array> temps;
auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, temps.back(), CopyType::General, s);
return temps.back();
}
};
@@ -485,7 +590,10 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
_qmm_dispatch(
out, x, w, scales, biases, group_size_, bits_, transpose_, stream());
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporaries(std::move(temps));
}
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -498,15 +606,17 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
auto ensure_row_contiguous_last_dims = [](const array& arr) {
std::vector<array> temps;
auto ensure_row_contiguous_last_dims = [s = stream(),
&temps](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, temps.back(), CopyType::General, s);
return temps.back();
}
};
@@ -526,31 +636,30 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
rhs_indices,
group_size_,
bits_,
transpose_);
transpose_,
stream());
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporaries(std::move(temps));
}
template <typename T, typename U>
void quantize(
const array& w_,
array& out_,
array& scales_,
array& biases_,
const T* w,
U* out,
T* scales,
T* biases,
int bits,
int group_size) {
const T* w = w_.data<T>();
auto out = out_.data<U>();
T* scales = scales_.data<T>();
T* biases = biases_.data<T>();
int group_size,
size_t w_size) {
float n_bins = (1 << bits) - 1;
float eps = 1e-7;
bool power_of_2_bits = is_power_of_2(bits);
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
int bytes_per_pack = power_of_2_bits ? 1 : 3;
int int_per_group = group_size * bytes_per_pack / el_per_int;
size_t n_groups = w_.size() / group_size;
size_t n_groups = w_size / group_size;
for (size_t i = 0; i < n_groups; ++i) {
size_t w_idx = i * group_size;
@@ -593,20 +702,50 @@ void quantize(
}
}
template <typename T, typename U>
void dispatch_quantize(
const array& w,
array& out,
array& scales,
array& biases,
int bits,
int group_size,
Stream stream) {
auto w_ptr = w.data<T>();
auto out_ptr = out.data<U>();
auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([w_ptr,
out_ptr,
scales_ptr,
biases_ptr,
bits,
group_size,
w_size = w.size()]() {
quantize<T, U>(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w_size);
});
}
void fast::AffineQuantize::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto ensure_row_contiguous = [](const array& arr) {
auto ensure_row_contiguous = [s = stream()](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
return std::make_pair(arr, false);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
copy(arr, arr_copy, CopyType::General, s);
return std::make_pair(arr_copy, true);
}
};
auto w = ensure_row_contiguous(inputs[0]);
auto [w, copied] = ensure_row_contiguous(inputs[0]);
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -616,27 +755,35 @@ void fast::AffineQuantize::eval_cpu(
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
if (w.dtype() == float16) {
if (is_power_of_2(bits_)) {
quantize<float16_t, uint32_t>(w, out, scales, biases, bits_, group_size_);
dispatch_quantize<float16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream());
} else {
quantize<float16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
dispatch_quantize<float16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream());
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
dispatch_quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream());
} else {
quantize<bfloat16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
dispatch_quantize<bfloat16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream());
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
quantize<float, uint32_t>(w, out, scales, biases, bits_, group_size_);
dispatch_quantize<float, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream());
} else {
quantize<float, uint8_t>(w, out, scales, biases, bits_, group_size_);
dispatch_quantize<float, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream());
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
if (copied) {
cpu::get_command_encoder(stream()).add_temporary(w);
}
}
} // namespace mlx::core

View File

@@ -5,6 +5,7 @@
#include <limits>
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
@@ -140,25 +141,33 @@ void reduction_op(
array& out,
const std::vector<int>& axes,
U init,
Op op) {
Stream stream) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
ReductionPlan plan = get_reduction_plan(x, axes);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_output_array(out);
auto in_ptr = x.data<T>();
auto out_ptr = out.data<U>();
if (plan.type == ContiguousAllReduce) {
U* out_ptr = out.data<U>();
*out_ptr = init;
contiguous_reduce(x.data<T>(), out_ptr, x.size(), op, init);
encoder.dispatch([in_ptr, out_ptr, init, size = x.size()]() {
*out_ptr = init;
contiguous_reduce(in_ptr, out_ptr, size, Op{}, init);
});
return;
}
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0];
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
*out_ptr = init;
contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init);
}
encoder.dispatch(
[in_ptr, out_ptr, init, reduction_size, size = out.size()]() mutable {
for (int i = 0; i < size; i++, out_ptr++, in_ptr += reduction_size) {
*out_ptr = init;
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
}
});
return;
}
@@ -166,34 +175,43 @@ void reduction_op(
int reduction_size = plan.shape.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
// Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost.
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init);
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
if (plan.shape.size() == 0) {
for (int i = 0; i < size; i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
contiguous_reduce(
in_ptr + offset, out_ptr, reduction_size, Op{}, init);
}
} else {
for (int i = 0; i < size; i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
nd_loop(
[&](int extra_offset) {
contiguous_reduce(
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
Op{},
init);
},
plan.shape,
plan.strides);
}
}
} else {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
nd_loop(
[&](int extra_offset) {
contiguous_reduce(
x_ptr + offset + extra_offset,
out_ptr,
reduction_size,
op,
init);
},
plan.shape,
plan.strides);
}
}
});
return;
}
@@ -202,14 +220,20 @@ void reduction_op(
size_t reduction_stride = plan.strides.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op);
x_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride;
}
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
reduction_stride,
size = out.size()]() mutable {
for (int i = 0; i < size; i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
in_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride;
}
});
return;
}
@@ -219,53 +243,69 @@ void reduction_op(
size_t reduction_stride = plan.strides.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(
x_ptr + offset, out_ptr, reduction_size, reduction_stride, op);
out_ptr += reduction_stride;
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
reduction_stride,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
if (plan.shape.size() == 0) {
for (int i = 0; i < size; i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(
in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
out_ptr += reduction_stride;
}
} else {
for (int i = 0; i < size; i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
nd_loop(
[&](int extra_offset) {
strided_reduce(
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
reduction_stride,
Op{});
},
plan.shape,
plan.strides);
out_ptr += reduction_stride;
}
}
} else {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
nd_loop(
[&](int extra_offset) {
strided_reduce(
x_ptr + offset + extra_offset,
out_ptr,
reduction_size,
reduction_stride,
op);
},
plan.shape,
plan.strides);
out_ptr += reduction_stride;
}
}
});
return;
}
if (plan.type == GeneralReduce) {
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;
nd_loop(
[&](int extra_offset) {
val = op(val, *(x_ptr + offset + extra_offset));
},
plan.shape,
plan.strides);
*out_ptr = val;
}
encoder.dispatch([in_ptr,
out_ptr,
init,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
for (int i = 0; i < size; i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;
nd_loop(
[&](int extra_offset) {
val = Op{}(val, *(in_ptr + offset + extra_offset));
},
plan.shape,
plan.strides);
*out_ptr = val;
}
});
}
}
@@ -394,11 +434,12 @@ void reduce_dispatch_and_or(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
const std::vector<int>& axes,
Stream stream) {
if (rtype == Reduce::And) {
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
reduction_op<InT, bool, AndReduce>(in, out, axes, true, stream);
} else {
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
reduction_op<InT, bool, OrReduce>(in, out, axes, false, stream);
}
}
@@ -407,18 +448,19 @@ void reduce_dispatch_sum_prod(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
const std::vector<int>& axes,
Stream stream) {
if (rtype == Reduce::Sum) {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 0, SumReduce());
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0, stream);
} else {
reduction_op<InT, InT>(in, out, axes, 0, SumReduce());
reduction_op<InT, InT, SumReduce>(in, out, axes, 0, stream);
}
} else {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 1, ProdReduce());
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1, stream);
} else {
reduction_op<InT, InT>(in, out, axes, 1, ProdReduce());
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1, stream);
}
}
}
@@ -428,13 +470,14 @@ void reduce_dispatch_min_max(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
const std::vector<int>& axes,
Stream stream) {
if (rtype == Reduce::Max) {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
reduction_op<InT, InT, MaxReduce>(in, out, axes, init, stream);
} else {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
reduction_op<InT, InT, MinReduce>(in, out, axes, init, stream);
}
}
@@ -448,24 +491,28 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
reduce_dispatch_and_or<int8_t>(
in, out, reduce_type_, axes_, stream());
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_and_or<int16_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
reduce_dispatch_and_or<int32_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint64:
case int64:
case float64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
reduce_dispatch_and_or<int64_t>(
in, out, reduce_type_, axes_, stream());
break;
}
break;
@@ -476,34 +523,43 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<int8_t>(
in, out, reduce_type_, axes_, stream());
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<int16_t>(
in, out, reduce_type_, axes_, stream());
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<int32_t>(
in, out, reduce_type_, axes_, stream());
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<int64_t>(
in, out, reduce_type_, axes_, stream());
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<float16_t>(
in, out, reduce_type_, axes_, stream());
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<bfloat16_t>(
in, out, reduce_type_, axes_, stream());
break;
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<float>(
in, out, reduce_type_, axes_, stream());
break;
case float64:
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<double>(
in, out, reduce_type_, axes_, stream());
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<complex64_t>(
in, out, reduce_type_, axes_, stream());
break;
}
break;
@@ -512,46 +568,59 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case Reduce::Min: {
switch (in.dtype()) {
case bool_:
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_, stream());
break;
case uint8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint8_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint16_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint32:
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint32_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint64:
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint64_t>(
in, out, reduce_type_, axes_, stream());
break;
case int8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint8_t>(
in, out, reduce_type_, axes_, stream());
break;
case int16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint16_t>(
in, out, reduce_type_, axes_, stream());
break;
case int32:
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<int32_t>(
in, out, reduce_type_, axes_, stream());
break;
case int64:
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<int64_t>(
in, out, reduce_type_, axes_, stream());
break;
case float16:
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<float16_t>(
in, out, reduce_type_, axes_, stream());
break;
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<float>(
in, out, reduce_type_, axes_, stream());
break;
case float64:
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<double>(
in, out, reduce_type_, axes_, stream());
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<bfloat16_t>(
in, out, reduce_type_, axes_, stream());
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<complex64_t>(
in, out, reduce_type_, axes_, stream());
break;
}
break;

View File

@@ -4,6 +4,7 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
@@ -153,37 +154,44 @@ void strided_scan(
template <typename T, typename U, typename Op>
void scan_op(
const array& input,
array& output,
const array& in,
array& out,
int axis,
bool reverse,
bool inclusive,
const Op& op,
U init) {
output.set_data(allocator::malloc_or_wait(output.nbytes()));
U init,
Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (input.flags().row_contiguous) {
if (input.strides()[axis] == 1) {
contiguous_scan(
input.data<T>(),
output.data<U>(),
input.size() / input.shape(axis),
input.shape(axis),
reverse,
inclusive,
op,
init);
if (in.flags().row_contiguous) {
if (in.strides()[axis] == 1) {
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<U>(),
count = in.size() / in.shape(axis),
stride = in.shape(axis),
reverse,
inclusive,
op = std::move(op),
init]() {
contiguous_scan(
in_ptr, out_ptr, count, stride, reverse, inclusive, op, init);
});
} else {
strided_scan(
input.data<T>(),
output.data<U>(),
input.size() / input.shape(axis) / input.strides()[axis],
input.shape(axis),
input.strides()[axis],
reverse,
inclusive,
op,
init);
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<U>(),
count = in.size() / in.shape(axis) / in.strides()[axis],
size = in.shape(axis),
stride = in.strides()[axis],
reverse,
inclusive,
op = std::move(op),
init]() {
strided_scan(
in_ptr, out_ptr, count, size, stride, reverse, inclusive, op, init);
});
}
} else {
throw std::runtime_error("Scan op supports only contiguous inputs");
@@ -193,38 +201,39 @@ void scan_op(
template <typename T, typename U>
void scan_dispatch(
Scan::ReduceType rtype,
const array& input,
array& output,
const array& in,
array& out,
int axis,
bool reverse,
bool inclusive) {
bool inclusive,
Stream stream) {
switch (rtype) {
case Scan::Sum: {
auto op = [](U y, T x) { return y + x; };
auto init = static_cast<U>(0);
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
break;
}
case Scan::Prod: {
auto op = [](U y, T x) { return y * x; };
auto init = static_cast<U>(1);
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
break;
}
case Scan::Min: {
auto op = [](U y, T x) { return x < y ? x : y; };
auto init = (issubdtype(input.dtype(), floating))
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
break;
}
case Scan::Max: {
auto op = [](U y, T x) { return x < y ? y : x; };
auto init = (issubdtype(input.dtype(), floating))
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
break;
}
}
@@ -237,11 +246,14 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
// Ensure contiguity
auto in = inputs[0];
bool copied = false;
if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General);
copy(in, arr_copy, CopyType::General, stream());
in = arr_copy;
copied = true;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
switch (in.dtype()) {
case bool_: {
@@ -252,65 +264,68 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
// floats perhaps we should add the full all-to-all dispatch.
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
scan_dispatch<bool, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
} else {
scan_dispatch<bool, bool>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
}
break;
}
case uint8:
scan_dispatch<uint8_t, uint8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case uint16:
scan_dispatch<uint16_t, uint16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case uint32:
scan_dispatch<uint32_t, uint32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case uint64:
scan_dispatch<uint64_t, uint64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case int8:
scan_dispatch<int8_t, int8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case int16:
scan_dispatch<int16_t, int16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case int32:
scan_dispatch<int32_t, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case int64:
scan_dispatch<int64_t, int64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case float16:
scan_dispatch<float16_t, float16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case float32:
scan_dispatch<float, float>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case float64:
scan_dispatch<double, double>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case bfloat16:
scan_dispatch<bfloat16_t, bfloat16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case complex64:
throw std::runtime_error("Scan ops do not support complex types yet");
break;
}
if (copied) {
cpu::get_command_encoder(stream()).add_temporary(std::move(in));
}
}
} // namespace mlx::core

View File

@@ -4,6 +4,7 @@
#include <cmath>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/types/limits.h"
@@ -15,92 +16,100 @@ namespace {
using namespace mlx::core::simd;
template <typename T, typename AccT>
void softmax(const array& in, array& out) {
constexpr bool same_t = std::is_same_v<T, AccT>;
constexpr int N = std::min(max_size<AccT>, max_size<T>);
void softmax(const array& in, array& out, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
const T* current_in_ptr;
T* current_out_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
constexpr bool same_t = std::is_same_v<T, AccT>;
constexpr int N = std::min(max_size<AccT>, max_size<T>);
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
const T* current_in_ptr;
T* current_out_ptr;
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
if constexpr (same_t) {
store(current_out_ptr, vexp);
}
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
if constexpr (same_t) {
*current_out_ptr = _exp;
}
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
}
normalizer = 1 / normalizer;
// Normalize
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
if constexpr (same_t) {
store(
current_out_ptr,
Simd<T, N>(load<T, N>(current_out_ptr) * normalizer));
} else {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum) * normalizer;
store(current_out_ptr, Simd<T, N>(vexp));
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
if constexpr (same_t) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
current_out_ptr++;
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
if constexpr (same_t) {
store(current_out_ptr, vexp);
}
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
if constexpr (same_t) {
*current_out_ptr = _exp;
}
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
}
normalizer = 1 / normalizer;
// Normalize
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
if constexpr (same_t) {
store(
current_out_ptr,
Simd<T, N>(load<T, N>(current_out_ptr) * normalizer));
} else {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum) * normalizer;
store(current_out_ptr, Simd<T, N>(vexp));
current_in_ptr += N;
}
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
if constexpr (same_t) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
current_in_ptr++;
}
current_out_ptr++;
}
}
}
});
}
} // namespace
@@ -109,30 +118,32 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
auto set_output = [s = stream(), &out](const array& x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General);
copy(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
};
array in = check_input(std::move(inputs[0]));
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
auto in = set_output(inputs[0]);
switch (in.dtype()) {
case bool_:
@@ -148,24 +159,24 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
"Softmax is defined only for floating point types");
break;
case float32:
softmax<float, float>(in, out);
softmax<float, float>(in, out, stream());
break;
case float16:
if (precise_) {
softmax<float16_t, float>(in, out);
softmax<float16_t, float>(in, out, stream());
} else {
softmax<float16_t, float16_t>(in, out);
softmax<float16_t, float16_t>(in, out, stream());
}
break;
case bfloat16:
if (precise_) {
softmax<bfloat16_t, float>(in, out);
softmax<bfloat16_t, float>(in, out, stream());
} else {
softmax<bfloat16_t, bfloat16_t>(in, out);
softmax<bfloat16_t, bfloat16_t>(in, out, stream());
}
break;
case float64:
softmax<double, double>(in, out);
softmax<double, double>(in, out, stream());
break;
case complex64:
throw std::invalid_argument(

View File

@@ -7,6 +7,7 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
@@ -103,11 +104,11 @@ struct StridedIterator {
T* ptr_;
};
template <typename T, typename IdxT = uint32_t>
void sort(const array& in, array& out, int axis) {
template <typename T>
void sort(const array& in, array& out, int axis, Stream stream) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
copy(in, out, ctype, stream);
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
@@ -126,19 +127,27 @@ void sort(const array& in, array& out, int axis) {
// Perform sorting in place
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc;
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<T>(),
src_it = std::move(src_it),
n_rows,
axis_size,
axis_stride]() mutable {
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size);
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed);
src_it.step();
}
std::stable_sort(st, ed);
src_it.step();
}
});
}
template <typename T, typename IdxT = uint32_t>
void argsort(const array& in, array& out, int axis) {
void argsort(const array& in, array& out, int axis, Stream stream) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -167,35 +176,48 @@ void argsort(const array& in, array& out, int axis) {
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc;
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
in_it.step();
out_it.step();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<IdxT>(),
in_it = std::move(in_it),
out_it = std::move(out_it),
n_rows,
axis_size,
in_stride,
out_stride]() mutable {
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
in_it.step();
out_it.step();
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator ed(idx_ptr, out_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator ed(idx_ptr, out_stride, axis_size);
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
});
}
template <typename T, typename IdxT = uint32_t>
void partition(const array& in, array& out, int axis, int kth) {
template <typename T>
void partition(const array& in, array& out, int axis, int kth, Stream stream) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
copy(in, out, ctype, stream);
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
@@ -216,20 +238,34 @@ void partition(const array& in, array& out, int axis, int kth) {
// Perform partition in place
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc;
src_it.step();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<T>(),
src_it = std::move(src_it),
n_rows,
axis_size,
axis_stride,
kth]() mutable {
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
src_it.step();
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size);
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed);
}
std::nth_element(st, md, ed);
}
});
}
template <typename T, typename IdxT = uint32_t>
void argpartition(const array& in, array& out, int axis, int kth) {
void argpartition(
const array& in,
array& out,
int axis,
int kth,
Stream stream) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -260,29 +296,43 @@ void argpartition(const array& in, array& out, int axis, int kth) {
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc;
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
in_it.step();
out_it.step();
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<IdxT>(),
in_it = std::move(in_it),
out_it = std::move(out_it),
n_rows,
axis_size,
in_stride,
out_stride,
kth]() mutable {
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step();
out_it.step();
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator md(idx_ptr, out_stride, kth);
StridedIterator ed(idx_ptr, out_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator md(idx_ptr, out_stride, kth);
StridedIterator ed(idx_ptr, out_stride, axis_size);
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
});
}
} // namespace
@@ -293,33 +343,33 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) {
case bool_:
return argsort<bool>(in, out, axis_);
return argsort<bool>(in, out, axis_, stream());
case uint8:
return argsort<uint8_t>(in, out, axis_);
return argsort<uint8_t>(in, out, axis_, stream());
case uint16:
return argsort<uint16_t>(in, out, axis_);
return argsort<uint16_t>(in, out, axis_, stream());
case uint32:
return argsort<uint32_t>(in, out, axis_);
return argsort<uint32_t>(in, out, axis_, stream());
case uint64:
return argsort<uint64_t>(in, out, axis_);
return argsort<uint64_t>(in, out, axis_, stream());
case int8:
return argsort<int8_t>(in, out, axis_);
return argsort<int8_t>(in, out, axis_, stream());
case int16:
return argsort<int16_t>(in, out, axis_);
return argsort<int16_t>(in, out, axis_, stream());
case int32:
return argsort<int32_t>(in, out, axis_);
return argsort<int32_t>(in, out, axis_, stream());
case int64:
return argsort<int64_t>(in, out, axis_);
return argsort<int64_t>(in, out, axis_, stream());
case float32:
return argsort<float>(in, out, axis_);
return argsort<float>(in, out, axis_, stream());
case float64:
return argsort<double>(in, out, axis_);
return argsort<double>(in, out, axis_, stream());
case float16:
return argsort<float16_t>(in, out, axis_);
return argsort<float16_t>(in, out, axis_, stream());
case bfloat16:
return argsort<bfloat16_t>(in, out, axis_);
return argsort<bfloat16_t>(in, out, axis_, stream());
case complex64:
return argsort<complex64_t>(in, out, axis_);
return argsort<complex64_t>(in, out, axis_, stream());
}
}
@@ -329,33 +379,33 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) {
case bool_:
return sort<bool>(in, out, axis_);
return sort<bool>(in, out, axis_, stream());
case uint8:
return sort<uint8_t>(in, out, axis_);
return sort<uint8_t>(in, out, axis_, stream());
case uint16:
return sort<uint16_t>(in, out, axis_);
return sort<uint16_t>(in, out, axis_, stream());
case uint32:
return sort<uint32_t>(in, out, axis_);
return sort<uint32_t>(in, out, axis_, stream());
case uint64:
return sort<uint64_t>(in, out, axis_);
return sort<uint64_t>(in, out, axis_, stream());
case int8:
return sort<int8_t>(in, out, axis_);
return sort<int8_t>(in, out, axis_, stream());
case int16:
return sort<int16_t>(in, out, axis_);
return sort<int16_t>(in, out, axis_, stream());
case int32:
return sort<int32_t>(in, out, axis_);
return sort<int32_t>(in, out, axis_, stream());
case int64:
return sort<int64_t>(in, out, axis_);
return sort<int64_t>(in, out, axis_, stream());
case float32:
return sort<float>(in, out, axis_);
return sort<float>(in, out, axis_, stream());
case float64:
return sort<double>(in, out, axis_);
return sort<double>(in, out, axis_, stream());
case float16:
return sort<float16_t>(in, out, axis_);
return sort<float16_t>(in, out, axis_, stream());
case bfloat16:
return sort<bfloat16_t>(in, out, axis_);
return sort<bfloat16_t>(in, out, axis_, stream());
case complex64:
return sort<complex64_t>(in, out, axis_);
return sort<complex64_t>(in, out, axis_, stream());
}
}
@@ -365,33 +415,33 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) {
case bool_:
return argpartition<bool>(in, out, axis_, kth_);
return argpartition<bool>(in, out, axis_, kth_, stream());
case uint8:
return argpartition<uint8_t>(in, out, axis_, kth_);
return argpartition<uint8_t>(in, out, axis_, kth_, stream());
case uint16:
return argpartition<uint16_t>(in, out, axis_, kth_);
return argpartition<uint16_t>(in, out, axis_, kth_, stream());
case uint32:
return argpartition<uint32_t>(in, out, axis_, kth_);
return argpartition<uint32_t>(in, out, axis_, kth_, stream());
case uint64:
return argpartition<uint64_t>(in, out, axis_, kth_);
return argpartition<uint64_t>(in, out, axis_, kth_, stream());
case int8:
return argpartition<int8_t>(in, out, axis_, kth_);
return argpartition<int8_t>(in, out, axis_, kth_, stream());
case int16:
return argpartition<int16_t>(in, out, axis_, kth_);
return argpartition<int16_t>(in, out, axis_, kth_, stream());
case int32:
return argpartition<int32_t>(in, out, axis_, kth_);
return argpartition<int32_t>(in, out, axis_, kth_, stream());
case int64:
return argpartition<int64_t>(in, out, axis_, kth_);
return argpartition<int64_t>(in, out, axis_, kth_, stream());
case float32:
return argpartition<float>(in, out, axis_, kth_);
return argpartition<float>(in, out, axis_, kth_, stream());
case float64:
return argpartition<double>(in, out, axis_, kth_);
return argpartition<double>(in, out, axis_, kth_, stream());
case float16:
return argpartition<float16_t>(in, out, axis_, kth_);
return argpartition<float16_t>(in, out, axis_, kth_, stream());
case bfloat16:
return argpartition<bfloat16_t>(in, out, axis_, kth_);
return argpartition<bfloat16_t>(in, out, axis_, kth_, stream());
case complex64:
return argpartition<complex64_t>(in, out, axis_, kth_);
return argpartition<complex64_t>(in, out, axis_, kth_, stream());
}
}
@@ -401,33 +451,33 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) {
case bool_:
return partition<bool>(in, out, axis_, kth_);
return partition<bool>(in, out, axis_, kth_, stream());
case uint8:
return partition<uint8_t>(in, out, axis_, kth_);
return partition<uint8_t>(in, out, axis_, kth_, stream());
case uint16:
return partition<uint16_t>(in, out, axis_, kth_);
return partition<uint16_t>(in, out, axis_, kth_, stream());
case uint32:
return partition<uint32_t>(in, out, axis_, kth_);
return partition<uint32_t>(in, out, axis_, kth_, stream());
case uint64:
return partition<uint64_t>(in, out, axis_, kth_);
return partition<uint64_t>(in, out, axis_, kth_, stream());
case int8:
return partition<int8_t>(in, out, axis_, kth_);
return partition<int8_t>(in, out, axis_, kth_, stream());
case int16:
return partition<int16_t>(in, out, axis_, kth_);
return partition<int16_t>(in, out, axis_, kth_, stream());
case int32:
return partition<int32_t>(in, out, axis_, kth_);
return partition<int32_t>(in, out, axis_, kth_, stream());
case int64:
return partition<int64_t>(in, out, axis_, kth_);
return partition<int64_t>(in, out, axis_, kth_, stream());
case float32:
return partition<float>(in, out, axis_, kth_);
return partition<float>(in, out, axis_, kth_, stream());
case float64:
return partition<double>(in, out, axis_, kth_);
return partition<double>(in, out, axis_, kth_, stream());
case float16:
return partition<float16_t>(in, out, axis_, kth_);
return partition<float16_t>(in, out, axis_, kth_, stream());
case bfloat16:
return partition<bfloat16_t>(in, out, axis_, kth_);
return partition<bfloat16_t>(in, out, axis_, kth_, stream());
case complex64:
return partition<complex64_t>(in, out, axis_, kth_);
return partition<complex64_t>(in, out, axis_, kth_, stream());
}
}

View File

@@ -2,13 +2,18 @@
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {
template <typename T>
void svd_impl(const array& a, T* u_data, T* s_data, T* vt_data) {
void svd_impl(
const array& a,
std::vector<array>& outputs,
bool compute_uv,
Stream stream) {
// Lapack uses the column-major convention. To avoid having to transpose
// the input and then transpose the outputs, we swap the indices/sizes of the
// matrices and take advantage of the following identity (see
@@ -22,118 +27,24 @@ void svd_impl(const array& a, T* u_data, T* s_data, T* vt_data) {
const int N = a.shape(-1);
const int K = std::min(M, N);
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
const int lda = N;
// U of shape M x M. (N x N in lapack).
const int ldu = N;
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
size_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy.
array in(a.shape(), a.dtype(), nullptr, {});
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy(
a,
in,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream);
auto job_u = (u_data && vt_data) ? "V" : "N";
auto job_vt = (u_data && vt_data) ? "V" : "N";
static constexpr auto range = "A";
// Allocate outputs.
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
auto in_ptr = in.data<T>();
T* u_ptr;
T* s_ptr;
T* vt_ptr;
// Will contain the number of singular values after the call has returned.
int ns = 0;
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not used
// here but required by lapack).
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
static const int lwork_query = -1;
static const int ignored_int = 0;
static const T ignored_float = 0;
static T ignored_output = 0;
int info;
// Compute workspace size.
gesvdx<T>(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {
gesvdx<T>(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in.data<T>() + M * N * i,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s_data + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_data ? vt_data + N * N * i : &ignored_output,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u_data ? u_data + M * M * i : &ignored_output,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] failed with code " << info;
throw std::runtime_error(ss.str());
}
if (ns != K) {
std::stringstream ss;
ss << "[SVD::eval_cpu] expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
}
}
}
template <typename T>
void compute_svd(const array& a, bool compute_uv, std::vector<array>& outputs) {
if (compute_uv) {
array& u = outputs[0];
array& s = outputs[1];
@@ -143,25 +54,147 @@ void compute_svd(const array& a, bool compute_uv, std::vector<array>& outputs) {
s.set_data(allocator::malloc_or_wait(s.nbytes()));
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
svd_impl<T>(a, u.data<T>(), s.data<T>(), vt.data<T>());
encoder.set_output_array(u);
encoder.set_output_array(s);
encoder.set_output_array(vt);
s_ptr = s.data<T>();
u_ptr = u.data<T>();
vt_ptr = vt.data<T>();
} else {
array& s = outputs[0];
s.set_data(allocator::malloc_or_wait(s.nbytes()));
svd_impl<T>(a, nullptr, s.data<T>(), nullptr);
encoder.set_output_array(s);
s_ptr = s.data<T>();
u_ptr = nullptr;
vt_ptr = nullptr;
}
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
const int lda = N;
// U of shape M x M. (N x N in lapack).
const int ldu = N;
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
auto job_u = (u_ptr) ? "V" : "N";
auto job_vt = (u_ptr) ? "V" : "N";
static constexpr auto range = "A";
// Will contain the number of singular values after the call has returned.
int ns = 0;
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
static const int lwork_query = -1;
static const int ignored_int = 0;
static const T ignored_float = 0;
int info;
// Compute workspace size.
gesvdx<T>(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {
gesvdx<T>(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in_ptr + M * N * i,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s_ptr + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u_ptr ? u_ptr + M * M * i : nullptr,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
if (ns != K) {
std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
}
}
});
encoder.add_temporary(in);
}
template <typename T>
void compute_svd(
const array& a,
bool compute_uv,
std::vector<array>& outputs,
Stream stream) {}
void SVD::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
switch (inputs[0].dtype()) {
case float32:
compute_svd<float>(inputs[0], compute_uv_, outputs);
svd_impl<float>(inputs[0], outputs, compute_uv_, stream());
break;
case float64:
compute_svd<double>(inputs[0], compute_uv_, outputs);
svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
break;
default:
throw std::runtime_error(

View File

@@ -5,6 +5,8 @@
#include "mlx/array.h"
#include "mlx/backend/common/ternary.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -53,22 +55,18 @@ void ternary_op_dims(
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dispatch_dims(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
const T1* a_ptr,
const T2* b_ptr,
const T3* c_ptr,
U* out_ptr,
Op op,
size_t size,
Shape& shape,
std::vector<Strides>& strides) {
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& c_strides = strides[2];
const auto& out_strides = strides[3];
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<T3>();
int ndim = shape.size();
switch (ndim) {
case 1:
@@ -105,7 +103,7 @@ void ternary_op_dispatch_dims(
ContiguousIterator b_it(shape, b_strides, ndim - 2);
ContiguousIterator c_it(shape, c_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
for (size_t elem = 0; elem < size; elem += stride) {
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
@@ -134,23 +132,53 @@ void ternary_op(
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
// The full computation is scalar-scalar-scalar so we call the base op once.
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<U>();
if (topt == TernaryOpType::ScalarScalarScalar) {
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
encoder.dispatch(
[a_ptr, b_ptr, c_ptr, out_ptr, op = std::move(op)]() mutable {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
});
} else if (topt == TernaryOpType::VectorVectorVector) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<U>();
for (size_t i = 0; i < out.size(); ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++;
b_ptr++;
c_ptr++;
out_ptr++;
}
encoder.dispatch([a_ptr,
b_ptr,
c_ptr,
out_ptr,
op = std::move(op),
size = out.size()]() mutable {
for (size_t i = 0; i < size; ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++;
b_ptr++;
c_ptr++;
out_ptr++;
}
});
} else {
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
encoder.dispatch(
[a_ptr,
b_ptr,
c_ptr,
out_ptr,
op = std::move(op),
size = out.size(),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
ternary_op_dispatch_dims<T1, T2, T3, U>(
a_ptr, b_ptr, c_ptr, out_ptr, op, size, shape, strides);
});
}
}

View File

@@ -5,67 +5,83 @@
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
void set_unary_output_data(const array& in, array& out) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
}
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
}
template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) {
out[i] = op(*a);
out[i] = Op{}(*a);
a += stride;
}
}
template <typename T, typename U = T, typename Op>
void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>();
if (a.flags().contiguous) {
set_unary_output_data(a, out);
U* dst = out.data<U>();
constexpr int N = simd::max_size<T>;
size_t size = a.data_size();
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a_ptr)));
size -= N;
a_ptr += N;
dst += N;
void unary_op(const array& a, array& out, Op) {
set_unary_output_data(a, out);
const T* src = a.data<T>();
U* dst = out.data<U>();
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([src,
dst,
contig = a.flags().contiguous,
data_size = a.data_size(),
size = a.size(),
shapes = a.shape(),
strides = a.strides()]() mutable {
auto ndim = shapes.size();
if (contig) {
constexpr int N = simd::max_size<T>;
while (data_size >= N) {
simd::store(dst, Op{}(simd::load<T, N>(src)));
data_size -= N;
src += N;
dst += N;
}
while (data_size > 0) {
*dst = Op{}(*src);
data_size--;
dst++;
src++;
}
} else {
size_t shape = ndim > 0 ? shapes.back() : 1;
size_t stride = ndim > 0 ? strides.back() : 1;
if (ndim <= 1) {
unary_op<T, U, Op>(src, dst, shape, stride);
return;
}
auto it = ContiguousIterator(shapes, strides, ndim - 1);
for (size_t elem = 0; elem < size; elem += shape) {
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
it.step();
}
}
while (size > 0) {
*dst = op(*a_ptr);
size--;
dst++;
a_ptr++;
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
U* dst = out.data<U>();
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
if (a.ndim() <= 1) {
unary_op(a_ptr, dst, op, shape, stride);
return;
}
ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1);
for (size_t elem = 0; elem < a.size(); elem += shape) {
unary_op(a_ptr + it.loc, dst + elem, op, shape, stride);
it.step();
}
}
});
}
template <typename Op>