mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch * load + more async CPU * use command encoder API and move more ops to use it * make fence back-end generic + CPU only fence * faster build * fix async eval * fixes + handle temporaries * fix / improve cpu conv * remove unused status, fix siblings * fix extensions * fix * fix no cpu build * format * comments * fix perf regression, remove unecessary abort * fix events, task limit cpu * fix waiting * fix donation / temporaries in normalization
This commit is contained in:
@@ -44,7 +44,9 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
@@ -65,6 +67,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
||||
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
|
||||
@@ -2,76 +2,27 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void arange(T start, T next, array& out, size_t size) {
|
||||
void arange(T start, T next, array& out, size_t size, Stream stream) {
|
||||
auto ptr = out.data<T>();
|
||||
auto step_size = next - start;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
ptr[i] = start;
|
||||
start += step_size;
|
||||
}
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([ptr, start, step_size, size]() mutable {
|
||||
for (int i = 0; i < size; ++i) {
|
||||
ptr[i] = start;
|
||||
start += step_size;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void arange(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
double start,
|
||||
double step) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
throw std::runtime_error("Bool type unsupported for arange.");
|
||||
break;
|
||||
case uint8:
|
||||
arange<uint8_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint16:
|
||||
arange<uint16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint32:
|
||||
arange<uint32_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint64:
|
||||
arange<uint64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int8:
|
||||
arange<int8_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int16:
|
||||
arange<int16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int32:
|
||||
arange<int32_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int64:
|
||||
arange<int64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case float16:
|
||||
arange<float16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case float32:
|
||||
arange<float>(start, start + step, out, out.size());
|
||||
break;
|
||||
case float64:
|
||||
arange<double>(start, start + step, out, out.size());
|
||||
break;
|
||||
case bfloat16:
|
||||
arange<bfloat16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case complex64:
|
||||
arange<complex64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -10,23 +11,43 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename InT, typename OpT>
|
||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
void arg_reduce(
|
||||
const array& in,
|
||||
array& out,
|
||||
const OpT& op,
|
||||
int axis,
|
||||
Stream stream) {
|
||||
auto axis_size = in.shape()[axis];
|
||||
auto axis_stride = in.strides()[axis];
|
||||
Strides strides = in.strides();
|
||||
Shape shape = in.shape();
|
||||
strides.erase(strides.begin() + axis);
|
||||
shape.erase(shape.begin() + axis);
|
||||
for (uint32_t i = 0; i < out.size(); ++i) {
|
||||
auto loc = elem_to_loc(i, shape, strides);
|
||||
auto in_ptr = in.data<InT>() + loc;
|
||||
uint32_t ind_v = 0;
|
||||
InT v = (*in_ptr);
|
||||
for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) {
|
||||
op(j, (*in_ptr), &ind_v, &v);
|
||||
auto in_ptr = in.data<InT>();
|
||||
auto out_ptr = out.data<uint32_t>();
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
axis_size,
|
||||
axis_stride,
|
||||
op = std::move(op),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides),
|
||||
size = out.size()]() {
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
auto loc = elem_to_loc(i, shape, strides);
|
||||
auto local_in_ptr = in_ptr + loc;
|
||||
uint32_t ind_v = 0;
|
||||
InT v = (*local_in_ptr);
|
||||
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
|
||||
op(j, (*local_in_ptr), &ind_v, &v);
|
||||
}
|
||||
out_ptr[i] = ind_v;
|
||||
}
|
||||
out.data<uint32_t>()[i] = ind_v;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
@@ -34,7 +55,8 @@ void arg_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
ArgReduce::ReduceType rtype,
|
||||
int axis) {
|
||||
int axis,
|
||||
Stream stream) {
|
||||
switch (rtype) {
|
||||
case ArgReduce::ArgMin: {
|
||||
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
|
||||
@@ -43,7 +65,7 @@ void arg_reduce_dispatch(
|
||||
(*ind_y) = ind_x;
|
||||
}
|
||||
};
|
||||
arg_reduce<InT>(in, out, op, axis);
|
||||
arg_reduce<InT>(in, out, op, axis, stream);
|
||||
break;
|
||||
}
|
||||
case ArgReduce::ArgMax: {
|
||||
@@ -53,7 +75,7 @@ void arg_reduce_dispatch(
|
||||
(*ind_y) = ind_x;
|
||||
}
|
||||
};
|
||||
arg_reduce<InT>(in, out, op, axis);
|
||||
arg_reduce<InT>(in, out, op, axis, stream);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -68,46 +90,46 @@ void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
|
||||
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case uint8:
|
||||
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
|
||||
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case uint16:
|
||||
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
|
||||
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case uint32:
|
||||
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
|
||||
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case uint64:
|
||||
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
|
||||
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case int8:
|
||||
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
|
||||
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case int16:
|
||||
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
|
||||
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case int32:
|
||||
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
|
||||
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case int64:
|
||||
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
|
||||
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case float16:
|
||||
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
|
||||
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case float32:
|
||||
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
|
||||
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
|
||||
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case float64:
|
||||
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
|
||||
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
|
||||
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,49 +16,49 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename Op>
|
||||
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||
void comparison_op(const array& a, const array& b, array& out) {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, bool>(a, b, out, op);
|
||||
binary_op<bool, bool, Op>(a, b, out);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, bool>(a, b, out, op);
|
||||
binary_op<uint8_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, bool>(a, b, out, op);
|
||||
binary_op<uint16_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, bool>(a, b, out, op);
|
||||
binary_op<uint32_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, bool>(a, b, out, op);
|
||||
binary_op<uint64_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, bool>(a, b, out, op);
|
||||
binary_op<int8_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, bool>(a, b, out, op);
|
||||
binary_op<int16_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, bool>(a, b, out, op);
|
||||
binary_op<int32_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, bool>(a, b, out, op);
|
||||
binary_op<int64_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, bool>(a, b, out, op);
|
||||
binary_op<float16_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool>(a, b, out, op);
|
||||
binary_op<float, bool, Op>(a, b, out);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool>(a, b, out, op);
|
||||
binary_op<double, bool, Op>(a, b, out);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool>(a, b, out, op);
|
||||
binary_op<bfloat16_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool>(a, b, out, op);
|
||||
binary_op<complex64_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -151,47 +151,47 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (equal_nan_) {
|
||||
switch (a.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, bool>(a, b, out, detail::NaNEqual());
|
||||
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool>(a, b, out, detail::NaNEqual());
|
||||
binary_op<float, bool, detail::NaNEqual>(a, b, out);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool>(a, b, out, detail::NaNEqual());
|
||||
binary_op<double, bool, detail::NaNEqual>(a, b, out);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual());
|
||||
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool>(a, b, out, detail::NaNEqual());
|
||||
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[NanEqual::eval_cpu] Only for floating point types.");
|
||||
}
|
||||
} else {
|
||||
comparison_op(a, b, out, detail::Equal());
|
||||
comparison_op<detail::Equal>(a, b, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Greater());
|
||||
comparison_op<detail::Greater>(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
|
||||
comparison_op<detail::GreaterEqual>(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Less());
|
||||
comparison_op<detail::Less>(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
|
||||
comparison_op<detail::LessEqual>(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -200,16 +200,16 @@ void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||
binary_op<float16_t, detail::LogAddExp>(a, b, out);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||
binary_op<float, detail::LogAddExp>(a, b, out);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, out, detail::LogAddExp());
|
||||
binary_op<double, detail::LogAddExp>(a, b, out);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||
binary_op<bfloat16_t, detail::LogAddExp>(a, b, out);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
@@ -254,7 +254,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
|
||||
comparison_op<detail::NotEqual>(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
|
||||
@@ -14,22 +16,18 @@ namespace mlx::core {
|
||||
|
||||
template <typename Op>
|
||||
struct VectorScalar {
|
||||
Op op;
|
||||
|
||||
VectorScalar(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *b;
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
|
||||
simd::store(dst, Op{}(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
|
||||
dst += N;
|
||||
a += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, scalar);
|
||||
*dst = Op{}(*a, scalar);
|
||||
dst++;
|
||||
a++;
|
||||
}
|
||||
@@ -38,22 +36,18 @@ struct VectorScalar {
|
||||
|
||||
template <typename Op>
|
||||
struct ScalarVector {
|
||||
Op op;
|
||||
|
||||
ScalarVector(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *a;
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
|
||||
simd::store(dst, Op{}(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
|
||||
dst += N;
|
||||
b += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(scalar, *b);
|
||||
*dst = Op{}(scalar, *b);
|
||||
dst++;
|
||||
b++;
|
||||
}
|
||||
@@ -62,22 +56,18 @@ struct ScalarVector {
|
||||
|
||||
template <typename Op>
|
||||
struct VectorVector {
|
||||
Op op;
|
||||
|
||||
VectorVector(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a), simd::load<T, N>(b)));
|
||||
simd::store(dst, Op{}(simd::load<T, N>(a), simd::load<T, N>(b)));
|
||||
dst += N;
|
||||
a += N;
|
||||
b += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, *b);
|
||||
*dst = Op{}(*a, *b);
|
||||
dst++;
|
||||
a++;
|
||||
b++;
|
||||
@@ -90,7 +80,6 @@ void binary_op_dims(
|
||||
const T* a,
|
||||
const T* b,
|
||||
U* out,
|
||||
Op op,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
@@ -104,12 +93,12 @@ void binary_op_dims(
|
||||
for (int i = 0; i < N; i++) {
|
||||
if constexpr (D > 1) {
|
||||
binary_op_dims<T, U, Op, D - 1, Strided>(
|
||||
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
|
||||
a, b, out, shape, a_strides, b_strides, out_strides, axis + 1);
|
||||
} else {
|
||||
if constexpr (Strided) {
|
||||
op(a, b, out, stride_out);
|
||||
Op{}(a, b, out, stride_out);
|
||||
} else {
|
||||
*out = op(*a, *b);
|
||||
*out = Op{}(*a, *b);
|
||||
}
|
||||
}
|
||||
out += stride_out;
|
||||
@@ -120,66 +109,38 @@ void binary_op_dims(
|
||||
|
||||
template <typename T, typename U, bool Strided, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
const T* a,
|
||||
const T* b,
|
||||
U* out,
|
||||
int dim,
|
||||
int size,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& out_strides) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
switch (dim) {
|
||||
case 1:
|
||||
binary_op_dims<T, U, Op, 1, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
a, b, out, shape, a_strides, b_strides, out_strides, 0);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims<T, U, Op, 2, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
a, b, out, shape, a_strides, b_strides, out_strides, 0);
|
||||
return;
|
||||
case 3:
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
a, b, out, shape, a_strides, b_strides, out_strides, 0);
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(shape, a_strides, dim - 3);
|
||||
ContiguousIterator b_it(shape, b_strides, dim - 3);
|
||||
auto stride = out_strides[dim - 4];
|
||||
for (int64_t elem = 0; elem < a.size(); elem += stride) {
|
||||
for (int64_t elem = 0; elem < size; elem += stride) {
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
out_ptr + elem,
|
||||
op,
|
||||
a + a_it.loc,
|
||||
b + b_it.loc,
|
||||
out + elem,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
@@ -191,181 +152,216 @@ void binary_op_dispatch_dims(
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
void binary_op(const array& a, const array& b, array& out) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
|
||||
return;
|
||||
}
|
||||
auto a_ptr = a.data<T>();
|
||||
auto b_ptr = b.data<T>();
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
ScalarVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
VectorScalar{op}(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
VectorVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
|
||||
return;
|
||||
}
|
||||
|
||||
// General computation so let's try to optimize
|
||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), out.strides()});
|
||||
const auto& a_strides = new_strides[0];
|
||||
const auto& b_strides = new_strides[1];
|
||||
const auto& strides = new_strides[2];
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
|
||||
auto out_ptr = out.data<U>();
|
||||
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([bopt,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
a_data_size = a.data_size(),
|
||||
b_data_size = b.data_size(),
|
||||
size = a.size(),
|
||||
shape = a.shape(),
|
||||
a_strides = a.strides(),
|
||||
b_strides = b.strides(),
|
||||
strides = out.strides()]() mutable {
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
*out_ptr = Op{}(*a_ptr, *b_ptr);
|
||||
return;
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_rc_dim = leftmost_rc_dim(a_strides);
|
||||
auto b_rc_dim = leftmost_rc_dim(b_strides);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == 0; d--) {
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b_data_size);
|
||||
return;
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_s_dim = leftmost_s_dim(a_strides);
|
||||
auto b_s_dim = leftmost_s_dim(b_strides);
|
||||
|
||||
auto ndim = new_shape.size();
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a_data_size);
|
||||
return;
|
||||
}
|
||||
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, size);
|
||||
return;
|
||||
}
|
||||
|
||||
// General computation so let's try to optimize
|
||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||
shape,
|
||||
{std::move(a_strides), std::move(b_strides), std::move(strides)});
|
||||
a_strides = new_strides[0];
|
||||
b_strides = new_strides[1];
|
||||
strides = new_strides[2];
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_rc_dim = leftmost_rc_dim(a_strides);
|
||||
auto b_rc_dim = leftmost_rc_dim(b_strides);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == 0; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_s_dim = leftmost_s_dim(a_strides);
|
||||
auto b_s_dim = leftmost_s_dim(b_strides);
|
||||
|
||||
auto ndim = new_shape.size();
|
||||
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
|
||||
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||
// contiguous methods above. Except for the case that the flags do not
|
||||
// correspond to the underlying contiguity.
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
bopt = BinaryOpType::General;
|
||||
dim = ndim;
|
||||
}
|
||||
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||
// contiguous methods above. Except for the case that the flags do not
|
||||
// correspond to the underlying contiguity.
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
bopt = BinaryOpType::General;
|
||||
dim = ndim;
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
VectorVector{op},
|
||||
dim,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
VectorScalar{op},
|
||||
dim,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
ScalarVector{op},
|
||||
dim,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U, false>(
|
||||
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
}
|
||||
switch (bopt) {
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U, true, VectorVector<Op>>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
size,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U, true, VectorScalar<Op>>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
size,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U, true, ScalarVector<Op>>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
size,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U, false, Op>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
size,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out) {
|
||||
binary_op<T, T, Op>(a, b, out);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
binary_op<T, T>(a, b, out, op);
|
||||
binary_op<T, T, Op>(a, b, out);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary(const array& a, const array& b, array& out, Op op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, op);
|
||||
binary_op<bool, Op>(a, b, out);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, op);
|
||||
binary_op<uint8_t, Op>(a, b, out);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, op);
|
||||
binary_op<uint16_t, Op>(a, b, out);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, op);
|
||||
binary_op<uint32_t, Op>(a, b, out);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, op);
|
||||
binary_op<uint64_t, Op>(a, b, out);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, op);
|
||||
binary_op<int8_t, Op>(a, b, out);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, op);
|
||||
binary_op<int16_t, Op>(a, b, out);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, op);
|
||||
binary_op<int32_t, Op>(a, b, out);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, op);
|
||||
binary_op<int64_t, Op>(a, b, out);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, op);
|
||||
binary_op<float16_t, Op>(a, b, out);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, op);
|
||||
binary_op<float, Op>(a, b, out);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, out, op);
|
||||
binary_op<double, Op>(a, b, out);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, op);
|
||||
binary_op<bfloat16_t, Op>(a, b, out);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t>(a, b, out, op);
|
||||
binary_op<complex64_t, Op>(a, b, out);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/binary.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -55,65 +57,81 @@ void binary_op_dispatch_dims(
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Stream stream,
|
||||
Op op) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out_a);
|
||||
encoder.set_output_array(out_b);
|
||||
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), out_a.strides()});
|
||||
const auto& a_strides = strides[0];
|
||||
const auto& b_strides = strides[1];
|
||||
const auto& out_strides = strides[2];
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* out_a_ptr = out_a.data<U>();
|
||||
U* out_b_ptr = out_b.data<U>();
|
||||
|
||||
int ndim = shape.size();
|
||||
switch (ndim) {
|
||||
case 1:
|
||||
binary_op_dims<T, U, Op, 1>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims<T, U, Op, 2>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
}
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
size = a.size(),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides),
|
||||
op = std::move(op)]() {
|
||||
const auto& a_strides = strides[0];
|
||||
const auto& b_strides = strides[1];
|
||||
const auto& out_strides = strides[2];
|
||||
int ndim = shape.size();
|
||||
switch (ndim) {
|
||||
case 1:
|
||||
binary_op_dims<T, U, Op, 1>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims<T, U, Op, 2>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
out_a_ptr + elem,
|
||||
out_b_ptr + elem,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
ndim - 2);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < size; elem += stride) {
|
||||
binary_op_dims<T, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
out_a_ptr + elem,
|
||||
out_b_ptr + elem,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
ndim - 2);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename U = T, typename Op>
|
||||
@@ -128,40 +146,71 @@ void binary_op(
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
auto stream = out_a.primitive().stream();
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == BinaryOpType::General) {
|
||||
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
|
||||
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, stream, op);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out_a);
|
||||
encoder.set_output_array(out_b);
|
||||
|
||||
auto a_ptr = a.data<T>();
|
||||
auto b_ptr = b.data<T>();
|
||||
auto out_a_ptr = out_a.data<U>();
|
||||
auto out_b_ptr = out_b.data<U>();
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
encoder.dispatch(
|
||||
[a_ptr, b_ptr, out_a_ptr, out_b_ptr, op = std::move(op)]() mutable {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
});
|
||||
} else if (bopt == BinaryOpType::ScalarVector) {
|
||||
for (size_t i = 0; i < b.size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
size = b.size(),
|
||||
op = std::move(op)]() mutable {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
});
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
for (size_t i = 0; i < a.size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
a_ptr++;
|
||||
}
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
size = a.size(),
|
||||
op = std::move(op)]() mutable {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
a_ptr++;
|
||||
}
|
||||
});
|
||||
} else { // VectorVector
|
||||
for (size_t i = 0; i < a.size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
a_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
size = a.size(),
|
||||
op = std::move(op)]() mutable {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
a_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -9,7 +10,7 @@
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
|
||||
// Lapack uses the column-major convention. We take advantage of the fact that
|
||||
// the matrix should be symmetric:
|
||||
// (A)ᵀ = A
|
||||
@@ -17,60 +18,63 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
// triangular matrix, so uplo is the opposite of what we would expect from
|
||||
// upper
|
||||
|
||||
char uplo = (upper) ? 'L' : 'U';
|
||||
|
||||
// The decomposition is computed in place, so just copy the input to the
|
||||
// output.
|
||||
copy(
|
||||
a,
|
||||
factor,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream);
|
||||
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(factor);
|
||||
encoder.dispatch([matrix = factor.data<T>(),
|
||||
upper,
|
||||
N = a.shape(-1),
|
||||
size = a.size()]() mutable {
|
||||
char uplo = (upper) ? 'L' : 'U';
|
||||
size_t num_matrices = size / (N * N);
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
// Compute Cholesky factorization.
|
||||
int info;
|
||||
potrf<T>(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
|
||||
T* matrix = factor.data<T>();
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
// Compute Cholesky factorization.
|
||||
int info;
|
||||
potrf<T>(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
|
||||
// TODO: We do nothing when the matrix is not positive semi-definite
|
||||
// because throwing an error would result in a crash. If we figure out how
|
||||
// to catch errors from the implementation we should throw.
|
||||
if (info < 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[cholesky] Cholesky decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Zero out the upper/lower triangle while advancing the pointer to the
|
||||
// next matrix at the same time.
|
||||
for (int row = 0; row < N; row++) {
|
||||
if (upper) {
|
||||
std::fill(matrix, matrix + row, 0);
|
||||
} else {
|
||||
std::fill(matrix + row + 1, matrix + N, 0);
|
||||
// TODO: We do nothing when the matrix is not positive semi-definite
|
||||
// because throwing an error would result in a crash. If we figure out how
|
||||
// to catch errors from the implementation we should throw.
|
||||
if (info < 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Cholesky::eval_cpu] Cholesky decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Zero out the upper/lower triangle while advancing the pointer to the
|
||||
// next matrix at the same time.
|
||||
for (int row = 0; row < N; row++) {
|
||||
if (upper) {
|
||||
std::fill(matrix, matrix + row, 0);
|
||||
} else {
|
||||
std::fill(matrix + row + 1, matrix + N, 0);
|
||||
}
|
||||
matrix += N;
|
||||
}
|
||||
matrix += N;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {
|
||||
switch (inputs[0].dtype()) {
|
||||
case float32:
|
||||
cholesky_impl<float>(inputs[0], output, upper_);
|
||||
cholesky_impl<float>(inputs[0], output, upper_, stream());
|
||||
break;
|
||||
case float64:
|
||||
cholesky_impl<double>(inputs[0], output, upper_);
|
||||
cholesky_impl<double>(inputs[0], output, upper_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/cpu/compiled_preamble.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/jit_compiler.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
@@ -288,6 +289,7 @@ void Compiled::eval_cpu(
|
||||
// Figure out which kernel we are using
|
||||
auto& shape = outputs[0].shape();
|
||||
auto contiguous = compiled_check_contiguity(inputs, shape);
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
|
||||
// Handle all broadcasting and collect function input arguments
|
||||
std::vector<void*> args;
|
||||
@@ -298,6 +300,7 @@ void Compiled::eval_cpu(
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
encoder.set_input_array(x);
|
||||
args.push_back((void*)x.data<void>());
|
||||
|
||||
if (contiguous || is_scalar(x)) {
|
||||
@@ -356,18 +359,25 @@ void Compiled::eval_cpu(
|
||||
});
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, false);
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
||||
|
||||
for (auto& x : outputs) {
|
||||
args.push_back(x.data<void>());
|
||||
encoder.set_output_array(x);
|
||||
}
|
||||
Shape out_shape;
|
||||
if (!contiguous) {
|
||||
args.push_back((void*)outputs[0].shape().data());
|
||||
out_shape = outputs[0].shape();
|
||||
args.push_back((void*)out_shape.data());
|
||||
} else {
|
||||
args.push_back((void*)outputs[0].data_size());
|
||||
}
|
||||
auto fun = (void (*)(void**))fn_ptr;
|
||||
fun(args.data());
|
||||
encoder.dispatch(
|
||||
[fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,6 +5,7 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -12,20 +13,29 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_single(const array& src, array& dst) {
|
||||
auto val = static_cast<DstT>(src.data<SrcT>()[0]);
|
||||
void copy_single(const array& src, array& dst, Stream stream) {
|
||||
auto src_ptr = src.data<SrcT>();
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
for (int i = 0; i < dst.size(); ++i) {
|
||||
dst_ptr[i] = val;
|
||||
}
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
encoder.dispatch([src_ptr, dst_ptr, size = dst.size()]() {
|
||||
auto val = static_cast<DstT>(src_ptr[0]);
|
||||
std::fill_n(dst_ptr, size, val);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_vector(const array& src, array& dst) {
|
||||
void copy_vector(const array& src, array& dst, Stream stream) {
|
||||
auto src_ptr = src.data<SrcT>();
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
size_t size = src.data_size();
|
||||
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
encoder.dispatch([src_ptr, dst_ptr, size = src.data_size()]() {
|
||||
std::copy(src_ptr, src_ptr + size, dst_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, int D>
|
||||
@@ -56,151 +66,220 @@ template <typename SrcT, typename DstT>
|
||||
void copy_general_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
Stream stream,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
if (data_shape.empty()) {
|
||||
auto val = static_cast<DstT>(*(src.data<SrcT>() + i_offset));
|
||||
auto dst_ptr = dst.data<DstT>() + o_offset;
|
||||
*dst_ptr = val;
|
||||
return;
|
||||
}
|
||||
auto [shape, strides] =
|
||||
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
|
||||
int64_t o_offset,
|
||||
const std::optional<array>& dynamic_i_offset,
|
||||
const std::optional<array>& dynamic_o_offset) {
|
||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||
auto dst_ptr = dst.data<DstT>() + o_offset;
|
||||
int ndim = shape.size();
|
||||
if (ndim == 1) {
|
||||
copy_dims<SrcT, DstT, 1>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
} else if (ndim == 2) {
|
||||
copy_dims<SrcT, DstT, 2>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
} else if (ndim == 3) {
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
}
|
||||
ContiguousIterator in(shape, strides[0], ndim - 3);
|
||||
ContiguousIterator out(shape, strides[1], ndim - 3);
|
||||
auto stride = std::accumulate(
|
||||
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
|
||||
for (int64_t elem = 0; elem < src.size(); elem += stride) {
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
src_ptr + in.loc,
|
||||
dst_ptr + out.loc,
|
||||
shape,
|
||||
strides[0],
|
||||
strides[1],
|
||||
ndim - 3);
|
||||
in.step();
|
||||
out.step();
|
||||
}
|
||||
auto i_offset_ptr =
|
||||
dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr;
|
||||
auto o_offset_ptr =
|
||||
dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr;
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
encoder.dispatch([src_ptr,
|
||||
dst_ptr,
|
||||
size = src.size(),
|
||||
data_shape = data_shape,
|
||||
i_strides = i_strides,
|
||||
o_strides = o_strides,
|
||||
i_offset_ptr,
|
||||
o_offset_ptr]() mutable {
|
||||
if (data_shape.empty()) {
|
||||
auto val = static_cast<DstT>(*src_ptr);
|
||||
*dst_ptr = val;
|
||||
return;
|
||||
}
|
||||
auto [shape, strides] =
|
||||
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
|
||||
|
||||
int ndim = shape.size();
|
||||
if (ndim < 3) {
|
||||
if (i_offset_ptr) {
|
||||
src_ptr += i_offset_ptr[0];
|
||||
}
|
||||
if (o_offset_ptr) {
|
||||
dst_ptr += o_offset_ptr[0];
|
||||
}
|
||||
|
||||
if (ndim == 1) {
|
||||
copy_dims<SrcT, DstT, 1>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
} else if (ndim == 2) {
|
||||
copy_dims<SrcT, DstT, 2>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
} else if (ndim == 3) {
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (i_offset_ptr) {
|
||||
src_ptr += i_offset_ptr[0];
|
||||
}
|
||||
if (o_offset_ptr) {
|
||||
dst_ptr += o_offset_ptr[0];
|
||||
}
|
||||
|
||||
ContiguousIterator in(shape, strides[0], ndim - 3);
|
||||
ContiguousIterator out(shape, strides[1], ndim - 3);
|
||||
auto stride = std::accumulate(
|
||||
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
|
||||
for (int64_t elem = 0; elem < size; elem += stride) {
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
src_ptr + in.loc,
|
||||
dst_ptr + out.loc,
|
||||
shape,
|
||||
strides[0],
|
||||
strides[1],
|
||||
ndim - 3);
|
||||
in.step();
|
||||
out.step();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general_general(const array& src, array& dst) {
|
||||
inline void copy_general_general(const array& src, array& dst, Stream stream) {
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
|
||||
src,
|
||||
dst,
|
||||
stream,
|
||||
src.shape(),
|
||||
src.strides(),
|
||||
dst.strides(),
|
||||
0,
|
||||
0,
|
||||
std::nullopt,
|
||||
std::nullopt);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
Stream stream,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides&,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
int64_t o_offset,
|
||||
const std::optional<array>& dynamic_i_offset,
|
||||
const std::optional<array>& dynamic_o_offset) {
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src,
|
||||
dst,
|
||||
stream,
|
||||
data_shape,
|
||||
i_strides,
|
||||
make_contiguous_strides(data_shape),
|
||||
i_offset,
|
||||
o_offset);
|
||||
o_offset,
|
||||
dynamic_i_offset,
|
||||
dynamic_o_offset);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general(const array& src, array& dst) {
|
||||
inline void copy_general(const array& src, array& dst, Stream stream) {
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src,
|
||||
dst,
|
||||
stream,
|
||||
src.shape(),
|
||||
src.strides(),
|
||||
make_contiguous_strides(src.shape()),
|
||||
0,
|
||||
0);
|
||||
0,
|
||||
std::nullopt,
|
||||
std::nullopt);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename... Args>
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
void copy(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Stream stream,
|
||||
Args&&... args) {
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
copy_single<SrcT, DstT>(src, dst);
|
||||
copy_single<SrcT, DstT>(src, dst, stream);
|
||||
return;
|
||||
case CopyType::Vector:
|
||||
copy_vector<SrcT, DstT>(src, dst);
|
||||
copy_vector<SrcT, DstT>(src, dst, stream);
|
||||
return;
|
||||
case CopyType::General:
|
||||
copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||
copy_general<SrcT, DstT>(src, dst, stream, std::forward<Args>(args)...);
|
||||
return;
|
||||
case CopyType::GeneralGeneral:
|
||||
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src, dst, stream, std::forward<Args>(args)...);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename... Args>
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
void copy(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Stream stream,
|
||||
Args&&... args) {
|
||||
switch (dst.dtype()) {
|
||||
case bool_:
|
||||
copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, bool>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint8:
|
||||
copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, uint8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint16:
|
||||
copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, uint16_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint32:
|
||||
copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, uint32_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint64:
|
||||
copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, uint64_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int8:
|
||||
copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, int8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int16:
|
||||
copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, int16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int32:
|
||||
copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, int32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int64:
|
||||
copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, int64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float16:
|
||||
copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, float16_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float32:
|
||||
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, float>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float64:
|
||||
copy<SrcT, double>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, double>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, bfloat16_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case complex64:
|
||||
copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, complex64_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -210,84 +289,71 @@ inline void copy_inplace_dispatch(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Stream stream,
|
||||
Args&&... args) {
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<bool>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint8:
|
||||
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<uint8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint16:
|
||||
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<uint16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint32:
|
||||
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<uint32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint64:
|
||||
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<uint64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int8:
|
||||
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<int8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int16:
|
||||
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<int16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int32:
|
||||
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<int32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int64:
|
||||
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<int64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float16:
|
||||
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<float16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float32:
|
||||
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<float>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float64:
|
||||
copy<double>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<double>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<bfloat16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case complex64:
|
||||
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<complex64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype) {
|
||||
copy_inplace_dispatch(src, dst, ctype);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
copy_inplace_dispatch(src, dst, ctype, stream);
|
||||
}
|
||||
|
||||
void copy(const array& src, array& dst, CopyType ctype) {
|
||||
// Allocate the output
|
||||
switch (ctype) {
|
||||
case CopyType::Vector:
|
||||
if (src.is_donatable() && src.itemsize() == dst.itemsize()) {
|
||||
dst.copy_shared_buffer(src);
|
||||
} else {
|
||||
auto size = src.data_size();
|
||||
dst.set_data(
|
||||
allocator::malloc_or_wait(size * dst.itemsize()),
|
||||
size,
|
||||
src.strides(),
|
||||
src.flags());
|
||||
}
|
||||
break;
|
||||
case CopyType::Scalar:
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
dst.set_data(allocator::malloc_or_wait(dst.nbytes()));
|
||||
break;
|
||||
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
bool donated = set_copy_output_data(src, dst, ctype);
|
||||
if (donated && src.dtype() == dst.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_inplace(src, dst, ctype);
|
||||
copy_inplace(src, dst, ctype, stream);
|
||||
}
|
||||
|
||||
void copy_inplace(
|
||||
@@ -298,7 +364,10 @@ void copy_inplace(
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype) {
|
||||
CopyType ctype,
|
||||
Stream stream,
|
||||
const std::optional<array>& dynamic_i_offset, /* = std::nullopt */
|
||||
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
||||
switch (ctype) {
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
@@ -306,15 +375,18 @@ void copy_inplace(
|
||||
src,
|
||||
dst,
|
||||
ctype,
|
||||
stream,
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset);
|
||||
o_offset,
|
||||
dynamic_i_offset,
|
||||
dynamic_o_offset);
|
||||
break;
|
||||
case CopyType::Scalar:
|
||||
case CopyType::Vector:
|
||||
copy_inplace_dispatch(src, dst, ctype);
|
||||
copy_inplace_dispatch(src, dst, ctype, stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,14 +2,16 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy(const array& src, array& dst, CopyType ctype);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype);
|
||||
void copy(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
@@ -19,6 +21,9 @@ void copy_inplace(
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
CopyType ctype,
|
||||
Stream stream,
|
||||
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
||||
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
94
mlx/backend/cpu/distributed.cpp
Normal file
94
mlx/backend/cpu/distributed.cpp
Normal file
@@ -0,0 +1,94 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return {arr, false};
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General, stream);
|
||||
return {arr_copy, true};
|
||||
}
|
||||
};
|
||||
|
||||
void AllReduce::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto donate_or_copy = [s = stream()](const array& in, array& out) {
|
||||
if (in.flags().row_contiguous) {
|
||||
if (in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
return in;
|
||||
} else {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy(in, arr_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto in = donate_or_copy(inputs[0], outputs[0]);
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::all_sum(group(), in, outputs[0], stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
}
|
||||
|
||||
void AllGather::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
distributed::detail::all_gather(group(), in, outputs[0], stream());
|
||||
if (copied) {
|
||||
auto& enc = cpu::get_command_encoder(stream());
|
||||
enc.add_temporary(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Send::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
|
||||
distributed::detail::send(group(), in, dst_, stream());
|
||||
outputs[0].copy_shared_buffer(inputs[0]);
|
||||
if (copied) {
|
||||
auto& enc = cpu::get_command_encoder(stream());
|
||||
enc.add_temporary(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Recv::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 0);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -16,59 +17,72 @@ void eigh_impl(
|
||||
array& vectors,
|
||||
array& values,
|
||||
const std::string& uplo,
|
||||
bool compute_eigenvectors) {
|
||||
bool compute_eigenvectors,
|
||||
Stream stream) {
|
||||
auto vec_ptr = vectors.data<T>();
|
||||
auto eig_ptr = values.data<T>();
|
||||
|
||||
char jobz = compute_eigenvectors ? 'V' : 'N';
|
||||
auto N = vectors.shape(-1);
|
||||
|
||||
// Work query
|
||||
int lwork = -1;
|
||||
int liwork = -1;
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
int iwork;
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
uplo.c_str(),
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
}
|
||||
|
||||
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
||||
for (size_t i = 0; i < vectors.size() / (N * N); ++i) {
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
uplo.c_str(),
|
||||
&N,
|
||||
vec_ptr,
|
||||
&N,
|
||||
eig_ptr,
|
||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
if (info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(vectors);
|
||||
encoder.set_output_array(values);
|
||||
encoder.dispatch([vec_ptr,
|
||||
eig_ptr,
|
||||
jobz,
|
||||
uplo = uplo[0],
|
||||
N = vectors.shape(-1),
|
||||
size = vectors.size()]() mutable {
|
||||
// Work query
|
||||
int lwork = -1;
|
||||
int liwork = -1;
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
int iwork;
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
}
|
||||
|
||||
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
auto iwork_buf =
|
||||
array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
vec_ptr,
|
||||
&N,
|
||||
eig_ptr,
|
||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
if (info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
});
|
||||
if (!compute_eigenvectors) {
|
||||
encoder.add_temporary(vectors);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,7 +103,8 @@ void Eigh::eval_cpu(
|
||||
copy(
|
||||
a,
|
||||
vectors,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream());
|
||||
|
||||
if (compute_eigenvectors_) {
|
||||
// Set the strides and flags so the eigenvectors
|
||||
@@ -107,14 +122,15 @@ void Eigh::eval_cpu(
|
||||
flags.col_contiguous = true;
|
||||
}
|
||||
}
|
||||
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
|
||||
vectors.copy_shared_buffer(vectors, strides, flags, vectors.data_size());
|
||||
}
|
||||
switch (a.dtype()) {
|
||||
case float32:
|
||||
eigh_impl<float>(vectors, values, uplo_, compute_eigenvectors_);
|
||||
eigh_impl<float>(vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||
break;
|
||||
case float64:
|
||||
eigh_impl<double>(vectors, values, uplo_, compute_eigenvectors_);
|
||||
eigh_impl<double>(
|
||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
|
||||
16
mlx/backend/cpu/encoder.cpp
Normal file
16
mlx/backend/cpu/encoder.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream stream) {
|
||||
static std::unordered_map<int, CommandEncoder> encoder_map;
|
||||
auto it = encoder_map.find(stream.index);
|
||||
if (it == encoder_map.end()) {
|
||||
it = encoder_map.emplace(stream.index, stream).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cpu
|
||||
53
mlx/backend/cpu/encoder.h
Normal file
53
mlx/backend/cpu/encoder.h
Normal file
@@ -0,0 +1,53 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
struct CommandEncoder {
|
||||
CommandEncoder(Stream stream) : stream_(stream) {}
|
||||
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
CommandEncoder(CommandEncoder&&) = delete;
|
||||
CommandEncoder& operator=(CommandEncoder&&) = delete;
|
||||
|
||||
void set_input_array(const array& a) {}
|
||||
void set_output_array(array& a) {}
|
||||
|
||||
// Hold onto a temporary until any already scheduled tasks which use it as
|
||||
// an input are complete.
|
||||
void add_temporary(array arr) {
|
||||
temporaries_.push_back(std::move(arr));
|
||||
}
|
||||
|
||||
void add_temporaries(std::vector<array> arrays) {
|
||||
temporaries_.insert(
|
||||
temporaries_.end(),
|
||||
std::make_move_iterator(arrays.begin()),
|
||||
std::make_move_iterator(arrays.end()));
|
||||
}
|
||||
|
||||
std::vector<array>& temporaries() {
|
||||
return temporaries_;
|
||||
}
|
||||
|
||||
template <class F, class... Args>
|
||||
void dispatch(F&& f, Args&&... args) {
|
||||
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
|
||||
scheduler::enqueue(stream_, std::move(task));
|
||||
}
|
||||
|
||||
private:
|
||||
Stream stream_;
|
||||
std::vector<array> temporaries_;
|
||||
};
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream stream);
|
||||
|
||||
} // namespace mlx::core::cpu
|
||||
44
mlx/backend/cpu/eval.cpp
Normal file
44
mlx/backend/cpu/eval.cpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
#include "mlx/backend/cpu/eval.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
void eval(array& arr) {
|
||||
auto s = arr.primitive().stream();
|
||||
|
||||
auto outputs = arr.outputs();
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
arr.primitive().eval_cpu(arr.inputs(), outputs);
|
||||
}
|
||||
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
}
|
||||
// Remove the output if it was donated to by an input
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
auto& encoder = cpu::get_command_encoder(s);
|
||||
scheduler::notify_new_task(s);
|
||||
encoder.dispatch([s,
|
||||
buffers = std::move(buffers),
|
||||
temps = std::move(encoder.temporaries())]() {
|
||||
scheduler::notify_task_completion(s);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cpu
|
||||
12
mlx/backend/cpu/eval.h
Normal file
12
mlx/backend/cpu/eval.h
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
void eval(array& arr);
|
||||
|
||||
} // namespace mlx::core::cpu
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/3rdparty/pocketfft.h"
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -38,46 +39,78 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
});
|
||||
scale /= nelem;
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
if (in.dtype() == complex64 && out.dtype() == complex64) {
|
||||
auto in_ptr =
|
||||
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
|
||||
auto out_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
|
||||
pocketfft::c2c(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes_,
|
||||
!inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
encoder.dispatch([shape = std::move(shape),
|
||||
strides_in = std::move(strides_in),
|
||||
strides_out = std::move(strides_out),
|
||||
axes = axes_,
|
||||
inverse = inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale]() {
|
||||
pocketfft::c2c(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes,
|
||||
!inverse,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
});
|
||||
} else if (in.dtype() == float32 && out.dtype() == complex64) {
|
||||
auto in_ptr = in.data<float>();
|
||||
auto out_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
|
||||
pocketfft::r2c(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes_,
|
||||
!inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
encoder.dispatch([shape = std::move(shape),
|
||||
strides_in = std::move(strides_in),
|
||||
strides_out = std::move(strides_out),
|
||||
axes = axes_,
|
||||
inverse = inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale]() {
|
||||
pocketfft::r2c(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes,
|
||||
!inverse,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
});
|
||||
} else if (in.dtype() == complex64 && out.dtype() == float32) {
|
||||
auto in_ptr =
|
||||
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
|
||||
auto out_ptr = out.data<float>();
|
||||
pocketfft::c2r(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes_,
|
||||
!inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
encoder.dispatch([shape = std::move(shape),
|
||||
strides_in = std::move(strides_in),
|
||||
strides_out = std::move(strides_out),
|
||||
axes = axes_,
|
||||
inverse = inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale]() {
|
||||
pocketfft::c2r(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes,
|
||||
!inverse,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
});
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[FFT] Received unexpected input and output type combination.");
|
||||
|
||||
@@ -7,14 +7,20 @@ namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void matmul(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
const T* a,
|
||||
const T* b,
|
||||
T* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta);
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -9,39 +9,46 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
|
||||
uint32_t size_bits = size_of(mlx_dtype) * 8;
|
||||
switch (kindof(mlx_dtype)) {
|
||||
case Dtype::Kind::b:
|
||||
return BNNSDataTypeBoolean;
|
||||
case Dtype::Kind::u:
|
||||
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
|
||||
case Dtype::Kind::i:
|
||||
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
|
||||
case Dtype::Kind::f:
|
||||
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
|
||||
case Dtype::Kind::V:
|
||||
return BNNSDataTypeBFloat16;
|
||||
case Dtype::Kind::c:
|
||||
throw std::invalid_argument("BNNS does not support complex types");
|
||||
}
|
||||
template <typename T>
|
||||
constexpr BNNSDataType to_bnns_dtype();
|
||||
|
||||
template <>
|
||||
constexpr BNNSDataType to_bnns_dtype<float>() {
|
||||
return BNNSDataType(BNNSDataTypeFloatBit | 32);
|
||||
}
|
||||
template <>
|
||||
constexpr BNNSDataType to_bnns_dtype<float16_t>() {
|
||||
return BNNSDataType(BNNSDataTypeFloatBit | 16);
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr BNNSDataType to_bnns_dtype<bfloat16_t>() {
|
||||
return BNNSDataTypeBFloat16;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void matmul_bnns(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
const T* a,
|
||||
const T* b,
|
||||
T* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta) {
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
|
||||
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
|
||||
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
@@ -115,14 +122,14 @@ void matmul_bnns(
|
||||
auto bnns_filter =
|
||||
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
BNNSFilterApplyTwoInput(
|
||||
bnns_filter,
|
||||
a.data<uint8_t>() +
|
||||
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
|
||||
b.data<uint8_t>() +
|
||||
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
|
||||
out.data<uint8_t>() + M * N * i * out.itemsize());
|
||||
reinterpret_cast<const uint8_t*>(
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides)),
|
||||
reinterpret_cast<const uint8_t*>(
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides)),
|
||||
reinterpret_cast<uint8_t*>(out + M * N * i));
|
||||
}
|
||||
|
||||
BNNSFilterDestroy(bnns_filter);
|
||||
@@ -131,30 +138,72 @@ void matmul_bnns(
|
||||
|
||||
template <>
|
||||
void matmul<float16_t>(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
const float16_t* a,
|
||||
const float16_t* b,
|
||||
float16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta) {
|
||||
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
matmul_bnns(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
alpha,
|
||||
beta,
|
||||
batch_size,
|
||||
a_shape,
|
||||
a_strides,
|
||||
b_shape,
|
||||
b_strides);
|
||||
}
|
||||
|
||||
template <>
|
||||
void matmul<bfloat16_t>(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
const bfloat16_t* a,
|
||||
const bfloat16_t* b,
|
||||
bfloat16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta) {
|
||||
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
matmul_bnns(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
alpha,
|
||||
beta,
|
||||
batch_size,
|
||||
a_shape,
|
||||
a_strides,
|
||||
b_shape,
|
||||
b_strides);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -8,20 +8,27 @@ namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<float>(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
const float* a,
|
||||
const float* b,
|
||||
float* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta) {
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
@@ -29,34 +36,40 @@ void matmul<float>(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha, // alpha
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||
alpha,
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||
lda,
|
||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||
ldb,
|
||||
beta, // beta
|
||||
out.data<float>() + M * N * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
beta,
|
||||
out + M * N * i,
|
||||
ldc);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void matmul<double>(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
const double* a,
|
||||
const double* b,
|
||||
double* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta) {
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
cblas_dgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
@@ -64,15 +77,14 @@ void matmul<double>(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha, // alpha
|
||||
a.data<double>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||
alpha,
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||
lda,
|
||||
b.data<double>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||
ldb,
|
||||
beta, // beta
|
||||
out.data<double>() + M * N * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
beta,
|
||||
out + M * N * i,
|
||||
ldc);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,15 +6,21 @@ namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<bfloat16_t>(
|
||||
const array&,
|
||||
const array&,
|
||||
array&,
|
||||
const bfloat16_t*,
|
||||
const bfloat16_t*,
|
||||
bfloat16_t*,
|
||||
bool,
|
||||
bool,
|
||||
size_t,
|
||||
size_t,
|
||||
size_t,
|
||||
float,
|
||||
float) {
|
||||
float,
|
||||
size_t,
|
||||
const Shape&,
|
||||
const Strides&,
|
||||
const Shape&,
|
||||
const Strides&) {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
|
||||
}
|
||||
|
||||
|
||||
@@ -6,15 +6,21 @@ namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<float16_t>(
|
||||
const array&,
|
||||
const array&,
|
||||
array&,
|
||||
const float16_t*,
|
||||
const float16_t*,
|
||||
float16_t*,
|
||||
bool,
|
||||
bool,
|
||||
size_t,
|
||||
size_t,
|
||||
size_t,
|
||||
float,
|
||||
float) {
|
||||
float,
|
||||
size_t,
|
||||
const Shape&,
|
||||
const Strides&,
|
||||
const Shape&,
|
||||
const Strides&) {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
|
||||
}
|
||||
|
||||
|
||||
@@ -4,16 +4,17 @@
|
||||
|
||||
#include "mlx/backend/common/hadamard.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// n = 2^k component
|
||||
template <typename T>
|
||||
void hadamard_n(array& out, int n, int m, float scale) {
|
||||
for (int b = 0; b < out.size() / n; b++) {
|
||||
void hadamard_n(T* out, int n, int m, float scale, size_t size) {
|
||||
for (int b = 0; b < size / n; b++) {
|
||||
size_t loc = b * n;
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
T* data_ptr = out + loc;
|
||||
int h = 1;
|
||||
int n_over_2 = n / 2;
|
||||
while (h < n) {
|
||||
@@ -36,7 +37,7 @@ void hadamard_n(array& out, int n, int m, float scale) {
|
||||
|
||||
// m component
|
||||
template <typename T>
|
||||
void hadamard_m(array& out, int n, int m, float scale) {
|
||||
void hadamard_m(T* out, int n, int m, float scale, size_t size) {
|
||||
auto h_matrices = hadamard_matrices();
|
||||
auto& matrix = h_matrices[m];
|
||||
auto start = 1;
|
||||
@@ -51,9 +52,9 @@ void hadamard_m(array& out, int n, int m, float scale) {
|
||||
end = matrix.find('\n', start);
|
||||
}
|
||||
|
||||
for (int b = 0; b < out.size() / m / n; b++) {
|
||||
for (int b = 0; b < size / m / n; b++) {
|
||||
size_t loc = b * n * m;
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
T* data_ptr = out + loc;
|
||||
for (int i = 0; i < n; i++) {
|
||||
std::vector<float> out(m);
|
||||
for (int j = 0; j < m; j++) {
|
||||
@@ -74,12 +75,17 @@ void hadamard_m(array& out, int n, int m, float scale) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void hadamard(array& out, int n, int m, float scale) {
|
||||
float n_scale = m > 1 ? 1.0 : scale;
|
||||
hadamard_n<T>(out, n, m, n_scale);
|
||||
if (m > 1) {
|
||||
hadamard_m<T>(out, n, m, scale);
|
||||
}
|
||||
void hadamard(array& out, int n, int m, float scale, Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(out);
|
||||
auto out_ptr = out.data<T>();
|
||||
encoder.dispatch([out_ptr, size = out.size(), n, m, scale]() {
|
||||
float n_scale = m > 1 ? 1.0 : scale;
|
||||
hadamard_n<T>(out_ptr, n, m, n_scale, size);
|
||||
if (m > 1) {
|
||||
hadamard_m<T>(out_ptr, n, m, scale, size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -87,18 +93,26 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Copy input to output
|
||||
copy(in, out, CopyType::General);
|
||||
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy(
|
||||
in,
|
||||
out,
|
||||
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream());
|
||||
}
|
||||
|
||||
int axis = out.ndim() - 1;
|
||||
auto [n, m] = decompose_hadamard(out.shape(axis));
|
||||
|
||||
switch (in.dtype()) {
|
||||
case float32:
|
||||
return hadamard<float>(out, n, m, scale_);
|
||||
return hadamard<float>(out, n, m, scale_, stream());
|
||||
case float16:
|
||||
return hadamard<float16_t>(out, n, m, scale_);
|
||||
return hadamard<float16_t>(out, n, m, scale_, stream());
|
||||
case bfloat16:
|
||||
return hadamard<bfloat16_t>(out, n, m, scale_);
|
||||
return hadamard<bfloat16_t>(out, n, m, scale_, stream());
|
||||
default:
|
||||
throw std::invalid_argument("[hadamard] Unsupported type.");
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -27,7 +28,8 @@ void gather(
|
||||
const std::vector<array>& inds,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& slice_sizes) {
|
||||
const Shape& slice_sizes,
|
||||
Stream stream) {
|
||||
// If the array is row contiguous then we can do a contiguous copy given
|
||||
// two conditions on the slice size:
|
||||
// - Any number of leading ones in the slice sizes are allowed
|
||||
@@ -73,38 +75,60 @@ void gather(
|
||||
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
|
||||
const T* src_ptr = src.data<T>();
|
||||
T* dst_ptr = out.data<T>();
|
||||
size_t out_idx = 0;
|
||||
|
||||
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
|
||||
ContiguousIterator src_it;
|
||||
if (!can_copy && src.ndim() > 0) {
|
||||
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
|
||||
}
|
||||
for (int idx = 0; idx < ind_size; idx++) {
|
||||
size_t src_idx = 0;
|
||||
for (int ii = 0; ii < inds.size(); ++ii) {
|
||||
auto ax = axes[ii];
|
||||
auto idx_loc = its[ii].loc;
|
||||
its[ii].step();
|
||||
auto idx_val =
|
||||
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
|
||||
src_idx += (idx_val * src.strides()[ax]);
|
||||
}
|
||||
|
||||
if (slice_size == 1) {
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx];
|
||||
} else if (can_copy) {
|
||||
std::copy(
|
||||
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
|
||||
out_idx += slice_size;
|
||||
} else {
|
||||
for (int jj = 0; jj < slice_size; jj++) {
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
|
||||
src_it.step();
|
||||
}
|
||||
src_it.reset();
|
||||
}
|
||||
std::vector<const IdxT*> ind_ptrs;
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
for (auto& idx : inds) {
|
||||
ind_ptrs.push_back(idx.data<IdxT>());
|
||||
encoder.set_input_array(idx);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([src_ptr,
|
||||
dst_ptr,
|
||||
ind_ptrs = std::move(ind_ptrs),
|
||||
axes,
|
||||
ind_size,
|
||||
slice_size,
|
||||
src_shape = src.shape(),
|
||||
src_strides = src.strides(),
|
||||
src_it = std::move(src_it),
|
||||
its = std::move(its),
|
||||
can_copy]() mutable {
|
||||
size_t out_idx = 0;
|
||||
for (int idx = 0; idx < ind_size; idx++) {
|
||||
size_t src_idx = 0;
|
||||
for (int ii = 0; ii < ind_ptrs.size(); ++ii) {
|
||||
auto ax = axes[ii];
|
||||
auto idx_loc = its[ii].loc;
|
||||
its[ii].step();
|
||||
auto idx_val = offset_neg_idx(ind_ptrs[ii][idx_loc], src_shape[ax]);
|
||||
src_idx += (idx_val * src_strides[ax]);
|
||||
}
|
||||
|
||||
if (slice_size == 1) {
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx];
|
||||
} else if (can_copy) {
|
||||
std::copy(
|
||||
src_ptr + src_idx,
|
||||
src_ptr + src_idx + slice_size,
|
||||
dst_ptr + out_idx);
|
||||
out_idx += slice_size;
|
||||
} else {
|
||||
for (int jj = 0; jj < slice_size; jj++) {
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
|
||||
src_it.step();
|
||||
}
|
||||
src_it.reset();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename IdxT>
|
||||
@@ -113,49 +137,50 @@ void dispatch_gather(
|
||||
const std::vector<array>& inds,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& size) {
|
||||
const Shape& size,
|
||||
Stream stream) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
gather<bool, IdxT>(src, inds, out, axes, size);
|
||||
gather<bool, IdxT>(src, inds, out, axes, size, stream);
|
||||
break;
|
||||
case uint8:
|
||||
gather<uint8_t, IdxT>(src, inds, out, axes, size);
|
||||
gather<uint8_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
break;
|
||||
case uint16:
|
||||
gather<uint16_t, IdxT>(src, inds, out, axes, size);
|
||||
gather<uint16_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
break;
|
||||
case uint32:
|
||||
gather<uint32_t, IdxT>(src, inds, out, axes, size);
|
||||
gather<uint32_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
break;
|
||||
case uint64:
|
||||
gather<uint64_t, IdxT>(src, inds, out, axes, size);
|
||||
gather<uint64_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
break;
|
||||
case int8:
|
||||
gather<int8_t, IdxT>(src, inds, out, axes, size);
|
||||
gather<int8_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
break;
|
||||
case int16:
|
||||
gather<int16_t, IdxT>(src, inds, out, axes, size);
|
||||
gather<int16_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
break;
|
||||
case int32:
|
||||
gather<int32_t, IdxT>(src, inds, out, axes, size);
|
||||
gather<int32_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
break;
|
||||
case int64:
|
||||
gather<int64_t, IdxT>(src, inds, out, axes, size);
|
||||
gather<int64_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
break;
|
||||
case float16:
|
||||
gather<float16_t, IdxT>(src, inds, out, axes, size);
|
||||
gather<float16_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
break;
|
||||
case float32:
|
||||
gather<float, IdxT>(src, inds, out, axes, size);
|
||||
gather<float, IdxT>(src, inds, out, axes, size, stream);
|
||||
break;
|
||||
case float64:
|
||||
gather<double, IdxT>(src, inds, out, axes, size);
|
||||
gather<double, IdxT>(src, inds, out, axes, size, stream);
|
||||
break;
|
||||
case bfloat16:
|
||||
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
|
||||
gather<bfloat16_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
break;
|
||||
case complex64:
|
||||
gather<complex64_t, IdxT>(src, inds, out, axes, size);
|
||||
gather<complex64_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -167,34 +192,34 @@ void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
std::vector<array> inds(inputs.begin() + 1, inputs.end());
|
||||
|
||||
if (inds.empty()) {
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
return;
|
||||
}
|
||||
|
||||
switch (inds[0].dtype()) {
|
||||
case uint8:
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
|
||||
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
|
||||
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
|
||||
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
break;
|
||||
case int8:
|
||||
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
break;
|
||||
case int16:
|
||||
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
|
||||
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
break;
|
||||
case int32:
|
||||
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
|
||||
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
break;
|
||||
case int64:
|
||||
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
|
||||
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
@@ -207,7 +232,8 @@ void gather_axis(
|
||||
const array& src,
|
||||
const array& ind,
|
||||
array& out,
|
||||
const int axis) {
|
||||
const int axis,
|
||||
Stream stream) {
|
||||
auto strides = ind.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
auto shape = ind.shape();
|
||||
@@ -235,20 +261,39 @@ void gather_axis(
|
||||
for (int i = axis + 1; i < ind.ndim(); ++i) {
|
||||
size_post *= ind.shape(i);
|
||||
}
|
||||
size_t stride_pre = size_post * ind_ax_size;
|
||||
for (size_t i = 0; i < size_pre; i++) {
|
||||
for (size_t k = 0; k < size_post; k++) {
|
||||
for (int j = 0; j < ind_ax_size; ++j) {
|
||||
auto ind_val = offset_neg_idx(
|
||||
ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size);
|
||||
dst_ptr[k + j * dst_ax_stride] =
|
||||
src_ptr[src_it.loc + ind_val * src_ax_stride];
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_input_array(ind);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
encoder.dispatch([ind_ptr,
|
||||
src_ptr,
|
||||
dst_ptr,
|
||||
size_pre,
|
||||
size_post,
|
||||
ind_ax_size,
|
||||
src_ax_size,
|
||||
ind_ax_stride,
|
||||
src_ax_stride,
|
||||
dst_ax_stride,
|
||||
ind_it = std::move(ind_it),
|
||||
src_it = std::move(src_it)]() mutable {
|
||||
size_t stride_pre = size_post * ind_ax_size;
|
||||
for (size_t i = 0; i < size_pre; i++) {
|
||||
for (size_t k = 0; k < size_post; k++) {
|
||||
for (int j = 0; j < ind_ax_size; ++j) {
|
||||
auto ind_val = offset_neg_idx(
|
||||
ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size);
|
||||
dst_ptr[k + j * dst_ax_stride] =
|
||||
src_ptr[src_it.loc + ind_val * src_ax_stride];
|
||||
}
|
||||
ind_it.step();
|
||||
src_it.step();
|
||||
}
|
||||
ind_it.step();
|
||||
src_it.step();
|
||||
dst_ptr += stride_pre;
|
||||
}
|
||||
dst_ptr += stride_pre;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename IdxT>
|
||||
@@ -256,49 +301,50 @@ void dispatch_gather_axis(
|
||||
const array& src,
|
||||
const array& inds,
|
||||
array& out,
|
||||
const int axis) {
|
||||
const int axis,
|
||||
Stream stream) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
gather_axis<bool, IdxT>(src, inds, out, axis);
|
||||
gather_axis<bool, IdxT>(src, inds, out, axis, stream);
|
||||
break;
|
||||
case uint8:
|
||||
gather_axis<uint8_t, IdxT>(src, inds, out, axis);
|
||||
gather_axis<uint8_t, IdxT>(src, inds, out, axis, stream);
|
||||
break;
|
||||
case uint16:
|
||||
gather_axis<uint16_t, IdxT>(src, inds, out, axis);
|
||||
gather_axis<uint16_t, IdxT>(src, inds, out, axis, stream);
|
||||
break;
|
||||
case uint32:
|
||||
gather_axis<uint32_t, IdxT>(src, inds, out, axis);
|
||||
gather_axis<uint32_t, IdxT>(src, inds, out, axis, stream);
|
||||
break;
|
||||
case uint64:
|
||||
gather_axis<uint64_t, IdxT>(src, inds, out, axis);
|
||||
gather_axis<uint64_t, IdxT>(src, inds, out, axis, stream);
|
||||
break;
|
||||
case int8:
|
||||
gather_axis<int8_t, IdxT>(src, inds, out, axis);
|
||||
gather_axis<int8_t, IdxT>(src, inds, out, axis, stream);
|
||||
break;
|
||||
case int16:
|
||||
gather_axis<int16_t, IdxT>(src, inds, out, axis);
|
||||
gather_axis<int16_t, IdxT>(src, inds, out, axis, stream);
|
||||
break;
|
||||
case int32:
|
||||
gather_axis<int32_t, IdxT>(src, inds, out, axis);
|
||||
gather_axis<int32_t, IdxT>(src, inds, out, axis, stream);
|
||||
break;
|
||||
case int64:
|
||||
gather_axis<int64_t, IdxT>(src, inds, out, axis);
|
||||
gather_axis<int64_t, IdxT>(src, inds, out, axis, stream);
|
||||
break;
|
||||
case float16:
|
||||
gather_axis<float16_t, IdxT>(src, inds, out, axis);
|
||||
gather_axis<float16_t, IdxT>(src, inds, out, axis, stream);
|
||||
break;
|
||||
case float32:
|
||||
gather_axis<float, IdxT>(src, inds, out, axis);
|
||||
gather_axis<float, IdxT>(src, inds, out, axis, stream);
|
||||
break;
|
||||
case float64:
|
||||
gather_axis<double, IdxT>(src, inds, out, axis);
|
||||
gather_axis<double, IdxT>(src, inds, out, axis, stream);
|
||||
break;
|
||||
case bfloat16:
|
||||
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
|
||||
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis, stream);
|
||||
break;
|
||||
case complex64:
|
||||
gather_axis<complex64_t, IdxT>(src, inds, out, axis);
|
||||
gather_axis<complex64_t, IdxT>(src, inds, out, axis, stream);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -309,28 +355,28 @@ void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& inds = inputs[1];
|
||||
switch (inds.dtype()) {
|
||||
case uint8:
|
||||
dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
|
||||
dispatch_gather_axis<uint8_t>(src, inds, out, axis_, stream());
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_gather_axis<uint16_t>(src, inds, out, axis_);
|
||||
dispatch_gather_axis<uint16_t>(src, inds, out, axis_, stream());
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_gather_axis<uint32_t>(src, inds, out, axis_);
|
||||
dispatch_gather_axis<uint32_t>(src, inds, out, axis_, stream());
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_gather_axis<uint64_t>(src, inds, out, axis_);
|
||||
dispatch_gather_axis<uint64_t>(src, inds, out, axis_, stream());
|
||||
break;
|
||||
case int8:
|
||||
dispatch_gather_axis<int8_t>(src, inds, out, axis_);
|
||||
dispatch_gather_axis<int8_t>(src, inds, out, axis_, stream());
|
||||
break;
|
||||
case int16:
|
||||
dispatch_gather_axis<int16_t>(src, inds, out, axis_);
|
||||
dispatch_gather_axis<int16_t>(src, inds, out, axis_, stream());
|
||||
break;
|
||||
case int32:
|
||||
dispatch_gather_axis<int32_t>(src, inds, out, axis_);
|
||||
dispatch_gather_axis<int32_t>(src, inds, out, axis_, stream());
|
||||
break;
|
||||
case int64:
|
||||
dispatch_gather_axis<int64_t>(src, inds, out, axis_);
|
||||
dispatch_gather_axis<int64_t>(src, inds, out, axis_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
@@ -345,7 +391,8 @@ void scatter(
|
||||
array& out,
|
||||
const std::vector<array>& inds,
|
||||
const std::vector<int>& axes,
|
||||
const OpT& op) {
|
||||
const OpT& op,
|
||||
Stream stream) {
|
||||
int nind = inds.size();
|
||||
auto inds_ndim = updates.ndim() - out.ndim();
|
||||
size_t n_updates = nind ? inds[0].size() : 1;
|
||||
@@ -361,26 +408,45 @@ void scatter(
|
||||
ContiguousIterator update_it(updates);
|
||||
ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
|
||||
|
||||
for (int i = 0; i < n_updates; ++i) {
|
||||
size_t out_offset = 0;
|
||||
for (int j = 0; j < nind; ++j) {
|
||||
auto ax = axes[j];
|
||||
auto idx_loc = its[j].loc;
|
||||
its[j].step();
|
||||
auto idx_val =
|
||||
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
|
||||
out_offset += (idx_val * out.strides()[ax]);
|
||||
}
|
||||
update_it.seek(i * update_size);
|
||||
for (int j = 0; j < update_size; ++j) {
|
||||
op(updates.data<InT>()[update_it.loc],
|
||||
out.data<InT>() + out_offset + out_it.loc);
|
||||
update_it.step();
|
||||
out_it.step();
|
||||
}
|
||||
out_it.reset();
|
||||
update_it.reset();
|
||||
std::vector<const IdxT*> ind_ptrs;
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(updates);
|
||||
for (auto& idx : inds) {
|
||||
ind_ptrs.push_back(idx.data<IdxT>());
|
||||
encoder.set_input_array(idx);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out_ptr = out.data<InT>(),
|
||||
upd_ptr = updates.data<InT>(),
|
||||
ind_ptrs = std::move(ind_ptrs),
|
||||
axes,
|
||||
n_updates,
|
||||
update_size,
|
||||
op = std::move(op),
|
||||
out_shape = out.shape(),
|
||||
out_strides = out.strides(),
|
||||
out_it = std::move(out_it),
|
||||
update_it = std::move(update_it),
|
||||
its = std::move(its)]() mutable {
|
||||
for (int i = 0; i < n_updates; ++i) {
|
||||
size_t out_offset = 0;
|
||||
for (int j = 0; j < ind_ptrs.size(); ++j) {
|
||||
auto ax = axes[j];
|
||||
auto idx_loc = its[j].loc;
|
||||
its[j].step();
|
||||
auto idx_val = offset_neg_idx(ind_ptrs[j][idx_loc], out_shape[ax]);
|
||||
out_offset += (idx_val * out_strides[ax]);
|
||||
}
|
||||
update_it.seek(i * update_size);
|
||||
for (int j = 0; j < update_size; ++j) {
|
||||
op(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
|
||||
update_it.step();
|
||||
out_it.step();
|
||||
}
|
||||
out_it.reset();
|
||||
update_it.reset();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename InT, typename IdxT>
|
||||
@@ -389,29 +455,53 @@ void dispatch_scatter_inds(
|
||||
const std::vector<array>& indices,
|
||||
const array& updates,
|
||||
const std::vector<int>& axes,
|
||||
Scatter::ReduceType rtype) {
|
||||
Scatter::ReduceType rtype,
|
||||
Stream stream) {
|
||||
switch (rtype) {
|
||||
case Scatter::None:
|
||||
scatter<InT, IdxT>(
|
||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) = x; });
|
||||
updates,
|
||||
out,
|
||||
indices,
|
||||
axes,
|
||||
[](auto x, auto* y) { (*y) = x; },
|
||||
stream);
|
||||
break;
|
||||
case Scatter::Sum:
|
||||
scatter<InT, IdxT>(
|
||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) += x; });
|
||||
updates,
|
||||
out,
|
||||
indices,
|
||||
axes,
|
||||
[](auto x, auto* y) { (*y) += x; },
|
||||
stream);
|
||||
break;
|
||||
case Scatter::Prod:
|
||||
scatter<InT, IdxT>(
|
||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) *= x; });
|
||||
updates,
|
||||
out,
|
||||
indices,
|
||||
axes,
|
||||
[](auto x, auto* y) { (*y) *= x; },
|
||||
stream);
|
||||
break;
|
||||
case Scatter::Max:
|
||||
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
|
||||
(*y) = (*y > x) ? *y : x;
|
||||
});
|
||||
scatter<InT, IdxT>(
|
||||
updates,
|
||||
out,
|
||||
indices,
|
||||
axes,
|
||||
[](auto x, auto* y) { (*y) = (*y > x) ? *y : x; },
|
||||
stream);
|
||||
break;
|
||||
case Scatter::Min:
|
||||
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
|
||||
(*y) = (*y < x) ? *y : x;
|
||||
});
|
||||
scatter<InT, IdxT>(
|
||||
updates,
|
||||
out,
|
||||
indices,
|
||||
axes,
|
||||
[](auto x, auto* y) { (*y) = (*y < x) ? *y : x; },
|
||||
stream);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -422,36 +512,46 @@ void dispatch_scatter(
|
||||
const std::vector<array>& inds,
|
||||
const array& updates,
|
||||
const std::vector<int>& axes,
|
||||
Scatter::ReduceType rtype) {
|
||||
Scatter::ReduceType rtype,
|
||||
Stream stream) {
|
||||
if (inds.empty()) {
|
||||
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
|
||||
dispatch_scatter_inds<InT, uint8_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (inds[0].dtype()) {
|
||||
case uint8:
|
||||
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
|
||||
dispatch_scatter_inds<InT, uint8_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter_inds<InT, uint16_t>(out, inds, updates, axes, rtype);
|
||||
dispatch_scatter_inds<InT, uint16_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter_inds<InT, uint32_t>(out, inds, updates, axes, rtype);
|
||||
dispatch_scatter_inds<InT, uint32_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter_inds<InT, uint64_t>(out, inds, updates, axes, rtype);
|
||||
dispatch_scatter_inds<InT, uint64_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter_inds<InT, int8_t>(out, inds, updates, axes, rtype);
|
||||
dispatch_scatter_inds<InT, int8_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter_inds<InT, int16_t>(out, inds, updates, axes, rtype);
|
||||
dispatch_scatter_inds<InT, int16_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter_inds<InT, int32_t>(out, inds, updates, axes, rtype);
|
||||
dispatch_scatter_inds<InT, int32_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
|
||||
dispatch_scatter_inds<InT, int64_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
@@ -469,50 +569,63 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
auto ctype =
|
||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(src, out, ctype);
|
||||
copy(src, out, ctype, stream());
|
||||
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
|
||||
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_, stream());
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
|
||||
dispatch_scatter<uint8_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
dispatch_scatter<uint16_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
|
||||
dispatch_scatter<uint32_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
dispatch_scatter<uint64_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
|
||||
dispatch_scatter<int8_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
dispatch_scatter<int16_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
|
||||
dispatch_scatter<int32_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
dispatch_scatter<int64_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
break;
|
||||
case float16:
|
||||
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
dispatch_scatter<float16_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
break;
|
||||
case float32:
|
||||
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
|
||||
dispatch_scatter<float>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
break;
|
||||
case float64:
|
||||
dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
|
||||
dispatch_scatter<double>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
dispatch_scatter<bfloat16_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
dispatch_scatter<complex64_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -523,7 +636,8 @@ void scatter_axis(
|
||||
const array idx,
|
||||
const array& upd,
|
||||
int axis,
|
||||
const OpT& op) {
|
||||
const OpT& op,
|
||||
Stream stream) {
|
||||
auto strides = idx.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
auto shape = idx.shape();
|
||||
@@ -543,6 +657,11 @@ void scatter_axis(
|
||||
auto idx_ax_size = idx.shape(axis);
|
||||
auto dst_ax_size = out.shape(axis);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(idx);
|
||||
encoder.set_input_array(upd);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
size_t size_pre = 1;
|
||||
size_t size_post = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
@@ -551,20 +670,34 @@ void scatter_axis(
|
||||
for (int i = axis + 1; i < idx.ndim(); ++i) {
|
||||
size_post *= idx.shape(i);
|
||||
}
|
||||
size_t stride_pre = size_post * dst_ax_size;
|
||||
for (size_t i = 0; i < size_pre; i++) {
|
||||
for (size_t k = 0; k < size_post; k++) {
|
||||
for (int j = 0; j < idx_ax_size; ++j) {
|
||||
auto ind_val = offset_neg_idx(
|
||||
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
|
||||
op(upd_ptr[upd_it.loc + j * upd_ax_stride],
|
||||
dst_ptr + k + ind_val * dst_ax_stride);
|
||||
encoder.dispatch([idx_ptr,
|
||||
upd_ptr,
|
||||
dst_ptr,
|
||||
size_pre,
|
||||
size_post,
|
||||
idx_ax_size,
|
||||
dst_ax_size,
|
||||
idx_ax_stride,
|
||||
upd_ax_stride,
|
||||
dst_ax_stride,
|
||||
idx_it = std::move(idx_it),
|
||||
upd_it = std::move(upd_it),
|
||||
op = std::move(op)]() mutable {
|
||||
size_t stride_pre = size_post * dst_ax_size;
|
||||
for (size_t i = 0; i < size_pre; i++) {
|
||||
for (size_t k = 0; k < size_post; k++) {
|
||||
for (int j = 0; j < idx_ax_size; ++j) {
|
||||
auto ind_val = offset_neg_idx(
|
||||
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
|
||||
op(upd_ptr[upd_it.loc + j * upd_ax_stride],
|
||||
dst_ptr + k + ind_val * dst_ax_stride);
|
||||
}
|
||||
idx_it.step();
|
||||
upd_it.step();
|
||||
}
|
||||
idx_it.step();
|
||||
upd_it.step();
|
||||
dst_ptr += stride_pre;
|
||||
}
|
||||
dst_ptr += stride_pre;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename InT, typename IdxT>
|
||||
@@ -573,15 +706,16 @@ void dispatch_scatter_axis_op(
|
||||
const array& idx,
|
||||
const array& updates,
|
||||
int axis,
|
||||
ScatterAxis::ReduceType rtype) {
|
||||
ScatterAxis::ReduceType rtype,
|
||||
Stream stream) {
|
||||
switch (rtype) {
|
||||
case ScatterAxis::None:
|
||||
scatter_axis<InT, IdxT>(
|
||||
out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; });
|
||||
out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; }, stream);
|
||||
break;
|
||||
case ScatterAxis::Sum:
|
||||
scatter_axis<InT, IdxT>(
|
||||
out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; });
|
||||
out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; }, stream);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -592,31 +726,40 @@ void dispatch_scatter_axis(
|
||||
const array& idx,
|
||||
const array& updates,
|
||||
int axis,
|
||||
ScatterAxis::ReduceType rtype) {
|
||||
ScatterAxis::ReduceType rtype,
|
||||
Stream stream) {
|
||||
switch (idx.dtype()) {
|
||||
case uint8:
|
||||
dispatch_scatter_axis_op<InT, uint8_t>(out, idx, updates, axis, rtype);
|
||||
dispatch_scatter_axis_op<InT, uint8_t>(
|
||||
out, idx, updates, axis, rtype, stream);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter_axis_op<InT, uint16_t>(out, idx, updates, axis, rtype);
|
||||
dispatch_scatter_axis_op<InT, uint16_t>(
|
||||
out, idx, updates, axis, rtype, stream);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter_axis_op<InT, uint32_t>(out, idx, updates, axis, rtype);
|
||||
dispatch_scatter_axis_op<InT, uint32_t>(
|
||||
out, idx, updates, axis, rtype, stream);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter_axis_op<InT, uint64_t>(out, idx, updates, axis, rtype);
|
||||
dispatch_scatter_axis_op<InT, uint64_t>(
|
||||
out, idx, updates, axis, rtype, stream);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter_axis_op<InT, int8_t>(out, idx, updates, axis, rtype);
|
||||
dispatch_scatter_axis_op<InT, int8_t>(
|
||||
out, idx, updates, axis, rtype, stream);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter_axis_op<InT, int16_t>(out, idx, updates, axis, rtype);
|
||||
dispatch_scatter_axis_op<InT, int16_t>(
|
||||
out, idx, updates, axis, rtype, stream);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter_axis_op<InT, int32_t>(out, idx, updates, axis, rtype);
|
||||
dispatch_scatter_axis_op<InT, int32_t>(
|
||||
out, idx, updates, axis, rtype, stream);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter_axis_op<InT, int64_t>(out, idx, updates, axis, rtype);
|
||||
dispatch_scatter_axis_op<InT, int64_t>(
|
||||
out, idx, updates, axis, rtype, stream);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
@@ -634,51 +777,64 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
auto ctype =
|
||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(src, out, ctype);
|
||||
copy(src, out, ctype, stream());
|
||||
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
|
||||
dispatch_scatter_axis<bool>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);
|
||||
dispatch_scatter_axis<uint8_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
dispatch_scatter_axis<uint16_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);
|
||||
dispatch_scatter_axis<uint32_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);
|
||||
dispatch_scatter_axis<uint64_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);
|
||||
dispatch_scatter_axis<int8_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
dispatch_scatter_axis<int16_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);
|
||||
dispatch_scatter_axis<int32_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
|
||||
dispatch_scatter_axis<int64_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
break;
|
||||
case float16:
|
||||
dispatch_scatter_axis<float16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
dispatch_scatter_axis<float16_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
break;
|
||||
case float32:
|
||||
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
|
||||
dispatch_scatter_axis<float>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
break;
|
||||
case float64:
|
||||
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
|
||||
dispatch_scatter_axis<double>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
dispatch_scatter_axis<bfloat16_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
dispatch_scatter_axis<complex64_t>(
|
||||
out, idx, updates, axis_, reduce_type_);
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,20 +2,21 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void general_inv(array& inv, int N, int i) {
|
||||
void general_inv(T* inv, int N) {
|
||||
int info;
|
||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||
// Compute LU factorization.
|
||||
getrf<T>(
|
||||
/* m = */ &N,
|
||||
/* n = */ &N,
|
||||
/* a = */ inv.data<T>() + N * N * i,
|
||||
/* a = */ inv,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
@@ -53,7 +54,7 @@ void general_inv(array& inv, int N, int i) {
|
||||
// Compute inverse.
|
||||
getri<T>(
|
||||
/* m = */ &N,
|
||||
/* a = */ inv.data<T>() + N * N * i,
|
||||
/* a = */ inv,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
||||
@@ -68,29 +69,28 @@ void general_inv(array& inv, int N, int i) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void tri_inv(array& inv, int N, int i, bool upper) {
|
||||
void tri_inv(T* inv, int N, bool upper) {
|
||||
const char uplo = upper ? 'L' : 'U';
|
||||
const char diag = 'N';
|
||||
T* data = inv.data<T>() + N * N * i;
|
||||
int info;
|
||||
trtri<T>(
|
||||
/* uplo = */ &uplo,
|
||||
/* diag = */ &diag,
|
||||
/* N = */ &N,
|
||||
/* a = */ data,
|
||||
/* a = */ inv,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
|
||||
// zero out the other triangle
|
||||
if (upper) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
std::fill(data, data + i, 0.0f);
|
||||
data += N;
|
||||
std::fill(inv, inv + i, 0.0f);
|
||||
inv += N;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; i++) {
|
||||
std::fill(data + i + 1, data + N, 0.0f);
|
||||
data += N;
|
||||
std::fill(inv + i + 1, inv + N, 0.0f);
|
||||
inv += N;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,34 +103,53 @@ void tri_inv(array& inv, int N, int i, bool upper) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
||||
void inverse_impl(
|
||||
const array& a,
|
||||
array& inv,
|
||||
bool tri,
|
||||
bool upper,
|
||||
Stream stream) {
|
||||
// Lapack uses the column-major convention. We take advantage of the following
|
||||
// identity to avoid transposing (see
|
||||
// https://math.stackexchange.com/a/340234):
|
||||
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
||||
|
||||
// The inverse is computed in place, so just copy the input to the output.
|
||||
copy(a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
copy(
|
||||
a,
|
||||
inv,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream);
|
||||
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
if (tri) {
|
||||
tri_inv<T>(inv, N, i, upper);
|
||||
} else {
|
||||
general_inv<T>(inv, N, i);
|
||||
}
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(inv);
|
||||
|
||||
auto inv_ptr = inv.data<T>();
|
||||
if (tri) {
|
||||
encoder.dispatch([inv_ptr, N, num_matrices, upper]() {
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
tri_inv<T>(inv_ptr + N * N * i, N, upper);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
encoder.dispatch([inv_ptr, N, num_matrices]() {
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
general_inv<T>(inv_ptr + N * N * i, N);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) {
|
||||
switch (inputs[0].dtype()) {
|
||||
case float32:
|
||||
inverse_impl<float>(inputs[0], output, tri_, upper_);
|
||||
inverse_impl<float>(inputs[0], output, tri_, upper_, stream());
|
||||
break;
|
||||
case float64:
|
||||
inverse_impl<double>(inputs[0], output, tri_, upper_);
|
||||
inverse_impl<double>(inputs[0], output, tri_, upper_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
|
||||
@@ -4,15 +4,22 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void luf_impl(const array& a, array& lu, array& pivots, array& row_indices) {
|
||||
void luf_impl(
|
||||
const array& a,
|
||||
array& lu,
|
||||
array& pivots,
|
||||
array& row_indices,
|
||||
Stream stream) {
|
||||
int M = a.shape(-2);
|
||||
int N = a.shape(-1);
|
||||
int K = std::min(M, N);
|
||||
|
||||
// Copy a into lu and make it col contiguous
|
||||
auto ndim = lu.ndim();
|
||||
@@ -26,57 +33,72 @@ void luf_impl(const array& a, array& lu, array& pivots, array& row_indices) {
|
||||
lu.set_data(
|
||||
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||
copy_inplace(
|
||||
a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral);
|
||||
a,
|
||||
lu,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
strides,
|
||||
0,
|
||||
0,
|
||||
CopyType::GeneralGeneral,
|
||||
stream);
|
||||
|
||||
auto a_ptr = lu.data<T>();
|
||||
|
||||
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
|
||||
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
|
||||
auto pivots_ptr = pivots.data<uint32_t>();
|
||||
auto row_indices_ptr = row_indices.data<uint32_t>();
|
||||
|
||||
int info;
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
for (size_t i = 0; i < num_matrices; ++i) {
|
||||
// Compute LU factorization of A
|
||||
getrf<T>(
|
||||
/* m */ &M,
|
||||
/* n */ &N,
|
||||
/* a */ a_ptr,
|
||||
/* lda */ &M,
|
||||
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
|
||||
/* info */ &info);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(lu);
|
||||
encoder.set_output_array(pivots);
|
||||
encoder.set_output_array(row_indices);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
|
||||
<< ((info > 0) ? " because matrix is singular"
|
||||
: " because argument had an illegal value");
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
encoder.dispatch(
|
||||
[a_ptr, pivots_ptr, row_indices_ptr, num_matrices, M, N, K]() mutable {
|
||||
int info;
|
||||
for (size_t i = 0; i < num_matrices; ++i) {
|
||||
// Compute LU factorization of A
|
||||
getrf<T>(
|
||||
/* m */ &M,
|
||||
/* n */ &N,
|
||||
/* a */ a_ptr,
|
||||
/* lda */ &M,
|
||||
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
|
||||
/* info */ &info);
|
||||
|
||||
// Subtract 1 to get 0-based index
|
||||
int j = 0;
|
||||
for (; j < pivots.shape(-1); ++j) {
|
||||
pivots_ptr[j]--;
|
||||
row_indices_ptr[j] = j;
|
||||
}
|
||||
for (; j < row_indices.shape(-1); ++j) {
|
||||
row_indices_ptr[j] = j;
|
||||
}
|
||||
for (int j = pivots.shape(-1) - 1; j >= 0; --j) {
|
||||
auto piv = pivots_ptr[j];
|
||||
auto t1 = row_indices_ptr[piv];
|
||||
auto t2 = row_indices_ptr[j];
|
||||
row_indices_ptr[j] = t1;
|
||||
row_indices_ptr[piv] = t2;
|
||||
}
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
|
||||
<< ((info > 0) ? " because matrix is singular"
|
||||
: " because argument had an illegal value");
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
// Advance pointers to the next matrix
|
||||
a_ptr += M * N;
|
||||
pivots_ptr += pivots.shape(-1);
|
||||
row_indices_ptr += pivots.shape(-1);
|
||||
}
|
||||
// Subtract 1 to get 0-based index
|
||||
int j = 0;
|
||||
for (; j < K; ++j) {
|
||||
pivots_ptr[j]--;
|
||||
row_indices_ptr[j] = j;
|
||||
}
|
||||
for (; j < M; ++j) {
|
||||
row_indices_ptr[j] = j;
|
||||
}
|
||||
for (int j = K - 1; j >= 0; --j) {
|
||||
auto piv = pivots_ptr[j];
|
||||
auto t1 = row_indices_ptr[piv];
|
||||
auto t2 = row_indices_ptr[j];
|
||||
row_indices_ptr[j] = t1;
|
||||
row_indices_ptr[piv] = t2;
|
||||
}
|
||||
|
||||
// Advance pointers to the next matrix
|
||||
a_ptr += M * N;
|
||||
pivots_ptr += K;
|
||||
row_indices_ptr += M;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void LUF::eval_cpu(
|
||||
@@ -85,10 +107,10 @@ void LUF::eval_cpu(
|
||||
assert(inputs.size() == 1);
|
||||
switch (inputs[0].dtype()) {
|
||||
case float32:
|
||||
luf_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
luf_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2], stream());
|
||||
break;
|
||||
case float64:
|
||||
luf_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
luf_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2], stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -64,36 +65,36 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b_pre = inputs[1];
|
||||
|
||||
auto check_transpose =
|
||||
[](const array& arr, bool do_copy, bool expand_all = false) {
|
||||
[s = stream()](const array& arr, bool do_copy, bool expand_all = false) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
copy(arr, arr_copy, CopyType::Vector, s);
|
||||
return std::make_tuple(false, stx, arr_copy, true);
|
||||
}
|
||||
return std::make_tuple(false, stx, arr);
|
||||
return std::make_tuple(false, stx, arr, false);
|
||||
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(true, sty, arr_copy);
|
||||
copy(arr, arr_copy, CopyType::Vector, s);
|
||||
return std::make_tuple(true, sty, arr_copy, true);
|
||||
}
|
||||
return std::make_tuple(true, sty, arr);
|
||||
return std::make_tuple(true, sty, arr, false);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
copy(arr, arr_copy, CopyType::General, s);
|
||||
int64_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
return std::make_tuple(false, stx, arr_copy, true);
|
||||
}
|
||||
};
|
||||
|
||||
bool has_op_mask = inputs.size() > 3;
|
||||
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
|
||||
auto [a_transposed, lda, a] =
|
||||
auto [a_transposed, lda, a, a_copied] =
|
||||
check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_);
|
||||
auto [b_transposed, ldb, b] =
|
||||
auto [b_transposed, ldb, b, b_copied] =
|
||||
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
|
||||
|
||||
size_t M = a.shape(-2);
|
||||
@@ -104,31 +105,39 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out_ptr = out.data<void>(), nbytes = out.nbytes()]() {
|
||||
std::memset(out_ptr, 0, nbytes);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
auto mask_array = [](const array& mask,
|
||||
auto mask_array = [](const void* mask,
|
||||
float* data,
|
||||
int block_size,
|
||||
int batch_idx,
|
||||
int X,
|
||||
int Y,
|
||||
size_t X_data_str,
|
||||
size_t Y_data_str) {
|
||||
size_t Y_data_str,
|
||||
const Shape& mask_shape,
|
||||
const Strides& mask_strides,
|
||||
bool is_bool) {
|
||||
auto ndim = mask_shape.size();
|
||||
auto mask_offset = elem_to_loc(
|
||||
mask.shape(-1) * mask.shape(-2) * batch_idx,
|
||||
mask.shape(),
|
||||
mask.strides());
|
||||
mask_shape[ndim - 1] * mask_shape[ndim - 2] * batch_idx,
|
||||
mask_shape,
|
||||
mask_strides);
|
||||
|
||||
auto X_mask_str = mask.strides()[mask.ndim() - 2];
|
||||
auto Y_mask_str = mask.strides()[mask.ndim() - 1];
|
||||
auto X_mask_str = mask_strides[ndim - 2];
|
||||
auto Y_mask_str = mask_strides[ndim - 1];
|
||||
|
||||
if (mask.dtype() == bool_) {
|
||||
if (is_bool) {
|
||||
return mask_matrix(
|
||||
data,
|
||||
mask.data<bool>(),
|
||||
static_cast<const bool*>(mask),
|
||||
block_size,
|
||||
X,
|
||||
Y,
|
||||
@@ -140,7 +149,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
return mask_matrix(
|
||||
data,
|
||||
mask.data<float>(),
|
||||
static_cast<const float*>(mask),
|
||||
block_size,
|
||||
X,
|
||||
Y,
|
||||
@@ -152,61 +161,155 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
};
|
||||
|
||||
for (int i = 0; i < (out.size() / (M * size_t(N))); ++i) {
|
||||
// Adjust pointer
|
||||
float* ai =
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||
float* bi =
|
||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides());
|
||||
float* ci = out.data<float>() + M * N * i;
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
const void* a_mask_ptr;
|
||||
const void* b_mask_ptr;
|
||||
const void* out_mask_ptr;
|
||||
Shape a_mask_shape;
|
||||
Shape b_mask_shape;
|
||||
Shape out_mask_shape;
|
||||
Strides a_mask_strides;
|
||||
Strides b_mask_strides;
|
||||
Strides out_mask_strides;
|
||||
bool a_mask_bool;
|
||||
bool b_mask_bool;
|
||||
bool out_mask_bool;
|
||||
if (has_op_mask) {
|
||||
auto& a_mask = inputs[inputs.size() - 2];
|
||||
auto& b_mask = inputs[inputs.size() - 1];
|
||||
a_mask_ptr = a_mask.data<void>();
|
||||
b_mask_ptr = b_mask.data<void>();
|
||||
a_mask_shape = a_mask.shape();
|
||||
b_mask_shape = b_mask.shape();
|
||||
a_mask_strides = a_mask.strides();
|
||||
b_mask_strides = b_mask.strides();
|
||||
a_mask_bool = (a_mask.dtype() == bool_);
|
||||
b_mask_bool = (b_mask.dtype() == bool_);
|
||||
encoder.set_input_array(a_mask);
|
||||
encoder.set_input_array(b_mask);
|
||||
}
|
||||
if (has_out_mask) {
|
||||
auto& out_mask = inputs[2];
|
||||
out_mask_ptr = out_mask.data<void>();
|
||||
out_mask_bool = (out_mask.dtype() == bool_);
|
||||
encoder.set_input_array(out_mask);
|
||||
out_mask_shape = out_mask.shape();
|
||||
out_mask_strides = out_mask.strides();
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
auto a_ptr = a.data<float>();
|
||||
auto b_ptr = b.data<float>();
|
||||
auto out_ptr = out.data<float>();
|
||||
size_t num_matrices = out.size() / (M * size_t(N));
|
||||
auto ldc = out.shape(-1);
|
||||
|
||||
// Zero out blocks in a and b if needed
|
||||
if (has_op_mask) {
|
||||
auto& a_mask = inputs[inputs.size() - 2];
|
||||
mask_array(
|
||||
a_mask,
|
||||
ai,
|
||||
block_size_,
|
||||
i,
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
a_mask_ptr,
|
||||
b_mask_ptr,
|
||||
out_mask_ptr,
|
||||
has_op_mask,
|
||||
has_out_mask,
|
||||
block_size = block_size_,
|
||||
num_matrices,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a_transposed = a_transposed,
|
||||
b_transposed = b_transposed,
|
||||
lda = lda,
|
||||
ldb = ldb,
|
||||
ldc,
|
||||
a_shape = a.shape(),
|
||||
a_strides = a.strides(),
|
||||
b_shape = b.shape(),
|
||||
b_strides = b.strides(),
|
||||
a_mask_shape = std::move(a_mask_shape),
|
||||
b_mask_shape = std::move(b_mask_shape),
|
||||
out_mask_shape = std::move(out_mask_shape),
|
||||
a_mask_strides = std::move(a_mask_strides),
|
||||
b_mask_strides = std::move(b_mask_strides),
|
||||
out_mask_strides = std::move(out_mask_strides),
|
||||
mask_array,
|
||||
a_mask_bool,
|
||||
b_mask_bool,
|
||||
out_mask_bool]() {
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// Adjust pointer
|
||||
float* ai = a_ptr + elem_to_loc(M * K * i, a_shape, a_strides);
|
||||
float* bi = b_ptr + elem_to_loc(K * N * i, b_shape, b_strides);
|
||||
float* ci = out_ptr + M * N * i;
|
||||
|
||||
// Zero out blocks in a and b if needed
|
||||
if (has_op_mask) {
|
||||
mask_array(
|
||||
a_mask_ptr,
|
||||
ai,
|
||||
block_size,
|
||||
i,
|
||||
M,
|
||||
K,
|
||||
a_transposed ? 1 : lda,
|
||||
a_transposed ? lda : 1,
|
||||
a_mask_shape,
|
||||
a_mask_strides,
|
||||
a_mask_bool);
|
||||
|
||||
mask_array(
|
||||
b_mask_ptr,
|
||||
bi,
|
||||
block_size,
|
||||
i,
|
||||
K,
|
||||
N,
|
||||
b_transposed ? 1 : ldb,
|
||||
b_transposed ? ldb : 1,
|
||||
b_mask_shape,
|
||||
b_mask_strides,
|
||||
b_mask_bool);
|
||||
}
|
||||
|
||||
// Do matmul
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
K,
|
||||
a_transposed ? 1 : lda,
|
||||
a_transposed ? lda : 1);
|
||||
|
||||
auto& b_mask = inputs[inputs.size() - 1];
|
||||
mask_array(
|
||||
b_mask,
|
||||
bi,
|
||||
block_size_,
|
||||
i,
|
||||
K,
|
||||
N,
|
||||
b_transposed ? 1 : ldb,
|
||||
b_transposed ? ldb : 1);
|
||||
}
|
||||
K,
|
||||
1.0, // alpha
|
||||
ai,
|
||||
lda,
|
||||
bi,
|
||||
ldb,
|
||||
0.0, // beta
|
||||
ci,
|
||||
ldc);
|
||||
|
||||
// Do matmul
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1.0, // alpha
|
||||
ai,
|
||||
lda,
|
||||
bi,
|
||||
ldb,
|
||||
0.0, // beta
|
||||
ci,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
|
||||
// Zero out blocks in out
|
||||
if (has_out_mask) {
|
||||
mask_array(inputs[2], ci, block_size_, i, M, N, N, 1);
|
||||
// Zero out blocks in out
|
||||
if (has_out_mask) {
|
||||
mask_array(
|
||||
out_mask_ptr,
|
||||
ci,
|
||||
block_size,
|
||||
i,
|
||||
M,
|
||||
N,
|
||||
N,
|
||||
1,
|
||||
out_mask_shape,
|
||||
out_mask_strides,
|
||||
out_mask_bool);
|
||||
}
|
||||
}
|
||||
});
|
||||
if (a_copied) {
|
||||
encoder.add_temporary(a);
|
||||
}
|
||||
if (b_copied) {
|
||||
encoder.add_temporary(b);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,7 +323,8 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
|
||||
auto check_transpose = [](const array& arr) {
|
||||
std::vector<array> temps;
|
||||
auto check_transpose = [s = stream(), &temps](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
@@ -228,10 +332,10 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy(arr, temps.back(), CopyType::General, s);
|
||||
int64_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
return std::make_tuple(false, stx, temps.back());
|
||||
}
|
||||
};
|
||||
|
||||
@@ -246,8 +350,12 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out_ptr = out.data<float>(), size = out.size()]() {
|
||||
std::fill_n(out_ptr, size, 0);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -272,29 +380,61 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
|
||||
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(lhs_indices);
|
||||
encoder.set_input_array(rhs_indices);
|
||||
encoder.set_output_array(out);
|
||||
auto ldc = out.shape(-1);
|
||||
|
||||
for (int i = 0; i < batch_size_out; i++) {
|
||||
// Get index
|
||||
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(i, lhs_indices)];
|
||||
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(i, rhs_indices)];
|
||||
encoder.dispatch([a_ptr = a.data<float>(),
|
||||
b_ptr = b.data<float>(),
|
||||
out_ptr = out.data<float>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda = lda,
|
||||
ldb = ldb,
|
||||
a_transposed = a_transposed,
|
||||
b_transposed = b_transposed,
|
||||
ldc,
|
||||
lhs_indices_ptr,
|
||||
rhs_indices_ptr,
|
||||
lhs_indices_shape = lhs_indices.shape(),
|
||||
lhs_indices_strides = lhs_indices.strides(),
|
||||
rhs_indices_shape = rhs_indices.shape(),
|
||||
rhs_indices_strides = rhs_indices.strides(),
|
||||
batch_size_out,
|
||||
matrix_stride_out,
|
||||
batch_shape_A = std::move(batch_shape_A),
|
||||
batch_shape_B = std::move(batch_shape_B),
|
||||
batch_strides_A = std::move(batch_strides_A),
|
||||
batch_strides_B = std::move(batch_strides_B)]() {
|
||||
for (int i = 0; i < batch_size_out; i++) {
|
||||
// Get index
|
||||
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(
|
||||
i, lhs_indices_shape, lhs_indices_strides)];
|
||||
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(
|
||||
i, rhs_indices_shape, rhs_indices_strides)];
|
||||
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1.0f, // alpha
|
||||
a.data<float>() + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
|
||||
lda,
|
||||
b.data<float>() + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
|
||||
ldb,
|
||||
0.0f, // beta
|
||||
out.data<float>() + matrix_stride_out * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
}
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1.0f, // alpha
|
||||
a_ptr + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
|
||||
lda,
|
||||
b_ptr + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
|
||||
ldb,
|
||||
0.0f, // beta
|
||||
out_ptr + matrix_stride_out * i,
|
||||
ldc);
|
||||
}
|
||||
});
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -3,18 +3,76 @@
|
||||
#include <cstring>
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void matmul_dispatch(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
float alpha,
|
||||
float beta,
|
||||
Stream stream) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
size_t ldc = out.shape(-1);
|
||||
size_t batch_size = a.size() / (a.shape(-2) * a.shape(-1));
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
alpha,
|
||||
beta,
|
||||
batch_size,
|
||||
a_shape = a.shape(),
|
||||
a_strides = a.strides(),
|
||||
b_shape = b.shape(),
|
||||
b_strides = b.strides()]() {
|
||||
matmul<T>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
alpha,
|
||||
beta,
|
||||
batch_size,
|
||||
a_shape,
|
||||
a_strides,
|
||||
b_shape,
|
||||
b_strides);
|
||||
});
|
||||
}
|
||||
|
||||
void matmul_general(
|
||||
const array& a_pre,
|
||||
const array& b_pre,
|
||||
array& out,
|
||||
Stream stream,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
auto check_transpose = [](const array& arr) {
|
||||
std::vector<array> temps;
|
||||
auto check_transpose = [stream, &temps](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
@@ -22,10 +80,10 @@ void matmul_general(
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy(arr, temps.back(), CopyType::General, stream);
|
||||
stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
return std::make_tuple(false, stx, temps.back());
|
||||
}
|
||||
};
|
||||
|
||||
@@ -39,28 +97,34 @@ void matmul_general(
|
||||
}
|
||||
|
||||
if (out.dtype() == float32) {
|
||||
matmul<float>(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
matmul_dispatch<float>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
|
||||
} else if (out.dtype() == float16) {
|
||||
matmul<float16_t>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
matmul_dispatch<float16_t>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
matmul<bfloat16_t>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
matmul_dispatch<bfloat16_t>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
|
||||
} else if (out.dtype() == float64) {
|
||||
matmul<double>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
matmul_dispatch<double>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
|
||||
} else {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
|
||||
}
|
||||
cpu::get_command_encoder(stream).add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (inputs[0].shape(-1) == 0) {
|
||||
std::memset(out.data<void>(), 0, out.nbytes());
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out_ptr = out.data<void>(), nbytes = out.nbytes()]() {
|
||||
std::memset(out_ptr, 0, nbytes);
|
||||
});
|
||||
return;
|
||||
}
|
||||
return matmul_general(inputs[0], inputs[1], out);
|
||||
matmul_general(inputs[0], inputs[1], out, stream());
|
||||
}
|
||||
|
||||
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -74,9 +138,9 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
CopyType ctype = c.data_size() == 1
|
||||
? CopyType::Scalar
|
||||
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
copy(c, out, ctype);
|
||||
copy(c, out, ctype, stream());
|
||||
|
||||
return matmul_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -7,11 +7,11 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/arange.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/threefry.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -22,39 +22,58 @@ void reshape(const array& in, array& out) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_inplace(in, out, CopyType::General);
|
||||
copy_inplace(in, out, CopyType::General, out.primitive().stream());
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t compute_dynamic_offset(
|
||||
static std::pair<array, bool> compute_dynamic_offset(
|
||||
const array& indices,
|
||||
const Strides& strides,
|
||||
const std::vector<int>& axes) {
|
||||
auto compute_offset = [&strides, &axes](const auto* indices) {
|
||||
int64_t offset = 0;
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
offset += indices[i] * strides[axes[i]];
|
||||
}
|
||||
return offset;
|
||||
};
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
array offset({1}, int64, nullptr, {});
|
||||
bool donate = indices.is_donatable() &&
|
||||
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
|
||||
if (donate) {
|
||||
offset.copy_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(indices);
|
||||
encoder.set_output_array(offset);
|
||||
auto compute_offset =
|
||||
[strides, axes, offset = offset.data<int64_t>()](const auto* indices) {
|
||||
int64_t offset_ = 0;
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
offset_ += indices[i] * strides[axes[i]];
|
||||
}
|
||||
offset[0] = offset_;
|
||||
};
|
||||
switch (indices.dtype()) {
|
||||
case int8:
|
||||
case uint8:
|
||||
return compute_offset(indices.data<uint8_t>());
|
||||
encoder.dispatch(compute_offset, indices.data<uint8_t>());
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
return compute_offset(indices.data<uint16_t>());
|
||||
encoder.dispatch(compute_offset, indices.data<uint16_t>());
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
return compute_offset(indices.data<uint32_t>());
|
||||
encoder.dispatch(compute_offset, indices.data<uint32_t>());
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
return compute_offset(indices.data<uint64_t>());
|
||||
encoder.dispatch(compute_offset, indices.data<uint64_t>());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Invalid indices type.");
|
||||
}
|
||||
return {offset, donate};
|
||||
}
|
||||
|
||||
void AsStrided::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -104,14 +123,59 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
arange(inputs, out, start_, step_);
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
throw std::runtime_error("Bool type unsupported for arange.");
|
||||
break;
|
||||
case uint8:
|
||||
arange<uint8_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case uint16:
|
||||
arange<uint16_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case uint32:
|
||||
arange<uint32_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case uint64:
|
||||
arange<uint64_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case int8:
|
||||
arange<int8_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case int16:
|
||||
arange<int16_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case int32:
|
||||
arange<int32_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case int64:
|
||||
arange<int64_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case float16:
|
||||
arange<float16_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case float32:
|
||||
arange<float>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case float64:
|
||||
arange<double>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
arange<bfloat16_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case complex64:
|
||||
arange<complex64_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype);
|
||||
copy(in, out, ctype, stream());
|
||||
}
|
||||
|
||||
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -134,7 +198,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
size_t data_offset = strides[axis_] * sizes[i];
|
||||
out_slice.copy_shared_buffer(
|
||||
out, strides, flags, out_slice.size(), data_offset);
|
||||
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral);
|
||||
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,7 +209,7 @@ void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
(allow_col_major_ && in.flags().col_contiguous)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy(in, out, CopyType::General);
|
||||
copy(in, out, CopyType::General, stream());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,14 +233,7 @@ void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy(in, out, ctype);
|
||||
}
|
||||
|
||||
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
load(out, offset_, reader_, swap_endianness_);
|
||||
copy(in, out, ctype, stream());
|
||||
}
|
||||
|
||||
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -192,7 +249,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||
|
||||
// Fill output with val
|
||||
copy(val, out, CopyType::Scalar);
|
||||
copy(val, out, CopyType::Scalar, stream());
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
@@ -207,7 +264,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_inplace(in, out_slice, CopyType::GeneralGeneral);
|
||||
copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
||||
}
|
||||
|
||||
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -223,39 +280,49 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto kptr = inputs[0].data<uint32_t>();
|
||||
auto cptr = out.data<char>();
|
||||
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||
auto half_size = out_skip / 2;
|
||||
bool even = out_skip % 2 == 0;
|
||||
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
|
||||
auto ptr = reinterpret_cast<uint32_t*>(cptr);
|
||||
// Get ith key
|
||||
auto kidx = 2 * i;
|
||||
auto k1_elem = elem_to_loc(kidx, keys.shape(), keys.strides());
|
||||
auto k2_elem = elem_to_loc(kidx + 1, keys.shape(), keys.strides());
|
||||
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(inputs[0]);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([kptr,
|
||||
cptr,
|
||||
bytes_per_key,
|
||||
num_keys,
|
||||
kshape = keys.shape(),
|
||||
kstrides = keys.strides()]() mutable {
|
||||
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||
auto half_size = out_skip / 2;
|
||||
bool even = out_skip % 2 == 0;
|
||||
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
|
||||
auto ptr = reinterpret_cast<uint32_t*>(cptr);
|
||||
// Get ith key
|
||||
auto kidx = 2 * i;
|
||||
auto k1_elem = elem_to_loc(kidx, kshape, kstrides);
|
||||
auto k2_elem = elem_to_loc(kidx + 1, kshape, kstrides);
|
||||
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
|
||||
|
||||
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
|
||||
for (; count.first + 1 < half_size; count.first++, count.second++) {
|
||||
std::tie(ptr[count.first], ptr[count.second]) =
|
||||
random::threefry2x32_hash(key, count);
|
||||
}
|
||||
if (count.first < half_size) {
|
||||
auto rb = random::threefry2x32_hash(key, count);
|
||||
ptr[count.first++] = rb.first;
|
||||
if (bytes_per_key % 4 > 0) {
|
||||
std::copy(
|
||||
reinterpret_cast<char*>(&rb.second),
|
||||
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
|
||||
cptr + 4 * count.second);
|
||||
} else {
|
||||
ptr[count.second] = rb.second;
|
||||
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
|
||||
for (; count.first + 1 < half_size; count.first++, count.second++) {
|
||||
std::tie(ptr[count.first], ptr[count.second]) =
|
||||
random::threefry2x32_hash(key, count);
|
||||
}
|
||||
if (count.first < half_size) {
|
||||
auto rb = random::threefry2x32_hash(key, count);
|
||||
ptr[count.first++] = rb.first;
|
||||
if (bytes_per_key % 4 > 0) {
|
||||
std::copy(
|
||||
reinterpret_cast<char*>(&rb.second),
|
||||
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
|
||||
cptr + 4 * count.second);
|
||||
} else {
|
||||
ptr[count.second] = rb.second;
|
||||
}
|
||||
}
|
||||
if (!even) {
|
||||
count.second = 0;
|
||||
ptr[half_size] = random::threefry2x32_hash(key, count).first;
|
||||
}
|
||||
}
|
||||
if (!even) {
|
||||
count.second = 0;
|
||||
ptr[half_size] = random::threefry2x32_hash(key, count).first;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -269,16 +336,23 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto i_offset = compute_dynamic_offset(inputs[1], in.strides(), axes_);
|
||||
auto [in_offset, donated] =
|
||||
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
|
||||
copy_inplace(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ out.shape(),
|
||||
/* const Strides& i_strides = */ in.strides(),
|
||||
/* const Strides& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ i_offset,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
stream(),
|
||||
/* const std::optional<array>& dynamic_i_offset = */ in_offset,
|
||||
/* const std::optional<array>& dynamic_o_offset = */ std::nullopt);
|
||||
if (!donated) {
|
||||
cpu::get_command_encoder(stream()).add_temporary(std::move(in_offset));
|
||||
}
|
||||
}
|
||||
|
||||
void DynamicSliceUpdate::eval_cpu(
|
||||
@@ -296,9 +370,10 @@ void DynamicSliceUpdate::eval_cpu(
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
auto o_offset = compute_dynamic_offset(inputs[2], out.strides(), axes_);
|
||||
auto [out_offset, donated] =
|
||||
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
|
||||
copy_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
@@ -306,8 +381,14 @@ void DynamicSliceUpdate::eval_cpu(
|
||||
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
|
||||
/* const std::vector<stride_t>& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ o_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
stream(),
|
||||
/* const std::optional<array>& dynamic_i_offset = */ std::nullopt,
|
||||
/* const std::optional<array>& dynamic_o_offset = */ out_offset);
|
||||
if (!donated) {
|
||||
cpu::get_command_encoder(stream()).add_temporary(std::move(out_offset));
|
||||
}
|
||||
}
|
||||
|
||||
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -329,7 +410,7 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] =
|
||||
@@ -344,7 +425,8 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const std::vector<stride_t>& o_strides = */ out_strides,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ data_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
stream());
|
||||
}
|
||||
|
||||
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -372,9 +454,9 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (in.dtype() == bool_) {
|
||||
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
||||
in_tmp.copy_shared_buffer(in);
|
||||
copy_inplace(in_tmp, tmp, CopyType::General);
|
||||
copy_inplace(in_tmp, tmp, CopyType::General, stream());
|
||||
} else {
|
||||
copy_inplace(in, tmp, CopyType::General);
|
||||
copy_inplace(in, tmp, CopyType::General, stream());
|
||||
}
|
||||
|
||||
auto flags = out.flags();
|
||||
@@ -382,7 +464,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
flags.row_contiguous = true;
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,20 +2,18 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void qrf_impl(const array& a, array& q, array& r) {
|
||||
void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
const int M = a.shape(-2);
|
||||
const int N = a.shape(-1);
|
||||
const int lda = M;
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
int num_reflectors = std::min(M, N);
|
||||
auto tau =
|
||||
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
|
||||
|
||||
// Copy A to inplace input and make it col-contiguous
|
||||
array in(a.shape(), a.dtype(), nullptr, {});
|
||||
@@ -29,93 +27,107 @@ void qrf_impl(const array& a, array& q, array& r) {
|
||||
strides[in.ndim() - 1] = M;
|
||||
in.set_data(
|
||||
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
|
||||
copy_inplace(a, in, CopyType::GeneralGeneral);
|
||||
|
||||
T optimal_work;
|
||||
int lwork = -1;
|
||||
int info;
|
||||
|
||||
// Compute workspace size
|
||||
geqrf<T>(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
|
||||
|
||||
// Update workspace size
|
||||
lwork = optimal_work;
|
||||
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// Solve
|
||||
geqrf<T>(
|
||||
&M,
|
||||
&N,
|
||||
in.data<T>() + M * N * i,
|
||||
&lda,
|
||||
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
||||
static_cast<T*>(work.raw_ptr()),
|
||||
&lwork,
|
||||
&info);
|
||||
}
|
||||
allocator::free(work);
|
||||
|
||||
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
q.set_data(allocator::malloc_or_wait(q.nbytes()));
|
||||
r.set_data(allocator::malloc_or_wait(r.nbytes()));
|
||||
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
/// num_reflectors x N
|
||||
for (int j = 0; j < r.shape(-2); ++j) {
|
||||
for (int k = 0; k < j; ++k) {
|
||||
r.data<T>()[i * N * num_reflectors + j * N + k] = 0;
|
||||
}
|
||||
for (int k = j; k < r.shape(-1); ++k) {
|
||||
r.data<T>()[i * N * num_reflectors + j * N + k] =
|
||||
in.data<T>()[i * N * M + j + k * M];
|
||||
auto in_ptr = in.data<T>();
|
||||
auto r_ptr = r.data<T>();
|
||||
auto q_ptr = q.data<T>();
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(q);
|
||||
encoder.set_output_array(r);
|
||||
encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {
|
||||
int num_reflectors = std::min(M, N);
|
||||
auto tau =
|
||||
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
|
||||
|
||||
T optimal_work;
|
||||
int lwork = -1;
|
||||
int info;
|
||||
|
||||
// Compute workspace size
|
||||
geqrf<T>(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
|
||||
|
||||
// Update workspace size
|
||||
lwork = optimal_work;
|
||||
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// Solve
|
||||
geqrf<T>(
|
||||
&M,
|
||||
&N,
|
||||
in_ptr + M * N * i,
|
||||
&lda,
|
||||
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
||||
static_cast<T*>(work.raw_ptr()),
|
||||
&lwork,
|
||||
&info);
|
||||
}
|
||||
allocator::free(work);
|
||||
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
/// num_reflectors x N
|
||||
for (int j = 0; j < num_reflectors; ++j) {
|
||||
for (int k = 0; k < j; ++k) {
|
||||
r_ptr[i * N * num_reflectors + j * N + k] = 0;
|
||||
}
|
||||
for (int k = j; k < N; ++k) {
|
||||
r_ptr[i * N * num_reflectors + j * N + k] =
|
||||
in_ptr[i * N * M + j + k * M];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get work size
|
||||
lwork = -1;
|
||||
orgqr<T>(
|
||||
&M,
|
||||
&num_reflectors,
|
||||
&num_reflectors,
|
||||
nullptr,
|
||||
&lda,
|
||||
nullptr,
|
||||
&optimal_work,
|
||||
&lwork,
|
||||
&info);
|
||||
lwork = optimal_work;
|
||||
work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// Compute Q
|
||||
// Get work size
|
||||
lwork = -1;
|
||||
orgqr<T>(
|
||||
&M,
|
||||
&num_reflectors,
|
||||
&num_reflectors,
|
||||
in.data<T>() + M * N * i,
|
||||
nullptr,
|
||||
&lda,
|
||||
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
||||
static_cast<T*>(work.raw_ptr()),
|
||||
nullptr,
|
||||
&optimal_work,
|
||||
&lwork,
|
||||
&info);
|
||||
}
|
||||
lwork = optimal_work;
|
||||
work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||
|
||||
q.set_data(allocator::malloc_or_wait(q.nbytes()));
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// M x num_reflectors
|
||||
for (int j = 0; j < q.shape(-2); ++j) {
|
||||
for (int k = 0; k < q.shape(-1); ++k) {
|
||||
q.data<T>()[i * M * num_reflectors + j * num_reflectors + k] =
|
||||
in.data<T>()[i * N * M + j + k * M];
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// Compute Q
|
||||
orgqr<T>(
|
||||
&M,
|
||||
&num_reflectors,
|
||||
&num_reflectors,
|
||||
in_ptr + M * N * i,
|
||||
&lda,
|
||||
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
||||
static_cast<T*>(work.raw_ptr()),
|
||||
&lwork,
|
||||
&info);
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// M x num_reflectors
|
||||
for (int j = 0; j < M; ++j) {
|
||||
for (int k = 0; k < num_reflectors; ++k) {
|
||||
q_ptr[i * M * num_reflectors + j * num_reflectors + k] =
|
||||
in_ptr[i * N * M + j + k * M];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
allocator::free(work);
|
||||
allocator::free(tau);
|
||||
// Cleanup
|
||||
allocator::free(work);
|
||||
allocator::free(tau);
|
||||
});
|
||||
encoder.add_temporary(in);
|
||||
}
|
||||
|
||||
void QRF::eval_cpu(
|
||||
@@ -123,10 +135,10 @@ void QRF::eval_cpu(
|
||||
std::vector<array>& outputs) {
|
||||
switch (inputs[0].dtype()) {
|
||||
case float32:
|
||||
qrf_impl<float>(inputs[0], outputs[0], outputs[1]);
|
||||
qrf_impl<float>(inputs[0], outputs[0], outputs[1], stream());
|
||||
break;
|
||||
case float64:
|
||||
qrf_impl<double>(inputs[0], outputs[0], outputs[1]);
|
||||
qrf_impl<double>(inputs[0], outputs[0], outputs[1], stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -316,6 +317,76 @@ void _qmm_dispatch_typed(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void _qmm_dispatch_typed(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
||||
int N = out.shape(-1);
|
||||
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
||||
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
int batch_size = x.size() / (K * M);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
auto scales_ptr = scales.data<T>();
|
||||
auto biases_ptr = biases.data<T>();
|
||||
|
||||
encoder.dispatch([out_ptr,
|
||||
x_ptr,
|
||||
w_ptr,
|
||||
scales_ptr,
|
||||
biases_ptr,
|
||||
x_shape = x.shape(),
|
||||
x_strides = x.strides(),
|
||||
w_shape = w.shape(),
|
||||
w_strides = w.strides(),
|
||||
scales_shape = scales.shape(),
|
||||
scales_strides = scales.strides(),
|
||||
biases_shape = biases.shape(),
|
||||
biases_strides = biases.strides(),
|
||||
w_els,
|
||||
g_els,
|
||||
batch_size,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w] {
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
_qmm_dispatch_typed<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(i * M * K, x_shape, x_strides),
|
||||
w_ptr + elem_to_loc(i * w_els, w_shape, w_strides),
|
||||
scales_ptr + elem_to_loc(i * g_els, scales_shape, scales_strides),
|
||||
biases_ptr + elem_to_loc(i * g_els, biases_shape, biases_strides),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void _qmm_dispatch(
|
||||
array& out,
|
||||
const array& x,
|
||||
@@ -324,64 +395,111 @@ void _qmm_dispatch(
|
||||
const array& biases,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void _bs_qmm_dispatch_typed(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
|
||||
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
||||
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
int w_els = w.shape(-1) * w.shape(-2);
|
||||
int g_els = scales.shape(-1) * scales.shape(-2);
|
||||
|
||||
int batch_size = x.size() / (K * M);
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out.data<float>() + i * M * N,
|
||||
x.data<float>() + elem_to_loc(i * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
||||
scales.data<float>() + elem_to_loc(i * g_els, scales),
|
||||
biases.data<float>() + elem_to_loc(i * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out.data<float16_t>() + i * M * N,
|
||||
x.data<float16_t>() + elem_to_loc(i * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
||||
scales.data<float16_t>() + elem_to_loc(i * g_els, scales),
|
||||
biases.data<float16_t>() + elem_to_loc(i * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out.data<bfloat16_t>() + i * M * N,
|
||||
x.data<bfloat16_t>() + elem_to_loc(i * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
||||
scales.data<bfloat16_t>() + elem_to_loc(i * g_els, scales),
|
||||
biases.data<bfloat16_t>() + elem_to_loc(i * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_input_array(lhs_indices);
|
||||
encoder.set_input_array(rhs_indices);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
auto scales_ptr = scales.data<T>();
|
||||
auto biases_ptr = biases.data<T>();
|
||||
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
|
||||
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
|
||||
|
||||
encoder.dispatch([out_ptr,
|
||||
x_ptr,
|
||||
w_ptr,
|
||||
scales_ptr,
|
||||
biases_ptr,
|
||||
lhs_indices_ptr,
|
||||
rhs_indices_ptr,
|
||||
x_shape = x.shape(),
|
||||
x_strides = x.strides(),
|
||||
w_shape = w.shape(),
|
||||
w_strides = w.strides(),
|
||||
scales_shape = scales.shape(),
|
||||
scales_strides = scales.strides(),
|
||||
biases_shape = biases.shape(),
|
||||
biases_strides = biases.strides(),
|
||||
lhs_indices_shape = lhs_indices.shape(),
|
||||
lhs_indices_strides = lhs_indices.strides(),
|
||||
rhs_indices_shape = rhs_indices.shape(),
|
||||
rhs_indices_strides = rhs_indices.strides(),
|
||||
w_els,
|
||||
g_els,
|
||||
indices_size = lhs_indices.size(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w]() {
|
||||
for (int i = 0; i < indices_size; i++) {
|
||||
int x_idx = lhs_indices_ptr[elem_to_loc(
|
||||
i, lhs_indices_shape, lhs_indices_strides)];
|
||||
int w_idx = rhs_indices_ptr[elem_to_loc(
|
||||
i, rhs_indices_shape, rhs_indices_strides)];
|
||||
_qmm_dispatch_typed<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(x_idx * M * K, x_shape, x_strides),
|
||||
w_ptr + elem_to_loc(w_idx * w_els, w_shape, w_strides),
|
||||
scales_ptr + elem_to_loc(w_idx * g_els, scales_shape, scales_strides),
|
||||
biases_ptr + elem_to_loc(w_idx * g_els, biases_shape, biases_strides),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void _bs_qmm_dispatch(
|
||||
@@ -394,68 +512,54 @@ void _bs_qmm_dispatch(
|
||||
const array& rhs_indices,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
|
||||
int w_els = w.shape(-1) * w.shape(-2);
|
||||
int g_els = scales.shape(-1) * scales.shape(-2);
|
||||
|
||||
const uint32_t* lhs_indices_data = lhs_indices.data<uint32_t>();
|
||||
const uint32_t* rhs_indices_data = rhs_indices.data<uint32_t>();
|
||||
|
||||
for (int i = 0; i < lhs_indices.size(); i++) {
|
||||
int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)];
|
||||
int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)];
|
||||
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out.data<float>() + i * M * N,
|
||||
x.data<float>() + elem_to_loc(x_idx * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||
scales.data<float>() + elem_to_loc(w_idx * g_els, scales),
|
||||
biases.data<float>() + elem_to_loc(w_idx * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out.data<float16_t>() + i * M * N,
|
||||
x.data<float16_t>() + elem_to_loc(x_idx * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||
scales.data<float16_t>() + elem_to_loc(w_idx * g_els, scales),
|
||||
biases.data<float16_t>() + elem_to_loc(w_idx * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out.data<bfloat16_t>() + i * M * N,
|
||||
x.data<bfloat16_t>() + elem_to_loc(x_idx * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||
scales.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, scales),
|
||||
biases.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_bs_qmm_dispatch_typed<float>(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w,
|
||||
stream);
|
||||
break;
|
||||
case float16:
|
||||
_bs_qmm_dispatch_typed<float16_t>(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w,
|
||||
stream);
|
||||
break;
|
||||
case bfloat16:
|
||||
_bs_qmm_dispatch_typed<bfloat16_t>(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w,
|
||||
stream);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -469,13 +573,14 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
|
||||
auto ensure_row_contiguous = [](const array& arr) {
|
||||
std::vector<array> temps;
|
||||
auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy(arr, temps.back(), CopyType::General, s);
|
||||
return temps.back();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -485,7 +590,10 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto biases = ensure_row_contiguous(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
_qmm_dispatch(
|
||||
out, x, w, scales, biases, group_size_, bits_, transpose_, stream());
|
||||
auto& enc = cpu::get_command_encoder(stream());
|
||||
enc.add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -498,15 +606,17 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& lhs_indices = inputs[4];
|
||||
auto& rhs_indices = inputs[5];
|
||||
|
||||
auto ensure_row_contiguous_last_dims = [](const array& arr) {
|
||||
std::vector<array> temps;
|
||||
auto ensure_row_contiguous_last_dims = [s = stream(),
|
||||
&temps](const array& arr) {
|
||||
auto stride_0 = arr.strides()[arr.ndim() - 2];
|
||||
auto stride_1 = arr.strides()[arr.ndim() - 1];
|
||||
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy(arr, temps.back(), CopyType::General, s);
|
||||
return temps.back();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -526,31 +636,30 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
rhs_indices,
|
||||
group_size_,
|
||||
bits_,
|
||||
transpose_);
|
||||
transpose_,
|
||||
stream());
|
||||
auto& enc = cpu::get_command_encoder(stream());
|
||||
enc.add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void quantize(
|
||||
const array& w_,
|
||||
array& out_,
|
||||
array& scales_,
|
||||
array& biases_,
|
||||
const T* w,
|
||||
U* out,
|
||||
T* scales,
|
||||
T* biases,
|
||||
int bits,
|
||||
int group_size) {
|
||||
const T* w = w_.data<T>();
|
||||
|
||||
auto out = out_.data<U>();
|
||||
T* scales = scales_.data<T>();
|
||||
T* biases = biases_.data<T>();
|
||||
|
||||
int group_size,
|
||||
size_t w_size) {
|
||||
float n_bins = (1 << bits) - 1;
|
||||
float eps = 1e-7;
|
||||
|
||||
bool power_of_2_bits = is_power_of_2(bits);
|
||||
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
|
||||
int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
int int_per_group = group_size * bytes_per_pack / el_per_int;
|
||||
size_t n_groups = w_.size() / group_size;
|
||||
size_t n_groups = w_size / group_size;
|
||||
|
||||
for (size_t i = 0; i < n_groups; ++i) {
|
||||
size_t w_idx = i * group_size;
|
||||
@@ -593,20 +702,50 @@ void quantize(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void dispatch_quantize(
|
||||
const array& w,
|
||||
array& out,
|
||||
array& scales,
|
||||
array& biases,
|
||||
int bits,
|
||||
int group_size,
|
||||
Stream stream) {
|
||||
auto w_ptr = w.data<T>();
|
||||
auto out_ptr = out.data<U>();
|
||||
auto scales_ptr = scales.data<T>();
|
||||
auto biases_ptr = biases.data<T>();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([w_ptr,
|
||||
out_ptr,
|
||||
scales_ptr,
|
||||
biases_ptr,
|
||||
bits,
|
||||
group_size,
|
||||
w_size = w.size()]() {
|
||||
quantize<T, U>(
|
||||
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w_size);
|
||||
});
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto ensure_row_contiguous = [](const array& arr) {
|
||||
auto ensure_row_contiguous = [s = stream()](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
return std::make_pair(arr, false);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
copy(arr, arr_copy, CopyType::General, s);
|
||||
return std::make_pair(arr_copy, true);
|
||||
}
|
||||
};
|
||||
auto w = ensure_row_contiguous(inputs[0]);
|
||||
|
||||
auto [w, copied] = ensure_row_contiguous(inputs[0]);
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
@@ -616,27 +755,35 @@ void fast::AffineQuantize::eval_cpu(
|
||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||
if (w.dtype() == float16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
quantize<float16_t, uint32_t>(w, out, scales, biases, bits_, group_size_);
|
||||
dispatch_quantize<float16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
} else {
|
||||
quantize<float16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||
dispatch_quantize<float16_t, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
}
|
||||
} else if (w.dtype() == bfloat16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
quantize<bfloat16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
dispatch_quantize<bfloat16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
} else {
|
||||
quantize<bfloat16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||
dispatch_quantize<bfloat16_t, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
}
|
||||
} else if (w.dtype() == float32) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
quantize<float, uint32_t>(w, out, scales, biases, bits_, group_size_);
|
||||
dispatch_quantize<float, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
} else {
|
||||
quantize<float, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||
dispatch_quantize<float, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||
}
|
||||
if (copied) {
|
||||
cpu::get_command_encoder(stream()).add_temporary(w);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <limits>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -140,25 +141,33 @@ void reduction_op(
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
U init,
|
||||
Op op) {
|
||||
Stream stream) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
ReductionPlan plan = get_reduction_plan(x, axes);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto in_ptr = x.data<T>();
|
||||
auto out_ptr = out.data<U>();
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
U* out_ptr = out.data<U>();
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(x.data<T>(), out_ptr, x.size(), op, init);
|
||||
encoder.dispatch([in_ptr, out_ptr, init, size = x.size()]() {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(in_ptr, out_ptr, size, Op{}, init);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
|
||||
int reduction_size = plan.shape[0];
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init);
|
||||
}
|
||||
encoder.dispatch(
|
||||
[in_ptr, out_ptr, init, reduction_size, size = out.size()]() mutable {
|
||||
for (int i = 0; i < size; i++, out_ptr++, in_ptr += reduction_size) {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -166,34 +175,43 @@ void reduction_op(
|
||||
int reduction_size = plan.shape.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
// Unrolling the following loop (and implementing it in order for
|
||||
// ContiguousReduce) should hold extra performance boost.
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init);
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
reduction_size,
|
||||
size = out.size(),
|
||||
plan = std::move(plan),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides)]() mutable {
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < size; i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(
|
||||
in_ptr + offset, out_ptr, reduction_size, Op{}, init);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < size; i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
contiguous_reduce(
|
||||
in_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
Op{},
|
||||
init);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
contiguous_reduce(
|
||||
x_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
op,
|
||||
init);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -202,14 +220,20 @@ void reduction_op(
|
||||
size_t reduction_stride = plan.strides.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op);
|
||||
x_ptr += reduction_stride * reduction_size;
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
size = out.size()]() mutable {
|
||||
for (int i = 0; i < size; i += reduction_stride) {
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
|
||||
in_ptr += reduction_stride * reduction_size;
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -219,53 +243,69 @@ void reduction_op(
|
||||
size_t reduction_stride = plan.strides.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(
|
||||
x_ptr + offset, out_ptr, reduction_size, reduction_stride, op);
|
||||
out_ptr += reduction_stride;
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
size = out.size(),
|
||||
plan = std::move(plan),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides)]() mutable {
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < size; i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(
|
||||
in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < size; i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
strided_reduce(
|
||||
in_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
Op{});
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
strided_reduce(
|
||||
x_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
op);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == GeneralReduce) {
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
val = op(val, *(x_ptr + offset + extra_offset));
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
*out_ptr = val;
|
||||
}
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
size = out.size(),
|
||||
plan = std::move(plan),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides)]() mutable {
|
||||
for (int i = 0; i < size; i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
val = Op{}(val, *(in_ptr + offset + extra_offset));
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
*out_ptr = val;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -394,11 +434,12 @@ void reduce_dispatch_and_or(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
if (rtype == Reduce::And) {
|
||||
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
|
||||
reduction_op<InT, bool, AndReduce>(in, out, axes, true, stream);
|
||||
} else {
|
||||
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
|
||||
reduction_op<InT, bool, OrReduce>(in, out, axes, false, stream);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -407,18 +448,19 @@ void reduce_dispatch_sum_prod(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
if (rtype == Reduce::Sum) {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, SumReduce());
|
||||
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0, stream);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, SumReduce());
|
||||
reduction_op<InT, InT, SumReduce>(in, out, axes, 0, stream);
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 1, ProdReduce());
|
||||
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1, stream);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 1, ProdReduce());
|
||||
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -428,13 +470,14 @@ void reduce_dispatch_min_max(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
if (rtype == Reduce::Max) {
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||
reduction_op<InT, InT, MaxReduce>(in, out, axes, init, stream);
|
||||
} else {
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||
reduction_op<InT, InT, MinReduce>(in, out, axes, init, stream);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -448,24 +491,28 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_and_or<int8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
case float16:
|
||||
case bfloat16:
|
||||
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_and_or<int16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint32:
|
||||
case int32:
|
||||
case float32:
|
||||
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_and_or<int32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint64:
|
||||
case int64:
|
||||
case float64:
|
||||
case complex64:
|
||||
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_and_or<int64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
}
|
||||
break;
|
||||
@@ -476,34 +523,43 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<int8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<int16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<int32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<int64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<float16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<float>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<double>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<complex64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
}
|
||||
break;
|
||||
@@ -512,46 +568,59 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case Reduce::Min: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<uint8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<uint16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<uint32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<uint64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<uint8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<uint16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<int32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<int64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<float16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<float>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<double>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<bfloat16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<complex64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -153,37 +154,44 @@ void strided_scan(
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void scan_op(
|
||||
const array& input,
|
||||
array& output,
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
const Op& op,
|
||||
U init) {
|
||||
output.set_data(allocator::malloc_or_wait(output.nbytes()));
|
||||
U init,
|
||||
Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
if (input.flags().row_contiguous) {
|
||||
if (input.strides()[axis] == 1) {
|
||||
contiguous_scan(
|
||||
input.data<T>(),
|
||||
output.data<U>(),
|
||||
input.size() / input.shape(axis),
|
||||
input.shape(axis),
|
||||
reverse,
|
||||
inclusive,
|
||||
op,
|
||||
init);
|
||||
if (in.flags().row_contiguous) {
|
||||
if (in.strides()[axis] == 1) {
|
||||
encoder.dispatch([in_ptr = in.data<T>(),
|
||||
out_ptr = out.data<U>(),
|
||||
count = in.size() / in.shape(axis),
|
||||
stride = in.shape(axis),
|
||||
reverse,
|
||||
inclusive,
|
||||
op = std::move(op),
|
||||
init]() {
|
||||
contiguous_scan(
|
||||
in_ptr, out_ptr, count, stride, reverse, inclusive, op, init);
|
||||
});
|
||||
} else {
|
||||
strided_scan(
|
||||
input.data<T>(),
|
||||
output.data<U>(),
|
||||
input.size() / input.shape(axis) / input.strides()[axis],
|
||||
input.shape(axis),
|
||||
input.strides()[axis],
|
||||
reverse,
|
||||
inclusive,
|
||||
op,
|
||||
init);
|
||||
encoder.dispatch([in_ptr = in.data<T>(),
|
||||
out_ptr = out.data<U>(),
|
||||
count = in.size() / in.shape(axis) / in.strides()[axis],
|
||||
size = in.shape(axis),
|
||||
stride = in.strides()[axis],
|
||||
reverse,
|
||||
inclusive,
|
||||
op = std::move(op),
|
||||
init]() {
|
||||
strided_scan(
|
||||
in_ptr, out_ptr, count, size, stride, reverse, inclusive, op, init);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Scan op supports only contiguous inputs");
|
||||
@@ -193,38 +201,39 @@ void scan_op(
|
||||
template <typename T, typename U>
|
||||
void scan_dispatch(
|
||||
Scan::ReduceType rtype,
|
||||
const array& input,
|
||||
array& output,
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis,
|
||||
bool reverse,
|
||||
bool inclusive) {
|
||||
bool inclusive,
|
||||
Stream stream) {
|
||||
switch (rtype) {
|
||||
case Scan::Sum: {
|
||||
auto op = [](U y, T x) { return y + x; };
|
||||
auto init = static_cast<U>(0);
|
||||
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
|
||||
break;
|
||||
}
|
||||
case Scan::Prod: {
|
||||
auto op = [](U y, T x) { return y * x; };
|
||||
auto init = static_cast<U>(1);
|
||||
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
|
||||
break;
|
||||
}
|
||||
case Scan::Min: {
|
||||
auto op = [](U y, T x) { return x < y ? x : y; };
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
auto init = (issubdtype(in.dtype(), floating))
|
||||
? static_cast<U>(std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
|
||||
break;
|
||||
}
|
||||
case Scan::Max: {
|
||||
auto op = [](U y, T x) { return x < y ? y : x; };
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
auto init = (issubdtype(in.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::min();
|
||||
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -237,11 +246,14 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Ensure contiguity
|
||||
auto in = inputs[0];
|
||||
bool copied = false;
|
||||
if (!in.flags().row_contiguous) {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy(in, arr_copy, CopyType::General);
|
||||
copy(in, arr_copy, CopyType::General, stream());
|
||||
in = arr_copy;
|
||||
copied = true;
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_: {
|
||||
@@ -252,65 +264,68 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// floats perhaps we should add the full all-to-all dispatch.
|
||||
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
|
||||
scan_dispatch<bool, int32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
} else {
|
||||
scan_dispatch<bool, bool>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
}
|
||||
break;
|
||||
}
|
||||
case uint8:
|
||||
scan_dispatch<uint8_t, uint8_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case uint16:
|
||||
scan_dispatch<uint16_t, uint16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case uint32:
|
||||
scan_dispatch<uint32_t, uint32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case uint64:
|
||||
scan_dispatch<uint64_t, uint64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case int8:
|
||||
scan_dispatch<int8_t, int8_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case int16:
|
||||
scan_dispatch<int16_t, int16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case int32:
|
||||
scan_dispatch<int32_t, int32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case int64:
|
||||
scan_dispatch<int64_t, int64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case float16:
|
||||
scan_dispatch<float16_t, float16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case float32:
|
||||
scan_dispatch<float, float>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case float64:
|
||||
scan_dispatch<double, double>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
scan_dispatch<bfloat16_t, bfloat16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
throw std::runtime_error("Scan ops do not support complex types yet");
|
||||
break;
|
||||
}
|
||||
if (copied) {
|
||||
cpu::get_command_encoder(stream()).add_temporary(std::move(in));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <cmath>
|
||||
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/types/limits.h"
|
||||
@@ -15,92 +16,100 @@ namespace {
|
||||
using namespace mlx::core::simd;
|
||||
|
||||
template <typename T, typename AccT>
|
||||
void softmax(const array& in, array& out) {
|
||||
constexpr bool same_t = std::is_same_v<T, AccT>;
|
||||
constexpr int N = std::min(max_size<AccT>, max_size<T>);
|
||||
void softmax(const array& in, array& out, Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
|
||||
int M = in.shape().back();
|
||||
int L = in.data_size() / M;
|
||||
const T* current_in_ptr;
|
||||
T* current_out_ptr;
|
||||
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
|
||||
vmaximum = maximum(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
|
||||
constexpr bool same_t = std::is_same_v<T, AccT>;
|
||||
constexpr int N = std::min(max_size<AccT>, max_size<T>);
|
||||
|
||||
AccT maximum = max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
const T* current_in_ptr;
|
||||
T* current_out_ptr;
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
Simd<AccT, N> vnormalizer(0.0);
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
||||
vexp = exp(vexp - maximum);
|
||||
if constexpr (same_t) {
|
||||
store(current_out_ptr, vexp);
|
||||
}
|
||||
vnormalizer = vnormalizer + vexp;
|
||||
current_in_ptr += N;
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT normalizer = sum(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
if constexpr (same_t) {
|
||||
*current_out_ptr = _exp;
|
||||
}
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
current_out_ptr++;
|
||||
}
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Normalize
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
if constexpr (same_t) {
|
||||
store(
|
||||
current_out_ptr,
|
||||
Simd<T, N>(load<T, N>(current_out_ptr) * normalizer));
|
||||
} else {
|
||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
||||
vexp = exp(vexp - maximum) * normalizer;
|
||||
store(current_out_ptr, Simd<T, N>(vexp));
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
|
||||
vmaximum = maximum(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
if constexpr (same_t) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(_exp * normalizer);
|
||||
|
||||
AccT maximum = max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
current_out_ptr++;
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
Simd<AccT, N> vnormalizer(0.0);
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
||||
vexp = exp(vexp - maximum);
|
||||
if constexpr (same_t) {
|
||||
store(current_out_ptr, vexp);
|
||||
}
|
||||
vnormalizer = vnormalizer + vexp;
|
||||
current_in_ptr += N;
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT normalizer = sum(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
if constexpr (same_t) {
|
||||
*current_out_ptr = _exp;
|
||||
}
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
current_out_ptr++;
|
||||
}
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Normalize
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
if constexpr (same_t) {
|
||||
store(
|
||||
current_out_ptr,
|
||||
Simd<T, N>(load<T, N>(current_out_ptr) * normalizer));
|
||||
} else {
|
||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
||||
vexp = exp(vexp - maximum) * normalizer;
|
||||
store(current_out_ptr, Simd<T, N>(vexp));
|
||||
current_in_ptr += N;
|
||||
}
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
if constexpr (same_t) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(_exp * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
current_out_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -109,30 +118,32 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto check_input = [](array x) {
|
||||
auto set_output = [s = stream(), &out](const array& x) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy(x, x_copy, CopyType::General);
|
||||
copy(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
array in = check_input(std::move(inputs[0]));
|
||||
if (in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
|
||||
auto in = set_output(inputs[0]);
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
@@ -148,24 +159,24 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float, float>(in, out);
|
||||
softmax<float, float>(in, out, stream());
|
||||
break;
|
||||
case float16:
|
||||
if (precise_) {
|
||||
softmax<float16_t, float>(in, out);
|
||||
softmax<float16_t, float>(in, out, stream());
|
||||
} else {
|
||||
softmax<float16_t, float16_t>(in, out);
|
||||
softmax<float16_t, float16_t>(in, out, stream());
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
if (precise_) {
|
||||
softmax<bfloat16_t, float>(in, out);
|
||||
softmax<bfloat16_t, float>(in, out, stream());
|
||||
} else {
|
||||
softmax<bfloat16_t, bfloat16_t>(in, out);
|
||||
softmax<bfloat16_t, bfloat16_t>(in, out, stream());
|
||||
}
|
||||
break;
|
||||
case float64:
|
||||
softmax<double, double>(in, out);
|
||||
softmax<double, double>(in, out, stream());
|
||||
break;
|
||||
case complex64:
|
||||
throw std::invalid_argument(
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -103,11 +104,11 @@ struct StridedIterator {
|
||||
T* ptr_;
|
||||
};
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void sort(const array& in, array& out, int axis) {
|
||||
template <typename T>
|
||||
void sort(const array& in, array& out, int axis, Stream stream) {
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype);
|
||||
copy(in, out, ctype, stream);
|
||||
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
@@ -126,19 +127,27 @@ void sort(const array& in, array& out, int axis) {
|
||||
// Perform sorting in place
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out.data<T>() + src_it.loc;
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out_ptr = out.data<T>(),
|
||||
src_it = std::move(src_it),
|
||||
n_rows,
|
||||
axis_size,
|
||||
axis_stride]() mutable {
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
|
||||
std::stable_sort(st, ed);
|
||||
src_it.step();
|
||||
}
|
||||
std::stable_sort(st, ed);
|
||||
src_it.step();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void argsort(const array& in, array& out, int axis) {
|
||||
void argsort(const array& in, array& out, int axis, Stream stream) {
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
@@ -167,35 +176,48 @@ void argsort(const array& in, array& out, int axis) {
|
||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
||||
ContiguousIterator out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in.data<T>() + in_it.loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(out);
|
||||
encoder.dispatch([in_ptr = in.data<T>(),
|
||||
out_ptr = out.data<IdxT>(),
|
||||
in_it = std::move(in_it),
|
||||
out_it = std::move(out_it),
|
||||
n_rows,
|
||||
axis_size,
|
||||
in_stride,
|
||||
out_stride]() mutable {
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
|
||||
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
||||
|
||||
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void partition(const array& in, array& out, int axis, int kth) {
|
||||
template <typename T>
|
||||
void partition(const array& in, array& out, int axis, int kth, Stream stream) {
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype);
|
||||
copy(in, out, ctype, stream);
|
||||
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
@@ -216,20 +238,34 @@ void partition(const array& in, array& out, int axis, int kth) {
|
||||
// Perform partition in place
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out.data<T>() + src_it.loc;
|
||||
src_it.step();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out_ptr = out.data<T>(),
|
||||
src_it = std::move(src_it),
|
||||
n_rows,
|
||||
axis_size,
|
||||
axis_stride,
|
||||
kth]() mutable {
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
src_it.step();
|
||||
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator md(data_ptr, axis_stride, kth);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator md(data_ptr, axis_stride, kth);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
|
||||
std::nth_element(st, md, ed);
|
||||
}
|
||||
std::nth_element(st, md, ed);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
void argpartition(
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis,
|
||||
int kth,
|
||||
Stream stream) {
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
@@ -260,29 +296,43 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
||||
ContiguousIterator out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in.data<T>() + in_it.loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(out);
|
||||
encoder.dispatch([in_ptr = in.data<T>(),
|
||||
out_ptr = out.data<IdxT>(),
|
||||
in_it = std::move(in_it),
|
||||
out_it = std::move(out_it),
|
||||
n_rows,
|
||||
axis_size,
|
||||
in_stride,
|
||||
out_stride,
|
||||
kth]() mutable {
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, out_stride, 0);
|
||||
StridedIterator md(idx_ptr, out_stride, kth);
|
||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
|
||||
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, out_stride, 0);
|
||||
StridedIterator md(idx_ptr, out_stride, kth);
|
||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
||||
|
||||
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -293,33 +343,33 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argsort<bool>(in, out, axis_);
|
||||
return argsort<bool>(in, out, axis_, stream());
|
||||
case uint8:
|
||||
return argsort<uint8_t>(in, out, axis_);
|
||||
return argsort<uint8_t>(in, out, axis_, stream());
|
||||
case uint16:
|
||||
return argsort<uint16_t>(in, out, axis_);
|
||||
return argsort<uint16_t>(in, out, axis_, stream());
|
||||
case uint32:
|
||||
return argsort<uint32_t>(in, out, axis_);
|
||||
return argsort<uint32_t>(in, out, axis_, stream());
|
||||
case uint64:
|
||||
return argsort<uint64_t>(in, out, axis_);
|
||||
return argsort<uint64_t>(in, out, axis_, stream());
|
||||
case int8:
|
||||
return argsort<int8_t>(in, out, axis_);
|
||||
return argsort<int8_t>(in, out, axis_, stream());
|
||||
case int16:
|
||||
return argsort<int16_t>(in, out, axis_);
|
||||
return argsort<int16_t>(in, out, axis_, stream());
|
||||
case int32:
|
||||
return argsort<int32_t>(in, out, axis_);
|
||||
return argsort<int32_t>(in, out, axis_, stream());
|
||||
case int64:
|
||||
return argsort<int64_t>(in, out, axis_);
|
||||
return argsort<int64_t>(in, out, axis_, stream());
|
||||
case float32:
|
||||
return argsort<float>(in, out, axis_);
|
||||
return argsort<float>(in, out, axis_, stream());
|
||||
case float64:
|
||||
return argsort<double>(in, out, axis_);
|
||||
return argsort<double>(in, out, axis_, stream());
|
||||
case float16:
|
||||
return argsort<float16_t>(in, out, axis_);
|
||||
return argsort<float16_t>(in, out, axis_, stream());
|
||||
case bfloat16:
|
||||
return argsort<bfloat16_t>(in, out, axis_);
|
||||
return argsort<bfloat16_t>(in, out, axis_, stream());
|
||||
case complex64:
|
||||
return argsort<complex64_t>(in, out, axis_);
|
||||
return argsort<complex64_t>(in, out, axis_, stream());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -329,33 +379,33 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return sort<bool>(in, out, axis_);
|
||||
return sort<bool>(in, out, axis_, stream());
|
||||
case uint8:
|
||||
return sort<uint8_t>(in, out, axis_);
|
||||
return sort<uint8_t>(in, out, axis_, stream());
|
||||
case uint16:
|
||||
return sort<uint16_t>(in, out, axis_);
|
||||
return sort<uint16_t>(in, out, axis_, stream());
|
||||
case uint32:
|
||||
return sort<uint32_t>(in, out, axis_);
|
||||
return sort<uint32_t>(in, out, axis_, stream());
|
||||
case uint64:
|
||||
return sort<uint64_t>(in, out, axis_);
|
||||
return sort<uint64_t>(in, out, axis_, stream());
|
||||
case int8:
|
||||
return sort<int8_t>(in, out, axis_);
|
||||
return sort<int8_t>(in, out, axis_, stream());
|
||||
case int16:
|
||||
return sort<int16_t>(in, out, axis_);
|
||||
return sort<int16_t>(in, out, axis_, stream());
|
||||
case int32:
|
||||
return sort<int32_t>(in, out, axis_);
|
||||
return sort<int32_t>(in, out, axis_, stream());
|
||||
case int64:
|
||||
return sort<int64_t>(in, out, axis_);
|
||||
return sort<int64_t>(in, out, axis_, stream());
|
||||
case float32:
|
||||
return sort<float>(in, out, axis_);
|
||||
return sort<float>(in, out, axis_, stream());
|
||||
case float64:
|
||||
return sort<double>(in, out, axis_);
|
||||
return sort<double>(in, out, axis_, stream());
|
||||
case float16:
|
||||
return sort<float16_t>(in, out, axis_);
|
||||
return sort<float16_t>(in, out, axis_, stream());
|
||||
case bfloat16:
|
||||
return sort<bfloat16_t>(in, out, axis_);
|
||||
return sort<bfloat16_t>(in, out, axis_, stream());
|
||||
case complex64:
|
||||
return sort<complex64_t>(in, out, axis_);
|
||||
return sort<complex64_t>(in, out, axis_, stream());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -365,33 +415,33 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argpartition<bool>(in, out, axis_, kth_);
|
||||
return argpartition<bool>(in, out, axis_, kth_, stream());
|
||||
case uint8:
|
||||
return argpartition<uint8_t>(in, out, axis_, kth_);
|
||||
return argpartition<uint8_t>(in, out, axis_, kth_, stream());
|
||||
case uint16:
|
||||
return argpartition<uint16_t>(in, out, axis_, kth_);
|
||||
return argpartition<uint16_t>(in, out, axis_, kth_, stream());
|
||||
case uint32:
|
||||
return argpartition<uint32_t>(in, out, axis_, kth_);
|
||||
return argpartition<uint32_t>(in, out, axis_, kth_, stream());
|
||||
case uint64:
|
||||
return argpartition<uint64_t>(in, out, axis_, kth_);
|
||||
return argpartition<uint64_t>(in, out, axis_, kth_, stream());
|
||||
case int8:
|
||||
return argpartition<int8_t>(in, out, axis_, kth_);
|
||||
return argpartition<int8_t>(in, out, axis_, kth_, stream());
|
||||
case int16:
|
||||
return argpartition<int16_t>(in, out, axis_, kth_);
|
||||
return argpartition<int16_t>(in, out, axis_, kth_, stream());
|
||||
case int32:
|
||||
return argpartition<int32_t>(in, out, axis_, kth_);
|
||||
return argpartition<int32_t>(in, out, axis_, kth_, stream());
|
||||
case int64:
|
||||
return argpartition<int64_t>(in, out, axis_, kth_);
|
||||
return argpartition<int64_t>(in, out, axis_, kth_, stream());
|
||||
case float32:
|
||||
return argpartition<float>(in, out, axis_, kth_);
|
||||
return argpartition<float>(in, out, axis_, kth_, stream());
|
||||
case float64:
|
||||
return argpartition<double>(in, out, axis_, kth_);
|
||||
return argpartition<double>(in, out, axis_, kth_, stream());
|
||||
case float16:
|
||||
return argpartition<float16_t>(in, out, axis_, kth_);
|
||||
return argpartition<float16_t>(in, out, axis_, kth_, stream());
|
||||
case bfloat16:
|
||||
return argpartition<bfloat16_t>(in, out, axis_, kth_);
|
||||
return argpartition<bfloat16_t>(in, out, axis_, kth_, stream());
|
||||
case complex64:
|
||||
return argpartition<complex64_t>(in, out, axis_, kth_);
|
||||
return argpartition<complex64_t>(in, out, axis_, kth_, stream());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -401,33 +451,33 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return partition<bool>(in, out, axis_, kth_);
|
||||
return partition<bool>(in, out, axis_, kth_, stream());
|
||||
case uint8:
|
||||
return partition<uint8_t>(in, out, axis_, kth_);
|
||||
return partition<uint8_t>(in, out, axis_, kth_, stream());
|
||||
case uint16:
|
||||
return partition<uint16_t>(in, out, axis_, kth_);
|
||||
return partition<uint16_t>(in, out, axis_, kth_, stream());
|
||||
case uint32:
|
||||
return partition<uint32_t>(in, out, axis_, kth_);
|
||||
return partition<uint32_t>(in, out, axis_, kth_, stream());
|
||||
case uint64:
|
||||
return partition<uint64_t>(in, out, axis_, kth_);
|
||||
return partition<uint64_t>(in, out, axis_, kth_, stream());
|
||||
case int8:
|
||||
return partition<int8_t>(in, out, axis_, kth_);
|
||||
return partition<int8_t>(in, out, axis_, kth_, stream());
|
||||
case int16:
|
||||
return partition<int16_t>(in, out, axis_, kth_);
|
||||
return partition<int16_t>(in, out, axis_, kth_, stream());
|
||||
case int32:
|
||||
return partition<int32_t>(in, out, axis_, kth_);
|
||||
return partition<int32_t>(in, out, axis_, kth_, stream());
|
||||
case int64:
|
||||
return partition<int64_t>(in, out, axis_, kth_);
|
||||
return partition<int64_t>(in, out, axis_, kth_, stream());
|
||||
case float32:
|
||||
return partition<float>(in, out, axis_, kth_);
|
||||
return partition<float>(in, out, axis_, kth_, stream());
|
||||
case float64:
|
||||
return partition<double>(in, out, axis_, kth_);
|
||||
return partition<double>(in, out, axis_, kth_, stream());
|
||||
case float16:
|
||||
return partition<float16_t>(in, out, axis_, kth_);
|
||||
return partition<float16_t>(in, out, axis_, kth_, stream());
|
||||
case bfloat16:
|
||||
return partition<bfloat16_t>(in, out, axis_, kth_);
|
||||
return partition<bfloat16_t>(in, out, axis_, kth_, stream());
|
||||
case complex64:
|
||||
return partition<complex64_t>(in, out, axis_, kth_);
|
||||
return partition<complex64_t>(in, out, axis_, kth_, stream());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,13 +2,18 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void svd_impl(const array& a, T* u_data, T* s_data, T* vt_data) {
|
||||
void svd_impl(
|
||||
const array& a,
|
||||
std::vector<array>& outputs,
|
||||
bool compute_uv,
|
||||
Stream stream) {
|
||||
// Lapack uses the column-major convention. To avoid having to transpose
|
||||
// the input and then transpose the outputs, we swap the indices/sizes of the
|
||||
// matrices and take advantage of the following identity (see
|
||||
@@ -22,118 +27,24 @@ void svd_impl(const array& a, T* u_data, T* s_data, T* vt_data) {
|
||||
const int N = a.shape(-1);
|
||||
const int K = std::min(M, N);
|
||||
|
||||
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
|
||||
const int lda = N;
|
||||
// U of shape M x M. (N x N in lapack).
|
||||
const int ldu = N;
|
||||
// Vᵀ of shape N x N. (M x M in lapack).
|
||||
const int ldvt = M;
|
||||
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
|
||||
// lapack clobbers the input, so we have to make a copy.
|
||||
array in(a.shape(), a.dtype(), nullptr, {});
|
||||
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
copy(
|
||||
a,
|
||||
in,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream);
|
||||
|
||||
auto job_u = (u_data && vt_data) ? "V" : "N";
|
||||
auto job_vt = (u_data && vt_data) ? "V" : "N";
|
||||
static constexpr auto range = "A";
|
||||
// Allocate outputs.
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
auto in_ptr = in.data<T>();
|
||||
T* u_ptr;
|
||||
T* s_ptr;
|
||||
T* vt_ptr;
|
||||
|
||||
// Will contain the number of singular values after the call has returned.
|
||||
int ns = 0;
|
||||
T workspace_dimension = 0;
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not used
|
||||
// here but required by lapack).
|
||||
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
|
||||
|
||||
static const int lwork_query = -1;
|
||||
|
||||
static const int ignored_int = 0;
|
||||
static const T ignored_float = 0;
|
||||
static T ignored_output = 0;
|
||||
|
||||
int info;
|
||||
|
||||
// Compute workspace size.
|
||||
gesvdx<T>(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ nullptr,
|
||||
/* u = */ nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
/* vt = */ nullptr,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ &workspace_dimension,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_dimension;
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
|
||||
// Loop over matrices.
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
gesvdx<T>(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ in.data<T>() + M * N * i,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ s_data + K * i,
|
||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||
/* u = */ vt_data ? vt_data + N * N * i : &ignored_output,
|
||||
/* ldu = */ &ldu,
|
||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||
/* vt = */ u_data ? u_data + M * M * i : &ignored_output,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[SVD::eval_cpu] failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
if (ns != K) {
|
||||
std::stringstream ss;
|
||||
ss << "[SVD::eval_cpu] expected " << K << " singular values, but " << ns
|
||||
<< " were computed.";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void compute_svd(const array& a, bool compute_uv, std::vector<array>& outputs) {
|
||||
if (compute_uv) {
|
||||
array& u = outputs[0];
|
||||
array& s = outputs[1];
|
||||
@@ -143,25 +54,147 @@ void compute_svd(const array& a, bool compute_uv, std::vector<array>& outputs) {
|
||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
|
||||
|
||||
svd_impl<T>(a, u.data<T>(), s.data<T>(), vt.data<T>());
|
||||
encoder.set_output_array(u);
|
||||
encoder.set_output_array(s);
|
||||
encoder.set_output_array(vt);
|
||||
|
||||
s_ptr = s.data<T>();
|
||||
u_ptr = u.data<T>();
|
||||
vt_ptr = vt.data<T>();
|
||||
} else {
|
||||
array& s = outputs[0];
|
||||
|
||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||
|
||||
svd_impl<T>(a, nullptr, s.data<T>(), nullptr);
|
||||
encoder.set_output_array(s);
|
||||
|
||||
s_ptr = s.data<T>();
|
||||
u_ptr = nullptr;
|
||||
vt_ptr = nullptr;
|
||||
}
|
||||
|
||||
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
|
||||
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
|
||||
const int lda = N;
|
||||
// U of shape M x M. (N x N in lapack).
|
||||
const int ldu = N;
|
||||
// Vᵀ of shape N x N. (M x M in lapack).
|
||||
const int ldvt = M;
|
||||
|
||||
auto job_u = (u_ptr) ? "V" : "N";
|
||||
auto job_vt = (u_ptr) ? "V" : "N";
|
||||
static constexpr auto range = "A";
|
||||
|
||||
// Will contain the number of singular values after the call has returned.
|
||||
int ns = 0;
|
||||
T workspace_dimension = 0;
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not
|
||||
// used here but required by lapack).
|
||||
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
|
||||
|
||||
static const int lwork_query = -1;
|
||||
|
||||
static const int ignored_int = 0;
|
||||
static const T ignored_float = 0;
|
||||
|
||||
int info;
|
||||
|
||||
// Compute workspace size.
|
||||
gesvdx<T>(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ nullptr,
|
||||
/* u = */ nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
/* vt = */ nullptr,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ &workspace_dimension,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_dimension;
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
|
||||
// Loop over matrices.
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
gesvdx<T>(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ in_ptr + M * N * i,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ s_ptr + K * i,
|
||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||
/* vt = */ u_ptr ? u_ptr + M * M * i : nullptr,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
if (ns != K) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: expected " << K << " singular values, but " << ns
|
||||
<< " were computed.";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
});
|
||||
encoder.add_temporary(in);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void compute_svd(
|
||||
const array& a,
|
||||
bool compute_uv,
|
||||
std::vector<array>& outputs,
|
||||
Stream stream) {}
|
||||
|
||||
void SVD::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
switch (inputs[0].dtype()) {
|
||||
case float32:
|
||||
compute_svd<float>(inputs[0], compute_uv_, outputs);
|
||||
svd_impl<float>(inputs[0], outputs, compute_uv_, stream());
|
||||
break;
|
||||
case float64:
|
||||
compute_svd<double>(inputs[0], compute_uv_, outputs);
|
||||
svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/ternary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -53,22 +55,18 @@ void ternary_op_dims(
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
|
||||
const T1* a_ptr,
|
||||
const T2* b_ptr,
|
||||
const T3* c_ptr,
|
||||
U* out_ptr,
|
||||
Op op,
|
||||
size_t size,
|
||||
Shape& shape,
|
||||
std::vector<Strides>& strides) {
|
||||
const auto& a_strides = strides[0];
|
||||
const auto& b_strides = strides[1];
|
||||
const auto& c_strides = strides[2];
|
||||
const auto& out_strides = strides[3];
|
||||
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
U* out_ptr = out.data<T3>();
|
||||
int ndim = shape.size();
|
||||
switch (ndim) {
|
||||
case 1:
|
||||
@@ -105,7 +103,7 @@ void ternary_op_dispatch_dims(
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
ContiguousIterator c_it(shape, c_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
for (size_t elem = 0; elem < size; elem += stride) {
|
||||
ternary_op_dims<T1, T2, T3, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
@@ -134,23 +132,53 @@ void ternary_op(
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt);
|
||||
|
||||
// The full computation is scalar-scalar-scalar so we call the base op once.
|
||||
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
U* out_ptr = out.data<U>();
|
||||
|
||||
if (topt == TernaryOpType::ScalarScalarScalar) {
|
||||
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
|
||||
encoder.dispatch(
|
||||
[a_ptr, b_ptr, c_ptr, out_ptr, op = std::move(op)]() mutable {
|
||||
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
||||
});
|
||||
} else if (topt == TernaryOpType::VectorVectorVector) {
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
U* out_ptr = out.data<U>();
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
||||
a_ptr++;
|
||||
b_ptr++;
|
||||
c_ptr++;
|
||||
out_ptr++;
|
||||
}
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
out_ptr,
|
||||
op = std::move(op),
|
||||
size = out.size()]() mutable {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
||||
a_ptr++;
|
||||
b_ptr++;
|
||||
c_ptr++;
|
||||
out_ptr++;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
|
||||
encoder.dispatch(
|
||||
|
||||
[a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
out_ptr,
|
||||
op = std::move(op),
|
||||
size = out.size(),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides)]() mutable {
|
||||
ternary_op_dispatch_dims<T1, T2, T3, U>(
|
||||
a_ptr, b_ptr, c_ptr, out_ptr, op, size, shape, strides);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,67 +5,83 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void set_unary_output_data(const array& in, array& out) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
if (in.flags().contiguous) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
|
||||
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
||||
for (size_t i = 0; i < shape; i += 1) {
|
||||
out[i] = op(*a);
|
||||
out[i] = Op{}(*a);
|
||||
a += stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const array& a, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
if (a.flags().contiguous) {
|
||||
set_unary_output_data(a, out);
|
||||
U* dst = out.data<U>();
|
||||
constexpr int N = simd::max_size<T>;
|
||||
size_t size = a.data_size();
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a_ptr)));
|
||||
size -= N;
|
||||
a_ptr += N;
|
||||
dst += N;
|
||||
void unary_op(const array& a, array& out, Op) {
|
||||
set_unary_output_data(a, out);
|
||||
const T* src = a.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
encoder.dispatch([src,
|
||||
dst,
|
||||
contig = a.flags().contiguous,
|
||||
data_size = a.data_size(),
|
||||
size = a.size(),
|
||||
shapes = a.shape(),
|
||||
strides = a.strides()]() mutable {
|
||||
auto ndim = shapes.size();
|
||||
if (contig) {
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (data_size >= N) {
|
||||
simd::store(dst, Op{}(simd::load<T, N>(src)));
|
||||
data_size -= N;
|
||||
src += N;
|
||||
dst += N;
|
||||
}
|
||||
while (data_size > 0) {
|
||||
*dst = Op{}(*src);
|
||||
data_size--;
|
||||
dst++;
|
||||
src++;
|
||||
}
|
||||
} else {
|
||||
size_t shape = ndim > 0 ? shapes.back() : 1;
|
||||
size_t stride = ndim > 0 ? strides.back() : 1;
|
||||
if (ndim <= 1) {
|
||||
unary_op<T, U, Op>(src, dst, shape, stride);
|
||||
return;
|
||||
}
|
||||
auto it = ContiguousIterator(shapes, strides, ndim - 1);
|
||||
for (size_t elem = 0; elem < size; elem += shape) {
|
||||
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
|
||||
it.step();
|
||||
}
|
||||
}
|
||||
while (size > 0) {
|
||||
*dst = op(*a_ptr);
|
||||
size--;
|
||||
dst++;
|
||||
a_ptr++;
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
U* dst = out.data<U>();
|
||||
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
|
||||
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
|
||||
if (a.ndim() <= 1) {
|
||||
unary_op(a_ptr, dst, op, shape, stride);
|
||||
return;
|
||||
}
|
||||
ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1);
|
||||
for (size_t elem = 0; elem < a.size(); elem += shape) {
|
||||
unary_op(a_ptr + it.loc, dst + elem, op, shape, stride);
|
||||
it.step();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
|
||||
Reference in New Issue
Block a user