awni's commit files

This commit is contained in:
Awni Hannun
2023-11-29 10:30:41 -08:00
parent e411fcae68
commit 8ca7f9e8e9
130 changed files with 30159 additions and 0 deletions

2
mlx/3rdparty/.clang-format vendored Normal file
View File

@@ -0,0 +1,2 @@
DisableFormat: true
SortIncludes: Never

48
mlx/allocator.cpp Normal file
View File

@@ -0,0 +1,48 @@
#include <cstdlib>
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/scheduler.h"
namespace mlx::core::allocator {
Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
void free(Buffer buffer) {
return allocator().free(buffer);
}
Buffer CommonAllocator::malloc(size_t size) {
return Buffer{std::malloc(size)};
}
void CommonAllocator::free(Buffer buffer) {
std::free(buffer.raw_ptr());
}
Buffer malloc_or_wait(size_t size) {
auto buffer = allocator().malloc(size);
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
scheduler::wait_for_one();
buffer = allocator().malloc(size);
}
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
} // namespace mlx::core::allocator

436
mlx/array.h Normal file
View File

@@ -0,0 +1,436 @@
#pragma once
#include <algorithm>
#include <cstdint>
#include <functional>
#include <memory>
#include <vector>
#include "mlx/allocator.h"
#include "mlx/dtype.h"
namespace mlx::core {
// Forward declaration
class Primitive;
using deleter_t = std::function<void(allocator::Buffer)>;
class array {
/* An array is really a node in a graph. It contains a shared ArrayDesc
* object */
public:
/** Construct a scalar array with zero dimensions. */
template <typename T>
explicit array(T val, Dtype dtype = TypeToDtype<T>());
/* Special case since std::complex can't be implicitly converted to other
* types. */
explicit array(const std::complex<float>& val, Dtype dtype = complex64);
template <typename It>
array(
It data,
const std::vector<int>& shape,
Dtype dtype =
TypeToDtype<typename std::iterator_traits<It>::value_type>());
template <typename T>
array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
/* Special case so empty lists default to float32. */
array(std::initializer_list<float> data);
template <typename T>
array(
std::initializer_list<T> data,
const std::vector<int>& shape,
Dtype dtype = TypeToDtype<T>());
/* Build an array from a buffer */
array(
allocator::Buffer data,
const std::vector<int>& shape,
Dtype dtype,
deleter_t deleter = allocator::free);
/** Assignment to rvalue does not compile. */
array& operator=(const array& other) && = delete;
array& operator=(array&& other) && = delete;
/** Default copy and move constructors otherwise. */
array& operator=(array&& other) & = default;
array(const array& other) = default;
array(array&& other) = default;
array& operator=(const array& other) & {
if (this->id() != other.id()) {
this->array_desc_ = other.array_desc_;
}
return *this;
};
/** The size of the array's datatype in bytes. */
size_t itemsize() const {
return size_of(dtype());
};
/** The number of elements in the array. */
size_t size() const {
return array_desc_->size;
};
/** The number of bytes in the array. */
size_t nbytes() const {
return size() * itemsize();
};
/** The number of dimensions of the array. */
size_t ndim() const {
return array_desc_->shape.size();
};
/** The shape of the array as a vector of integers. */
const std::vector<int>& shape() const {
return array_desc_->shape;
};
/**
* Get the size of the corresponding dimension.
*
* This function supports negative indexing and provides
* bounds checking. */
int shape(int dim) const {
return shape().at(dim < 0 ? dim + ndim() : dim);
};
/** The strides of the array. */
const std::vector<size_t>& strides() const {
return array_desc_->strides;
};
/** Get the arrays data type. */
Dtype dtype() const {
return array_desc_->dtype;
};
/** Evaluate the array. */
void eval(bool retain_graph = false);
/** Get the value from a scalar array. */
template <typename T>
T item(bool retain_graph = false);
struct ArrayIterator {
using iterator_category = std::random_access_iterator_tag;
using difference_type = size_t;
using value_type = const array;
using reference = value_type;
explicit ArrayIterator(const array& arr, int idx = 0) : arr(arr), idx(idx) {
if (arr.ndim() == 0) {
throw std::invalid_argument("Cannot iterate over 0-d array.");
}
}
reference operator*() const;
ArrayIterator& operator+(difference_type diff) {
idx += diff;
return *this;
}
ArrayIterator& operator++() {
idx++;
return *this;
}
friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
return a.arr.id() == b.arr.id() && a.idx == b.idx;
};
friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
return !(a == b);
};
private:
int idx;
const array& arr;
};
ArrayIterator begin() const {
return ArrayIterator(*this);
}
ArrayIterator end() const {
return ArrayIterator(*this, shape(0));
}
/**
* The following methods should be used with caution.
* They are intended for use by the backend implementation and the
* API may change.
*/
array(
const std::vector<int>& shape,
Dtype dtype,
std::unique_ptr<Primitive> primitive,
const std::vector<array>& inputs);
/** A unique identifier for an array. */
std::uintptr_t id() const {
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
}
struct Data {
allocator::Buffer buffer;
deleter_t d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
: buffer(buffer), d(d){};
// Not copyable
Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete;
~Data() {
d(buffer);
}
};
struct Flags {
// True if there are no gaps in the underlying data. Each item
// in the underlying data buffer belongs to at least one index.
bool contiguous : 1;
bool row_contiguous : 1;
bool col_contiguous : 1;
};
/** The array's primitive. */
Primitive& primitive() const {
return *(array_desc_->primitive);
};
/** Check if the array has an attached primitive or is a leaf node. */
bool has_primitive() const {
return array_desc_->primitive != nullptr;
};
/** The array's inputs. */
const std::vector<array>& inputs() const {
return array_desc_->inputs;
};
/** A non-const reference to the array's inputs so that they can be used to
* edit the graph. */
std::vector<array>& editable_inputs() {
return array_desc_->inputs;
}
/** Detach the array from the graph. */
void detach();
/** Get the Flags bit-field. */
const Flags& flags() const {
return array_desc_->flags;
};
/** The size (in elements) of the underlying buffer the array points to. */
size_t data_size() const {
return array_desc_->data_size;
};
allocator::Buffer& buffer() {
return array_desc_->data->buffer;
};
const allocator::Buffer& buffer() const {
return array_desc_->data->buffer;
};
template <typename T>
T* data() {
return static_cast<T*>(array_desc_->data_ptr);
};
template <typename T>
const T* data() const {
return static_cast<T*>(array_desc_->data_ptr);
};
// Check if the array has been evaluated
bool is_evaled() const {
return array_desc_->data != nullptr;
}
// Mark the array as a tracer array (true) or not.
void set_tracer(bool is_tracer) {
array_desc_->is_tracer = is_tracer;
}
// Check if the array is a tracer array
bool is_tracer() const {
return array_desc_->is_tracer;
}
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
void set_data(
allocator::Buffer buffer,
size_t data_size,
std::vector<size_t> strides,
Flags flags,
deleter_t d = allocator::free);
void copy_shared_buffer(
const array& other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
void copy_shared_buffer(const array& other);
void overwrite_descriptor(const array& other) {
array_desc_ = other.array_desc_;
}
private:
// Initialize the arrays data
template <typename It>
void init(const It src);
struct ArrayDesc {
std::vector<int> shape;
std::vector<size_t> strides;
size_t size;
Dtype dtype;
std::unique_ptr<Primitive> primitive{nullptr};
// Indicates an array is being used in a graph transform
// and should not be detached from the graph
bool is_tracer{false};
// This is a shared pointer so that *different* arrays
// can share the underlying data buffer.
std::shared_ptr<Data> data{nullptr};
// Properly offset data pointer
void* data_ptr{nullptr};
// The size in elements of the data buffer the array accesses
// This can be different than the actual size of the array if it
// has been broadcast or irregularly strided.
size_t data_size;
// Contains useful meta data about the array
Flags flags;
std::vector<array> inputs;
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
explicit ArrayDesc(
const std::vector<int>& shape,
Dtype dtype,
std::unique_ptr<Primitive> primitive,
const std::vector<array>& inputs);
~ArrayDesc();
};
// The ArrayDesc contains the details of the materialized array including the
// shape, strides, the data type. It also includes
// the primitive which knows how to compute the array's data from its inputs
// and a the list of array's inputs for the primitive.
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
};
template <typename T>
array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
init(&val);
}
template <typename It>
array::array(
It data,
const std::vector<int>& shape,
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
init(data);
}
template <typename T>
array::array(
std::initializer_list<T> data,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
dtype)) {
init(data.begin());
}
template <typename T>
array::array(
std::initializer_list<T> data,
const std::vector<int>& shape,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
if (data.size() != size()) {
throw std::invalid_argument(
"Data size and provided shape mismatch in array construction.");
}
init(data.begin());
}
template <typename T>
T array::item(bool retain_graph /* = false */) {
if (size() != 1) {
throw std::invalid_argument("item can only be called on arrays of size 1.");
}
eval(retain_graph);
return *data<T>();
}
template <typename It>
void array::init(It src) {
set_data(allocator::malloc(size() * size_of(dtype())));
switch (dtype()) {
case bool_:
std::copy(src, src + size(), data<bool>());
break;
case uint8:
std::copy(src, src + size(), data<uint8_t>());
break;
case uint16:
std::copy(src, src + size(), data<uint16_t>());
break;
case uint32:
std::copy(src, src + size(), data<uint32_t>());
break;
case uint64:
std::copy(src, src + size(), data<uint64_t>());
break;
case int8:
std::copy(src, src + size(), data<int8_t>());
break;
case int16:
std::copy(src, src + size(), data<int16_t>());
break;
case int32:
std::copy(src, src + size(), data<int32_t>());
break;
case int64:
std::copy(src, src + size(), data<int64_t>());
break;
case float16:
std::copy(src, src + size(), data<float16_t>());
break;
case float32:
std::copy(src, src + size(), data<float>());
break;
case bfloat16:
std::copy(src, src + size(), data<bfloat16_t>());
break;
case complex64:
std::copy(src, src + size(), data<complex64_t>());
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,9 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
)

View File

@@ -0,0 +1,167 @@
#include <cassert>
#include <vecLib/BNNS/bnns.h>
#include <vecLib/cblas_new.h>
#include "mlx/backend/accelerate/utils.h"
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
std::tuple<bool, size_t, array> check_transpose(const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
}
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[matmul_cblas] on CPU currently only supports float32");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0f, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
0.0f, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
}
}
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
// TODO: Update to utilize BNNS broadcasting
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ 1.0,
/* float beta = */ 0.0,
/* bool transA = */ a_transposed,
/* bool transB = */ b_transposed,
/* bool quadratic = */ false,
/* bool a_is_weights = */ false,
/* bool b_is_weights = */ false,
/* BNNSNDArrayDescriptor iA_desc = */
BNNSNDArrayDescriptor{
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
{lda, (M * K) / lda, 0, 0, 0, 0, 0, 0},
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
{1, lda, 0, 0, 0, 0, 0, 0},
/* void * _Nullable data = */ nullptr,
/* BNNSDataType data_type = */ bnns_dtype,
/* void * _Nullable table_data = */ nullptr,
/* BNNSDataType table_data_type = */ bnns_dtype,
/* float data_scale = */ 1.0,
/* float data_bias = */ 0.0,
},
/* BNNSNDArrayDescriptor iB_desc = */
BNNSNDArrayDescriptor{
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
{ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0},
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
{1, ldb, 0, 0, 0, 0, 0, 0},
/* void * _Nullable data = */ nullptr,
/* BNNSDataType data_type = */ bnns_dtype,
/* void * _Nullable table_data = */ nullptr,
/* BNNSDataType table_data_type = */ bnns_dtype,
/* float data_scale = */ 1.0,
/* float data_bias = */ 0.0,
},
/* BNNSNDArrayDescriptor o_desc = */
BNNSNDArrayDescriptor{
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
{N, M, 0, 0, 0, 0, 0, 0},
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
{1, N, 0, 0, 0, 0, 0, 0},
/* void * _Nullable data = */ nullptr,
/* BNNSDataType data_type = */ bnns_dtype,
/* void * _Nullable table_data = */ nullptr,
/* BNNSDataType table_data_type = */ bnns_dtype,
/* float data_scale = */ 1.0,
/* float data_bias = */ 0.0,
},
};
auto bnns_filter =
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
BNNSFilterApplyTwoInput(
bnns_filter,
a.data<uint8_t>() +
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
b.data<uint8_t>() +
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
out.data<uint8_t>() + M * N * i * out.itemsize());
}
BNNSFilterDestroy(bnns_filter);
}
} // namespace
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() == float32) {
return matmul_cblas(inputs[0], inputs[1], out);
}
return matmul_bnns(inputs[0], inputs[1], out);
}
} // namespace mlx::core

View File

@@ -0,0 +1,672 @@
#include <cassert>
#include <cmath>
#include <vecLib/vDSP.h>
#include <vecLib/vForce.h>
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/unary.h"
#include "mlx/primitives.h"
#define DEFAULT(primitive) \
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
primitive::eval(inputs, out); \
}
namespace mlx::core {
// Use the default implementation for the following primitives
DEFAULT(Arange)
DEFAULT(ArgPartition)
DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(Concatenate)
DEFAULT(Copy)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(FFT)
DEFAULT(Gather)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
DEFAULT(LogicalNot)
DEFAULT(LogAddExp)
DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT(RandomBits)
DEFAULT(Reshape)
DEFAULT(Scatter)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Slice)
DEFAULT(Sort)
DEFAULT(StopGradient)
DEFAULT(Transpose)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, size);
} else if (in.dtype() == int32 && in.flags().contiguous) {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, size);
} else if (is_unsigned(in.dtype())) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
unary(in, out, AbsOp());
}
}
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return x + y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary(
a,
b,
out,
[](auto x, auto y) { return x + y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return x + y; });
}
}
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvacosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvacoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvasinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvasinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvatanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvatanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().contiguous) {
auto allocfn = [&in, &out]() {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
};
// Use accelerate functions if possible
if (in.dtype() == float32 && out.dtype() == uint32) {
allocfn();
vDSP_vfixu32(
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == float32 && out.dtype() == int32) {
allocfn();
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == uint32 && out.dtype() == float32) {
allocfn();
vDSP_vfltu32(
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
} else if (in.dtype() == int32 && out.dtype() == float32) {
allocfn();
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
}
}
eval(inputs, out);
}
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvcosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvcoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == int32) {
binary(
a,
b,
out,
[](auto x, auto y) { return x / y; },
UseDefaultBinaryOp(),
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsdivi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
});
} else if (a.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return x / y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
vDSP_svdiv((const float*)s, (const float*)vec, 1, (float*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsdiv((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return x / y; });
}
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
} else {
throw std::invalid_argument(
"[exp] Cannot exponentiate elements in array"
" with non floating point type.");
}
}
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
assert(in.dtype() == out.dtype());
if (in.data_size() == 1 && out.dtype() == float32) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
vDSP_vfill(in.data<float>(), out.data<float>(), 1, out.size());
} else {
eval(inputs, out);
}
}
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
switch (base_) {
case Base::e:
vvlogf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
break;
case Base::two:
vvlog2f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
break;
case Base::ten:
vvlog10f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
break;
}
} else {
eval(inputs, out);
}
}
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else {
throw std::invalid_argument(
"[log1p] Cannot compute log of elements in array with"
" non floating point type.");
}
}
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return (x > y) ? x : y; },
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* out, int n) {
vDSP_vmax((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
}
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return (x < y) ? x : y; },
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* out, int n) {
vDSP_vmin((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return x * y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return x * y; });
}
}
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
unary(in, out, [](auto x) { return -x; });
}
}
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
int size = a.size();
out.set_data(allocator::malloc_or_wait(out.nbytes()));
vvpowf(out.data<float>(), a.data<float>(), b.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (reduce_type_ == Scan::Sum && out.dtype() == float32 &&
in.flags().row_contiguous && in.strides()[axis_] == 1 && !inclusive_) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
int stride = in.shape(axis_);
int count = in.size() / stride;
const float* input = in.data<float>();
float* output = out.data<float>();
float s = 1.0;
if (!reverse_) {
for (int i = 0; i < count; i++) {
vDSP_vrsum(input - 1, 1, &s, output, 1, stride);
input += stride;
output += stride;
}
} else {
for (int i = 0; i < count; i++) {
input += stride - 1;
output += stride - 1;
vDSP_vrsum(input + 1, -1, &s, output, -1, stride);
input++;
output++;
}
}
} else {
eval(inputs, out);
}
}
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvsinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvsinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
unary(in, out, [](auto x) { return x * x; });
}
}
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
if (recip_) {
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
} else {
vvsqrtf(out.data<float>(), in.data<float>(), &size);
}
} else {
eval(inputs, out);
}
}
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return x - y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
float minus_1 = -1;
vDSP_vsmsa(
(const float*)vec, 1, &minus_1, (const float*)s, (float*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
float val = -(*s);
vDSP_vsadd((const float*)vec, 1, &val, (float*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary(
a,
b,
out,
[](auto x, auto y) { return x - y; },
UseDefaultBinaryOp(),
[](const auto* vec, const auto* s, auto* o, auto n) {
int val = -(*s);
vDSP_vsaddi((const int*)vec, 1, &val, (int*)o, 1, n);
},
UseDefaultBinaryOp());
} else {
binary(a, b, out, [](auto x, auto y) { return x - y; });
}
}
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvtanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvtanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,147 @@
#include <cassert>
#include <simd/vector.h>
#include <vecLib/vDSP.h>
#include "mlx/backend/common/reduce.h"
#include "mlx/primitives.h"
namespace mlx::core {
template <typename T, typename VT, int N>
void _vectorized_strided_sum(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
VT val = (*(VT*)x);
*(VT*)a += val;
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a++ += *x++;
}
}
}
// TODO: Add proper templates for the strided reduce algorithm so we don't have
// to write max/min/sum etc.
template <typename T, typename VT, int N>
void _vectorized_strided_max(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = simd_max((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = std::max(*a, *x);
a++;
x++;
}
}
}
template <typename T, typename VT, int N>
void _vectorized_strided_min(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = simd_min((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = std::min(*a, *x);
a++;
x++;
}
}
}
template <typename T, typename VT, int N>
void _vectorized_sum(const T* x, T* accum, int size) {
VT _sum = {0};
while (size >= N) {
_sum += (*(VT*)x);
x += N;
size -= N;
}
T sum = _sum[0];
for (int i = 1; i < N; i++) {
sum += _sum[i];
}
*accum += sum;
}
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32) {
if (reduce_type_ == Reduce::Sum) {
reduction_op<float, float>(
in,
out,
axes_,
0,
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_sum<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
[](const auto* x, auto* accum, int size) {
float acc;
vDSP_sve((const float*)x, 1, &acc, size);
(*accum) += acc;
},
[](auto* accum, auto x) { *accum += x; });
return;
} else if (reduce_type_ == Reduce::Max) {
reduction_op<float, float>(
in,
out,
axes_,
-std::numeric_limits<float>::infinity(),
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_max<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
[](const auto* x, auto* accum, int size) {
float max;
vDSP_maxv((const float*)x, 1, &max, size);
(*accum) = (*accum < max) ? max : *accum;
},
[](auto* accum, auto x) { (*accum) = (*accum < x) ? x : *accum; });
return;
} else if (reduce_type_ == Reduce::Min) {
reduction_op<float, float>(
in,
out,
axes_,
std::numeric_limits<float>::infinity(),
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_min<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
[](const auto* x, auto* accum, int size) {
float min;
vDSP_minv((const float*)x, 1, &min, size);
(*accum) = (*accum > min) ? min : *accum;
},
[](auto* accum, auto x) { (*accum) = (*accum > x) ? x : *accum; });
return;
}
}
// TODO: Add integer addition and min/max using the templates above and
// simd_int16 and friends.
eval(inputs, out);
}
} // namespace mlx::core

View File

@@ -0,0 +1,18 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
)

View File

@@ -0,0 +1,72 @@
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
namespace mlx::core {
namespace {
template <typename T>
void arange(T start, T next, array& out, size_t size) {
auto ptr = out.data<T>();
auto step_size = next - start;
for (int i = 0; i < size; ++i) {
ptr[i] = start;
start += step_size;
}
}
} // namespace
void arange(
const std::vector<array>& inputs,
array& out,
double start,
double step) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
switch (out.dtype()) {
case bool_:
throw std::runtime_error("Bool type unsupported for arange.");
break;
case uint8:
arange<uint8_t>(start, start + step, out, out.size());
break;
case uint16:
arange<uint16_t>(start, start + step, out, out.size());
break;
case uint32:
arange<uint32_t>(start, start + step, out, out.size());
break;
case uint64:
arange<uint64_t>(start, start + step, out, out.size());
break;
case int8:
arange<int8_t>(start, start + step, out, out.size());
break;
case int16:
arange<int16_t>(start, start + step, out, out.size());
break;
case int32:
arange<int32_t>(start, start + step, out, out.size());
break;
case int64:
arange<int64_t>(start, start + step, out, out.size());
break;
case float16:
arange<float16_t>(start, start + step, out, out.size());
break;
case float32:
arange<float>(start, start + step, out, out.size());
break;
case bfloat16:
arange<bfloat16_t>(start, start + step, out, out.size());
break;
case complex64:
arange<complex64_t>(start, start + step, out, out.size());
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,110 @@
#include <cassert>
#include "mlx/primitives.h"
#include "utils.h"
namespace mlx::core {
namespace {
template <typename InT, typename OpT>
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis];
std::vector<size_t> strides = in.strides();
std::vector<int> shape = in.shape();
strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis);
for (uint32_t i = 0; i < out.size(); ++i) {
auto loc = elem_to_loc(i, shape, strides);
auto in_ptr = in.data<InT>() + loc;
uint32_t ind_v = 0;
InT v = (*in_ptr);
for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) {
op(j, (*in_ptr), &ind_v, &v);
}
out.data<uint32_t>()[i] = ind_v;
}
}
template <typename InT>
void arg_reduce_dispatch(
const array& in,
array& out,
ArgReduce::ReduceType rtype,
int axis) {
switch (rtype) {
case ArgReduce::ArgMin: {
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
if (x < (*y)) {
(*y) = x;
(*ind_y) = ind_x;
}
};
arg_reduce<InT>(in, out, op, axis);
break;
}
case ArgReduce::ArgMax: {
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
if (x > (*y)) {
(*y) = x;
(*ind_y) = ind_x;
}
};
arg_reduce<InT>(in, out, op, axis);
break;
}
}
}
} // namespace
void ArgReduce::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
switch (in.dtype()) {
case bool_:
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
break;
case uint8:
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
break;
case uint16:
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
break;
case uint32:
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
break;
case uint64:
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
break;
case int8:
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
break;
case int16:
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
break;
case int32:
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
break;
case int64:
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
break;
case float16:
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
break;
case float32:
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
break;
case bfloat16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
break;
case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
break;
}
}
} // namespace mlx::core

541
mlx/backend/common/conv.cpp Normal file
View File

@@ -0,0 +1,541 @@
#include <cassert>
#ifdef ACCELERATE_NEW_LAPACK
#include <vecLib/cblas_new.h>
#else
#include <cblas.h>
#endif
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
///////////////////////////////////////////////////////////////////////////////
// Naive reference conv
///////////////////////////////////////////////////////////////////////////////
template <typename T>
void slow_conv_1D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const T* start_wt_ptr = wt.data<T>();
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(2); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const size_t in_stride_N = in.strides()[0];
const size_t in_stride_H = in.strides()[1];
const size_t in_stride_C = in.strides()[2];
const size_t wt_stride_O = wt.strides()[0];
const size_t wt_stride_H = wt.strides()[1];
const size_t wt_stride_C = wt.strides()[2];
const size_t out_stride_N = out.strides()[0];
const size_t out_stride_H = out.strides()[1];
const size_t out_stride_O = out.strides()[2];
for (int n = 0; n < N; ++n) {
for (int oh = 0; oh < oH; ++oh) {
for (int o = 0; o < O; ++o) {
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
float r = 0.;
for (int wh = 0; wh < wH; ++wh) {
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
int ih = oh * wt_strides[0] - padding[0] + wh * wt_dilation[0];
if (ih >= 0 && ih < iH) {
for (int c = 0; c < C; ++c) {
r += static_cast<float>(
in_ptr[ih * in_stride_H + c * in_stride_C]) *
static_cast<float>(wt_ptr[c * wt_stride_C]);
} // c
} // ih check
} // wh
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
} // o
} // oh
in_ptr += in_stride_N;
out_ptr += out_stride_N;
} // n
}
template <typename T>
void slow_conv_2D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const T* st_wt_ptr = wt.data<T>();
const T* st_in_ptr = in.data<T>();
T* st_out_ptr = out.data<T>();
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int iW = in.shape(2); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(3); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int wW = wt.shape(2); // Weight spatial dim
const size_t in_stride_N = in.strides()[0];
const size_t in_stride_H = in.strides()[1];
const size_t in_stride_W = in.strides()[2];
const size_t in_stride_C = in.strides()[3];
const size_t wt_stride_O = wt.strides()[0];
const size_t wt_stride_H = wt.strides()[1];
const size_t wt_stride_W = wt.strides()[2];
const size_t wt_stride_C = wt.strides()[3];
const size_t out_stride_N = out.strides()[0];
const size_t out_stride_H = out.strides()[1];
const size_t out_stride_W = out.strides()[2];
const size_t out_stride_O = out.strides()[3];
auto pt_conv_no_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int ih = ih_base + wh * wt_dilation[0];
int iw = iw_base + ww * wt_dilation[1];
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
} // ww
} // wh
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
};
auto pt_conv_all_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int ih = ih_base + wh * wt_dilation[0];
int iw = iw_base + ww * wt_dilation[1];
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt =
in_ptr + ih * in_stride_H + iw * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
} // ih, iw check
} // ww
} // wh
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
};
int oH_border_0 = 0;
int oH_border_1 = (padding[0] + wt_strides[0] + 1) / wt_strides[0];
int oH_border_2 = (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0];
int oH_border_3 = oH;
int oW_border_0 = 0;
int oW_border_1 = (padding[1] + wt_strides[0] + 1) / wt_strides[1];
int oW_border_2 = (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1];
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) {
// Case 1: oh might put us out of bounds
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
// Case 2: oh in bounds
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
// Case a: ow might put us out of bounds
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case b: ow in bounds
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case c: ow might put us out of bounds
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
// Case 3: oh might put us out of bounds
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
st_in_ptr += in_stride_N;
st_out_ptr += out_stride_N;
} // n
}
void dispatch_slow_conv_1D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
if (in.dtype() == float32) {
return slow_conv_1D<float>(in, wt, out, padding, wt_strides, wt_dilation);
} else if (in.dtype() == float16) {
return slow_conv_1D<float16_t>(
in, wt, out, padding, wt_strides, wt_dilation);
} else if (in.dtype() == bfloat16) {
return slow_conv_1D<bfloat16_t>(
in, wt, out, padding, wt_strides, wt_dilation);
} else {
throw std::invalid_argument(
"[Convolution::eval] got unsupported data type.");
}
}
void dispatch_slow_conv_2D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
if (in.dtype() == float32) {
return slow_conv_2D<float>(in, wt, out, padding, wt_strides, wt_dilation);
} else if (in.dtype() == float16) {
return slow_conv_2D<float16_t>(
in, wt, out, padding, wt_strides, wt_dilation);
} else if (in.dtype() == bfloat16) {
return slow_conv_2D<bfloat16_t>(
in, wt, out, padding, wt_strides, wt_dilation);
} else {
throw std::invalid_argument(
"[Convolution::eval] got unsupported data type.");
}
}
///////////////////////////////////////////////////////////////////////////////
// Explicit gemm conv
///////////////////////////////////////////////////////////////////////////////
void explicit_gemm_conv_1D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(2); // In channels
const int wH = wt.shape(1); // Weight spatial dim
auto conv_dtype = float32;
// Pad input
std::vector<int> padded_shape = {N, iH + 2 * padding[0], C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
copy(array(0, conv_dtype), in_padded, CopyType::Scalar);
// Pick input slice from padded
size_t data_offset = padding[0] * in_padded.strides()[1];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view
std::vector<int> strided_shape = {N, oH, wH, C};
std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[1],
in_padded.strides()[2]};
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N * oH, wH * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);
// Check wt dtype and prepare
auto gemm_wt = wt;
auto gemm_out = out;
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy(wt, gemm_wt, ctype);
}
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
}
// Peform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided.data<float>(),
strided_reshape[1], // lda
gemm_wt.data<float>(),
strided_reshape[1], // ldb
0.0f, // beta
gemm_out.data<float>(),
O // ldc
);
// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
}
}
void explicit_gemm_conv_2D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int iW = in.shape(2); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(3); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int wW = wt.shape(2); // Weight spatial dim
auto conv_dtype = out.dtype();
// Pad input
std::vector<int> padded_shape = {
N, iH + 2 * padding[0], iW + 2 * padding[1], C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
copy(array(0, conv_dtype), in_padded, CopyType::Scalar);
// Pick input slice from padded
size_t data_offset =
padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view
std::vector<int> strided_shape = {N, oH, oW, wH, wW, C};
std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[2] * wt_strides[1],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]};
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N * oH * oW, wH * wW * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);
// Check wt dtype and prepare
auto gemm_wt = wt;
auto gemm_out = out;
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy(wt, gemm_wt, ctype);
}
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
}
// Peform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided.data<float>(),
strided_reshape[1], // lda
gemm_wt.data<float>(),
strided_reshape[1], // ldb
0.0f, // beta
gemm_out.data<float>(),
O // ldc
);
// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
}
}
///////////////////////////////////////////////////////////////////////////////
// Conv routing
///////////////////////////////////////////////////////////////////////////////
void conv_1D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
if (wt_dilation[0] == 1) {
return explicit_gemm_conv_1D_cpu(
in, wt, out, padding, wt_strides, wt_dilation);
}
return dispatch_slow_conv_1D(in, wt, out, padding, wt_strides, wt_dilation);
}
void conv_2D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
return dispatch_slow_conv_2D(in, wt, out, padding, wt_strides, wt_dilation);
}
} // namespace
void Convolution::eval(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& in = inputs[0];
auto& wt = inputs[1];
// 2D convolution
if (in.ndim() == (2 + 2)) {
return conv_2D_cpu(
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
}
// 1D convolution
else if (in.ndim() == (1 + 2)) {
return conv_1D_cpu(
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
}
// Throw error
else {
std::ostringstream msg;
msg << "[Convolution::eval] Convolution currently only supports"
<< " 1D and 2D convolutions. Got inputs with " << in.ndim() - 2
<< " spatial dimensions";
throw std::invalid_argument(msg.str());
}
}
} // namespace mlx::core

308
mlx/backend/common/copy.cpp Normal file
View File

@@ -0,0 +1,308 @@
#include <numeric>
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
namespace mlx::core {
namespace {
template <typename SrcT, typename DstT>
void copy_single(const array& src, array& dst) {
auto val = static_cast<DstT>(src.data<SrcT>()[0]);
auto dst_ptr = dst.data<DstT>();
for (int i = 0; i < dst.size(); ++i) {
dst_ptr[i] = val;
}
}
template <typename SrcT, typename DstT>
void copy_vector(const array& src, array& dst) {
auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>();
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
}
template <typename SrcT, typename DstT>
void copy_general_dim1(const array& src, array& dst) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
size_t src_idx = 0;
size_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += src.strides()[0];
}
}
template <typename SrcT, typename DstT>
void copy_general_dim2(const array& src, array& dst) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
size_t src_idx = 0;
size_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += src.strides()[1];
}
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
}
}
template <typename SrcT, typename DstT>
void copy_general_dim3(const array& src, array& dst) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
size_t src_idx = 0;
size_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) {
for (size_t k = 0; k < src.shape()[2]; ++k) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += src.strides()[2];
}
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
}
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
}
}
template <typename SrcT, typename DstT>
void copy_general_dim4(const array& src, array& dst) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
size_t src_idx = 0;
size_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) {
for (size_t k = 0; k < src.shape()[2]; ++k) {
for (size_t ii = 0; ii < src.shape()[3]; ++ii) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += src.strides()[3];
}
src_idx += src.strides()[2] - src.strides()[3] * src.shape()[3];
}
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
}
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
}
}
template <typename SrcT, typename DstT>
void copy_general(const array& src, array& dst) {
switch (src.ndim()) {
case 1:
copy_general_dim1<SrcT, DstT>(src, dst);
return;
case 2:
copy_general_dim2<SrcT, DstT>(src, dst);
return;
case 3:
copy_general_dim3<SrcT, DstT>(src, dst);
return;
case 4:
copy_general_dim4<SrcT, DstT>(src, dst);
return;
}
auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>();
for (size_t i = 0; i < dst.size(); ++i) {
size_t src_elem = elem_to_loc(i, src.shape(), src.strides());
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
}
}
template <typename SrcT, typename DstT, int D>
inline void copy_general_general_dims(
const array& src,
array& dst,
size_t offset_src,
size_t offset_dst) {
if constexpr (D > 1) {
int axis = src.ndim() - D;
auto stride_src = src.strides()[axis];
auto stride_dst = dst.strides()[axis];
auto N = src.shape(axis);
for (int i = 0; i < N; i++) {
copy_general_general_dims<SrcT, DstT, D - 1>(
src, dst, offset_src, offset_dst);
offset_src += stride_src;
offset_dst += stride_dst;
}
} else {
int axis = src.ndim() - 1;
auto stride_src = src.strides()[axis];
auto stride_dst = dst.strides()[axis];
auto N = src.shape(axis);
const SrcT* src_ptr = src.data<SrcT>() + offset_src;
DstT* dst_ptr = dst.data<DstT>() + offset_dst;
for (int i = 0; i < N; i++) {
*dst_ptr = static_cast<DstT>(*src_ptr);
src_ptr += stride_src;
dst_ptr += stride_dst;
}
}
}
template <typename SrcT, typename DstT>
void copy_general_general(const array& src, array& dst) {
switch (src.ndim()) {
case 1:
copy_general_general_dims<SrcT, DstT, 1>(src, dst, 0, 0);
return;
case 2:
copy_general_general_dims<SrcT, DstT, 2>(src, dst, 0, 0);
return;
case 3:
copy_general_general_dims<SrcT, DstT, 3>(src, dst, 0, 0);
return;
case 4:
copy_general_general_dims<SrcT, DstT, 4>(src, dst, 0, 0);
return;
case 5:
copy_general_general_dims<SrcT, DstT, 5>(src, dst, 0, 0);
return;
}
int size = std::accumulate(
src.shape().begin() - 5, src.shape().end(), 1, std::multiplies<int>());
for (int i = 0; i < src.size(); i += size) {
size_t offset_src = elem_to_loc(i, src.shape(), src.strides());
size_t offset_dst = elem_to_loc(i, dst.shape(), dst.strides());
copy_general_general_dims<SrcT, DstT, 5>(src, dst, offset_src, offset_dst);
}
}
template <typename SrcT, typename DstT>
void copy(const array& src, array& dst, CopyType ctype) {
switch (ctype) {
case CopyType::Scalar:
copy_single<SrcT, DstT>(src, dst);
return;
case CopyType::Vector:
copy_vector<SrcT, DstT>(src, dst);
return;
case CopyType::General:
copy_general<SrcT, DstT>(src, dst);
return;
case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>(src, dst);
}
}
template <typename SrcT>
void copy(const array& src, array& dst, CopyType ctype) {
switch (dst.dtype()) {
case bool_:
copy<SrcT, bool>(src, dst, ctype);
break;
case uint8:
copy<SrcT, uint8_t>(src, dst, ctype);
break;
case uint16:
copy<SrcT, uint16_t>(src, dst, ctype);
break;
case uint32:
copy<SrcT, uint32_t>(src, dst, ctype);
break;
case uint64:
copy<SrcT, uint64_t>(src, dst, ctype);
break;
case int8:
copy<SrcT, int8_t>(src, dst, ctype);
break;
case int16:
copy<SrcT, int16_t>(src, dst, ctype);
break;
case int32:
copy<SrcT, int32_t>(src, dst, ctype);
break;
case int64:
copy<SrcT, int64_t>(src, dst, ctype);
break;
case float16:
copy<SrcT, float16_t>(src, dst, ctype);
break;
case float32:
copy<SrcT, float>(src, dst, ctype);
break;
case bfloat16:
copy<SrcT, bfloat16_t>(src, dst, ctype);
break;
case complex64:
copy<SrcT, complex64_t>(src, dst, ctype);
break;
}
}
} // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype) {
switch (src.dtype()) {
case bool_:
copy<bool>(src, dst, ctype);
break;
case uint8:
copy<uint8_t>(src, dst, ctype);
break;
case uint16:
copy<uint16_t>(src, dst, ctype);
break;
case uint32:
copy<uint32_t>(src, dst, ctype);
break;
case uint64:
copy<uint64_t>(src, dst, ctype);
break;
case int8:
copy<int8_t>(src, dst, ctype);
break;
case int16:
copy<int16_t>(src, dst, ctype);
break;
case int32:
copy<int32_t>(src, dst, ctype);
break;
case int64:
copy<int64_t>(src, dst, ctype);
break;
case float16:
copy<float16_t>(src, dst, ctype);
break;
case float32:
copy<float>(src, dst, ctype);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype);
break;
case complex64:
copy<complex64_t>(src, dst, ctype);
break;
}
}
void copy(const array& src, array& dst, CopyType ctype) {
// Allocate the output
switch (ctype) {
case CopyType::Vector:
dst.set_data(
allocator::malloc_or_wait(src.data_size() * dst.itemsize()),
src.data_size(),
src.strides(),
src.flags());
break;
case CopyType::Scalar:
case CopyType::General:
case CopyType::GeneralGeneral:
dst.set_data(allocator::malloc_or_wait(dst.nbytes()));
break;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_inplace(src, dst, ctype);
}
} // namespace mlx::core

View File

@@ -0,0 +1,130 @@
#include <cblas.h>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
#define DEFAULT(primitive) \
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
primitive::eval(inputs, out); \
}
namespace mlx::core {
DEFAULT(Abs)
DEFAULT(Add)
DEFAULT(Arange)
DEFAULT(ArcCos)
DEFAULT(ArcCosh)
DEFAULT(ArcSin)
DEFAULT(ArcSinh)
DEFAULT(ArcTan)
DEFAULT(ArcTanh)
DEFAULT(ArgPartition)
DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(Concatenate)
DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)
DEFAULT(Cosh)
DEFAULT(Divide)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(Exp)
DEFAULT(FFT)
DEFAULT(Full)
DEFAULT(Gather)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
DEFAULT(Log)
DEFAULT(Log1p)
DEFAULT(LogicalNot)
DEFAULT(LogAddExp)
DEFAULT(Maximum)
DEFAULT(Minimum)
DEFAULT(Multiply)
DEFAULT(Negative)
DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT(Power)
DEFAULT(RandomBits)
DEFAULT(Reduce)
DEFAULT(Reshape)
DEFAULT(Scan)
DEFAULT(Scatter)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Sin)
DEFAULT(Sinh)
DEFAULT(Slice)
DEFAULT(Softmax)
DEFAULT(Sort)
DEFAULT(Square)
DEFAULT(Sqrt)
DEFAULT(StopGradient)
DEFAULT(Subtract)
DEFAULT(Tan)
DEFAULT(Tanh)
DEFAULT(Transpose)
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[Matmul::eval_cpu] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto check_transpose = [](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0f, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
0.0f, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,38 @@
#include <cmath>
namespace mlx::core {
/* Approximation to the inverse error function.
* Based on code from:
* https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348
*/
float erfinv(float a) {
auto t = std::fma(a, 0.0f - a, 1.0f);
t = std::log(t);
float p;
if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
p = 3.03697567e-10f; // 0x1.4deb44p-32
p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
} else { // maximum ulp error = 2.35002
p = 5.43877832e-9f; // 0x1.75c000p-28
p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
}
return a * p;
}
} // namespace mlx::core

10
mlx/backend/common/erf.h Normal file
View File

@@ -0,0 +1,10 @@
namespace mlx::core {
/* Approximation to the inverse error function.
* Based on code from:
* https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348
*/
float erfinv(float a);
} // namespace mlx::core

View File

@@ -0,0 +1,377 @@
#include <algorithm>
#include <cassert>
#include <cmath>
#include "mlx/allocator.h"
#include "mlx/primitives.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
template <typename IdxT>
inline size_t offset_neg_idx(IdxT idx, size_t size) {
return (idx < 0) ? idx + size : idx;
}
template <>
inline size_t offset_neg_idx(bool idx, size_t) {
return idx;
}
template <>
inline size_t offset_neg_idx(uint32_t idx, size_t) {
return idx;
}
template <typename T, typename IdxT>
void gather(
const array& src,
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const std::vector<int>& slice_sizes) {
// If the array is row contiguous then we can do a contiguous copy given
// two conditions on the slice size:
// - Any number of leading ones in the slice sizes are allowed
// - All other slice sizes match the corresponding dimension except the
// first non-singleton slice size
// If the array is col contiguous then the reverse is the case:
// - Any number of trailing ones in the slice sizes are allowed
// - All other slice sizes match the corresponding dimension except the
// first non-singleton slice size from the end
bool can_copy = false;
if (src.flags().row_contiguous) {
can_copy = true;
// Ignore leading 1s
int i = 0;
for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i)
;
// Check the remaining
i++;
for (; i < src.ndim() && can_copy; ++i) {
can_copy = (src.shape(i) == slice_sizes[i]);
}
} else if (src.flags().col_contiguous) {
can_copy = true;
// Ignore trailing 1s
int i = slice_sizes.size() - 1;
for (; i >= 0 && slice_sizes[i] == 1; --i)
;
// Skip the next slice size and check the remaining
i--;
for (; i >= 0 && can_copy; --i) {
can_copy = (src.shape(i) == slice_sizes[i]);
}
}
size_t slice_size = 1;
for (auto s : slice_sizes) {
slice_size *= s;
}
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>();
size_t out_idx = 0;
for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0;
for (int ii = 0; ii < inds.size(); ++ii) {
auto ax = axes[ii];
auto idx_loc = elem_to_loc(idx, inds[ii]);
auto idx_val =
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
src_idx += (idx_val * src.strides()[ax]);
}
if (slice_size == 1) {
dst_ptr[out_idx++] = src_ptr[src_idx];
} else if (can_copy) {
std::copy(
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
out_idx += slice_size;
} else {
for (int jj = 0; jj < slice_size; jj++) {
auto src_offset = elem_to_loc(jj, slice_sizes, src.strides());
dst_ptr[out_idx++] = src_ptr[src_idx + src_offset];
}
}
}
}
template <typename IdxT>
void dispatch_gather(
const array& src,
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const std::vector<int>& size) {
switch (out.dtype()) {
case bool_:
gather<bool, IdxT>(src, inds, out, axes, size);
break;
case uint8:
gather<uint8_t, IdxT>(src, inds, out, axes, size);
break;
case uint16:
gather<uint16_t, IdxT>(src, inds, out, axes, size);
break;
case uint32:
gather<uint32_t, IdxT>(src, inds, out, axes, size);
break;
case uint64:
gather<uint64_t, IdxT>(src, inds, out, axes, size);
break;
case int8:
gather<int8_t, IdxT>(src, inds, out, axes, size);
break;
case int16:
gather<int16_t, IdxT>(src, inds, out, axes, size);
break;
case int32:
gather<int32_t, IdxT>(src, inds, out, axes, size);
break;
case int64:
gather<int64_t, IdxT>(src, inds, out, axes, size);
break;
case float16:
gather<float16_t, IdxT>(src, inds, out, axes, size);
break;
case float32:
gather<float, IdxT>(src, inds, out, axes, size);
break;
case bfloat16:
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
break;
case complex64:
gather<complex64_t, IdxT>(src, inds, out, axes, size);
break;
}
}
void Gather::eval(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& src = inputs[0];
std::vector<array> inds(inputs.begin() + 1, inputs.end());
if (inds.empty()) {
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
return;
}
switch (inds[0].dtype()) {
case bool_:
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
break;
case uint8:
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint16:
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint32:
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint64:
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
break;
case int8:
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
break;
case int16:
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
break;
case int32:
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
break;
case int64:
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
break;
case float16:
case float32:
case bfloat16:
case complex64:
throw std::runtime_error(
"[Gather::eval] Cannot gather with floating point indices.");
break;
}
}
template <typename InT, typename IdxT, typename OpT>
void scatter(
const array& updates,
array& out,
const std::vector<array>& inds,
const std::vector<int>& axes,
const OpT& op) {
int nind = inds.size();
auto inds_ndim = updates.ndim() - out.ndim();
size_t n_updates = nind ? inds[0].size() : 1;
std::vector<int> update_shape(
updates.shape().begin() + inds_ndim, updates.shape().end());
size_t update_size = 1;
for (auto us : update_shape) {
update_size *= us;
}
for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0;
for (int j = 0; j < nind; ++j) {
auto ax = axes[j];
auto idx_loc = elem_to_loc(i, inds[j]);
auto idx_val =
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
out_offset += (idx_val * out.strides()[ax]);
}
for (int j = 0; j < update_size; ++j) {
auto update_loc = elem_to_loc(i * update_size + j, updates);
auto out_loc = elem_to_loc(j, update_shape, out.strides());
op(updates.data<InT>()[update_loc],
out.data<InT>() + out_offset + out_loc);
}
}
}
template <typename InT, typename IdxT>
void dispatch_scatter_inds(
array& out,
const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes,
Scatter::ReduceType rtype) {
switch (rtype) {
case Scatter::None:
scatter<InT, IdxT>(
updates, out, indices, axes, [](auto x, auto* y) { (*y) = x; });
break;
case Scatter::Sum:
scatter<InT, IdxT>(
updates, out, indices, axes, [](auto x, auto* y) { (*y) += x; });
break;
case Scatter::Prod:
scatter<InT, IdxT>(
updates, out, indices, axes, [](auto x, auto* y) { (*y) *= x; });
break;
case Scatter::Max:
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
(*y) = (*y > x) ? *y : x;
});
break;
case Scatter::Min:
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
(*y) = (*y < x) ? *y : x;
});
break;
}
}
template <typename InT>
void dispatch_scatter(
array& out,
const std::vector<array>& inds,
const array& updates,
const std::vector<int>& axes,
Scatter::ReduceType rtype) {
if (inds.empty()) {
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
return;
}
switch (inds[0].dtype()) {
case bool_:
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
break;
case uint8:
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
break;
case uint16:
dispatch_scatter_inds<InT, uint16_t>(out, inds, updates, axes, rtype);
break;
case uint32:
dispatch_scatter_inds<InT, uint32_t>(out, inds, updates, axes, rtype);
break;
case uint64:
dispatch_scatter_inds<InT, uint64_t>(out, inds, updates, axes, rtype);
break;
case int8:
dispatch_scatter_inds<InT, int8_t>(out, inds, updates, axes, rtype);
break;
case int16:
dispatch_scatter_inds<InT, int16_t>(out, inds, updates, axes, rtype);
break;
case int32:
dispatch_scatter_inds<InT, int32_t>(out, inds, updates, axes, rtype);
break;
case int64:
dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
break;
case float16:
case float32:
case bfloat16:
case complex64:
throw std::runtime_error(
"[Scatter::eval_cpu] Cannot scatter with floating point indices.");
}
}
void Scatter::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() >= 2);
auto& src = inputs[0];
std::vector<array> inds(inputs.begin() + 1, inputs.end() - 1);
auto& updates = inputs.back();
// Copy src into out (copy allocates memory for out)
copy(src, out, CopyType::General);
switch (src.dtype()) {
case bool_:
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
break;
case uint8:
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint16:
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint32:
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint64:
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
break;
case int8:
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
break;
case int16:
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
break;
case int32:
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
break;
case int64:
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
break;
case float16:
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
break;
case float32:
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
break;
case bfloat16:
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
break;
case complex64:
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,52 @@
#include <algorithm>
#include <cassert>
#include <utility>
#include "mlx/allocator.h"
#include "mlx/load.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <const uint8_t scalar_size>
void swap_endianess(uint8_t* data_bytes, size_t N) {
struct Elem {
uint8_t bytes[scalar_size];
};
Elem* data = reinterpret_cast<Elem*>(data_bytes);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < (scalar_size / 2); j++) {
std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
}
}
}
} // namespace
void Load::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
reader_->seek(offset_, std::ios_base::beg);
reader_->read(out.data<char>(), out.nbytes());
if (swap_endianness_) {
switch (out.itemsize()) {
case 2:
swap_endianess<2>(out.data<uint8_t>(), out.data_size());
break;
case 4:
swap_endianess<4>(out.data<uint8_t>(), out.data_size());
break;
case 8:
swap_endianess<8>(out.data<uint8_t>(), out.data_size());
break;
}
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,622 @@
#include <algorithm>
#include <cassert>
#include <cmath>
#include <numeric>
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/backend/common/arange.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/erf.h"
#include "mlx/backend/common/threefry.h"
#include "mlx/backend/common/unary.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
void Abs::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (is_unsigned(in.dtype())) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
unary(in, out, AbsOp());
}
}
void Arange::eval(const std::vector<array>& inputs, array& out) {
arange(inputs, out, start_, step_);
}
void ArcCos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::acos(x); });
} else {
throw std::invalid_argument(
"[arccos] Cannot compute inverse cosine of elements in array"
" with non floating point type.");
}
}
void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::acosh(x); });
} else {
throw std::invalid_argument(
"[arccosh] Cannot compute inverse hyperbolic cosine of elements in"
" array with non floating point type.");
}
}
void ArcSin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::asin(x); });
} else {
throw std::invalid_argument(
"[arcsin] Cannot compute inverse sine of elements in array"
" with non floating point type.");
}
}
void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::asinh(x); });
} else {
throw std::invalid_argument(
"[arcsinh] Cannot compute inverse hyperbolic sine of elements in"
" array with non floating point type.");
}
}
void ArcTan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::atan(x); });
} else {
throw std::invalid_argument(
"[arctan] Cannot compute inverse tangent of elements in array"
" with non floating point type.");
}
}
void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::atanh(x); });
} else {
throw std::invalid_argument(
"[arctanh] Cannot compute inverse hyperbolic tangent of elements in"
" array with non floating point type.");
}
}
void AsType::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
}
void AsStrided::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (!in.flags().row_contiguous) {
// Just ensuring that inputs[0] came from the ops which would ensure the
// input is row contiguous.
throw std::runtime_error(
"AsStrided must be used with row contiguous arrays only.");
}
// Compute the flags given the shape and strides
bool row_contiguous = true, col_contiguous = true;
size_t r = 1, c = 1;
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
r *= shape_[i];
c *= shape_[j];
}
auto flags = in.flags();
// TODO: Compute the contiguous flag in a better way cause now we are
// unnecessarily strict.
flags.contiguous = row_contiguous || col_contiguous;
flags.row_contiguous = row_contiguous;
flags.col_contiguous = col_contiguous;
// There is no easy way to compute the actual data size so we use out.size().
// The contiguous flag will almost certainly not be set so no code should
// rely on data_size anyway.
size_t data_size = out.size();
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
std::vector<size_t> strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Concatenate::eval(const std::vector<array>& inputs, array& out) {
std::vector<int> sizes;
sizes.push_back(0);
for (auto& p : inputs) {
sizes.push_back(p.shape(axis_));
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto strides = out.strides();
auto flags = out.flags();
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral);
}
}
void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
}
void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::cos(x); });
} else {
throw std::invalid_argument(
"[cos] Cannot compute cosine of elements in array"
" with non floating point type.");
}
}
void Cosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::cosh(x); });
} else {
throw std::invalid_argument(
"[cosh] Cannot compute hyperbolic cosine of elements in array"
" with non floating point type.");
}
}
void Erf::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float>(in, out, [](auto x) { return std::erf(x); });
break;
case float16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float16_t>(in, out, [](auto x) {
return static_cast<float16_t>(std::erf(static_cast<float>(x)));
});
break;
case bfloat16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<bfloat16_t>(in, out, [](auto x) {
return static_cast<bfloat16_t>(std::erf(static_cast<float>(x)));
});
break;
default:
throw std::invalid_argument(
"[erf] Error function only defined for arrays"
" with real floating point type.");
}
}
void ErfInv::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float>(in, out, [](auto x) { return erfinv(x); });
break;
case float16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float16_t>(in, out, [](auto x) {
return static_cast<float16_t>(erfinv(static_cast<float>(x)));
});
break;
case bfloat16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<bfloat16_t>(in, out, [](auto x) {
return static_cast<bfloat16_t>(erfinv(static_cast<float>(x)));
});
break;
default:
throw std::invalid_argument(
"[erf_inv] Inverse error function only defined for arrays"
" with real floating point type.");
}
}
void Exp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
} else {
throw std::invalid_argument(
"[exp] Cannot exponentiate elements in array"
" with non floating point type.");
}
}
void Full::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
assert(in.dtype() == out.dtype());
CopyType ctype;
if (in.data_size() == 1) {
ctype = CopyType::Scalar;
} else if (in.flags().contiguous) {
ctype = CopyType::Vector;
} else {
ctype = CopyType::General;
}
copy(in, out, ctype);
}
void Log::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
switch (base_) {
case Base::e:
unary_fp(in, out, [](auto x) { return std::log(x); });
break;
case Base::two:
unary_fp(in, out, [](auto x) { return std::log2(x); });
break;
case Base::ten:
unary_fp(in, out, [](auto x) { return std::log10(x); });
break;
}
} else {
throw std::invalid_argument(
"[log] Cannot compute log of elements in array with"
" non floating point type.");
}
}
void Log1p::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else {
throw std::invalid_argument(
"[log1p] Cannot compute log of elements in array with"
" non floating point type.");
}
}
void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, [](auto x) { return !x; });
}
void Negative::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, [](auto x) { return -x; });
}
void Pad::eval(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];
auto& val = inputs[1];
// Padding value must be a scalar
assert(val.size() == 1);
// Padding value, input and output must be of the same type
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
// Fill output with val
copy(val, out, CopyType::Scalar);
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes_.size(); i++) {
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
data_offset += out.strides()[ax] * low_pad_size_[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_inplace(in, out_slice, CopyType::GeneralGeneral);
}
void RandomBits::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0];
size_t num_keys = keys.size() / 2;
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto kptr = inputs[0].data<uint32_t>();
auto cptr = out.data<char>();
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0;
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
auto ptr = reinterpret_cast<uint32_t*>(cptr);
// Get ith key
auto kidx = 2 * i;
auto k1_elem = elem_to_loc(kidx, keys.shape(), keys.strides());
auto k2_elem = elem_to_loc(kidx + 1, keys.shape(), keys.strides());
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
for (; count.first + 1 < half_size; count.first++, count.second++) {
std::tie(ptr[count.first], ptr[count.second]) =
random::threefry2x32_hash(key, count);
}
if (count.first < half_size) {
auto rb = random::threefry2x32_hash(key, count);
ptr[count.first++] = rb.first;
if (bytes_per_key % 4 > 0) {
std::copy(
reinterpret_cast<char*>(&rb.second),
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
cptr + 4 * count.second);
} else {
ptr[count.second] = rb.second;
}
}
if (!even) {
count.second = 0;
ptr[half_size] = random::threefry2x32_hash(key, count).first;
}
}
}
void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (in.flags().row_contiguous) {
// For row contiguous reshapes:
// - Shallow copy the buffer
// - If reshaping into a vector (all singleton dimensions except one) it
// becomes col contiguous again.
auto flags = in.flags();
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.copy_shared_buffer(in, out.strides(), flags, in.data_size());
} else {
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
}
}
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
auto sigmoid_op = [](auto x) {
auto one = static_cast<decltype(x)>(1.0);
return one / (one + std::exp(-x));
};
unary_fp(in, out, sigmoid_op);
} else {
throw std::invalid_argument(
"[sigmoid] Cannot sigmoid of elements in array with"
" non floating point type.");
}
}
void Sign::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == bool_) {
out.copy_shared_buffer(in);
} else {
unary(in, out, SignOp());
}
}
void Sin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::sin(x); });
} else {
throw std::invalid_argument(
"[sin] Cannot compute sine of elements in array"
" with non floating point type.");
}
}
void Sinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::sinh(x); });
} else {
throw std::invalid_argument(
"[sinh] Cannot compute hyperbolic sine of elements in array"
" with non floating point type.");
}
}
void Slice::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto strides = in.strides();
auto flags = in.flags();
size_t data_offset = 0;
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
strides[i] *= strides_[i];
}
// Compute row/col contiguity
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
flags.row_contiguous = true;
flags.col_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; ri >= 0; i++, ri--) {
flags.col_contiguous &= strides[i] == f_stride || out.shape(i) == 1;
flags.row_contiguous &= strides[ri] == b_stride || out.shape(ri) == 1;
f_stride *= out.shape(i);
b_stride *= out.shape(ri);
if (strides[i] > 0) {
data_size *= out.shape(i);
}
}
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, strides, flags, data_size, data_offset);
}
void Square::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, [](auto x) { return x * x; });
}
void Sqrt::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (recip_) {
unary_fp(in, out, [](auto x) {
return static_cast<decltype(x)>(1.0) / sqrt(x);
});
} else {
unary_fp(in, out, [](auto x) { return sqrt(x); });
}
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
}
void Tan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::tan(x); });
} else {
throw std::invalid_argument(
"[tan] Cannot compute tangent of elements in array"
" with non floating point type.");
}
}
void Tanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::tanh(x); });
} else {
throw std::invalid_argument(
"[tanh] Cannot compute hyperbolic tangent of elements in array"
" with non floating point type.");
}
}
void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
std::vector<size_t> out_strides(out.ndim());
auto& in = inputs[0];
for (int ax = 0; ax < axes_.size(); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]];
}
// Conditions for {row/col}_contiguous
// - array must be contiguous (no gaps)
// - underlying buffer size should have the same size as the array
// - cumulative product of shapes is equal to the strides (we can ignore axes
// with size == 1)
// - in the forward direction (column contiguous)
// - in the reverse direction (row contiguous)
// - vectors are both row and col contiguous (hence if both row/col are
// true, they stay true)
auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) {
size_t f_stride = 1;
size_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
f_stride *= out.shape(i);
flags.row_contiguous &=
(out_strides[ri] == b_stride || out.shape(ri) == 1);
b_stride *= out.shape(ri);
}
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -0,0 +1,215 @@
#include <cassert>
#include <functional>
#include <limits>
#include "mlx/backend/common/reduce.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename U>
struct Limits {
static const U max;
static const U min;
};
#define instantiate_default_limit(type) \
template <> \
struct Limits<type> { \
static constexpr type max = std::numeric_limits<type>::max(); \
static constexpr type min = std::numeric_limits<type>::min(); \
};
instantiate_default_limit(uint8_t);
instantiate_default_limit(uint16_t);
instantiate_default_limit(uint32_t);
instantiate_default_limit(uint64_t);
instantiate_default_limit(int8_t);
instantiate_default_limit(int16_t);
instantiate_default_limit(int32_t);
instantiate_default_limit(int64_t);
#define instantiate_float_limit(type) \
template <> \
struct Limits<type> { \
static const type max; \
static const type min; \
};
instantiate_float_limit(float16_t);
instantiate_float_limit(bfloat16_t);
instantiate_float_limit(float);
instantiate_float_limit(complex64_t);
template <>
struct Limits<bool> {
static constexpr bool max = true;
static constexpr bool min = false;
};
const float Limits<float>::max = std::numeric_limits<float>::infinity();
const float Limits<float>::min = -std::numeric_limits<float>::infinity();
const bfloat16_t Limits<bfloat16_t>::max =
std::numeric_limits<float>::infinity();
const bfloat16_t Limits<bfloat16_t>::min =
-std::numeric_limits<float>::infinity();
const float16_t Limits<float16_t>::max = std::numeric_limits<float>::infinity();
const float16_t Limits<float16_t>::min =
-std::numeric_limits<float>::infinity();
const complex64_t Limits<complex64_t>::max =
std::numeric_limits<float>::infinity();
const complex64_t Limits<complex64_t>::min =
-std::numeric_limits<float>::infinity();
struct AndReduce {
template <typename T>
void operator()(bool* a, T b) {
(*a) &= (b != 0);
}
void operator()(bool* y, bool x) {
(*y) &= x;
}
};
struct OrReduce {
template <typename T>
void operator()(bool* a, T b) {
(*a) |= (b != 0);
}
void operator()(bool* y, bool x) {
(*y) |= x;
}
};
template <typename InT>
void reduce_dispatch_out(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
switch (rtype) {
case Reduce::And: {
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
break;
}
case Reduce::Or: {
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
break;
}
case Reduce::Sum: {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
switch (out.dtype()) {
case bool_:
reduction_op<InT, bool>(in, out, axes, false, op);
break;
case uint8:
reduction_op<InT, uint8_t>(in, out, axes, 0, op);
break;
case uint16:
reduction_op<InT, uint16_t>(in, out, axes, 0, op);
break;
case uint32:
reduction_op<InT, uint32_t>(in, out, axes, 0, op);
break;
case uint64:
reduction_op<InT, uint64_t>(in, out, axes, 0, op);
break;
case int8:
reduction_op<InT, int8_t>(in, out, axes, 0, op);
break;
case int16:
reduction_op<InT, int16_t>(in, out, axes, 0, op);
break;
case int32:
reduction_op<InT, int32_t>(in, out, axes, 0, op);
break;
case int64:
reduction_op<InT, int64_t>(in, out, axes, 0, op);
break;
case float16:
reduction_op<InT, float16_t>(in, out, axes, 0.0f, op);
break;
case float32:
reduction_op<InT, float>(in, out, axes, 0.0f, op);
break;
case bfloat16:
reduction_op<InT, bfloat16_t>(in, out, axes, 0.0f, op);
break;
case complex64:
reduction_op<InT, complex64_t>(in, out, axes, complex64_t{0.0f}, op);
break;
}
} break;
case Reduce::Prod: {
auto op = [](auto y, auto x) { (*y) *= x; };
reduction_op<InT, InT>(in, out, axes, 1, op);
break;
}
case Reduce::Max: {
auto op = [](auto y, auto x) { (*y) = (*y > x) ? *y : x; };
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, op);
break;
}
case Reduce::Min: {
auto op = [](auto y, auto x) { (*y) = (*y < x) ? *y : x; };
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, op);
break;
}
}
}
} // namespace
void Reduce::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
reduce_dispatch_out<bool>(in, out, reduce_type_, axes_);
break;
case uint8:
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_out<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_out<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
break;
case int16:
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
break;
case int32:
reduce_dispatch_out<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
reduce_dispatch_out<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_out<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_out<float>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_out<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_out<complex64_t>(in, out, reduce_type_, axes_);
break;
}
}
} // namespace mlx::core

364
mlx/backend/common/reduce.h Normal file
View File

@@ -0,0 +1,364 @@
#pragma once
#include "mlx/backend/common/utils.h"
namespace mlx::core {
namespace {
enum ReductionOpType {
// Self-explanatory. Read everything and produce 1 output.
ContiguousAllReduce,
// The input is contiguous and the last axis is reduced
// N1xR1xN2xR2x...xNnxRn
ContiguousReduce,
// The input is contiguous and the last axis is not reduced
// R1xN1xR2xN2x...xRnxNn
ContiguousStridedReduce,
// The input is not contiguous but the last axis is and it is reduced so we
// need to figure out the offsets but we can call the contiguous reduce after
// that.
// N3xR1xN1xR4x...xRn
GeneralContiguousReduce,
// The input is not contiguous but the last reduction axis and the last axis
// are so we need to figure out the offset but we can call the strided reduce
// after that.
GeneralStridedReduce,
// The input is not contiguous after the reduction axis and it may contain
// 0-stride axes or transpositions. We could copy the strides and produce a
// transposed outcome or we can read the input out of order and write the
// output in order.
GeneralReduce
};
// Helper for the ndimensional strided loop
// Should this be in utils?
inline void nd_loop(
std::function<void(int)> callback,
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
std::function<void(int, int)> loop_inner;
loop_inner = [&](int dim, int offset) {
if (dim < shape.size() - 1) {
int size = shape[dim];
size_t stride = strides[dim];
for (int i = 0; i < size; i++) {
loop_inner(dim + 1, offset + i * stride);
}
} else {
int size = shape[dim];
size_t stride = strides[dim];
for (int i = 0; i < size; i++) {
callback(offset + i * stride);
}
}
};
loop_inner(0, 0);
}
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
std::vector<int> shape = x.shape();
std::vector<size_t> strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
shape.erase(shape.begin() + a);
strides.erase(strides.begin() + a);
}
return std::make_pair(shape, strides);
}
template <typename T, typename U, typename Op>
struct DefaultStridedReduce {
Op op;
DefaultStridedReduce(Op op_) : op(op_) {}
void operator()(const T* x, U* accumulator, int size, size_t stride) {
for (int i = 0; i < size; i++) {
U* moving_accumulator = accumulator;
for (int j = 0; j < stride; j++) {
op(moving_accumulator, *x);
moving_accumulator++;
x++;
}
}
}
};
template <typename T, typename U, typename Op>
struct DefaultContiguousReduce {
Op op;
DefaultContiguousReduce(Op op_) : op(op_) {}
void operator()(const T* x, U* accumulator, int size) {
while (size-- > 0) {
op(accumulator, *x);
x++;
}
}
};
struct ReductionPlan {
ReductionOpType type;
std::vector<int> shape;
std::vector<size_t> strides;
ReductionPlan(
ReductionOpType type_,
std::vector<int> shape_,
std::vector<size_t> strides_)
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
(x.flags().row_contiguous || x.flags().col_contiguous)) {
return ContiguousAllReduce;
}
// Row contiguous input so the output is row contiguous
if (x.flags().row_contiguous) {
// Merge consecutive axes
std::vector<int> shape = {x.shape(axes[0])};
std::vector<size_t> strides = {x.strides()[axes[0]]};
for (int i = 1; i < axes.size(); i++) {
if (axes[i] - 1 == axes[i - 1]) {
shape.back() *= x.shape(axes[i]);
strides.back() = x.strides()[axes[i]];
} else {
shape.push_back(x.shape(axes[i]));
strides.push_back(x.strides()[axes[i]]);
}
}
if (strides.back() == 1) {
return ReductionPlan(ContiguousReduce, shape, strides);
} else if (strides.back() > 1) {
return ReductionPlan(ContiguousStridedReduce, shape, strides);
}
}
// Let's check if we can optimize our access patterns
//
// 1. We have a reduction axis with stride 1. Simply call
// GeneralContiguousReduce and be done with it.
// 2. We have transpositions and we are not reducing over the axis with
// stride 1. However, we are reducing over an axis where everything is
// contiguous in memory to the right of that axis. We can call strided
// reduce and be done with it.
// 2. We have weird transpositions and expands. Copy the strides to the
// output, then call strided reduce.
// Sort reduction axes by stride in order to merge them and figure out if we
// have a contiguous reduction.
std::vector<std::pair<int, size_t>> reductions;
for (auto a : axes) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
}
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
return a.second > b.second;
});
// Extract the two smallest and try to merge them in case the contiguous
// reduction can be bigger than just the last axis.
for (int i = reductions.size() - 1; i >= 1; i--) {
auto a = reductions[i];
auto b = reductions[i - 1];
// b.stride = a.shape * a.stride then a and b are contiguous
if (b.second == a.first * a.second) {
reductions.erase(reductions.begin() + i);
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
}
}
std::vector<int> shape;
std::vector<size_t> strides;
for (auto r : reductions) {
shape.push_back(r.first);
strides.push_back(r.second);
}
// We can call the contiguous reduction op for every weird way the input is
// structured in the rest of the axes.
if (strides.back() == 1) {
return ReductionPlan(GeneralContiguousReduce, shape, strides);
}
// Delegate to the general strided reduction op if the axes after
// strides.back() are contiguous.
if (strides.back() > 1) {
int size = 1;
for (int i = x.ndim() - 1; i >= 0; i--) {
if (axes.back() == i) {
continue;
}
if (x.strides()[i] != size) {
break;
}
size *= x.shape(i);
}
if (size >= strides.back()) {
return ReductionPlan(GeneralStridedReduce, shape, strides);
}
}
return ReductionPlan(GeneralReduce, shape, strides);
}
template <typename T, typename U, typename OpS, typename OpC, typename Op>
void reduction_op(
const array& x,
array& out,
const std::vector<int>& axes,
U init,
OpS ops,
OpC opc,
Op op) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
ReductionPlan plan = get_reduction_plan(x, axes);
if (plan.type == ContiguousAllReduce) {
U* out_ptr = out.data<U>();
*out_ptr = init;
opc(x.data<T>(), out_ptr, x.size());
return;
}
std::vector<int> shape;
std::vector<size_t> strides;
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0];
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
*out_ptr = init;
opc(x_ptr, out_ptr, reduction_size);
}
return;
}
if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) {
int reduction_size = plan.shape.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
// Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost.
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
opc(x_ptr + offset, out_ptr, reduction_size);
}
} else {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
nd_loop(
[&](int extra_offset) {
opc(x_ptr + offset + extra_offset, out_ptr, reduction_size);
},
plan.shape,
plan.strides);
}
}
return;
}
if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape.back();
size_t reduction_stride = plan.strides.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init);
ops(x_ptr, out_ptr, reduction_size, reduction_stride);
x_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride;
}
return;
}
if (plan.type == GeneralStridedReduce ||
plan.type == ContiguousStridedReduce) {
int reduction_size = plan.shape.back();
size_t reduction_stride = plan.strides.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
ops(x_ptr + offset, out_ptr, reduction_size, reduction_stride);
out_ptr += reduction_stride;
}
} else {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
nd_loop(
[&](int extra_offset) {
ops(x_ptr + offset + extra_offset,
out_ptr,
reduction_size,
reduction_stride);
},
plan.shape,
plan.strides);
out_ptr += reduction_stride;
}
}
return;
}
if (plan.type == GeneralReduce) {
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;
nd_loop(
[&](int extra_offset) { op(&val, *(x_ptr + offset + extra_offset)); },
plan.shape,
plan.strides);
*out_ptr = val;
}
}
}
template <typename T, typename U, typename Op>
void reduction_op(
const array& x,
array& out,
const std::vector<int>& axes,
U init,
Op op) {
DefaultStridedReduce<T, U, Op> ops(op);
DefaultContiguousReduce<T, U, Op> opc(op);
reduction_op<T, U>(x, out, axes, init, ops, opc, op);
}
} // namespace
} // namespace mlx::core

323
mlx/backend/common/scan.cpp Normal file
View File

@@ -0,0 +1,323 @@
#include <cassert>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T, typename U, typename Op>
struct DefaultContiguousScan {
Op op;
U init;
DefaultContiguousScan(Op op_, U init_) : op(op_), init(init_) {}
void operator()(
const T* input,
U* output,
int count,
int stride,
bool reverse,
bool inclusive) {
if (!reverse) {
if (inclusive) {
for (int i = 0; i < count; i++) {
*output = *input;
for (int j = 1; j < stride; j++) {
input++;
output++;
op(output, output - 1, input);
}
output++;
input++;
}
} else {
for (int i = 0; i < count; i++) {
*output = init;
for (int j = 1; j < stride; j++) {
op(output + 1, output, input);
input++;
output++;
}
output++;
input++;
}
}
} else {
if (inclusive) {
for (int i = 0; i < count; i++) {
output += stride - 1;
input += stride - 1;
*output = *input;
for (int j = 1; j < stride; j++) {
input--;
output--;
op(output, output + 1, input);
}
output += stride;
input += stride;
}
} else {
for (int i = 0; i < count; i++) {
output += stride - 1;
input += stride - 1;
*output = init;
for (int j = 1; j < stride; j++) {
op(output - 1, output, input);
input--;
output--;
}
output += stride;
input += stride;
}
}
}
}
};
template <typename T, typename U, typename Op>
struct DefaultStridedScan {
Op op;
U init;
DefaultStridedScan(Op op_, U init_) : op(op_), init(init_) {}
void operator()(
const T* input,
U* output,
int count,
int size,
int stride,
bool reverse,
bool inclusive) {
// TODO: Vectorize the following naive implementation
if (!reverse) {
if (inclusive) {
for (int i = 0; i < count; i++) {
std::copy(input, input + stride, output);
output += stride;
input += stride;
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
op(output, output - stride, input);
output++;
input++;
}
}
}
} else {
for (int i = 0; i < count; i++) {
std::fill(output, output + stride, init);
output += stride;
input += stride;
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
op(output, output - stride, input - stride);
output++;
input++;
}
}
}
}
} else {
if (inclusive) {
for (int i = 0; i < count; i++) {
output += (size - 1) * stride;
input += (size - 1) * stride;
std::copy(input, input + stride, output);
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
output--;
input--;
op(output, output + stride, input);
}
}
output += size * stride;
input += size * stride;
}
} else {
for (int i = 0; i < count; i++) {
output += (size - 1) * stride;
input += (size - 1) * stride;
std::fill(output, output + stride, init);
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
output--;
input--;
op(output, output + stride, input + stride);
}
}
output += size * stride;
input += size * stride;
}
}
}
}
};
template <typename T, typename U, typename OpCS, typename OpSS>
void scan_op(
OpCS opcs,
OpSS opss,
const array& input,
array& output,
int axis,
bool reverse,
bool inclusive) {
output.set_data(allocator::malloc_or_wait(output.nbytes()));
if (input.flags().row_contiguous) {
if (input.strides()[axis] == 1) {
opcs(
input.data<T>(),
output.data<U>(),
input.size() / input.shape(axis),
input.shape(axis),
reverse,
inclusive);
} else {
opss(
input.data<T>(),
output.data<U>(),
input.size() / input.shape(axis) / input.strides()[axis],
input.shape(axis),
input.strides()[axis],
reverse,
inclusive);
}
} else {
throw std::runtime_error("Scan op supports only contiguous inputs");
}
}
template <typename T, typename U>
void scan_dispatch(
Scan::ReduceType rtype,
const array& input,
array& output,
int axis,
bool reverse,
bool inclusive) {
switch (rtype) {
case Scan::Sum: {
auto op = [](U* o, const U* y, const T* x) { *o = *y + *x; };
auto init = static_cast<U>(0);
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
break;
}
case Scan::Prod: {
auto op = [](U* o, const U* y, const T* x) { *o = *y * (*x); };
auto init = static_cast<U>(1);
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
break;
}
case Scan::Min: {
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
auto init = (is_floating_point(input.dtype()))
? static_cast<U>(std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
break;
}
case Scan::Max: {
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
auto init = (is_floating_point(input.dtype()))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
break;
}
}
}
} // namespace
void Scan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Ensure contiguity
auto in = inputs[0];
if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General);
in = arr_copy;
}
switch (in.dtype()) {
case bool_: {
// We could do a full dtype x dtype switch but this is the only case
// where we accumulate in a different type, for now.
//
// TODO: If we add the option to accumulate floats in higher precision
// floats perhaps we should add the full all-to-all dispatch.
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
scan_dispatch<bool, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
} else {
scan_dispatch<bool, bool>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
}
break;
}
case uint8:
scan_dispatch<uint8_t, uint8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint16:
scan_dispatch<uint16_t, uint16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint32:
scan_dispatch<uint32_t, uint32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint64:
scan_dispatch<uint64_t, uint64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int8:
scan_dispatch<int8_t, int8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int16:
scan_dispatch<int16_t, int16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int32:
scan_dispatch<int32_t, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int64:
scan_dispatch<int64_t, int64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float16:
scan_dispatch<float16_t, float16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float32:
scan_dispatch<float, float>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case bfloat16:
scan_dispatch<bfloat16_t, bfloat16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case complex64:
throw std::runtime_error("Scan ops do not support complex types yet");
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,29 @@
#include "mlx/backend/common/threefry.h"
namespace mlx::core::random {
std::pair<uint32_t, uint32_t> threefry2x32_hash(
const std::pair<uint32_t, uint32_t>& key,
std::pair<uint32_t, uint32_t> count) {
constexpr static uint32_t rotations[2][4] = {
{13, 15, 26, 6}, {17, 29, 16, 24}};
uint32_t ks[3] = {key.first, key.second, key.first ^ key.second ^ 0x1BD11BDA};
count.first += ks[0];
count.second += ks[1];
for (int i = 0; i < 5; ++i) {
for (auto r : rotations[i % 2]) {
count.first += count.second;
count.second = (count.second << r) | (count.second >> (32 - r));
count.second ^= count.first;
}
count.first += ks[(i + 1) % 3];
count.second += ks[(i + 2) % 3] + i + 1;
}
return count;
}
} // namespace mlx::core::random

View File

@@ -0,0 +1,19 @@
#pragma once
#include <cstdint>
#include <utility>
namespace mlx::core::random {
/** Applies the Threefry 2x32 hash function.
* This code is based on the Jax counter-based and splittable PRNG
* https://github.com/google/jax/blob/main/docs/jep/263-prng.md
*
* Original Threefry reference:
* http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
*/
std::pair<uint32_t, uint32_t> threefry2x32_hash(
const std::pair<uint32_t, uint32_t>& key,
std::pair<uint32_t, uint32_t> count);
} // namespace mlx::core::random

147
mlx/backend/common/unary.h Normal file
View File

@@ -0,0 +1,147 @@
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
struct AbsOp {
template <typename T>
T operator()(T x) {
return std::abs(x);
}
uint8_t operator()(uint8_t x) {
return x;
}
uint16_t operator()(uint16_t x) {
return x;
}
uint32_t operator()(uint32_t x) {
return x;
}
uint64_t operator()(uint64_t x) {
return x;
}
bool operator()(bool x) {
return x;
}
};
struct SignOp {
template <typename T>
T operator()(T x) {
return (x > T(0)) - (x < T(0));
}
uint8_t operator()(uint8_t x) {
return x != 0;
}
uint16_t operator()(uint16_t x) {
return x != 0;
}
uint32_t operator()(uint32_t x) {
return x != 0;
}
uint64_t operator()(uint64_t x) {
return x != 0;
}
};
template <typename T, typename Op>
void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>();
if (a.flags().contiguous) {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
T* dst = out.data<T>();
for (size_t i = 0; i < a.data_size(); ++i) {
dst[i] = op(a_ptr[i]);
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
T* dst = out.data<T>();
for (size_t i = 0; i < out.size(); ++i) {
// TODO this is super inefficient, need to fix.
int a_idx = elem_to_loc(i, a.shape(), a.strides());
dst[i] = op(a_ptr[a_idx]);
}
}
}
template <typename Op>
void unary(const array& a, array& out, Op op) {
switch (out.dtype()) {
case bool_:
unary_op<bool>(a, out, op);
break;
case uint8:
unary_op<uint8_t>(a, out, op);
break;
case uint16:
unary_op<uint16_t>(a, out, op);
break;
case uint32:
unary_op<uint32_t>(a, out, op);
break;
case uint64:
unary_op<uint64_t>(a, out, op);
break;
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
}
}
template <typename Op>
void unary_fp(const array& a, array& out, Op op) {
switch (out.dtype()) {
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_fp] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
}
} // namespace
} // namespace mlx::core

View File

@@ -0,0 +1,29 @@
#pragma once
#include <vector>
#include "mlx/array.h"
namespace mlx::core {
inline size_t elem_to_loc(
int elem,
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
size_t loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]);
loc += q_and_r.rem * strides[i];
elem = q_and_r.quot;
}
return loc;
}
inline size_t elem_to_loc(int elem, const array& a) {
if (a.flags().row_contiguous) {
return elem;
}
return elem_to_loc(elem, a.shape(), a.strides());
}
} // namespace mlx::core

View File

@@ -0,0 +1,26 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
)
if (NOT MLX_METAL_PATH)
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
target_compile_definitions(
mlx PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib")

113
mlx/backend/metal/copy.cpp Normal file
View File

@@ -0,0 +1,113 @@
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}
void copy_gpu_inplace(
const array& in,
array& out,
CopyType ctype,
const Stream& s) {
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(in, out);
auto& strides_in = strides[0];
auto& strides_out = strides[1];
auto& d = metal::device(s.device);
std::ostringstream kname;
switch (ctype) {
case CopyType::Scalar:
kname << "scopy";
break;
case CopyType::Vector:
kname << "vcopy";
break;
case CopyType::General:
kname << "gcopy";
break;
case CopyType::GeneralGeneral:
kname << "ggcopy";
break;
}
kname << type_to_name(in) << type_to_name(out);
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
size_t ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 3);
if (ctype == CopyType::GeneralGeneral) {
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 4);
}
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 2);
if (ctype == CopyType::GeneralGeneral) {
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 3);
}
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(
&ndim, sizeof(int), (ctype == CopyType::GeneralGeneral) ? 5 : 4);
}
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
int rest = in.size() / (dim0 * dim1);
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,81 @@
#pragma once
#include <Metal/Metal.hpp>
#include <functional>
#include <mutex>
#include <string>
#include <unordered_map>
#include <dlfcn.h>
#include <filesystem>
#include "mlx/device.h"
namespace fs = std::filesystem;
namespace mlx::core::metal {
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
Dl_info info;
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) {
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
}
return mtllib_path;
}
class Device {
public:
Device();
Device(const Device&) = delete;
Device& operator=(const Device&) = delete;
~Device();
MTL::Device* mtl_device() {
return device_;
};
void new_queue(int index);
MTL::CommandBuffer* new_command_buffer(int index);
MTL::CommandBuffer* get_command_buffer(int index);
int get_command_buffer_ops(int index);
void increment_command_buffer_ops(int index);
void commit_command_buffer(int index);
MTL::ComputeCommandEncoder* get_command_encoder(int index);
void end_encoding(int index);
void register_library(
const std::string& lib_name,
const std::string& lib_path);
void register_library(
const std::string& lib_name,
const std::function<std::string(const std::string&)>& lib_path_func =
get_colocated_mtllib_path);
MTL::ComputePipelineState* get_kernel(
const std::string& name,
const std::string& lib_name = "mlx");
MTL::ArgumentEncoder* argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
private:
NS::AutoreleasePool* pool_;
MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
std::unordered_map<int32_t, MTL::ComputeCommandEncoder*> encoder_map_;
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
std::unordered_map<std::string, MTL::Library*> library_map_;
std::mutex mtx_;
};
Device& device(mlx::core::Device);
NS::AutoreleasePool*& thread_autorelease_pool();
} // namespace mlx::core::metal

View File

@@ -0,0 +1,296 @@
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>
#include "mlx/backend/common/binary.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
} // namespace
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& src = inputs[0];
int nidx = inputs.size() - 1;
if (nidx > METAL_MAX_INDEX_ARRAYS) {
std::ostringstream msg;
msg << "[Gather::eval_gpu] Gathering with more than "
<< METAL_MAX_INDEX_ARRAYS << " index arrays not yet supported.";
throw std::runtime_error(msg.str());
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
std::ostringstream kname;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
size_t slice_size = 1;
for (auto s : slice_sizes_) {
slice_size *= s;
}
size_t ndim = src.ndim();
size_t nthreads = out.size();
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel);
// Make the argument buffer to store the indices for the
// `Indices` struct in kernels/indexing.metal
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[0]->setIndex(0);
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[0]->setArrayLength(nidx);
// Shapes
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[1]->setIndex(nidx + 1);
// Strides
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[2]->setIndex(nidx + 2);
// Indices ndim
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
arg_descs[3]->setIndex(nidx + 3);
// Get the argument encoder
auto arg_enc = d.argument_encoder(arg_descs);
// Allocate and fill buffers for shapes and strides
int idx_ndim = nidx ? inputs[1].ndim() : 0;
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
for (int i = 0; i < nidx; ++i) {
std::copy(
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end(),
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
std::copy(
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end(),
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
}
// Allocate the argument bufer
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
// Register data with the encoder
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
for (int i = 0; i < nidx; ++i) {
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
}
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), MTL::ResourceUsageRead);
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), MTL::ResourceUsageRead);
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
// Set all the buffers
set_array_buffer(compute_encoder, src, 0);
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6);
compute_encoder->setBytes(&slice_size, sizeof(size_t), 7);
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 8);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Cleanup temporaries
arg_enc->release();
d.get_command_buffer(s.index)->addCompletedHandler(
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
allocator::free(arg_buf);
allocator::free(idx_shapes_buf);
allocator::free(idx_strides_buf);
});
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
if (size_of(out.dtype()) == 8) {
std::ostringstream msg;
msg << "[Scatter::eval_gpu] Does not support " << out.dtype();
throw std::invalid_argument(msg.str());
}
int nidx = axes_.size();
if (nidx > METAL_MAX_INDEX_ARRAYS) {
std::ostringstream msg;
msg << "[Scatter::eval_gpu] Gathering with more than "
<< METAL_MAX_INDEX_ARRAYS << " index arrays not yet supported.";
throw std::runtime_error(msg.str());
}
// Copy src into out
auto copy_type =
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
copy_gpu(inputs[0], out, copy_type);
// Get stream
auto& s = stream();
auto& d = metal::device(s.device);
// Get kernel name
std::ostringstream kname;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
kname << "scatter" << type_to_name(out) << idx_type_name;
switch (reduce_type_) {
case Scatter::None:
kname << "_none";
break;
case Scatter::Sum:
kname << "_sum";
break;
case Scatter::Prod:
kname << "_prod";
break;
case Scatter::Max:
kname << "_max";
break;
case Scatter::Min:
kname << "_min";
break;
}
kname << "_" << nidx;
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto& upd = inputs.back();
size_t nthreads = upd.size();
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel);
// Make the argument buffer to store the indices for the
// `Indices` struct in kernels/indexing.metal
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[0]->setIndex(0);
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[0]->setArrayLength(nidx);
// Shapes
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[1]->setIndex(nidx + 1);
// Strides
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[2]->setIndex(nidx + 2);
// Indices ndim
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
arg_descs[3]->setIndex(nidx + 3);
// Get the argument encoder
auto arg_enc = d.argument_encoder(arg_descs);
// Allocate and fill buffers for shapes and strides
int idx_ndim = nidx ? inputs[1].ndim() : 0;
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
for (int i = 0; i < nidx; ++i) {
std::copy(
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end(),
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
std::copy(
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end(),
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
}
// Allocate the argument bufer
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
// Register data with the encoder
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
for (int i = 0; i < nidx; ++i) {
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
}
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), MTL::ResourceUsageRead);
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), MTL::ResourceUsageRead);
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 0);
size_t upd_ndim = upd.ndim();
size_t upd_size = 1;
for (int i = idx_ndim; i < upd.ndim(); ++i) {
upd_size *= upd.shape(i);
}
set_array_buffer(compute_encoder, upd, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
compute_encoder->setBytes(upd.strides().data(), upd_ndim * sizeof(size_t), 4);
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
size_t out_ndim = out.ndim();
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
compute_encoder->setBytes(out.strides().data(), out_ndim * sizeof(size_t), 8);
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Cleanup temporaries
arg_enc->release();
d.get_command_buffer(s.index)->addCompletedHandler(
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
allocator::free(arg_buf);
allocator::free(idx_shapes_buf);
allocator::free(idx_strides_buf);
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,320 @@
#pragma once
#include <metal_atomic>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Atomic utils
///////////////////////////////////////////////////////////////////////////////
#pragma METAL internals : enable
template <typename T>
constexpr constant bool is_metal_atomic = _disjunction<
is_same<T, int>,
is_same<T, uint>,
is_same<T, ulong>,
is_same<T, float>>::value;
#pragma METAL internals : disable
template <typename T, typename = void>
struct mlx_atomic {
atomic<uint> val;
};
template <typename T>
struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
atomic<T> val;
};
///////////////////////////////////////////////////////////////////////////////
// Native metal atomics
///////////////////////////////////////////////////////////////////////////////
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC T
mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_min_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_max_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_add_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_mul_explicit(device mlx_atomic<T>* object, T val, int offset) {
T expected = mlx_atomic_load_explicit(object, offset);
while (!mlx_atomic_compare_exchange_weak_explicit(
object, &expected, val * expected, offset)) {
}
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
device mlx_atomic<T>* object,
thread T* expected,
T val,
int offset) {
return atomic_compare_exchange_weak_explicit(
&(object[offset].val),
expected,
val,
memory_order_relaxed,
memory_order_relaxed);
}
// Specialization for float since it does not atomic_fetch_min_explicit
template <>
METAL_FUNC void mlx_atomic_fetch_min_explicit<float>(
device mlx_atomic<float>* object,
float val,
int offset) {
float expected = mlx_atomic_load_explicit(object, offset);
while (val < expected) {
if (mlx_atomic_compare_exchange_weak_explicit(
object, &expected, val, offset)) {
return;
}
}
}
// Specialization for float since it does not atomic_fetch_max_explicit
template <>
METAL_FUNC void mlx_atomic_fetch_max_explicit<float>(
device mlx_atomic<float>* object,
float val,
int offset) {
float expected = mlx_atomic_load_explicit(object, offset);
while (val > expected) {
if (mlx_atomic_compare_exchange_weak_explicit(
object, &expected, val, offset)) {
return;
}
}
}
///////////////////////////////////////////////////////////////////////////////
// Custom atomics
///////////////////////////////////////////////////////////////////////////////
namespace {
template <typename T>
constexpr constant uint packing_size = sizeof(uint) / sizeof(T);
template <typename T>
union uint_or_packed {
T val[packing_size<T>];
uint bits;
};
template <typename T, typename Op>
struct mlx_atomic_update_helper {
uint operator()(uint_or_packed<T> init, T update, int elem_offset) {
Op op;
init.val[elem_offset] = op(update, init.val[elem_offset]);
return init.bits;
}
};
template <typename T, typename Op>
METAL_FUNC void mlx_atomic_update_and_store(
device mlx_atomic<T>* object,
T update,
int offset) {
int pack_offset = offset / packing_size<T>;
int elem_offset = offset % packing_size<T>;
mlx_atomic_update_helper<T, Op> helper;
uint_or_packed<T> expected;
expected.bits =
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
while (Op::condition(update, expected.val[elem_offset]) &&
!mlx_atomic_compare_exchange_weak_explicit(
object,
&(expected.bits),
helper(expected, update, elem_offset),
pack_offset)) {
}
}
template <typename T>
struct __None {
static bool condition(T a, T b) {
#pragma unused(a)
#pragma unused(b)
return true;
}
T operator()(T a, T b) {
#pragma unused(b)
return a;
}
};
template <typename T>
struct __Add {
static bool condition(T a, T b) {
#pragma unused(a)
#pragma unused(b)
return true;
}
T operator()(T a, T b) {
return a + b;
}
};
template <typename T>
struct __Mul {
static bool condition(T a, T b) {
#pragma unused(a)
return b != 0;
}
T operator()(T a, T b) {
return a * b;
}
};
template <typename T>
struct __Max {
static bool condition(T a, T b) {
return a > b;
}
T operator()(T a, T b) {
return max(a, b);
}
};
template <typename T>
struct __Min {
static bool condition(T a, T b) {
return a < b;
}
T operator()(T a, T b) {
return min(a, b);
}
};
} // namespace
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC T
mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
int pack_offset = offset / sizeof(T);
int elem_offset = offset % sizeof(T);
uint_or_packed<T> packed_val;
packed_val.bits =
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
return packed_val.val[elem_offset];
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
int pack_offset = offset / packing_size<T>;
int elem_offset = offset % packing_size<T>;
uint_or_packed<T> identity;
identity.bits = __UINT32_MAX__;
identity.val[elem_offset] = val;
atomic_fetch_and_explicit(
&(object[pack_offset].val), identity.bits, memory_order_relaxed);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
int pack_offset = offset / packing_size<T>;
int elem_offset = offset % packing_size<T>;
uint_or_packed<T> identity;
identity.bits = 0;
identity.val[elem_offset] = val;
atomic_fetch_or_explicit(
&(object[pack_offset].val), identity.bits, memory_order_relaxed);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_min_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_max_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_add_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_mul_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
device mlx_atomic<T>* object,
thread uint* expected,
uint val,
int offset) {
return atomic_compare_exchange_weak_explicit(
&(object[offset].val),
expected,
val,
memory_order_relaxed,
memory_order_relaxed);
}

View File

@@ -0,0 +1,315 @@
#pragma once
#include <metal_stdlib>
using namespace metal;
#if defined(__HAVE_BFLOAT__)
typedef bfloat bfloat16_t;
#else
/////////////////////////////////////////////////////////////////////////////
// Helpers
/////////////////////////////////////////////////////////////////////////////
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
// Check for nan
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
_fp_encoding_traits<float>::inf_mask) {
return uint16_t(as_type<uint32_t>(0x7FC0));
}
// Take bits
uint32_t float_bits = as_type<uint32_t>(x);
// Round to nearest even
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
// Take upper 16 bits
return float_bits >> 16;
}
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
// Upper 16 bits are the data and lower 16 bits are 0s
return as_type<float>((uint32_t)x << 16);
}
struct _MLX_BFloat16;
template <typename T>
static constexpr constant bool can_convert_to_bfloat =
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
template <typename T>
static constexpr constant bool can_convert_from_bfloat =
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
/////////////////////////////////////////////////////////////////////////////
// Bfloat struct
/////////////////////////////////////////////////////////////////////////////
struct _MLX_BFloat16 {
/////////////////////////////////////////////////////////////////////////////
// Constructors
uint16_t bits_;
_MLX_BFloat16() thread = default;
_MLX_BFloat16() threadgroup = default;
_MLX_BFloat16() device = default;
_MLX_BFloat16() constant = default;
struct bits_to_bfloat_struct {};
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
return bits_to_bfloat_struct();
}
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
: bits_(bits) {}
/////////////////////////////////////////////////////////////////////////////
// Conversions to bfloat
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) device
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
/////////////////////////////////////////////////////////////////////////////
// Conversions from bfloat
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const thread {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const threadgroup {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const device {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const constant {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
};
/////////////////////////////////////////////////////////////////////////////
// Bfloat operators
/////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////
// Unary ops
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
return -static_cast<float>(x);
}
/////////////////////////////////////////////////////////////////////////////
// Binary operators
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
} \
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}
/////////////////////////////////////////////////////////////////////////////
// Arithmetic Operators
#define bfloat_binop(_op_, _operator_) \
bfloat_binop_base( \
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
bfloat_binop_helper(_op_, _operator_, float, float, float); \
bfloat_binop_helper(_op_, _operator_, float, half, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
bfloat_binop(+, operator+);
bfloat_binop(-, operator-);
bfloat_binop(*, operator*);
bfloat_binop(/, operator/);
/////////////////////////////////////////////////////////////////////////////
// Comparison ops
#define bfloat_compop(__op__, __operator__) \
bfloat_binop_base( \
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
bfloat_compop(>, operator>);
bfloat_compop(<, operator<);
bfloat_compop(>=, operator>=);
bfloat_compop(<=, operator<=);
bfloat_compop(==, operator==);
bfloat_compop(!=, operator!=);
#undef bfloat_compop
#undef bfloat_binop_base
#undef bfloat_binop_helper
#undef bfloat_binop
/////////////////////////////////////////////////////////////////////////////
// Inplace Operators
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
addr_space _MLX_BFloat16& lhs, itype rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
} \
constexpr METAL_FUNC addr_space itype& __operator__( \
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
}
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
#define bfloat_inplace_op(itype) \
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
bfloat_inplace_op(float);
bfloat_inplace_op(half);
bfloat_inplace_op(int16_t);
bfloat_inplace_op(int32_t);
bfloat_inplace_op(int64_t);
bfloat_inplace_op(uint16_t);
bfloat_inplace_op(uint32_t);
bfloat_inplace_op(uint64_t);
#undef bfloat_inplace_op_helper
#undef bfloat_inplace_op_addr_space_helper
#undef bfloat_inplace_op
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
}
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
bfloat_inplace_op_helper(__op__, __operator__, device); \
bfloat_inplace_op_helper(__op__, __operator__, thread); \
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
bfloat_inplace_op_addr_space_helper(+, operator+=);
bfloat_inplace_op_addr_space_helper(-, operator-=);
bfloat_inplace_op_addr_space_helper(*, operator*=);
bfloat_inplace_op_addr_space_helper(/, operator/=);
#undef bfloat_inplace_op_helper
#undef bfloat_inplace_op_addr_space_helper
/////////////////////////////////////////////////////////////////////////////
// Bfloat typedef
/////////////////////////////////////////////////////////////////////////////
typedef struct _MLX_BFloat16 bfloat16_t;
/////////////////////////////////////////////////////////////////////////////
// Bfloat numeric limits
/////////////////////////////////////////////////////////////////////////////
#pragma METAL internals : enable
namespace metal {
template <>
struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
static constexpr constant int digits = 8;
static constexpr constant int digits10 = 2;
static constexpr constant int max_digits10 = 4;
static constexpr constant int radix = 2;
static constexpr constant int min_exponent = -125;
static constexpr constant int min_exponent10 = -37;
static constexpr constant int max_exponent = 128;
static constexpr constant int max_exponent10 = 38;
static constexpr bfloat16_t min() {
return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t lowest() {
return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t max() {
return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t epsilon() {
return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t round_error() {
return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t infinity() {
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t quiet_NaN() {
return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t signaling_NaN() {
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t denorm_min() {
return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat());
}
};
METAL_FUNC bool isnan(_MLX_BFloat16 x) {
return x != x;
}
} // namespace metal
#pragma METAL internals : disable
#endif // defined(__HAVE_BFLOAT__)
#include "mlx/backend/metal/kernels/bf16_math.h"

View File

@@ -0,0 +1,369 @@
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
struct Add {
template <typename T> T operator()(T x, T y) { return x + y; }
};
struct Divide {
template <typename T> T operator()(T x, T y) { return x / y; }
};
struct Equal {
template <typename T> bool operator()(T x, T y) { return x == y; }
};
struct NaNEqual {
template <typename T> bool operator()(T x, T y) {
return x == y || (metal::isnan(x) && metal::isnan(y));
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x == y ||
(metal::isnan(x.real) && metal::isnan(y.real)
&& metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
}
};
struct Greater {
template <typename T> bool operator()(T x, T y) { return x > y; }
};
struct GreaterEqual {
template <typename T> bool operator()(T x, T y) { return x >= y; }
};
struct Less {
template <typename T> bool operator()(T x, T y) { return x < y; }
};
struct LessEqual {
template <typename T> bool operator()(T x, T y) { return x <= y; }
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
constexpr T inf = metal::numeric_limits<T>::infinity();
T maxval = metal::max(x, y);
T minval = metal::min(x, y);
return (minval == -inf || maxval == inf) ? maxval :
(maxval + log1p(metal::exp(minval - maxval)));
};
};
struct Maximum {
template <typename T> T operator()(T x, T y) { return metal::max(x, y); }
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x >= y ? x : y;
}
};
struct Minimum {
template <typename T> T operator()(T x, T y) { return metal::min(x, y); }
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x <= y ? x : y;
}
};
struct Multiply {
template <typename T> T operator()(T x, T y) { return x * y; }
};
struct NotEqual {
template <typename T> bool operator()(T x, T y) { return x != y; }
template <>
bool operator()(complex64_t x, complex64_t y) {
return x.real != y.real || x.imag != y.imag;
}
};
struct Power {
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
return metal::pow(base, exp);
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
auto x_theta = metal::atan(x.imag / x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
auto phase = y.imag * x_ln_r + y.real * x_theta;
return {mag * metal::cos(phase), mag * metal::sin(phase)};
}
};
struct Subtract {
template <typename T> T operator()(T x, T y) { return x - y; }
};
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_s2s(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_ss(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_sv(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[index]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_vs(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_vv(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[index]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g_nd1(
device const T* a,
device const T* b,
device U* c,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g_nd2(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g_nd3(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_op_g_nd(
device const T* a,
device const T* b,
device U* c,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g(
device const T* a,
device const T* b,
device U* c,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
#define instantiate_binary(name, itype, otype, op, bopt) \
template [[host_name(name)]] \
[[kernel]] void binary_op_##bopt<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void binary_op_g_nd<itype, otype, op, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_g_nd(name, itype, otype, op) \
template [[host_name(name "_1")]] \
[[kernel]] void binary_op_g_nd1<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void binary_op_g_nd2<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void binary_op_g_nd3<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5)
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name(name)]] \
[[kernel]] void binary_op_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_all(name, tname, itype, otype, op) \
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op)
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_float(name, op)
#define instantiate_binary_types_bool(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
instantiate_binary_all(name, int8, int8_t, bool, op) \
instantiate_binary_all(name, int16, int16_t, bool, op) \
instantiate_binary_all(name, int32, int32_t, bool, op) \
instantiate_binary_all(name, int64, int64_t, bool, op) \
instantiate_binary_all(name, float16, half, bool, op) \
instantiate_binary_all(name, float32, float, bool, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
instantiate_binary_all(name, complex64, complex64_t, bool, op)
instantiate_binary_types(add, Add)
instantiate_binary_float(div, Divide)
instantiate_binary_types_bool(eq, Equal)
instantiate_binary_types_bool(ge, Greater)
instantiate_binary_types_bool(geq, GreaterEqual)
instantiate_binary_types_bool(le, Less)
instantiate_binary_types_bool(leq, LessEqual)
instantiate_binary_types_bool(neq, NotEqual)
instantiate_binary_float(lae, LogAddExp)
instantiate_binary_types(max, Maximum)
instantiate_binary_types(min, Minimum)
instantiate_binary_types(mul, Multiply)
instantiate_binary_types(sub, Subtract)
instantiate_binary_types(pow, Power)
// NaNEqual only needed for floating point types with boolean output
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
instantiate_binary_all(naneq, float32, float, bool, NaNEqual)
instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)

View File

@@ -0,0 +1,110 @@
#pragma once
#include <metal_stdlib>
using namespace metal;
struct complex64_t;
template <typename T>
static constexpr constant bool can_convert_to_complex64 =
!is_same_v<T, complex64_t> && is_convertible_v<T, float>;
template <typename T>
static constexpr constant bool can_convert_from_complex64 =
!is_same_v<T, complex64_t> &&
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
struct complex64_t {
float real;
float imag;
// Constructors
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
// Conversions to complex64_t
template <
typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) thread : real(x), imag(0) {}
template <
typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
template <
typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) device : real(x), imag(0) {}
template <
typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) constant : real(x), imag(0) {}
// Converstions from complex64_t
template <
typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const thread {
return static_cast<T>(real);
}
template <
typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const threadgroup {
return static_cast<T>(real);
}
template <
typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const device {
return static_cast<T>(real);
}
template <
typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const constant {
return static_cast<T>(real);
}
};
constexpr complex64_t operator-(complex64_t x) {
return {-x.real, -x.imag};
}
constexpr bool operator>=(complex64_t a, complex64_t b) {
return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
}
constexpr bool operator>(complex64_t a, complex64_t b) {
return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
}
constexpr bool operator<=(complex64_t a, complex64_t b) {
return operator>=(b, a);
}
constexpr bool operator<(complex64_t a, complex64_t b) {
return operator>(b, a);
}
constexpr bool operator==(complex64_t a, complex64_t b) {
return a.real == b.real && a.imag == b.imag;
}
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
return {a.real + b.real, a.imag + b.imag};
}
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
return {a.real - b.real, a.imag - b.imag};
}
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
}

View File

@@ -0,0 +1,14 @@
#pragma once
#ifdef __METAL__
#define MTL_CONST constant
#else
#define MTL_CONST
#endif
static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
static MTL_CONST constexpr int REDUCE_N_READS = 16;
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;

View File

@@ -0,0 +1,479 @@
#pragma once
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/conv_params.h"
#define MLX_MTL_CONST static constant constexpr const
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int vec_size,
int tgp_size,
int tgp_padding = 0>
struct Conv2DInputBlockLoader {
// Destination dimensions
MLX_MTL_CONST int dst_fd = BM;
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
MLX_MTL_CONST int n_vecs = BK / vec_size;
// Stride along block row within the block
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
MLX_MTL_CONST int n_rows = dst_fd / bstride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
const constant MLXConvParams<2>& params;
int weight_h;
int weight_w;
int offsets_n[n_rows];
int offsets_oh[n_rows];
int offsets_ow[n_rows];
/* Constructor */
METAL_FUNC Conv2DInputBlockLoader(
const device T* src_,
threadgroup T* dst_,
const constant MLXConvParams<2>& params_,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / n_vecs),
bj(vec_size * (thread_idx % n_vecs)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bj),
params(params_),
weight_h(0),
weight_w(0) {
int out_n_pixels = params.oS[0] * params.oS[1];
for (int i = 0; i < n_rows; ++i) {
int offset_nhw = tid.y * BM + bi + i * bstride;
offsets_n[i] = offset_nhw / out_n_pixels;
int hw = offset_nhw % out_n_pixels;
offsets_oh[i] = hw / params.oS[1];
offsets_ow[i] = hw % params.oS[1];
}
(void)lid;
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
#pragma clang loop unroll(full)
for (short i = 0, is = 0; i < n_rows; ++i, is += bstride) {
int n = offsets_n[i];
int oh = offsets_oh[i];
int ow = offsets_ow[i];
int ih = oh * params.str[0] - params.pad[0] + weight_h * params.dil[0];
int iw = ow * params.str[1] - params.pad[1] + weight_w * params.dil[1];
// Read from input if in bounds
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
const device T* curr_src = src + n * params.in_strides[0] +
ih * params.in_strides[1] + iw * params.in_strides[2];
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = curr_src[j];
}
}
// Zero pad otherwize
else {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_w < params.wS[1]) {
return;
}
weight_w = 0;
if (++weight_h < params.wS[0]) {
return;
}
weight_h = 0;
src += BK;
}
};
template <
typename T,
int BM,
int BN,
int BK,
int vec_size,
int tgp_size,
int tgp_padding = 0>
struct Conv2DWeightBlockLoader {
// Destination dimensions
MLX_MTL_CONST int dst_fd = BN;
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
MLX_MTL_CONST int n_vecs = BK / vec_size;
// Stride along block row within the block
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
MLX_MTL_CONST int n_rows = dst_fd / bstride;
// Leading dimension for src
const int src_ld;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
const constant MLXConvParams<2>& params;
int weight_h;
int weight_w;
/* Constructor */
METAL_FUNC Conv2DWeightBlockLoader(
const device T* src_,
threadgroup T* dst_,
const constant MLXConvParams<2>& params_,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_.wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / n_vecs),
bj(vec_size * (thread_idx % n_vecs)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj),
params(params_),
weight_h(0),
weight_w(0) {
(void)lid;
(void)tid;
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
const device T* curr_src =
src + weight_h * params.wt_strides[1] + weight_w * params.wt_strides[2];
#pragma clang loop unroll(full)
for (short i = 0; i < dst_fd; i += bstride) {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_w < params.wS[1]) {
return;
}
weight_w = 0;
if (++weight_h < params.wS[0]) {
return;
}
weight_h = 0;
src += BK;
}
};
///////////////////////////////////////////////////////////////////////////////
// Transforms
///////////////////////////////////////////////////////////////////////////////
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
int tgp_padding_a = 0,
int tgp_padding_b = 0,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct Conv2DBlockMMA {
// Warp tile size along M
MLX_MTL_CONST int TM = BM / (WM * 8);
// Warp tile size along N
MLX_MTL_CONST int TN = BN / (WN * 8);
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TN_stride = 8 * WN;
// Leading dimensions of threadgroup A, B blocks
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
// Strides of A, B along reduction axis
MLX_MTL_CONST short simd_stride_a =
transpose_a ? TM_stride : TM_stride * lda_tgp;
MLX_MTL_CONST short simd_stride_b =
transpose_b ? TN_stride * ldb_tgp : TN_stride;
// Jump between elements
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
// Offsets within threadgroup
const int tm;
const int tn;
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
short sm;
short sn;
/* Constructor */
METAL_FUNC Conv2DBlockMMA(
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
short qid = simd_lane_id / 4;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Iterate over BK in blocks of 8
#pragma clang loop unroll(full)
for (short kk = 0; kk < BK; kk += 8) {
short2 offset_a =
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
short2 offset_b =
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
As__ += simd_stride_a;
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
Bs__ += simd_stride_b;
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into resulr simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
simdgroup_multiply_accumulate(
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device T* C, const int ldc) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
METAL_FUNC void
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
}
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct Conv2DImplicitGEMMKernel {
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
MLX_MTL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
MLX_MTL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
MLX_MTL_CONST short tgp_size = WM * WN * 32;
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
using loader_a_t =
Conv2DInputBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_a>;
using loader_b_t =
Conv2DWeightBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_b>;
using mma_t = Conv2DBlockMMA<
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
tgp_padding_a,
tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const int K = params.wt_strides[0];
const int N = params.O;
B += c_col * K;
C += c_row * N + c_col;
// Prepare threadgroup memory for loading
threadgroup T* As = tgp_memory;
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
// Prepare threadgroup loading operations
loader_a_t loader_a(A, As, params, tid, lid, simd_gid, simd_lid);
loader_b_t loader_b(B, Bs, params, tid, lid, simd_gid, simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
mma_op.store_result(C, N);
}
};

View File

@@ -0,0 +1,536 @@
#pragma once
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#define MLX_MTL_CONST static constant constexpr const
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BROWS,
int BCOLS,
int BK,
int vec_size,
int tgp_size,
bool transpose,
bool ldK,
int tgp_padding = 0>
struct BlockLoader {
// Destination dimensions
MLX_MTL_CONST int dst_fd = transpose ? BCOLS : BROWS;
MLX_MTL_CONST int dst_ld = (transpose ? BROWS : BCOLS) + tgp_padding;
MLX_MTL_CONST int n_vecs = (transpose ? BROWS : BCOLS) / vec_size;
// Stride along block row within the block
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
// Leading dimension for src
const int src_ld;
// Stride along reduction axis between blocks
const int tstride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
/* Constructor */
METAL_FUNC BlockLoader(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tstride(
BK * ((int)(transpose ^ !ldK) * src_ld + (int)(transpose ^ ldK))),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / n_vecs),
bj(vec_size * (thread_idx % n_vecs)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
#pragma clang loop unroll(full)
for (short i = 0; i < dst_fd; i += bstride) {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = src[i * src_ld + j];
}
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = transpose ? src_tile_dim.yx : src_tile_dim.xy;
// Iterate over rows of block
#pragma clang loop unroll(full)
for (short i = 0; i < dst_fd; i += bstride) {
// Row is in bounds, we check against column
if ((bi + i) < src_tile_dim.y) {
// Use fast thread memory for bound checks
short tmp_idx[vec_size];
T tmp_val[vec_size];
// Make sure tmp_idx only contains valid indices
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
}
// Read all valid indcies into tmp_val
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[i * src_ld + tmp_idx[j]];
}
// Zero out uneeded values
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
// Row is out of bounds, we just fill tgp memory with zeros
else {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tstride;
}
};
///////////////////////////////////////////////////////////////////////////////
// Transforms
///////////////////////////////////////////////////////////////////////////////
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
int tgp_padding_a = 0,
int tgp_padding_b = 0,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct BlockMMA {
// Warp tile size along M
MLX_MTL_CONST int TM = BM / (WM * 8);
// Warp tile size along N
MLX_MTL_CONST int TN = BN / (WN * 8);
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TN_stride = 8 * WN;
// Leading dimensions of threadgroup A, B blocks
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
// Strides of A, B along reduction axis
MLX_MTL_CONST short simd_stride_a =
transpose_a ? TM_stride : TM_stride * lda_tgp;
MLX_MTL_CONST short simd_stride_b =
transpose_b ? TN_stride * ldb_tgp : TN_stride;
// Jump between elements
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
// Offsets within threadgroup
const int tm;
const int tn;
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
short sm;
short sn;
/* Constructor */
METAL_FUNC BlockMMA(
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
short qid = simd_lane_id / 4;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Iterate over BK in blocks of 8
#pragma clang loop unroll(full)
for (short kk = 0; kk < BK; kk += 8) {
short2 offset_a =
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
short2 offset_b =
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
As__ += simd_stride_a;
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
Bs__ += simd_stride_b;
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into resulr simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
simdgroup_multiply_accumulate(
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device T* C, const int ldc) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
METAL_FUNC void
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
}
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct GEMMKernel {
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
MLX_MTL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
MLX_MTL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
MLX_MTL_CONST short tgp_size = WM * WN * 32;
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
using loader_a_t = BlockLoader<
T,
BM,
BK,
BK,
vec_size,
tgp_size,
transpose_a,
true,
tgp_padding_a>;
using loader_b_t = BlockLoader<
T,
BK,
BN,
BK,
vec_size,
tgp_size,
transpose_b,
false,
tgp_padding_b>;
using mma_t = BlockMMA<
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
tgp_padding_a,
tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant int& M [[buffer(3)]],
const constant int& N [[buffer(4)]],
const constant int& K [[buffer(5)]],
const constant int& batch_stride_a [[buffer(6)]],
const constant int& batch_stride_b [[buffer(7)]],
const constant int& batch_stride_c [[buffer(8)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
// Adjust for batch
A += batch_stride_a * tid.z;
B += batch_stride_b * tid.z;
C += batch_stride_c * tid.z;
// Adjust for transpose
const int lda_dev = transpose_a ? M : K;
const int ldb_dev = transpose_b ? K : N;
// Find block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
A += transpose_a ? c_row : c_row * K;
B += transpose_b ? c_col * K : c_col;
C += c_row * N + c_col;
// Prepare threadgroup memory for loading
threadgroup T* As = tgp_memory;
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
// Prepare threadgroup loading operations
loader_a_t loader_a(A, lda_dev, As, simd_group_id, simd_lane_id);
loader_b_t loader_b(B, ldb_dev, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
mma_t mma_op(simd_group_id, simd_lane_id);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned && K_aligned) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
mma_op.store_result(C, N);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN aligned, K unaligned loop
else if (MN_aligned && !K_aligned) {
// Main loop
int k = 0;
for (; k + BK <= K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Loop tail
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_a.load_safe(short2(K - k, BM));
loader_b.load_safe(short2(BN, K - k));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
// Store results to device memory
mma_op.store_result(C, N);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MNK unaligned loop
else { // Loop over K - unaligned case
short2 src_tile_dims(min(BN, N - c_col), min(BM, M - c_row));
if (src_tile_dims.y == BM && src_tile_dims.x == BN) {
int k = 0;
for (; k + BK <= K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
if (k < K) {
loader_a.load_safe(short2(K - k, BM));
loader_b.load_safe(short2(BN, K - k));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
mma_op.store_result(C, N);
return;
} else {
int k = 0;
for (; k + BK <= K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_safe(short2(BK, src_tile_dims.y));
loader_b.load_safe(short2(src_tile_dims.x, BK));
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
if (k < K) {
loader_a.load_safe(short2(K - k, src_tile_dims.y));
loader_b.load_safe(short2(src_tile_dims.x, K - k));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
threadgroup_barrier(mem_flags::mem_none);
mma_op.store_result_safe(C, N, src_tile_dims);
return;
}
}
}
};

View File

@@ -0,0 +1,302 @@
#include <metal_stdlib>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/bf16.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
/// Matrix vector multiplication
///////////////////////////////////////////////////////////////////////////////
static constant constexpr int SIMD_SIZE = 32;
template <typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel]] void gemv(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& vector_batch_stride [[buffer(5)]],
const constant int& matrix_batch_stride [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each thead group is launched with (BN, BM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across the rows
// These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results remain zero)
// * The last thread that partialy overlaps with the matrix is shifted inwards
// such that the thread block fits exactly in the matrix
// Update batch offsets
in_vec += tid.z * vector_batch_stride;
mat += tid.z * matrix_batch_stride;
out_vec += tid.z * out_vec_size;
// Threadgroup in_vec cache
threadgroup T in_vec_block[BN][TN * 2];
// Thread local accumulation results
thread T result[TM] = {0};
thread T inter[TN];
thread T v_coeff[TN];
// Block position
int out_row = (tid.x * BM + simd_gid) * TM;
// Exit simdgroup if rows out of bound
if(out_row >= out_vec_size)
return;
// Adjust tail simdgroup to ensure in bound reads
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
// Advance matrix
mat += out_row * in_vec_size;
// Loop over in_vec in blocks of BN * TN
for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Prefetch in_vector for threadgroup use
if(simd_gid == 0) {
// Main load loop
if(bn + TN <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
in_vec_block[simd_lid][tn] = in_vec[bn + tn];
}
} else { // Edgecase
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
in_vec_block[simd_lid][tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load for all rows
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
v_coeff[tn] = in_vec_block[simd_lid][tn];
}
// Per thread work loop
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
// Load for the row
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[tm * in_vec_size + bn + tn];
}
// Accumulate results
for(int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
}
}
// Simdgroup accumulations
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
result[tm] = simd_sum(result[tm]);
}
// Write outputs
if(simd_lid == 0) {
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
out_vec[out_row + tm] = result[tm];
}
}
}
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
[[kernel]] void gemv<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \
const constant int& vector_batch_stride [[buffer(5)]], \
const constant int& matrix_batch_stride [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_blocks(name, itype) \
instantiate_gemv(name, itype, 4, 32, 1, 4) \
instantiate_gemv(name, itype, 4, 32, 4, 4) \
instantiate_gemv(name, itype, 8, 32, 4, 4)
instantiate_gemv_blocks(float32, float)
instantiate_gemv_blocks(float16, half)
instantiate_gemv_blocks(bfloat16, bfloat16_t)
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel]] void gemv_t(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& vector_batch_stride [[buffer(5)]],
const constant int& matrix_batch_stride [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each thead group is launched with (BN, BM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across the rows
// These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results remain zero)
// * The last thread that partialy overlaps with the matrix is shifted inwards
// such that the thread block fits exactly in the matrix
// Update batch offsets
in_vec += tid.z * vector_batch_stride;
mat += tid.z * matrix_batch_stride;
out_vec += tid.z * out_vec_size;
// Thread local accumulation results
T result[TN] = {0};
T inter[TN];
T v_coeff[TM];
// Threadgroup accumulation results
threadgroup T tgp_results[BN][BM][TM];
int out_col = (tid.x * BN + lid.x) * TN;
int in_row = lid.y * TM;
// Edgecase handling
if (out_col < out_vec_size) {
// Edgecase handling
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
// Per thread accumulation main loop
int bm = in_row;
for(; bm < in_vec_size; bm += BM * TM) {
// Adding a threadgroup_barrier improves performance slightly
// This is possibly it may help exploit cache better
threadgroup_barrier(mem_flags::mem_none);
if(bm + TM <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
v_coeff[tm] = in_vec[bm + tm];
}
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
}
for(int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
} else { // Edgecase handling
for(int tm = 0; bm + tm < in_vec_size; tm++) {
v_coeff[tm] = in_vec[bm + tm];
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
}
for(int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
}
}
// Threadgroup collection
for(int i = 0; i < TN; i++) {
tgp_results[lid.x][lid.y][i] = result[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if(lid.y == 0 && out_col < out_vec_size) {
// Threadgroup accumulation
#pragma clang loop unroll(full)
for(int i = 1; i < BM; i++) {
for(int j = 0; j < TN; j++) {
result[j] += tgp_results[lid.x][i][j];
}
}
for(int j = 0; j < TN; j++) {
out_vec[out_col + j] = result[j];
}
}
}
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
[[kernel]] void gemv_t<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \
const constant int& vector_batch_stride [[buffer(5)]], \
const constant int& matrix_batch_stride [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemv_t_blocks(name, itype) \
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
instantiate_gemv_t(name, itype, 8, 128, 4, 4)
instantiate_gemv_t_blocks(float32, float)
instantiate_gemv_t_blocks(float16, half)
instantiate_gemv_t_blocks(bfloat16, bfloat16_t)

View File

@@ -0,0 +1,226 @@
#include <metal_atomic>
#include <metal_common>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
template <typename T>
inline T softmax_exp(T x) {
// Softmax doesn't need high precision exponential cause it is gonna be x
// will be in (-oo, 0] anyway and subsequently it will be divided by
// sum(exp(x_i)).
return fast::exp(x);
}
template <typename T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_single_row(
const device T* in,
device T* out,
constant int& axis_size,
threadgroup T* local_max [[threadgroup(0)]],
threadgroup T* local_normalizer [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
T ld[N_READS];
in += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
ld[i] = in[i];
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] =
((lid * N_READS + i) < axis_size) ? in[i] : T(Limits<T>::finite_min);
}
}
if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<T>::finite_min;
local_normalizer[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max
T maxval = Limits<T>::finite_min;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < ld[i]) ? ld[i] : maxval;
}
maxval = simd_max(maxval);
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
maxval = simd_max(local_max[simd_lane_id]);
if (simd_lane_id == 0) {
local_max[0] = maxval;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
T normalizer = 0;
for (int i = 0; i < N_READS; i++) {
T exp_x = softmax_exp(ld[i] - maxval);
ld[i] = exp_x;
normalizer += exp_x;
}
normalizer = simd_sum(normalizer);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
local_normalizer[0] = normalizer;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = 1 / local_normalizer[0];
// Normalize and write to the output
out += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
out[i] = ld[i] * normalizer;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
out[i] = ld[i] * normalizer;
}
}
}
}
template <typename T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_looped(
const device T* in,
device T* out,
constant int& axis_size,
threadgroup T* local_max [[threadgroup(0)]],
threadgroup T* local_normalizer [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * axis_size;
// Get the max and the normalizer in one go
T prevmax;
T maxval = Limits<T>::finite_min;
T normalizer = 0;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
T vals[N_READS];
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
vals[i] = in[offset + i];
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] =
(offset + i < axis_size) ? in[offset + i] : T(Limits<T>::finite_min);
}
}
prevmax = maxval;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < vals[i]) ? vals[i] : maxval;
}
normalizer *= softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer += softmax_exp(vals[i] - maxval);
}
}
// Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
// lsize) parts. We need to combine them.
// 1. We start by finding the max across simd groups
// 2. We then change the partial normalizers to account for a possible
// change in max
// 3. We sum all normalizers
prevmax = maxval;
maxval = simd_max(maxval);
normalizer *= softmax_exp(prevmax - maxval);
normalizer = simd_sum(normalizer);
// Now the normalizer and max value is correct for each simdgroup. We write
// them shared memory and combine them.
prevmax = maxval;
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = simd_max(local_max[simd_lane_id]);
normalizer *= softmax_exp(prevmax - maxval);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
normalizer = 1 / normalizer;
// Finally given the normalizer and max value we can directly write the
// softmax output
out += gid * axis_size;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
if (offset + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
}
} else {
for (int i = 0; i < N_READS; i++) {
if (offset + i < axis_size) {
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
}
}
}
}
}
#define instantiate_softmax_single_row(name, itype) \
template [[host_name("softmax_" #name)]] [[kernel]] void \
softmax_single_row<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
threadgroup itype* local_max [[threadgroup(0)]], \
threadgroup itype* local_normalizer [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax_looped(name, itype) \
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
softmax_looped<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
threadgroup itype* local_max [[threadgroup(0)]], \
threadgroup itype* local_normalizer [[threadgroup(1)]], \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax(name, itype) \
instantiate_softmax_single_row(name, itype) \
instantiate_softmax_looped(name, itype)
instantiate_softmax(float32, float) instantiate_softmax(float16, half)
instantiate_softmax(bfloat16, bfloat16_t)

View File

@@ -0,0 +1,818 @@
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#define MLX_MTL_CONST static constant constexpr const
#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
using namespace metal;\
// Based on GPU merge sort algorithm at https://github.com/NVIDIA/cccl/tree/main/cub/cub
///////////////////////////////////////////////////////////////////////////////
// Thread-level sort
///////////////////////////////////////////////////////////////////////////////
template <typename T>
METAL_FUNC void thread_swap(thread T& a, thread T& b) {
T w = a;
a = b;
b = w;
}
template <typename T>
struct LessThan {
static constexpr constant T init = Limits<T>::max;
METAL_FUNC bool operator()(T a, T b) {
return a < b;
}
};
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short N_PER_THREAD,
typename CompareOp>
struct ThreadSort {
static METAL_FUNC void sort(
thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op;
MLX_MTL_LOOP_UNROLL
for(short i = 0; i < N_PER_THREAD; ++i) {
MLX_MTL_LOOP_UNROLL
for(short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
if(op(vals[j + 1], vals[j])) {
thread_swap(vals[j + 1], vals[j]);
thread_swap(idxs[j + 1], idxs[j]);
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// Threadgroup-level sort
///////////////////////////////////////////////////////////////////////////////
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp>
struct BlockMergeSort {
using thread_sort_t = ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
static METAL_FUNC int merge_partition(
const threadgroup val_t* As,
const threadgroup val_t* Bs,
short A_sz,
short B_sz,
short sort_md) {
CompareOp op;
short A_st = max(0, sort_md - B_sz);
short A_ed = min(sort_md, A_sz);
while(A_st < A_ed) {
short md = A_st + (A_ed - A_st) / 2;
auto a = As[md];
auto b = Bs[sort_md - 1 - md];
if(op(b, a)) {
A_ed = md;
} else {
A_st = md + 1;
}
}
return A_ed;
}
static METAL_FUNC void merge_step(
const threadgroup val_t* As,
const threadgroup val_t* Bs,
const threadgroup idx_t* As_idx,
const threadgroup idx_t* Bs_idx,
short A_sz,
short B_sz,
thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op;
short a_idx = 0;
short b_idx = 0;
for(int i = 0; i < N_PER_THREAD; ++i) {
auto a = As[a_idx];
auto b = Bs[b_idx];
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
vals[i] = pred ? b : a;
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
b_idx += short(pred);
a_idx += short(!pred);
}
}
static METAL_FUNC void sort(
threadgroup val_t* tgp_vals [[threadgroup(0)]],
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
int size_sorted_axis,
uint3 lid [[thread_position_in_threadgroup]]) {
// Get thread location
int idx = lid.x * N_PER_THREAD;
// Load from shared memory
thread val_t thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD];
for(int i = 0; i < N_PER_THREAD; ++i) {
thread_vals[i] = tgp_vals[idx + i];
if(ARG_SORT) {
thread_idxs[i] = tgp_idxs[idx + i];
}
}
// Per thread sort
if(idx < size_sorted_axis) {
thread_sort_t::sort(thread_vals, thread_idxs);
}
// Do merges using threadgroup memory
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; merge_threads *= 2) {
// Update threadgroup memory
threadgroup_barrier(mem_flags::mem_threadgroup);
for(int i = 0; i < N_PER_THREAD; ++i) {
tgp_vals[idx + i] = thread_vals[i];
if(ARG_SORT) {
tgp_idxs[idx + i] = thread_idxs[i];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Find location in merge step
int merge_group = lid.x / merge_threads;
int merge_lane = lid.x % merge_threads;
int sort_sz = N_PER_THREAD * merge_threads;
int sort_st = N_PER_THREAD * merge_threads * merge_group;
// As = tgp_vals[A_st:A_ed] is sorted
// Bs = tgp_vals[B_st:B_ed] is sorted
int A_st = sort_st;
int A_ed = sort_st + sort_sz / 2;
int B_st = sort_st + sort_sz / 2;
int B_ed = sort_st + sort_sz;
const threadgroup val_t* As = tgp_vals + A_st;
const threadgroup val_t* Bs = tgp_vals + B_st;
int A_sz = A_ed - A_st;
int B_sz = B_ed - B_st;
// Find a partition of merge elements
// Ci = merge(As[partition:], Bs[sort_md - partition:])
// of size N_PER_THREAD for each merge lane i
// C = [Ci] is sorted
int sort_md = N_PER_THREAD * merge_lane;
int partition = merge_partition(
As,
Bs,
A_sz,
B_sz,
sort_md);
As += partition;
Bs += sort_md - partition;
A_sz -= partition;
B_sz -= sort_md - partition;
const threadgroup idx_t* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
const threadgroup idx_t* Bs_idx = ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
// Merge starting at the partition and store results in thread registers
merge_step(
As,
Bs,
As_idx,
Bs_idx,
A_sz,
B_sz,
thread_vals,
thread_idxs);
}
// Write out to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
for(int i = 0; i < N_PER_THREAD; ++i) {
tgp_vals[idx + i] = thread_vals[i];
if(ARG_SORT) {
tgp_idxs[idx + i] = thread_idxs[i];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// Kernel sort
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<T>>
struct KernelMergeSort {
using val_t = T;
using idx_t = uint;
using block_merge_sort_t = BlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
static METAL_FUNC void block_sort(
const device T* inp,
device U* out,
const constant int& size_sorted_axis,
const constant int& stride_sorted_axis,
const constant int& stride_segment_axis,
threadgroup val_t* tgp_vals,
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
inp += tid.y * stride_segment_axis;
out += tid.y * stride_segment_axis;
// Copy into threadgroup memory
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] : val_t(CompareOp::init);
if(ARG_SORT) {
tgp_idxs[i] = i;
}
}
// Sort elements within the block
threadgroup_barrier(mem_flags::mem_threadgroup);
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
for(int i = lid.x; i < size_sorted_axis; i+= BLOCK_THREADS) {
if(ARG_SORT) {
out[i * stride_sorted_axis] = tgp_idxs[i];
} else {
out[i * stride_sorted_axis] = tgp_vals[i];
}
}
}
};
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort(
const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
if(ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
stride_segment_axis,
tgp_vals,
tgp_idxs,
tid,
lid);
} else {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
stride_segment_axis,
tgp_vals,
nullptr,
tid,
lid);
}
}
constant constexpr const int zero_helper = 0;
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc(
const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
inp += block_idx;
out += block_idx;
if(ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
zero_helper,
tgp_vals,
tgp_idxs,
tid,
lid);
} else {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
zero_helper,
tgp_vals,
nullptr,
tid,
lid);
}
}
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_block_sort(name, itname, itype, otname, otype, arg_sort, bn, tn) \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void block_sort<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& stride_segment_axis [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn "_nc")]] \
[[kernel]] void block_sort_nc<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& nc_dim [[buffer(4)]], \
const device int* nc_shape [[buffer(5)]], \
const device size_t* nc_strides [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort(arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
#define instantiate_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort(block_merge_sort, itname, itype, itname, itype, false, bn, tn)
#define instantiate_block_sort_tn(itname, itype, bn) \
instantiate_block_sort_base(itname, itype, bn, 8) \
instantiate_arg_block_sort_base(itname, itype, bn, 8)
#define instantiate_block_sort_bn(itname, itype) \
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256) \
instantiate_block_sort_tn(itname, itype, 512)
instantiate_block_sort_bn(uint8, uint8_t)
instantiate_block_sort_bn(uint16, uint16_t)
instantiate_block_sort_bn(uint32, uint32_t)
instantiate_block_sort_bn(int8, int8_t)
instantiate_block_sort_bn(int16, int16_t)
instantiate_block_sort_bn(int32, int32_t)
instantiate_block_sort_bn(float16, half)
instantiate_block_sort_bn(float32, float)
instantiate_block_sort_bn(bfloat16, bfloat16_t)
#define instantiate_block_sort_long(itname, itype) \
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256)
instantiate_block_sort_long(uint64, uint64_t)
instantiate_block_sort_long(int64, int64_t)
///////////////////////////////////////////////////////////////////////////////
// Multi block merge sort
///////////////////////////////////////////////////////////////////////////////
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
struct KernelMultiBlockMergeSort {
using block_merge_sort_t = BlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
static METAL_FUNC void block_sort(
const device val_t* inp,
device val_t* out_vals,
device idx_t* out_idxs,
const constant int& size_sorted_axis,
const constant int& stride_sorted_axis,
threadgroup val_t* tgp_vals,
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
int base_idx = tid.x * N_PER_BLOCK;
// Copy into threadgroup memory
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
int idx = base_idx + i;
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] : val_t(CompareOp::init);
tgp_idxs[i] = idx;
}
// Sort elements within the block
threadgroup_barrier(mem_flags::mem_threadgroup);
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
for(int i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
int idx = base_idx + i;
if(idx < size_sorted_axis) {
out_vals[idx] = tgp_vals[i];
out_idxs[idx] = tgp_idxs[i];
}
}
}
static METAL_FUNC int merge_partition(
const device val_t* As,
const device val_t* Bs,
int A_sz,
int B_sz,
int sort_md) {
CompareOp op;
int A_st = max(0, sort_md - B_sz);
int A_ed = min(sort_md, A_sz);
while(A_st < A_ed) {
int md = A_st + (A_ed - A_st) / 2;
auto a = As[md];
auto b = Bs[sort_md - 1 - md];
if(op(b, a)) {
A_ed = md;
} else {
A_st = md + 1;
}
}
return A_ed;
}
};
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort(
const device val_t* inp [[buffer(0)]],
device val_t* out_vals [[buffer(1)]],
device idx_t* out_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
inp += block_idx;
out_vals += tid.y * size_sorted_axis;
out_idxs += tid.y * size_sorted_axis;
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out_vals,
out_idxs,
size_sorted_axis,
stride_sorted_axis,
tgp_vals,
tgp_idxs,
tid,
lid);
}
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partiton(
device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals [[buffer(1)]],
const device idx_t* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
block_partitions += tid.y * tgp_dims.x;
dev_vals += tid.y * size_sorted_axis;
dev_idxs += tid.y * size_sorted_axis;
// Find location in merge step
int merge_group = lid.x / merge_tiles;
int merge_lane = lid.x % merge_tiles;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
int A_st = min(size_sorted_axis, sort_st);
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
int B_st = A_ed;
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
int partition = sort_kernel::merge_partition(
dev_vals + A_st,
dev_vals + B_st,
A_ed - A_st,
B_ed - B_st,
partition_at);
block_partitions[lid.x] = A_st + partition;
}
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_merge(
const device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals_in [[buffer(1)]],
const device idx_t* dev_idxs_in [[buffer(2)]],
device val_t* dev_vals_out [[buffer(3)]],
device idx_t* dev_idxs_out [[buffer(4)]],
const constant int& size_sorted_axis [[buffer(5)]],
const constant int& merge_tiles [[buffer(6)]],
const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
using block_sort_t = typename sort_kernel::block_merge_sort_t;
block_partitions += tid.y * (num_tiles + 1);
dev_vals_in += tid.y * size_sorted_axis;
dev_idxs_in += tid.y * size_sorted_axis;
dev_vals_out += tid.y * size_sorted_axis;
dev_idxs_out += tid.y * size_sorted_axis;
int block_idx = tid.x;
int merge_group = block_idx / merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
int A_st = block_partitions[block_idx + 0];
int A_ed = block_partitions[block_idx + 1];
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md - A_st);
int B_ed = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
if((block_idx % merge_tiles) == merge_tiles - 1) {
A_ed = min(size_sorted_axis, sort_st + sort_sz/2);
B_ed = min(size_sorted_axis, sort_st + sort_sz);
}
int A_sz = A_ed - A_st;
int B_sz = B_ed - B_st;
// Load from global memory
thread val_t thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD];
for(int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x;
if(idx < (A_sz + B_sz)) {
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] : dev_vals_in[B_st + idx - A_sz];
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] : dev_idxs_in[B_st + idx - A_sz];
} else {
thread_vals[i] = CompareOp::init;
thread_idxs[i] = 0;
}
}
// Write to shared memory
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
threadgroup_barrier(mem_flags::mem_threadgroup);
for(int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x;
tgp_vals[idx] = thread_vals[i];
tgp_idxs[idx] = thread_idxs[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Merge
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
int A_st_local = block_sort_t::merge_partition(
tgp_vals,
tgp_vals + A_sz,
A_sz,
B_sz,
sort_md_local);
int A_ed_local = A_sz;
int B_st_local = sort_md_local - A_st_local;
int B_ed_local = B_sz;
int A_sz_local = A_ed_local - A_st_local;
int B_sz_local = B_ed_local - B_st_local;
// Do merge
block_sort_t::merge_step(
tgp_vals + A_st_local,
tgp_vals + A_ed_local + B_st_local,
tgp_idxs + A_st_local,
tgp_idxs + A_ed_local + B_st_local,
A_sz_local,
B_sz_local,
thread_vals,
thread_idxs);
threadgroup_barrier(mem_flags::mem_threadgroup);
for(int i = 0; i < N_PER_THREAD; ++i) {
int idx = lid.x * N_PER_THREAD;
tgp_vals[idx + i] = thread_vals[i];
tgp_idxs[idx + i] = thread_idxs[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
for(int i = lid.x; i < sort_kernel::N_PER_BLOCK; i+= BLOCK_THREADS) {
int idx = base_idx + i;
if(idx < size_sorted_axis) {
dev_vals_out[idx] = tgp_vals[i];
dev_idxs_out[idx] = tgp_idxs[i];
}
}
}
#define instantiate_multi_block_sort(vtname, vtype, itname, itype, arg_sort, bn, tn) \
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
const device vtype* inp [[buffer(0)]], \
device vtype* out_vals [[buffer(1)]], \
device itype* out_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& stride_sorted_axis [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("mb_block_partiton_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void mb_block_partiton<vtype, itype, arg_sort, bn, tn>( \
device itype* block_partitions [[buffer(0)]], \
const device vtype* dev_vals [[buffer(1)]], \
const device itype* dev_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& merge_tiles [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_dims [[threads_per_threadgroup]]); \
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
const device itype* block_partitions [[buffer(0)]], \
const device vtype* dev_vals_in [[buffer(1)]], \
const device itype* dev_idxs_in [[buffer(2)]], \
device vtype* dev_vals_out [[buffer(3)]], \
device itype* dev_idxs_out [[buffer(4)]], \
const constant int& size_sorted_axis [[buffer(5)]], \
const constant int& merge_tiles [[buffer(6)]], \
const constant int& num_tiles [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_multi_block_sort_base(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
instantiate_multi_block_sort_base(uint8, uint8_t)
instantiate_multi_block_sort_base(uint16, uint16_t)
instantiate_multi_block_sort_base(uint32, uint32_t)
instantiate_multi_block_sort_base(int8, int8_t)
instantiate_multi_block_sort_base(int16, int16_t)
instantiate_multi_block_sort_base(int32, int32_t)
instantiate_multi_block_sort_base(float16, half)
instantiate_multi_block_sort_base(float32, float)
instantiate_multi_block_sort_base(bfloat16, bfloat16_t)
#define instantiate_multi_block_sort_long(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8)
instantiate_multi_block_sort_long(uint64, uint64_t)
instantiate_multi_block_sort_long(int64, int64_t)

View File

@@ -0,0 +1,244 @@
#pragma once
#include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/complex.h"
///////////////////////////////////////////////////////////////////////////////
// Type limits utils
///////////////////////////////////////////////////////////////////////////////
template <typename U>
struct Limits {
static const constant U max;
static const constant U min;
static const constant U finite_max;
static const constant U finite_min;
};
#define instantiate_default_limit(type) \
template <> \
struct Limits<type> { \
static constexpr constant type max = metal::numeric_limits<type>::max(); \
static constexpr constant type min = metal::numeric_limits<type>::min(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
metal::numeric_limits<type>::min(); \
};
instantiate_default_limit(uint8_t);
instantiate_default_limit(uint16_t);
instantiate_default_limit(uint32_t);
instantiate_default_limit(uint64_t);
instantiate_default_limit(int8_t);
instantiate_default_limit(int16_t);
instantiate_default_limit(int32_t);
instantiate_default_limit(int64_t);
#define instantiate_float_limit(type) \
template <> \
struct Limits<type> { \
static constexpr constant type max = \
metal::numeric_limits<type>::infinity(); \
static constexpr constant type min = \
-metal::numeric_limits<type>::infinity(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
-metal::numeric_limits<type>::max(); \
};
instantiate_float_limit(half);
instantiate_float_limit(float);
instantiate_float_limit(bfloat16_t);
template <>
struct Limits<bool> {
static constexpr constant bool max = true;
static constexpr constant bool min = false;
};
///////////////////////////////////////////////////////////////////////////////
// Indexing utils
///////////////////////////////////////////////////////////////////////////////
inline size_t elem_to_loc(
uint elem,
device const int* shape,
device const size_t* strides,
int ndim) {
size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
inline size_t elem_to_loc(
uint elem,
constant const int* shape,
constant const size_t* strides,
int ndim) {
size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
template <int NDIM>
inline uint2 elem_to_loc_2_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM]) {
uint2 loc = {
static_cast<uint>(
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
static_cast<uint>(
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
elem.z /= shape[d];
}
return loc;
}
template <int NDIM>
inline size_t elem_to_loc_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t strides[NDIM]) {
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
for (int d = NDIM - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d];
}
return loc;
}
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
return elem * stride;
}
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
return elem.x * strides[1] + elem.y * strides[0];
}
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
}
// Non templated version to handle arbitrary dims
inline size_t elem_to_loc(
uint3 elem,
constant const int* shape,
constant const size_t* strides,
int ndim) {
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
for (int d = ndim - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d];
}
return loc;
}
inline uint2 elem_to_loc_2_nd(
uint3 elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
int ndim) {
uint2 loc = {
static_cast<uint>(
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
static_cast<uint>(
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
elem.z /= shape[d];
}
return loc;
}
template <int NDIM>
inline uint elem_to_loc_nd(
uint elem,
device const int* shape,
device const size_t* strides);
template <>
inline uint elem_to_loc_nd<1>(
uint elem,
device const int* shape,
device const size_t* strides) {
return (elem % shape[0]) * strides[0];
}
template <>
inline uint elem_to_loc_nd<2>(
uint elem,
device const int* shape,
device const size_t* strides) {
uint loc = (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
template <>
inline uint elem_to_loc_nd<3>(
uint elem,
device const int* shape,
device const size_t* strides) {
uint loc = (elem % shape[2]) * strides[2];
elem /= shape[2];
loc += (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
template <>
inline uint elem_to_loc_nd<4>(
uint elem,
device const int* shape,
device const size_t* strides) {
uint loc = (elem % shape[3]) * strides[3];
elem /= shape[3];
loc += (elem % shape[2]) * strides[2];
elem /= shape[2];
loc += (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
///////////////////////////////////////////////////////////////////////////////
// Calculation utils
///////////////////////////////////////////////////////////////////////////////
/** Compute ceil((float)N/(float)M) */
inline size_t ceildiv(size_t N, size_t M) {
return (N + M - 1) / M;
}
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
inline float log1p(float x) {
float xp1 = 1.0f + x;
return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f));
}
inline bfloat16_t log1p(bfloat16_t x) {
float xp1 = 1.0f + static_cast<float>(x);
bfloat16_t ret =
(xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
return ret;
}

View File

@@ -0,0 +1,446 @@
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/matmul.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
bool use_mps() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_USE_MPS")) {
return std::string(buff_str) != "OFF";
} else {
return false;
}
};
static bool use_mps_ = get_val();
return use_mps_;
}
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
inline void mps_matmul(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies) {
MPS::DataType mps_dtype = MPS::DataTypeFloat32;
if (out.dtype() == float16) {
mps_dtype = MPS::DataTypeFloat16;
} else if (out.dtype() == bfloat16) {
mps_dtype = MPS::DataTypeBFloat16;
}
// Used batched MPSMatrixMultiplication if batch_size_out > 1
// We only accept the following cases:
// 1. Both a, b have batch_size_out matrices worth of data
// 2. Only one of a or b has batch_size_out matrices worth of data and
// the other has matrix worth of data
// The matrix dimsenisons of a and b are sure to be regularly strided
if (batch_size_out > 1) {
// No broadcasting defaults
auto batch_size_a = a.data_size() / (M * K);
auto batch_size_b = b.data_size() / (K * N);
auto matrix_stride_a = M * K;
auto matrix_stride_b = K * N;
auto matrix_stride_out = M * N;
// At this point, batch_size_a, batch_size_b show the number of matrices
// in data, no broadcasted strides considered
if (batch_size_out == std::max(batch_size_a, batch_size_b)) {
// Handle simple broadcasting
if (std::min(batch_size_a, batch_size_b) == 1) {
matrix_stride_a = (batch_size_a == 1) ? 0 : matrix_stride_a;
matrix_stride_b = (batch_size_b == 1) ? 0 : matrix_stride_b;
batch_size_a = batch_size_out;
batch_size_b = batch_size_out;
}
// Only proceed if broadcasting between a and b is simple
// At this point, batch_size_a, batch_size_b show the number of matrices
// after broadcasting
if (batch_size_a == batch_size_b) {
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
(M * K) / lda,
lda,
batch_size_a,
lda * a.itemsize(),
(matrix_stride_a * a.itemsize()),
mps_dtype);
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
(K * N) / ldb,
ldb,
batch_size_b,
ldb * b.itemsize(),
(matrix_stride_b * b.itemsize()),
mps_dtype);
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
M,
N,
batch_size_out,
N * out.itemsize(),
matrix_stride_out * out.itemsize(),
mps_dtype);
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
auto kernel = MPS::MatrixMultiplication::alloc()->init(
d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0);
auto command_buffer = d.get_command_buffer(s.index);
kernel->setBatchSize(batch_size_out);
kernel->setBatchStart(0);
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
command_buffer->addCompletedHandler(
[a_mat, b_mat, out_mat, kernel, copies](
MTL::CommandBuffer*) mutable {
a_mat->release();
b_mat->release();
out_mat->release();
kernel->release();
copies.clear();
});
return;
}
}
}
// Schedule as many calls to MPSMatrixMultiplication as needed otherwise
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
a.data_size() / lda, lda, lda * a.itemsize(), mps_dtype);
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
b.data_size() / ldb, ldb, ldb * b.itemsize(), mps_dtype);
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
batch_size_out * M, N, N * out.itemsize(), mps_dtype);
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
auto kernel = MPS::MatrixMultiplication::alloc()->init(
d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0);
auto command_buffer = d.get_command_buffer(s.index);
for (int i = 0; i < batch_size_out; ++i) {
auto a_row = elem_to_loc(M * K * i, a.shape(), a.strides()) / lda;
auto b_row = elem_to_loc(K * N * i, b.shape(), b.strides()) / ldb;
kernel->setLeftMatrixOrigin({a_row, 0, 0});
kernel->setRightMatrixOrigin({b_row, 0, 0});
kernel->setResultMatrixOrigin({i * static_cast<size_t>(M), 0, 0});
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
}
command_buffer->addCompletedHandler(
[a_mat, b_mat, out_mat, kernel, copies](MTL::CommandBuffer*) mutable {
a_mat->release();
b_mat->release();
out_mat->release();
kernel->release();
copies.clear();
});
}
} // namespace
void mlx_matmul(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies) {
// Account for batch sizes and basic broadcasting
int batch_size_a = a.data_size() / (M * K);
int batch_size_b = b.data_size() / (K * N);
int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K;
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
int matrix_stride_out = M * N;
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int wm = 2, wn = 2;
if ((size_t)batch_size_out * M * N >= 2ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
bk = (out.dtype() == float32) ? 16 : 32;
} else {
bm = 64;
bn = 64;
}
}
std::ostringstream kname;
kname << "gemm_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n')
<< "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Launch only 1 kernel in the case of simple batching / broadcasting
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
(batch_size_a == batch_size_b ||
std::min(batch_size_a, batch_size_b) == 1)) {
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims =
MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, batch_size_out);
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(&M, sizeof(int), 3);
compute_encoder->setBytes(&N, sizeof(int), 4);
compute_encoder->setBytes(&K, sizeof(int), 5);
compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6);
compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7);
compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
} else { // Other launch kernels with set offsets
for (int i = 0; i < batch_size_out; ++i) {
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, 1);
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0);
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 2);
compute_encoder->setBytes(&M, sizeof(int), 3);
compute_encoder->setBytes(&N, sizeof(int), 4);
compute_encoder->setBytes(&K, sizeof(int), 5);
compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6);
compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7);
compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (!is_floating_point(out.dtype())) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [a_transposed, a_cols, a] = check_transpose(a_pre);
auto [b_transposed, b_cols, b] = check_transpose(b_pre);
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
auto batch_size_out = out.size() / (M * N);
// Route to gemv if needed
if (std::min(M, N) == 1) {
// Collect problem info
bool is_b_matrix = N != 1;
auto& mat = is_b_matrix ? b : a;
auto& vec = is_b_matrix ? a : b;
bool transpose_mat = is_b_matrix ? !b_transposed : a_transposed;
int in_vector_len = K;
int out_vector_len = is_b_matrix ? N : M;
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int batch_size_mat = mat.data_size() / (mat_cols * mat_rows);
int stride_mat = batch_size_mat == batch_size_out ? mat_cols * mat_rows : 0;
int batch_size_vec = vec.data_size() / in_vector_len;
int stride_vec = batch_size_vec == batch_size_out ? in_vector_len : 0;
// Determine dispatch kernel
int tm = 4, tn = 4;
int bm, bn, n_out_per_tgp;
std::ostringstream kname;
if (transpose_mat) {
bm = 8;
bn = 8;
if (out_vector_len >= 24576) {
bn = 128;
} else if (out_vector_len >= 16384) {
bn = 64;
} else if (out_vector_len >= 8192) {
bn = 16;
}
// Specialized kernel for very small outputs
tn = out_vector_len < tn ? 1 : tn;
n_out_per_tgp = bn * tn;
kname << "gemv_t_" << type_to_name(out);
} else {
bm = out_vector_len >= 4096 ? 8 : 4;
bn = 32;
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;
n_out_per_tgp = bm * tm;
kname << "gemv_" << type_to_name(out);
}
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(bn, bm, 1);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
set_array_buffer(compute_encoder, mat, 0);
set_array_buffer(compute_encoder, vec, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 3);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&stride_vec, sizeof(int), 5);
compute_encoder->setBytes(&stride_mat, sizeof(int), 6);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
d.end_encoding(s.index);
if (use_mps()) {
mps_matmul(
s,
d,
a,
b,
out,
M,
N,
K,
batch_size_out,
a_cols,
b_cols,
a_transposed,
b_transposed,
copies);
return;
}
mlx_matmul(
s,
d,
a,
b,
out,
M,
N,
K,
batch_size_out,
a_cols,
b_cols,
a_transposed,
b_transposed,
copies);
}
} // namespace mlx::core

View File

@@ -0,0 +1,29 @@
#include <algorithm>
#include <cassert>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/utils.h"
namespace mlx::core {
void mlx_matmul(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies);
} // namespace mlx::core

View File

@@ -0,0 +1,368 @@
#pragma once
#include <Metal/Metal.hpp>
#define _MPS_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol)
#define _MPS_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor)
namespace MTL::Private::Class {
_MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor);
_MTL_PRIVATE_DEF_CLS(MPSMatrix);
_MTL_PRIVATE_DEF_CLS(MPSVectorDescriptor);
_MTL_PRIVATE_DEF_CLS(MPSVector);
_MTL_PRIVATE_DEF_CLS(MPSKernel);
_MTL_PRIVATE_DEF_CLS(MPSMatrixMultiplication);
_MTL_PRIVATE_DEF_CLS(MPSMatrixVectorMultiplication);
} // namespace MTL::Private::Class
namespace MTL::Private::Selector {
_MTL_PRIVATE_DEF_SEL(
matrixDescriptorWithRows_columns_rowBytes_dataType,
"matrixDescriptorWithRows:columns:rowBytes:dataType:");
_MTL_PRIVATE_DEF_SEL(
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType,
"matrixDescriptorWithRows:columns:matrices:rowBytes:matrixBytes:dataType:");
_MTL_PRIVATE_DEF_SEL(rows, "rows");
_MTL_PRIVATE_DEF_SEL(initWithBuffer_descriptor, "initWithBuffer:descriptor:");
_MTL_PRIVATE_DEF_SEL(
initWithDevice_,
"initWithDevice:transposeLeft:transposeRight:"
"resultRows:resultColumns:interiorColumns:alpha:beta:");
_MTL_PRIVATE_DEF_SEL(
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix,
"encodeToCommandBuffer:leftMatrix:rightMatrix:resultMatrix:");
_MTL_PRIVATE_DEF_SEL(setLeftMatrixOrigin_, "setLeftMatrixOrigin:");
_MTL_PRIVATE_DEF_SEL(setRightMatrixOrigin_, "setRightMatrixOrigin:");
_MTL_PRIVATE_DEF_SEL(setResultMatrixOrigin_, "setResultMatrixOrigin:");
_MTL_PRIVATE_DEF_SEL(setBatchStart_, "setBatchStart:");
_MTL_PRIVATE_DEF_SEL(setBatchSize_, "setBatchSize:");
_MTL_PRIVATE_DEF_SEL(
vectorDescriptorWithLength_dataType,
"vectorDescriptorWithLength:dataType:");
_MTL_PRIVATE_DEF_SEL(
vectorDescriptorWithLength_vectors_vectorBytes_dataType,
"vectorDescriptorWithLength:vectors:vectorBytes:dataType:");
_MTL_PRIVATE_DEF_SEL(
initWithDevice_transpose_rows_columns_alpha_beta,
"initWithDevice:transpose:rows:columns:alpha:beta:");
_MTL_PRIVATE_DEF_SEL(
encodeToCommandBuffer_inputMatrix_inputVector_resultVector,
"encodeToCommandBuffer:inputMatrix:inputVector:resultVector:");
} // namespace MTL::Private::Selector
namespace MPS {
typedef enum DataType : uint32_t {
DataTypeFloatBit = 0x10000000,
DataTypeAlternateEncodingBit = 0x80000000,
DataTypeFloat16 = DataTypeFloatBit | 16,
DataTypeFloat32 = DataTypeFloatBit | 32,
DataTypeBFloat16 = DataTypeAlternateEncodingBit | DataTypeFloat16
} DataType;
class MatrixDescriptor : public NS::Copying<MatrixDescriptor> {
public:
static class MatrixDescriptor* matrixDescriptor(
NS::UInteger rows,
NS::UInteger columns,
NS::UInteger rowBytes,
NS::UInteger dataType);
static class MatrixDescriptor* matrixDescriptor(
NS::UInteger rows,
NS::UInteger columns,
NS::UInteger matrices,
NS::UInteger rowBytes,
NS::UInteger matrixBytes,
NS::UInteger dataType);
NS::UInteger rows() const;
};
class Matrix : public NS::Referencing<Matrix> {
public:
static class Matrix* alloc();
Matrix* init(MTL::Buffer* buffer, MatrixDescriptor* descriptor);
Matrix* init(const MTL::Buffer* buffer, MatrixDescriptor* descriptor);
};
class Kernel : public NS::Referencing<Kernel> {
public:
NS::String* label() const;
MTL::Device* device() const;
};
class MatrixMultiplication
: public NS::Referencing<MatrixMultiplication, Kernel> {
public:
static class MatrixMultiplication* alloc();
MatrixMultiplication* init(
MTL::Device* device,
bool transposeLeft,
bool transposeRight,
NS::UInteger resultRows,
NS::UInteger resultColumns,
NS::UInteger interiorColumns,
double alpha,
double beta);
void encodeToCommandBuffer(
MTL::CommandBuffer* commandBuffer,
Matrix* leftMatrix,
Matrix* rightMatrix,
Matrix* resultMatrix);
void setLeftMatrixOrigin(MTL::Origin origin);
void setRightMatrixOrigin(MTL::Origin origin);
void setResultMatrixOrigin(MTL::Origin origin);
void setBatchStart(NS::UInteger batchStart);
void setBatchSize(NS::UInteger batchSize);
};
class VectorDescriptor : public NS::Copying<VectorDescriptor> {
public:
static class VectorDescriptor* vectorDescriptor(
NS::UInteger length,
NS::UInteger dataType);
static class VectorDescriptor* vectorDescriptor(
NS::UInteger length,
NS::UInteger vectors,
NS::UInteger vectorBytes,
NS::UInteger dataType);
};
class Vector : public NS::Referencing<Vector> {
public:
static class Vector* alloc();
Vector* init(MTL::Buffer* buffer, VectorDescriptor* descriptor);
Vector* init(const MTL::Buffer* buffer, VectorDescriptor* descriptor);
};
class MatrixVectorMultiplication
: public NS::Referencing<MatrixVectorMultiplication, Kernel> {
public:
static class MatrixVectorMultiplication* alloc();
MatrixVectorMultiplication* init(
MTL::Device* device,
bool transpose,
NS::UInteger rows,
NS::UInteger columns,
double alpha,
double beta);
void encodeToCommandBuffer(
MTL::CommandBuffer* commandBuffer,
Matrix* inputMatrix,
Vector* inputVector,
Vector* resultVector);
};
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
NS::UInteger rows,
NS::UInteger columns,
NS::UInteger rowBytes,
NS::UInteger dataType) {
return Object::sendMessage<MatrixDescriptor*>(
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
_MPS_PRIVATE_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType),
rows,
columns,
rowBytes,
dataType);
}
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
NS::UInteger rows,
NS::UInteger columns,
NS::UInteger matrices,
NS::UInteger rowBytes,
NS::UInteger matrixBytes,
NS::UInteger dataType) {
return Object::sendMessage<MatrixDescriptor*>(
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
_MPS_PRIVATE_SEL(
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType),
rows,
columns,
matrices,
rowBytes,
matrixBytes,
dataType);
}
_MTL_INLINE NS::UInteger MatrixDescriptor::rows() const {
return Object::sendMessage<NS::UInteger>(this, _MPS_PRIVATE_SEL(rows));
}
_MTL_INLINE Matrix* Matrix::alloc() {
return NS::Object::alloc<Matrix>(_MPS_PRIVATE_CLS(MPSMatrix));
}
_MTL_INLINE Matrix* Matrix::init(
MTL::Buffer* buffer,
MatrixDescriptor* descriptor) {
return Object::sendMessage<Matrix*>(
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
}
_MTL_INLINE Matrix* Matrix::init(
const MTL::Buffer* buffer,
MatrixDescriptor* descriptor) {
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
}
_MTL_INLINE NS::String* Kernel::label() const {
return Object::sendMessage<NS::String*>(this, _MPS_PRIVATE_SEL(label));
}
_MTL_INLINE MTL::Device* Kernel::device() const {
return Object::sendMessage<MTL::Device*>(this, _MPS_PRIVATE_SEL(device));
}
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::alloc() {
return NS::Object::alloc<MatrixMultiplication>(
_MPS_PRIVATE_CLS(MPSMatrixMultiplication));
}
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::init(
MTL::Device* device,
bool transposeLeft,
bool transposeRight,
NS::UInteger resultRows,
NS::UInteger resultColumns,
NS::UInteger interiorColumns,
double alpha,
double beta) {
return Object::sendMessage<MatrixMultiplication*>(
this,
_MPS_PRIVATE_SEL(initWithDevice_),
device,
transposeLeft,
transposeRight,
resultRows,
resultColumns,
interiorColumns,
alpha,
beta);
}
_MTL_INLINE void MatrixMultiplication::encodeToCommandBuffer(
MTL::CommandBuffer* commandBuffer,
Matrix* leftMatrix,
Matrix* rightMatrix,
Matrix* resultMatrix) {
return Object::sendMessage<void>(
this,
_MPS_PRIVATE_SEL(
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix),
commandBuffer,
leftMatrix,
rightMatrix,
resultMatrix);
}
_MTL_INLINE void MatrixMultiplication::setLeftMatrixOrigin(MTL::Origin origin) {
Object::sendMessage<void>(
this, _MPS_PRIVATE_SEL(setLeftMatrixOrigin_), origin);
}
_MTL_INLINE void MatrixMultiplication::setRightMatrixOrigin(
MTL::Origin origin) {
Object::sendMessage<void>(
this, _MPS_PRIVATE_SEL(setRightMatrixOrigin_), origin);
}
_MTL_INLINE void MatrixMultiplication::setResultMatrixOrigin(
MTL::Origin origin) {
Object::sendMessage<void>(
this, _MPS_PRIVATE_SEL(setResultMatrixOrigin_), origin);
}
_MTL_INLINE void MatrixMultiplication::setBatchStart(NS::UInteger batchStart) {
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchStart_), batchStart);
}
_MTL_INLINE void MatrixMultiplication::setBatchSize(NS::UInteger batchSize) {
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchSize_), batchSize);
}
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
NS::UInteger length,
NS::UInteger dataType) {
return Object::sendMessage<VectorDescriptor*>(
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_dataType),
length,
dataType);
}
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
NS::UInteger length,
NS::UInteger vectors,
NS::UInteger vectorBytes,
NS::UInteger dataType) {
return Object::sendMessage<VectorDescriptor*>(
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_vectors_vectorBytes_dataType),
length,
vectors,
vectorBytes,
dataType);
}
_MTL_INLINE Vector* Vector::alloc() {
return NS::Object::alloc<Vector>(_MPS_PRIVATE_CLS(MPSVector));
}
_MTL_INLINE Vector* Vector::init(
MTL::Buffer* buffer,
VectorDescriptor* descriptor) {
return Object::sendMessage<Vector*>(
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
}
_MTL_INLINE Vector* Vector::init(
const MTL::Buffer* buffer,
VectorDescriptor* descriptor) {
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
}
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::alloc() {
return NS::Object::alloc<MatrixVectorMultiplication>(
_MPS_PRIVATE_CLS(MPSMatrixVectorMultiplication));
}
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::init(
MTL::Device* device,
bool transpose,
NS::UInteger rows,
NS::UInteger columns,
double alpha,
double beta) {
return Object::sendMessage<MatrixVectorMultiplication*>(
this,
_MPS_PRIVATE_SEL(initWithDevice_transpose_rows_columns_alpha_beta),
device,
transpose,
rows,
columns,
alpha,
beta);
}
_MTL_INLINE void MatrixVectorMultiplication::encodeToCommandBuffer(
MTL::CommandBuffer* commandBuffer,
Matrix* inputMatrix,
Vector* inputVector,
Vector* resultVector) {
return Object::sendMessage<void>(
this,
_MPS_PRIVATE_SEL(
encodeToCommandBuffer_inputMatrix_inputVector_resultVector),
commandBuffer,
inputMatrix,
inputVector,
resultVector);
}
} // namespace MPS

View File

@@ -0,0 +1,604 @@
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>
#include "mlx/backend/common/binary.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
void binary_op(
const std::vector<array>& inputs,
array& out,
const std::string op) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
std::ostringstream kname;
switch (bopt) {
case ScalarScalar:
kname << "ss";
break;
case ScalarVector:
kname << "sv";
break;
case VectorScalar:
kname << "vs";
break;
case VectorVector:
kname << "vv";
break;
case General:
kname << "g";
break;
}
kname << op << type_to_name(a);
if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, out, 2);
if (bopt == General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 6);
}
// Launch up to 3D grid of threads
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
int rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads = bopt == General ? out.size() : out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
}
void unary_op(
const std::vector<array>& inputs,
array& out,
const std::string op) {
auto& in = inputs[0];
bool contig = in.flags().contiguous;
if (contig) {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
std::string tname = type_to_name(in);
std::string opt_name = contig ? "v" : "g";
auto kernel = d.get_kernel(opt_name + op + tname);
size_t nthreads = contig ? in.data_size() : in.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
if (!contig) {
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);
compute_encoder->setBytes(
in.strides().data(), in.ndim() * sizeof(size_t), 3);
int ndim = in.ndim();
compute_encoder->setBytes(&ndim, sizeof(int), 4);
}
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
} // namespace
void Abs::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "abs");
}
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "add");
}
template <typename T>
void arange_set_scalars(T start, T next, MTL::ComputeCommandEncoder* enc) {
enc->setBytes(&start, sizeof(T), 0);
T step = next - start;
enc->setBytes(&step, sizeof(T), 1);
}
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto kernel = d.get_kernel("arange" + type_to_name(out));
size_t nthreads = out.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size group_dims = MTL::Size(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
switch (out.dtype()) {
case bool_: // unsupported
throw std::runtime_error("[Arange::eval_gpu] Does not support bool");
case uint8:
arange_set_scalars<uint8_t>(start_, start_ + step_, compute_encoder);
break;
case uint16:
arange_set_scalars<uint16_t>(start_, start_ + step_, compute_encoder);
break;
case uint32:
arange_set_scalars<uint32_t>(start_, start_ + step_, compute_encoder);
break;
case uint64:
arange_set_scalars<uint64_t>(start_, start_ + step_, compute_encoder);
break;
case int8:
arange_set_scalars<int8_t>(start_, start_ + step_, compute_encoder);
break;
case int16:
arange_set_scalars<int16_t>(start_, start_ + step_, compute_encoder);
break;
case int32:
arange_set_scalars<int32_t>(start_, start_ + step_, compute_encoder);
break;
case int64:
arange_set_scalars<int64_t>(start_, start_ + step_, compute_encoder);
break;
case float16:
arange_set_scalars<float16_t>(start_, start_ + step_, compute_encoder);
break;
case float32:
arange_set_scalars<float>(start_, start_ + step_, compute_encoder);
break;
case bfloat16:
throw std::runtime_error("[Arange::eval_gpu] Does not support bfloat16");
case complex64:
throw std::runtime_error("[Arange::eval_gpu] Does not support complex64");
}
set_array_buffer(compute_encoder, out, 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void ArcCos::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arccos");
}
void ArcCosh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arccosh");
}
void ArcSin::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arcsin");
}
void ArcSinh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arcsinh");
}
void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arctan");
}
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arctanh");
}
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
std::string op_name;
switch (reduce_type_) {
case ArgReduce::ArgMin:
op_name = "argmin_";
break;
case ArgReduce::ArgMax:
op_name = "argmax_";
break;
}
// Prepare the shapes, strides and axis arguments.
std::vector<size_t> in_strides = in.strides();
std::vector<int> shape = in.shape();
std::vector<size_t> out_strides = out.strides();
size_t axis_stride = in_strides[axis_];
size_t axis_size = shape[axis_];
if (out_strides.size() == in_strides.size()) {
out_strides.erase(out_strides.begin() + axis_);
}
in_strides.erase(in_strides.begin() + axis_);
shape.erase(shape.begin() + axis_);
size_t ndim = shape.size();
// ArgReduce
int simd_size = 32;
int n_reads = 4;
auto compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name + type_to_name(in));
NS::UInteger thread_group_size = std::min(
(axis_size + n_reads - 1) / n_reads,
kernel->maxTotalThreadsPerThreadgroup());
// round up to the closest number divisible by simd_size
thread_group_size =
(thread_group_size + simd_size - 1) / simd_size * simd_size;
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = out.size() * thread_group_size;
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(in_strides.data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
compute_encoder->setThreadgroupMemoryLength(
simd_size * (sizeof(uint32_t) + in.itemsize()), 0);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
}
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
CopyType ctype =
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
copy_gpu(inputs[0], out, ctype);
}
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
std::vector<int> sizes;
sizes.push_back(0);
for (auto& p : inputs) {
sizes.push_back(p.shape(axis_));
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto strides = out.strides();
auto flags = out.flags();
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Cos::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "cos");
}
void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "cosh");
}
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "div");
}
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
}
void Erf::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "erf");
}
void ErfInv::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "erfinv");
}
void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "exp");
}
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0];
CopyType ctype;
if (in.data_size() == 1) {
ctype = CopyType::Scalar;
} else if (in.flags().contiguous) {
ctype = CopyType::Vector;
} else {
ctype = CopyType::General;
}
copy_gpu(in, out, ctype);
}
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "ge");
}
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "geq");
}
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "le");
}
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "leq");
}
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (base_) {
case Base::e:
unary_op(inputs, out, "log");
break;
case Base::two:
unary_op(inputs, out, "log2");
break;
case Base::ten:
unary_op(inputs, out, "log10");
break;
}
}
void Log1p::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "log1p");
}
void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "lnot");
}
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "lae");
}
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "max");
}
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "min");
}
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "mul");
}
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "neg");
}
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "neq");
}
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];
auto& val = inputs[1];
// Padding value must be a scalar
assert(val.size() == 1);
// Padding value, input and output must be of the same type
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
// Fill output with val
copy_gpu(val, out, CopyType::Scalar, stream());
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes_.size(); i++) {
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
data_offset += out.strides()[ax] * low_pad_size_[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
}
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "pow");
}
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0];
size_t num_keys = keys.size() / 2;
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc_or_wait(out.nbytes()));
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
size_t half_size = out_per_key / 2;
bool odd = out_per_key % 2;
auto& s = stream();
auto& d = metal::device(s.device);
std::string kname = keys.flags().row_contiguous ? "rbitsc" : "rbits";
auto kernel = d.get_kernel(kname);
// organize into grid nkeys x elem_per_key
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
auto nthreads = std::min(num_keys * (half_size + odd), thread_group_size);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, keys, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&odd, sizeof(bool), 2);
compute_encoder->setBytes(&bytes_per_key, sizeof(size_t), 3);
if (!keys.flags().row_contiguous) {
int ndim = keys.ndim();
compute_encoder->setBytes(&ndim, sizeof(int), 4);
compute_encoder->setBytes(
keys.shape().data(), keys.ndim() * sizeof(int), 5);
compute_encoder->setBytes(
keys.strides().data(), keys.ndim() * sizeof(size_t), 6);
}
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (in.flags().row_contiguous) {
auto flags = in.flags();
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.copy_shared_buffer(in, out.strides(), flags, in.data_size());
} else {
copy_gpu(in, out, CopyType::General);
}
}
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sigmoid");
}
void Sign::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sign");
}
void Sin::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sin");
}
void Sinh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sinh");
}
void Square::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "square");
}
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
if (recip_) {
unary_op(inputs, out, "rsqrt");
} else {
unary_op(inputs, out, "sqrt");
}
}
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "sub");
}
void Tan::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "tan");
}
void Tanh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "tanh");
}
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
} // namespace mlx::core

130
mlx/backend/metal/scan.cpp Normal file
View File

@@ -0,0 +1,130 @@
#include <cassert>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
// Ensure contiguity
std::vector<array> copies;
auto in = inputs[0];
if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
in = arr_copy;
}
std::ostringstream kname;
if (in.strides()[axis_] == 1) {
kname << "contiguous_scan_";
if (reverse_) {
kname << "reverse_";
}
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
switch (reduce_type_) {
case Scan::Sum:
kname << "sum_";
break;
case Scan::Prod:
kname << "prod_";
break;
case Scan::Max:
kname << "max_";
break;
case Scan::Min:
kname << "min_";
break;
}
kname << type_to_name(in) << "_" << type_to_name(out);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
size_t size = in.shape(axis_);
compute_encoder->setBytes(&size, sizeof(size_t), 2);
// Compute the thread grid
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
int elements_per_simd = n_reads * 32;
int thread_groups = in.size() / size;
int thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (size < n_reads * 1024) {
thread_group_size = ((size + elements_per_simd - 1) / elements_per_simd) *
elements_per_simd;
} else if (size < n_reads * 2048) {
thread_group_size =
((size / 2 + elements_per_simd - 1) / elements_per_simd) *
elements_per_simd;
}
thread_group_size = std::min(
thread_group_size,
static_cast<int>(kernel->maxTotalThreadsPerThreadgroup()));
MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
kname << "strided_scan_";
if (reverse_) {
kname << "reverse_";
}
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
switch (reduce_type_) {
case Scan::Sum:
kname << "sum_";
break;
case Scan::Prod:
kname << "prod_";
break;
case Scan::Max:
kname << "max_";
break;
case Scan::Min:
kname << "min_";
break;
}
kname << type_to_name(in) << "_" << type_to_name(out);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
size_t size = in.shape(axis_);
size_t stride = in.strides()[axis_];
compute_encoder->setBytes(&size, sizeof(size_t), 2);
compute_encoder->setBytes(&stride, sizeof(size_t), 3);
// Compute the thread grid
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
int tile_x = 32;
int tile_y = 32;
int elements_per_tile_x = tile_x * n_reads;
int grid_y = in.size() / size / stride;
int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x;
MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1);
MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
if (copies.size() > 0) {
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
}
} // namespace mlx::core

167
mlx/backend/metal/utils.h Normal file
View File

@@ -0,0 +1,167 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/metal/device.h"
namespace mlx::core {
namespace {
void set_array_buffer(
MTL::ComputeCommandEncoder* compute_encoder,
MTL::ArgumentEncoder* enc,
const array& a,
int idx) {
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
enc->setBuffer(a_buf, offset, idx);
// MTL::Resource usage through argument buffer needs to be explicity
// flagged to enable hazard tracking
compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
}
void set_array_buffer(
MTL::ComputeCommandEncoder* enc,
const array& a,
int idx) {
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
enc->setBuffer(a_buf, offset, idx);
}
std::string type_to_name(const array& a) {
std::string tname;
switch (a.dtype()) {
case bool_:
tname = "bool_";
break;
case uint8:
tname = "uint8";
break;
case uint16:
tname = "uint16";
break;
case uint32:
tname = "uint32";
break;
case uint64:
tname = "uint64";
break;
case int8:
tname = "int8";
break;
case int16:
tname = "int16";
break;
case int32:
tname = "int32";
break;
case int64:
tname = "int64";
break;
case float16:
tname = "float16";
break;
case float32:
tname = "float32";
break;
case bfloat16:
tname = "bfloat16";
break;
case complex64:
tname = "complex64";
break;
}
return tname;
}
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
int pows[3] = {0, 0, 0};
int sum = 0;
while (true) {
int presum = sum;
// Check all the pows
if (dim0 >= (1 << (pows[0] + 1))) {
pows[0]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim1 >= (1 << (pows[1] + 1))) {
pows[1]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim2 >= (1 << (pows[2] + 1))) {
pows[2]++;
sum++;
}
if (sum == presum || sum == 10) {
break;
}
}
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
}
// Collapse dims that are contiguous to possibly route to a better kernel
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
// should return {{2, 4}, {{1, 2}}}.
//
// When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned.
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
if (xs[0].ndim() > 0) {
to_collapse.push_back(0);
for (int i = 1; i < xs[0].ndim(); i++) {
bool contiguous = true;
for (auto& x : xs) {
if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) {
contiguous = false;
}
if (!contiguous) {
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
to_collapse.push_back(i);
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<size_t>> out_strides(xs.size());
for (int i = 0; i < to_collapse.size(); i++) {
int current_shape = xs[0].shape()[to_collapse[i]];
while (to_collapse[++i] != -1) {
current_shape *= xs[0].shape()[to_collapse[i]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < xs.size(); j++) {
out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]);
}
}
return std::make_tuple(out_shape, out_strides);
}
template <typename... Arrays>
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(Arrays... xs) {
return collapse_contiguous_dims(
std::vector<array>{std::forward<Arrays>(xs)...});
}
} // namespace
} // namespace mlx::core

View File

@@ -0,0 +1,18 @@
#include <stdexcept>
#include "mlx/backend/metal/metal.h"
namespace mlx::core::metal {
void new_stream(Stream) {}
std::function<void()> make_task(
array& arr,
std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p,
bool retain_graph) {
throw std::runtime_error(
"[metal::make_task] Cannot make GPU task without metal backend");
}
} // namespace mlx::core::metal

27
mlx/device.h Normal file
View File

@@ -0,0 +1,27 @@
#pragma once
namespace mlx::core {
struct Device {
enum class DeviceType {
cpu,
gpu,
};
static constexpr DeviceType cpu = DeviceType::cpu;
static constexpr DeviceType gpu = DeviceType::gpu;
Device(DeviceType type, int index = 0) : type(type), index(index){};
DeviceType type;
int index;
};
const Device& default_device();
void set_default_device(const Device& d);
bool operator==(const Device& lhs, const Device& rhs);
bool operator!=(const Device& lhs, const Device& rhs);
} // namespace mlx::core

205
mlx/dtype.cpp Normal file
View File

@@ -0,0 +1,205 @@
#include <cstdint>
#include <sstream>
#include <vector>
#include "mlx/dtype.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
static constexpr int num_types = 13;
static constexpr Dtype::Kind type_kinds[num_types] = {
Dtype::Kind::b, // bool_,
Dtype::Kind::u, // uint8,
Dtype::Kind::u, // uint16,
Dtype::Kind::u, // uint32,
Dtype::Kind::u, // uint64,
Dtype::Kind::i, // int8,
Dtype::Kind::i, // int16,
Dtype::Kind::i, // int32,
Dtype::Kind::i, // int64,
Dtype::Kind::f, // float16,
Dtype::Kind::f, // float32,
Dtype::Kind::V, // bfloat16,
Dtype::Kind::c // complex64,
};
// Following Jax type promotion rules:
// https://jax.readthedocs.io/en/latest/type_promotion.html
// clang-format off
static constexpr Dtype type_rules[num_types][num_types] = {
// bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 bfloat16 complex64
{bool_, uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // bool
{uint8, uint8, uint16, uint32, uint64, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // uint8
{uint16, uint16, uint16, uint32, uint64, int32, int32, int32, int64, float16, float32, bfloat16, complex64}, // uint16
{uint32, uint32, uint32, uint32, uint64, int64, int64, int64, int64, float16, float32, bfloat16, complex64}, // uint32
{uint64, uint64, uint64, uint64, uint64, float32, float32, float32, float32, float16, float32, bfloat16, complex64}, // uint64
{int8, int16, int32, int64, float32, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // int8
{int16, int16, int32, int64, float32, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // int16
{int32, int32, int32, int64, float32, int32, int32, int32, int64, float16, float32, bfloat16, complex64}, // int32
{int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, bfloat16, complex64}, // int64
{float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float32, complex64}, // float16
{float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, complex64}, // float32
{bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, bfloat16, complex64}, // bfloat16
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64
};
// clang-format on
inline bool is_big_endian() {
union ByteOrder {
int32_t i;
uint8_t c[4];
};
ByteOrder b = {0x01234567};
return b.c[0] == 0x01;
}
} // namespace
Dtype promote_types(const Dtype& t1, const Dtype& t2) {
return Dtype(type_rules[static_cast<int>(t1.val)][static_cast<int>(t2.val)]);
}
Dtype::Kind kindof(const Dtype& t) {
return type_kinds[static_cast<int>(t.val)];
}
template <>
TypeToDtype<bool>::operator Dtype() {
return bool_;
}
template <>
TypeToDtype<uint8_t>::operator Dtype() {
return uint8;
}
template <>
TypeToDtype<uint16_t>::operator Dtype() {
return uint16;
}
template <>
TypeToDtype<uint32_t>::operator Dtype() {
return uint32;
}
template <>
TypeToDtype<uint64_t>::operator Dtype() {
return uint64;
}
template <>
TypeToDtype<int8_t>::operator Dtype() {
return int8;
}
template <>
TypeToDtype<int16_t>::operator Dtype() {
return int16;
}
template <>
TypeToDtype<int32_t>::operator Dtype() {
return int32;
}
template <>
TypeToDtype<int64_t>::operator Dtype() {
return int64;
}
template <>
TypeToDtype<float16_t>::operator Dtype() {
return float16;
}
template <>
TypeToDtype<float>::operator Dtype() {
return float32;
}
template <>
TypeToDtype<double>::operator Dtype() {
return float32;
}
template <>
TypeToDtype<bfloat16_t>::operator Dtype() {
return bfloat16;
}
template <>
TypeToDtype<complex64_t>::operator Dtype() {
return complex64;
}
// Array protocol typestring for Dtype
std::string dtype_to_array_protocol(const Dtype& t) {
std::ostringstream r;
if (size_of(t) > 1)
r << (is_big_endian() ? ">" : "<");
else
r << "|";
r << kindof(t) << (int)size_of(t);
return r.str();
}
// Dtype from array protocol type string
Dtype dtype_from_array_protocol(const std::string& t) {
if (t.length() == 2 || t.length() == 3) {
std::string r = t.length() == 3 ? t.substr(1, 2) : t;
if (r == "V2") {
return bfloat16;
}
uint8_t size = r[1] - '0';
switch (r[0]) {
case 'b': {
if (size == 1)
return bool_;
}
case 'i': {
if (size == 1)
return int8;
else if (size == 2)
return int16;
else if (size == 4)
return int32;
else if (size == 8)
return int64;
}
case 'u': {
if (size == 1)
return uint8;
else if (size == 2)
return uint16;
else if (size == 4)
return uint32;
else if (size == 8)
return uint64;
}
case 'f': {
if (size == 2)
return float16;
else if (size == 4)
return float32;
}
case 'c': {
return complex64;
}
}
}
throw std::invalid_argument(
"[from_str] Invalid array protocol type-string: " + t);
}
} // namespace mlx::core

144
mlx/graph_utils.cpp Normal file
View File

@@ -0,0 +1,144 @@
#include <functional>
#include <optional>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
using OptionalArrayRef = std::optional<std::reference_wrapper<const array>>;
struct ArrayNames {
std::unordered_map<std::uintptr_t, std::string> names;
std::string get_name(const array& x) {
auto it = names.find(x.id());
if (it == names.end()) {
// Get the next name in the sequence
// [A, B, ..., Z, AA, AB, ...]
std::vector<char> letters;
auto var_num = names.size() + 1;
while (var_num > 0) {
letters.push_back('A' + (var_num - 1) % 26);
var_num = (var_num - 1) / 26;
}
std::string name(letters.rbegin(), letters.rend());
names.insert({x.id(), name});
return name;
}
return it->second;
}
};
void depth_first_traversal(
std::function<void(OptionalArrayRef, const array&, int)> callback,
const std::vector<array>& outputs) {
std::function<void(OptionalArrayRef, const array&, int)> recurse;
std::unordered_set<std::uintptr_t> cache;
recurse = [&](OptionalArrayRef parent, const array& x, int input_index) {
auto id = x.id();
if (cache.find(id) != cache.end()) {
return;
}
cache.insert(id);
for (int i = 0; i < x.inputs().size(); i++) {
recurse(x, x.inputs()[i], i);
}
callback(parent, x, input_index);
};
for (auto x : outputs) {
recurse(std::nullopt, x, 0);
}
}
void depth_first_traversal(
std::function<void(const array&)> callback,
const std::vector<array>& outputs) {
depth_first_traversal(
[&callback](OptionalArrayRef p, const array& x, int input_index) {
callback(x);
},
outputs);
}
void print_graph(std::ostream& os, const std::vector<array>& outputs) {
std::vector<array> tape;
std::vector<array> inputs;
depth_first_traversal(
[&](const array& x) {
if (x.has_primitive()) {
tape.push_back(x);
} else {
inputs.push_back(x);
}
},
outputs);
ArrayNames namer;
auto print_arr = [&namer, &os](const array& a) {
os << namer.get_name(a);
os << " [" << a.shape() << ", " << a.dtype() << "]";
};
auto print_arrs = [&](const std::vector<array>& arrs) {
for (auto& arr : arrs) {
print_arr(arr);
if (&arr != &arrs.back()) {
os << ", ";
}
}
};
os << "Inputs: ";
print_arrs(inputs);
os << "\nOutputs: ";
print_arrs(outputs);
os << "\n";
for (auto& arr : tape) {
arr.primitive().print(os);
os << " ";
print_arrs(arr.inputs());
os << " -> ";
print_arr(arr);
os << "\n";
}
}
void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
os << "digraph {" << std::endl;
ArrayNames namer;
depth_first_traversal(
[&namer, &os](auto parent, const array& x, int input_index) {
os << "{ ";
if (!x.has_primitive()) {
os << "rank=source; ";
}
if (!parent) {
os << "rank=sink; ";
}
os << namer.get_name(x);
if (x.has_primitive()) {
os << " [label =\"";
x.primitive().print(os);
os << "\"]";
}
os << "; }" << std::endl;
for (auto c : x.inputs()) {
os << namer.get_name(c) << " -> " << namer.get_name(x) << std::endl;
}
},
outputs);
os << "}";
}
} // namespace mlx::core

112
mlx/load.h Normal file
View File

@@ -0,0 +1,112 @@
#pragma once
#include <fstream>
#include <istream>
#include <memory>
namespace mlx::core {
namespace io {
class Reader {
public:
virtual bool is_open() const = 0;
virtual bool good() const = 0;
virtual size_t tell() const = 0;
virtual void seek(
int64_t off,
std::ios_base::seekdir way = std::ios_base::beg) = 0;
virtual void read(char* data, size_t n) = 0;
virtual std::string label() const = 0;
};
class Writer {
public:
virtual bool is_open() const = 0;
virtual bool good() const = 0;
virtual size_t tell() const = 0;
virtual void seek(
int64_t off,
std::ios_base::seekdir way = std::ios_base::beg) = 0;
virtual void write(const char* data, size_t n) = 0;
virtual std::string label() const = 0;
};
class FileReader : public Reader {
public:
explicit FileReader(const std::shared_ptr<std::ifstream>& is)
: is_(is), label_("stream") {}
explicit FileReader(const std::string& file_path)
: is_(std::make_shared<std::ifstream>(file_path, std::ios::binary)),
label_(file_path) {}
bool is_open() const override {
return is_->is_open();
}
bool good() const override {
return is_->good();
}
size_t tell() const override {
return is_->tellg();
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
is_->seekg(off, way);
}
void read(char* data, size_t n) override {
is_->read(data, n);
}
std::string label() const override {
return "file " + label_;
}
private:
std::shared_ptr<std::ifstream> is_;
std::string label_;
};
class FileWriter : public Writer {
public:
explicit FileWriter(const std::shared_ptr<std::ofstream>& is)
: os_(is), label_("stream") {}
explicit FileWriter(const std::string& file_path)
: os_(std::make_shared<std::ofstream>(file_path, std::ios::binary)),
label_(file_path) {}
bool is_open() const override {
return os_->is_open();
}
bool good() const override {
return os_->good();
}
size_t tell() const override {
return os_->tellp();
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
os_->seekp(off, way);
}
void write(const char* data, size_t n) override {
os_->write(data, n);
}
std::string label() const override {
return "file " + label_;
}
private:
std::shared_ptr<std::ofstream> os_;
std::string label_;
};
} // namespace io
} // namespace mlx::core

11
mlx/mlx.h Normal file
View File

@@ -0,0 +1,11 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/device.h"
#include "mlx/fft.h"
#include "mlx/ops.h"
#include "mlx/random.h"
#include "mlx/stream.h"
#include "mlx/transforms.h"
#include "mlx/utils.h"

932
mlx/ops.h Normal file
View File

@@ -0,0 +1,932 @@
#pragma once
#include <variant>
#include "array.h"
#include "device.h"
#include "load.h"
#include "stream.h"
namespace mlx::core {
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
Stream to_stream(StreamOrDevice s);
/** Creation operations */
/**
* A 1D array of numbers starting at `start` (optional),
* stopping at stop, stepping by `step` (optional). **/
array arange(
double start,
double stop,
double step,
Dtype dtype,
StreamOrDevice s = {});
array arange(double start, double stop, double step, StreamOrDevice s = {});
array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {});
array arange(double start, double stop, StreamOrDevice s = {});
array arange(double stop, Dtype dtype, StreamOrDevice s = {});
array arange(double stop, StreamOrDevice s = {});
array arange(int start, int stop, int step, StreamOrDevice s = {});
array arange(int start, int stop, StreamOrDevice s = {});
array arange(int stop, StreamOrDevice s = {});
/** Convert an array to the given data type. */
array astype(const array& a, Dtype dtype, StreamOrDevice s = {});
/** Create a view of an array with the given shape and strides. */
array as_strided(
const array& a,
std::vector<int> shape,
std::vector<size_t> strides,
size_t offset,
StreamOrDevice s = {});
/** Copy another array. */
array copy(const array& a, StreamOrDevice s = {});
/** Fill an array of the given shape with the given value(s). */
array full(
const std::vector<int>& shape,
const array& vals,
Dtype dtype,
StreamOrDevice s = {});
array full(
const std::vector<int>& shape,
const array& vals,
StreamOrDevice s = {});
template <typename T>
array full(
const std::vector<int>& shape,
T val,
Dtype dtype,
StreamOrDevice s = {}) {
return full(shape, array(val, dtype), to_stream(s));
}
template <typename T>
array full(const std::vector<int>& shape, T val, StreamOrDevice s = {}) {
return full(shape, array(val), to_stream(s));
}
/** Fill an array of the given shape with zeros. */
array zeros(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
inline array zeros(const std::vector<int>& shape, StreamOrDevice s = {}) {
return zeros(shape, float32, s);
}
array zeros_like(const array& a, StreamOrDevice s = {});
/** Fill an array of the given shape with ones. */
array ones(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
return ones(shape, float32, s);
}
array ones_like(const array& a, StreamOrDevice s = {});
/** array manipulation */
/** Reshape an array to the given shape. */
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
/** Remove singleton dimensions at the given axes. */
array squeeze(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
/** Remove singleton dimensions at the given axis. */
inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) {
return squeeze(a, std::vector<int>{axis}, s);
}
/** Remove all singleton dimensions. */
array squeeze(const array& a, StreamOrDevice s = {});
/** Add a singleton dimension at the given axes. */
array expand_dims(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
/** Add a singleton dimension at the given axis. */
inline array expand_dims(const array& a, int axis, StreamOrDevice s = {}) {
return expand_dims(a, std::vector<int>{axis}, s);
}
/** Slice an array. */
array slice(
const array& a,
std::vector<int> start,
std::vector<int> stop,
std::vector<int> strides,
StreamOrDevice s = {});
/** Slice an array with a stride of 1 in each dimension. */
array slice(
const array& a,
const std::vector<int>& start,
const std::vector<int>& stop,
StreamOrDevice s = {});
/** Split an array into sub-arrays along a given axis. */
std::vector<array>
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
std::vector<array> split(
const array& a,
const std::vector<int>& indices,
int axis,
StreamOrDevice s = {});
std::vector<array>
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
/** Concatenate arrays along a given axis. */
array concatenate(
const std::vector<array>& arrays,
int axis,
StreamOrDevice s = {});
array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
/** Permutes the dimensions according to the given axes. */
array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
inline array transpose(
const array& a,
std::initializer_list<int> axes,
StreamOrDevice s = {}) {
return transpose(a, std::vector<int>(axes), s);
}
/** Pad an array with a constant value */
array pad(
const array& a,
const std::vector<int>& axes,
const std::vector<int>& low_pad_size,
const std::vector<int>& high_pad_size,
const array& pad_value = array(0),
StreamOrDevice s = {});
/** Pad an array with a constant value along all axes */
array pad(
const array& a,
const std::vector<std::pair<int, int>>& pad_width,
const array& pad_value = array(0),
StreamOrDevice s = {});
array pad(
const array& a,
const std::pair<int, int>& pad_width,
const array& pad_value = array(0),
StreamOrDevice s = {});
array pad(
const array& a,
int pad_width,
const array& pad_value = array(0),
StreamOrDevice s = {});
/** Permutes the dimensions in reverse order. */
array transpose(const array& a, StreamOrDevice s = {});
/** Broadcast an array to a given shape. */
array broadcast_to(
const array& a,
const std::vector<int>& shape,
StreamOrDevice s = {});
/** Broadcast a vector of arrays against one another. */
std::vector<array> broadcast_arrays(
const std::vector<array>& inputs,
StreamOrDevice s = {});
/** Comparison operations */
/** Returns the bool array with (a == b) element-wise. */
array equal(const array& a, const array& b, StreamOrDevice s = {});
inline array operator==(const array& a, const array& b) {
return equal(a, b);
}
template <typename T>
array operator==(T a, const array& b) {
return equal(array(a), b);
}
template <typename T>
array operator==(const array& a, T b) {
return equal(a, array(b));
}
/** Returns the bool array with (a != b) element-wise. */
array not_equal(const array& a, const array& b, StreamOrDevice s = {});
inline array operator!=(const array& a, const array& b) {
return not_equal(a, b);
}
template <typename T>
array operator!=(T a, const array& b) {
return not_equal(array(a), b);
}
template <typename T>
array operator!=(const array& a, T b) {
return not_equal(a, array(b));
}
/** Returns bool array with (a > b) element-wise. */
array greater(const array& a, const array& b, StreamOrDevice s = {});
inline array operator>(const array& a, const array& b) {
return greater(a, b);
}
template <typename T>
array operator>(T a, const array& b) {
return greater(array(a), b);
}
template <typename T>
array operator>(const array& a, T b) {
return greater(a, array(b));
}
/** Returns bool array with (a >= b) element-wise. */
array greater_equal(const array& a, const array& b, StreamOrDevice s = {});
inline array operator>=(const array& a, const array& b) {
return greater_equal(a, b);
}
template <typename T>
array operator>=(T a, const array& b) {
return greater_equal(array(a), b);
}
template <typename T>
array operator>=(const array& a, T b) {
return greater_equal(a, array(b));
}
/** Returns bool array with (a < b) element-wise. */
array less(const array& a, const array& b, StreamOrDevice s = {});
inline array operator<(const array& a, const array& b) {
return less(a, b);
}
template <typename T>
array operator<(T a, const array& b) {
return less(array(a), b);
}
template <typename T>
array operator<(const array& a, T b) {
return less(a, array(b));
}
/** Returns bool array with (a <= b) element-wise. */
array less_equal(const array& a, const array& b, StreamOrDevice s = {});
inline array operator<=(const array& a, const array& b) {
return less_equal(a, b);
}
template <typename T>
array operator<=(T a, const array& b) {
return less_equal(array(a), b);
}
template <typename T>
array operator<=(const array& a, T b) {
return less_equal(a, array(b));
}
/** True if two arrays have the same shape and elements. */
array array_equal(
const array& a,
const array& b,
bool equal_nan,
StreamOrDevice s = {});
inline array
array_equal(const array& a, const array& b, StreamOrDevice s = {}) {
return array_equal(a, b, false, s);
}
/** Select from x or y depending on condition. */
array where(
const array& condition,
const array& x,
const array& y,
StreamOrDevice s = {});
/** Reduction operations */
/** True if all elements in the array are true (or non-zero). **/
array all(const array& a, bool keepdims, StreamOrDevice s = {});
inline array all(const array& a, StreamOrDevice s = {}) {
return all(a, false, to_stream(s));
}
/** True if the two arrays are equal within the specified tolerance. */
array allclose(
const array& a,
const array& b,
double rtol = 1e-5,
double atol = 1e-8,
StreamOrDevice s = {});
/**
* Reduces the input along the given axes. An output value is true
* if all the corresponding inputs are true.
**/
array all(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/**
* Reduces the input along the given axis. An output value is true
* if all the corresponding inputs are true.
**/
array all(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** True if any elements in the array are true (or non-zero). **/
array any(const array& a, bool keepdims, StreamOrDevice s = {});
inline array any(const array& a, StreamOrDevice s = {}) {
return any(a, false, to_stream(s));
}
/**
* Reduces the input along the given axes. An output value is true
* if any of the corresponding inputs are true.
**/
array any(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/**
* Reduces the input along the given axis. An output value is true
* if any of the corresponding inputs are true.
**/
array any(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** Sums the elements of an array. */
array sum(const array& a, bool keepdims, StreamOrDevice s = {});
inline array sum(const array& a, StreamOrDevice s = {}) {
return sum(a, false, to_stream(s));
}
/** Sums the elements of an array along the given axes. */
array sum(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/** Sums the elements of an array along the given axis. */
array sum(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** Computes the mean of the elements of an array. */
array mean(const array& a, bool keepdims, StreamOrDevice s = {});
inline array mean(const array& a, StreamOrDevice s = {}) {
return mean(a, false, to_stream(s));
}
/** Computes the mean of the elements of an array along the given axes */
array mean(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/** Computes the mean of the elements of an array along the given axis */
array mean(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** Computes the mean of the elements of an array. */
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
inline array var(const array& a, StreamOrDevice s = {}) {
return var(a, false, 0, to_stream(s));
}
/** Computes the var of the elements of an array along the given axes */
array var(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {});
/** Computes the var of the elements of an array along the given axis */
array var(
const array& a,
int axis,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {});
/** The product of all elements of the array. */
array prod(const array& a, bool keepdims, StreamOrDevice s = {});
inline array prod(const array& a, StreamOrDevice s = {}) {
return prod(a, false, to_stream(s));
}
/** The product of the elements of an array along the given axes. */
array prod(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/** The product of the elements of an array along the given axis. */
array prod(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** The maximum of all elements of the array. */
array max(const array& a, bool keepdims, StreamOrDevice s = {});
inline array max(const array& a, StreamOrDevice s = {}) {
return max(a, false, to_stream(s));
}
/** The maximum of the elements of an array along the given axes. */
array max(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/** The maximum of the elements of an array along the given axis. */
array max(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** The minimum of all elements of the array. */
array min(const array& a, bool keepdims, StreamOrDevice s = {});
inline array min(const array& a, StreamOrDevice s = {}) {
return min(a, false, to_stream(s));
}
/** The minimum of the elements of an array along the given axes. */
array min(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/** The minimum of the elements of an array along the given axis. */
array min(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** Returns the index of the minimum value in the array. */
array argmin(const array& a, bool keepdims, StreamOrDevice s = {});
inline array argmin(const array& a, StreamOrDevice s = {}) {
return argmin(a, false, s);
}
/** Returns the indices of the minimum values along a given axis. */
array argmin(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** Returns the index of the maximum value in the array. */
array argmax(const array& a, bool keepdims, StreamOrDevice s = {});
inline array argmax(const array& a, StreamOrDevice s = {}) {
return argmax(a, false, s);
}
/** Returns the indices of the maximum values along a given axis. */
array argmax(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** Returns a sorted copy of the flattened array. */
array sort(const array& a, StreamOrDevice s = {});
/** Returns a sorted copy of the array along a given axis. */
array sort(const array& a, int axis, StreamOrDevice s = {});
/** Returns indices that sort the flattened array. */
array argsort(const array& a, StreamOrDevice s = {});
/** Returns indices that sort the array along a given axis. */
array argsort(const array& a, int axis, StreamOrDevice s = {});
/**
* Returns a partitioned copy of the flattened array
* such that the smaller kth elements are first.
**/
array partition(const array& a, int kth, StreamOrDevice s = {});
/**
* Returns a partitioned copy of the array along a given axis
* such that the smaller kth elements are first.
**/
array partition(const array& a, int kth, int axis, StreamOrDevice s = {});
/**
* Returns indices that partition the flattened array
* such that the smaller kth elements are first.
**/
array argpartition(const array& a, int kth, StreamOrDevice s = {});
/**
* Returns indices that partition the array along a given axis
* such that the smaller kth elements are first.
**/
array argpartition(const array& a, int kth, int axis, StreamOrDevice s = {});
/** Returns topk elements of the flattened array. */
array topk(const array& a, int k, StreamOrDevice s = {});
/** Returns topk elements of the array along a given axis. */
array topk(const array& a, int k, int axis, StreamOrDevice s = {});
/** The logsumexp of all elements of the array. */
array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
inline array logsumexp(const array& a, StreamOrDevice s = {}) {
return logsumexp(a, false, to_stream(s));
}
/** The logsumexp of the elements of an array along the given axes. */
array logsumexp(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/** The logsumexp of the elements of an array along the given axis. */
array logsumexp(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** Simple arithmetic operations */
/** Absolute value of elements in an array. */
array abs(const array& a, StreamOrDevice s = {});
/** Negate an array. */
array negative(const array& a, StreamOrDevice s = {});
array operator-(const array& a);
/** The sign of the elements in an array. */
array sign(const array& a, StreamOrDevice s = {});
/** Logical not of an array */
array logical_not(const array& a, StreamOrDevice s = {});
/** The reciprocal (1/x) of the elements in an array. */
array reciprocal(const array& a, StreamOrDevice s = {});
/** Add two arrays. */
array add(const array& a, const array& b, StreamOrDevice s = {});
array operator+(const array& a, const array& b);
template <typename T>
array operator+(T a, const array& b) {
return add(array(a), b);
}
template <typename T>
array operator+(const array& a, T b) {
return add(a, array(b));
}
/** Subtract two arrays. */
array subtract(const array& a, const array& b, StreamOrDevice s = {});
array operator-(const array& a, const array& b);
template <typename T>
array operator-(T a, const array& b) {
return subtract(array(a), b);
}
template <typename T>
array operator-(const array& a, T b) {
return subtract(a, array(b));
}
/** Multiply two arrays. */
array multiply(const array& a, const array& b, StreamOrDevice s = {});
array operator*(const array& a, const array& b);
template <typename T>
array operator*(T a, const array& b) {
return multiply(array(a), b);
}
template <typename T>
array operator*(const array& a, T b) {
return multiply(a, array(b));
}
/** Divide two arrays. */
array divide(const array& a, const array& b, StreamOrDevice s = {});
array operator/(const array& a, const array& b);
array operator/(double a, const array& b);
array operator/(const array& a, double b);
/** Element-wise maximum between two arrays. */
array maximum(const array& a, const array& b, StreamOrDevice s = {});
/** Element-wise minimum between two arrays. */
array minimum(const array& a, const array& b, StreamOrDevice s = {});
/** Square the elements of an array. */
array square(const array& a, StreamOrDevice s = {});
/** Exponential of the elements of an array. */
array exp(const array& a, StreamOrDevice s = {});
/** Sine of the elements of an array */
array sin(const array& a, StreamOrDevice s = {});
/** Cosine of the elements of an array */
array cos(const array& a, StreamOrDevice s = {});
/** Tangent of the elements of an array */
array tan(const array& a, StreamOrDevice s = {});
/** Arc Sine of the elements of an array */
array arcsin(const array& a, StreamOrDevice s = {});
/** Arc Cosine of the elements of an array */
array arccos(const array& a, StreamOrDevice s = {});
/** Arc Tangent of the elements of an array */
array arctan(const array& a, StreamOrDevice s = {});
/** Hyperbolic Sine of the elements of an array */
array sinh(const array& a, StreamOrDevice s = {});
/** Hyperbolic Cosine of the elements of an array */
array cosh(const array& a, StreamOrDevice s = {});
/** Hyperbolic Tangent of the elements of an array */
array tanh(const array& a, StreamOrDevice s = {});
/** Inverse Hyperbolic Sine of the elements of an array */
array arcsinh(const array& a, StreamOrDevice s = {});
/** Inverse Hyperbolic Cosine of the elements of an array */
array arccosh(const array& a, StreamOrDevice s = {});
/** Inverse Hyperbolic Tangent of the elements of an array */
array arctanh(const array& a, StreamOrDevice s = {});
/** Natural logarithm of the elements of an array. */
array log(const array& a, StreamOrDevice s = {});
/** Log base 2 of the elements of an array. */
array log2(const array& a, StreamOrDevice s = {});
/** Log base 10 of the elements of an array. */
array log10(const array& a, StreamOrDevice s = {});
/** Natural logarithm of one plus elements in the array: `log(1 + a)`. */
array log1p(const array& a, StreamOrDevice s = {});
/** Log-add-exp of one elements in the array: `log(exp(a) + exp(b))`. */
array logaddexp(const array& a, const array& b, StreamOrDevice s = {});
/** Element-wise logistic sigmoid of the array: `1 / (1 + exp(-x)`. */
array sigmoid(const array& a, StreamOrDevice s = {});
/** Computes the error function of the elements of an array. */
array erf(const array& a, StreamOrDevice s = {});
/** Computes the inverse error function of the elements of an array. */
array erfinv(const array& a, StreamOrDevice s = {});
/** Stop the flow of gradients. */
array stop_gradient(const array& a, StreamOrDevice s = {});
/** Matrix-matrix multiplication. */
array matmul(const array& a, const array& b, StreamOrDevice s = {});
/** Gather array entries given indices and slices */
array gather(
const array& a,
const std::vector<array>& indices,
const std::vector<int>& axes,
const std::vector<int>& slice_sizes,
StreamOrDevice s = {});
inline array gather(
const array& a,
const array& indices,
int axis,
const std::vector<int>& slice_sizes,
StreamOrDevice s = {}) {
return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
}
/** Take array slices at the given indices of the specified axis. */
array take(
const array& a,
const array& indices,
int axis,
StreamOrDevice s = {});
/** Take array entries at the given indices treating the array as flattened. */
array take(const array& a, const array& indices, StreamOrDevice s = {});
/** Take array entries given indices along the axis */
array take_along_axis(
const array& a,
const array& indices,
int axis,
StreamOrDevice s = {});
/** Scatter updates to given linear indices */
array scatter(
const array& a,
const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes,
StreamOrDevice s = {});
inline array scatter(
const array& a,
const array& indices,
const array& updates,
int axis,
StreamOrDevice s = {}) {
return scatter(a, {indices}, updates, std::vector<int>{axis}, s);
}
/** Scatter and add updates to given indices */
array scatter_add(
const array& a,
const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes,
StreamOrDevice s = {});
inline array scatter_add(
const array& a,
const array& indices,
const array& updates,
int axis,
StreamOrDevice s = {}) {
return scatter_add(a, {indices}, updates, std::vector<int>{axis}, s);
}
/** Scatter and prod updates to given indices */
array scatter_prod(
const array& a,
const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes,
StreamOrDevice s = {});
inline array scatter_prod(
const array& a,
const array& indices,
const array& updates,
int axis,
StreamOrDevice s = {}) {
return scatter_prod(a, {indices}, updates, std::vector<int>{axis}, s);
}
/** Scatter and max updates to given linear indices */
array scatter_max(
const array& a,
const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes,
StreamOrDevice s = {});
inline array scatter_max(
const array& a,
const array& indices,
const array& updates,
int axis,
StreamOrDevice s = {}) {
return scatter_max(a, {indices}, updates, std::vector<int>{axis}, s);
}
/** Scatter and min updates to given linear indices */
array scatter_min(
const array& a,
const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes,
StreamOrDevice s = {});
inline array scatter_min(
const array& a,
const array& indices,
const array& updates,
int axis,
StreamOrDevice s = {}) {
return scatter_min(a, {indices}, updates, std::vector<int>{axis}, s);
}
/** Square root the elements of an array. */
array sqrt(const array& a, StreamOrDevice s = {});
/** Square root and reciprocal the elements of an array. */
array rsqrt(const array& a, StreamOrDevice s = {});
/** Softmax of an array. */
array softmax(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
/** Softmax of an array. */
array softmax(const array& a, StreamOrDevice s = {});
/** Softmax of an array. */
inline array softmax(const array& a, int axis, StreamOrDevice s = {}) {
return softmax(a, std::vector<int>{axis}, s);
}
/** Raise elements of a to the power of b element-wise */
array power(const array& a, const array& b, StreamOrDevice s = {});
inline array operator^(const array& a, const array& b) {
return power(a, b);
}
template <typename T>
array operator^(T a, const array& b) {
return power(array(a), b);
}
template <typename T>
array operator^(const array& a, T b) {
return power(a, array(b));
}
/** Cumulative sum of an array. */
array cumsum(
const array& a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Cumulative product of an array. */
array cumprod(
const array& a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Cumulative max of an array. */
array cummax(
const array& a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Cumulative min of an array. */
array cummin(
const array& a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Convolution operations */
/** 1D convolution with a filter */
array conv1d(
const array& input,
const array& weight,
int stride = 1,
int padding = 0,
int dilation = 1,
int groups = 1,
StreamOrDevice s = {});
/** 2D convolution with a filter */
array conv2d(
const array& input,
const array& weight,
const std::pair<int, int>& stride = {1, 1},
const std::pair<int, int>& padding = {0, 0},
const std::pair<int, int>& dilation = {1, 1},
int groups = 1,
StreamOrDevice s = {});
/** Serialization operations */
/** Save array to out stream in .npy format */
void save(
std::shared_ptr<io::Writer> out_stream,
array a,
bool retain_graph = true);
/** Save array to file in .npy format */
void save(const std::string& file, array a, bool retain_graph = true);
/** Load array from reader in .npy format */
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
/** Load array from file in .npy format */
array load(const std::string& file, StreamOrDevice s = {});
} // namespace mlx::core

1527
mlx/primitives.h Normal file

File diff suppressed because it is too large Load Diff

300
mlx/random.cpp Normal file
View File

@@ -0,0 +1,300 @@
#include <cmath>
#include <sstream>
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/random.h"
#include "mlx/utils.h"
namespace mlx::core::random {
KeySequence::KeySequence(uint64_t seed) : key_(key(seed)) {}
void KeySequence::seed(uint64_t seed) {
key_ = key((seed));
}
array KeySequence::next() {
auto out = split(key_);
key_ = out.first;
return out.second;
}
void seed(uint64_t seed) {
KeySequence::default_().seed(seed);
}
array key(uint64_t seed) {
uint32_t k1 = static_cast<uint32_t>(seed >> 32);
uint32_t k2 = static_cast<uint32_t>(seed);
return array({k1, k2});
}
array bits(
const std::vector<int>& shape,
int width /* 4 */,
const std::optional<array>& key_ /*= nullopt */,
StreamOrDevice s /* = {} */) {
auto key = key_ ? *key_ : KeySequence::default_().next();
if (key.dtype() != uint32) {
std::ostringstream msg;
msg << "Expected key type uint32 but received " << key.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (key.shape() != std::vector<int>{2}) {
std::ostringstream msg;
msg << "Expected key shape (2) but received " << key.shape() << ".";
throw std::invalid_argument(msg.str());
}
auto get_dtype = [width]() {
switch (width) {
case 4:
return uint32;
case 2:
return uint16;
case 1:
return uint8;
default:
std::ostringstream msg;
msg << "[bits] Bit width must be in {1, 2, 4} but got " << width << ".";
throw std::invalid_argument(msg.str());
}
};
return array(
shape,
get_dtype(),
std::make_unique<RandomBits>(to_stream(s), shape, width),
{key});
}
std::pair<array, array> split(const array& key, StreamOrDevice s /* = {} */) {
auto stream = to_stream(s);
auto out = mlx::core::split(random::split(key, 2, stream), 2, stream);
return {reshape(out[0], {2}, stream), reshape(out[1], {2}, stream)};
}
array split(const array& key, int num, StreamOrDevice s /* = {} */) {
return bits({num, 2}, 4, key, s);
}
array uniform(
const array& low,
const array& high,
const std::vector<int>& shape,
Dtype dtype /* = float32 */,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
if (!is_floating_point(dtype)) {
throw std::invalid_argument(
"Can only generate uniform numbers with floating point type.");
}
auto stream = to_stream(s);
auto range = subtract(high, low, stream);
auto out_shape = broadcast_shapes(shape, range.shape());
if (out_shape != shape) {
std::ostringstream msg;
msg << "Cannot generate random values of shape " << shape
<< " from broadcasted shape " << out_shape << ".";
throw std::invalid_argument(msg.str());
}
// Get random values between [0, nextafter(maxval, 0.0f)] since samples must
// be in [low, high)
// TODO replace minimum with modulo uint32_t(nextafter(maxval, 0.0f)) to avoid
// clipping effects
float maxval = std::numeric_limits<uint32_t>::max();
auto upper = array(std::nextafter(maxval, 0.0f), dtype);
auto out = minimum(bits(shape, size_of(dtype), key, stream), upper, stream);
out = divide(out, array(maxval, dtype), stream);
return add(multiply(range, out, stream), low, stream);
}
array uniform(
const std::vector<int>& shape,
Dtype dtype,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
return uniform(
array(0.0, dtype), array(1.0, dtype), shape, dtype, key, to_stream(s));
}
array normal(
const std::vector<int>& shape,
Dtype dtype,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
auto stream = to_stream(s);
auto low = array(std::nextafter(-1.0f, 0.0f), dtype);
auto high = array(1.0f, dtype);
auto samples = uniform(low, high, shape, dtype, key, stream);
return multiply(
array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream);
}
array randint(
const array& low,
const array& high,
const std::vector<int>& shape,
Dtype dtype /* = int32 */,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
if (!is_integral(dtype)) {
throw std::invalid_argument(
"[randint] randint only accepts integer dtypes and bool.");
}
auto u = uniform(low, high, shape, float32, key, s);
return astype(maximum(u, low, s), dtype, s);
}
array bernoulli(
const array& p,
const std::vector<int>& shape,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
if (!is_floating_point(p.dtype())) {
throw std::invalid_argument(
"[bernoulli] bernoulli probability `p` must be a float type.");
}
auto res = uniform(shape, p.dtype(), key, s);
res = less(res, p, s);
if (res.shape() != shape) {
throw std::invalid_argument(
"[bernoulli] shape of `p` is incompatible with argument `shape`.");
}
return res;
}
array bernoulli(
const array& p,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
return bernoulli(p, p.shape(), key, s);
}
array bernoulli(
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
return bernoulli(array(0.5f), key, s);
}
array truncated_normal(
const array& lower,
const array& upper,
const std::vector<int>& shape,
Dtype dtype /* = float32 */,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
// Same as
// https://jax.readthedocs.io/en/latest/_modules/jax/_src/random.html#truncated_normal
if (!is_floating_point(dtype)) {
throw std::invalid_argument(
"[trunc_normal] trunc_normal only accepts floating point dtypes.");
}
auto sqrt2 = array(std::sqrt(2.0), dtype);
auto lower_t = astype(lower, dtype, s);
auto upper_t = astype(upper, dtype, s);
auto a = erf(divide(lower_t, sqrt2, s), s);
auto b = erf(divide(upper_t, sqrt2, s), s);
auto u = uniform(a, b, shape, dtype, key, s);
auto out = multiply(sqrt2, erfinv(u, s), s);
// Clip in bouds
return maximum(minimum(upper_t, out, s), lower_t, s);
}
array truncated_normal(
const array& lower,
const array& upper,
Dtype dtype /* = float32 */,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
auto shape = broadcast_shapes(lower.shape(), upper.shape());
return truncated_normal(lower, upper, shape, dtype, key, s);
}
array gumbel(
const std::vector<int>& shape,
Dtype dtype /* = float32 */,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
// -log(-log(uniform(shape)))
return negative(
log(negative(log(uniform(shape, dtype, key, s), s), s), s), s);
}
int get_valid_axis(int axis, int ndim) {
int ax = axis < 0 ? axis + ndim : axis;
if (ax < 0 || ax >= ndim) {
std::ostringstream msg;
msg << "[categorical] Invalid axis " << axis << " for logits with " << ndim
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
return ax;
}
array categorical_impl(
const array& logits,
int axis,
const std::vector<int>& shape,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s) {
auto gumbel_shape = shape;
auto offset = axis + shape.size() - logits.ndim() + 1;
gumbel_shape.insert(gumbel_shape.begin() + offset, logits.shape(axis));
auto g = gumbel(gumbel_shape, float32, key, s);
return argmax(add(g, logits, s), offset, false, s);
}
array categorical(
const array& logits,
int axis,
const std::vector<int>& shape,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
// Validate and normalize axis
axis = get_valid_axis(axis, logits.ndim());
// Check that shape broadcasts with reduce(logits, axis)
auto reduced_shape = logits.shape();
reduced_shape.erase(reduced_shape.begin() + axis);
if (broadcast_shapes(shape, reduced_shape) != shape) {
std::ostringstream msg;
msg << "[categorical] Requested shape " << shape
<< " is not broadcast compatable with reduced logits shape"
<< reduced_shape << ".";
throw std::invalid_argument(msg.str());
}
return categorical_impl(logits, axis, shape, key, s);
}
array categorical(
const array& logits_,
int axis,
int num_samples,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
axis = get_valid_axis(axis, logits_.ndim());
auto logits = expand_dims(logits_, -1);
auto shape = logits.shape();
shape.erase(shape.begin() + axis);
shape.back() = num_samples;
return categorical_impl(logits, axis, shape, key, s);
}
array categorical(
const array& logits,
int axis /* = -1 */,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
axis = get_valid_axis(axis, logits.ndim());
auto shape = logits.shape();
shape.erase(shape.begin() + axis);
return categorical_impl(logits, axis, shape, key, s);
}
} // namespace mlx::core::random

43
mlx/scheduler.cpp Normal file
View File

@@ -0,0 +1,43 @@
#include "mlx/scheduler.h"
#include "mlx/backend/metal/metal.h"
namespace mlx::core {
Stream default_stream(Device d) {
if (!metal::is_available() && d == Device::gpu) {
throw std::invalid_argument(
"[default_stream] Cannot get gpu stream without gpu backend.");
}
return scheduler::scheduler().get_default_stream(d);
}
void set_default_stream(Stream s) {
if (!metal::is_available() && s.device == Device::gpu) {
throw std::invalid_argument(
"[set_default_stream] Cannot set gpu stream without gpu backend.");
}
return scheduler::scheduler().set_default_stream(s);
}
Stream new_stream(Device d) {
if (!metal::is_available() && d == Device::gpu) {
throw std::invalid_argument(
"[new_stream] Cannot make gpu stream without gpu backend.");
}
return scheduler::scheduler().new_stream(d);
}
Stream new_stream() {
return scheduler::scheduler().new_stream(default_device());
}
namespace scheduler {
/** A singleton scheduler to manage devices, streams, and task execution. */
Scheduler& scheduler() {
static Scheduler scheduler;
return scheduler;
}
} // namespace scheduler
} // namespace mlx::core

170
mlx/scheduler.h Normal file
View File

@@ -0,0 +1,170 @@
#pragma once
#include <atomic>
#include <future>
#include <queue>
#include <thread>
#include <unordered_map>
#include "mlx/backend/metal/metal.h"
#include "mlx/device.h"
#include "mlx/stream.h"
namespace mlx::core::scheduler {
struct StreamThread {
std::mutex mtx;
std::queue<std::function<void()>> q;
std::condition_variable cond;
bool stop;
Stream stream;
std::thread thread;
StreamThread(Stream stream)
: stop(false), stream(stream), thread(&StreamThread::thread_fn, this) {}
~StreamThread() {
{
std::unique_lock<std::mutex> lk(mtx);
stop = true;
}
cond.notify_one();
thread.join();
}
void thread_fn() {
metal::new_stream(stream);
while (true) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lk(mtx);
cond.wait(lk, [this] { return !this->q.empty() || this->stop; });
if (q.empty() && stop) {
return;
}
task = std::move(q.front());
q.pop();
}
task();
}
}
template <typename F>
void enqueue(F&& f) {
{
std::unique_lock<std::mutex> lk(mtx);
if (stop) {
throw std::runtime_error(
"Cannot enqueue work after stream is stopped.");
}
q.emplace(std::forward<F>(f));
}
cond.notify_one();
}
};
class Scheduler {
public:
Scheduler() : n_active_tasks_(0) {
if (metal::is_available()) {
default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
}
default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
}
// Not copyable or moveable
Scheduler(const Scheduler&) = delete;
Scheduler(Scheduler&&) = delete;
Scheduler& operator=(const Scheduler&) = delete;
Scheduler& operator=(Scheduler&&) = delete;
Stream new_stream(const Device& d) {
auto stream = Stream(streams_.size(), d);
streams_.push_back(new StreamThread{stream});
return stream;
}
template <typename F>
void enqueue(const Stream& stream, F&& f);
Stream get_default_stream(const Device& d) {
return default_streams_.at(d.type);
}
void set_default_stream(const Stream& s) {
default_streams_.at(s.device.type) = s;
}
void notify_new_task(const Stream& stream) {
{
std::unique_lock<std::mutex> lk(mtx);
n_active_tasks_++;
}
completion_cv.notify_all();
}
void notify_task_completion(const Stream& stream) {
{
std::unique_lock<std::mutex> lk(mtx);
n_active_tasks_--;
}
completion_cv.notify_all();
}
int n_active_tasks() const {
return n_active_tasks_;
}
void wait_for_one() {
std::unique_lock<std::mutex> lk(mtx);
int n_tasks_old = n_active_tasks();
if (n_tasks_old > 1) {
completion_cv.wait(lk, [this, n_tasks_old] {
return this->n_active_tasks() != n_tasks_old;
});
}
}
~Scheduler() {
for (auto s : streams_) {
delete s;
}
}
private:
int n_active_tasks_;
std::vector<StreamThread*> streams_;
std::unordered_map<Device::DeviceType, Stream> default_streams_;
std::condition_variable completion_cv;
std::mutex mtx;
};
template <typename F>
void Scheduler::enqueue(const Stream& stream, F&& f) {
streams_[stream.index]->enqueue(std::forward<F>(f));
}
Scheduler& scheduler();
template <typename F>
void enqueue(const Stream& stream, F&& f) {
scheduler().enqueue(stream, std::forward<F>(f));
}
inline int n_active_tasks() {
return scheduler().n_active_tasks();
}
inline void notify_new_task(const Stream& stream) {
scheduler().notify_new_task(stream);
}
inline void notify_task_completion(const Stream& stream) {
scheduler().notify_task_completion(stream);
}
inline void wait_for_one() {
scheduler().wait_for_one();
}
} // namespace mlx::core::scheduler

30
mlx/stream.h Normal file
View File

@@ -0,0 +1,30 @@
#pragma once
#include "mlx/device.h"
namespace mlx::core {
struct Stream {
int index;
Device device;
explicit Stream(int index, Device device) : index(index), device(device) {}
};
/** Get the default stream for the given device. */
Stream default_stream(Device d);
/** Make the stream the default for its device. */
void set_default_stream(Stream s);
/** Make a new stream on the given device. */
Stream new_stream(Device d);
inline bool operator==(const Stream& lhs, const Stream& rhs) {
return lhs.index == rhs.index;
}
inline bool operator!=(const Stream& lhs, const Stream& rhs) {
return !(lhs == rhs);
}
} // namespace mlx::core

54
mlx/types/half_types.h Normal file
View File

@@ -0,0 +1,54 @@
#pragma once
#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
#include <arm_fp16.h>
namespace mlx::core {
typedef __fp16 float16_t;
} // namespace mlx::core
#else
#define ADD_HALF_BINOPS
#include "mlx/types/fp16.h"
namespace mlx::core {
typedef struct _MLX_Float16 float16_t;
} // namespace mlx::core
#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
#ifdef __ARM_FEATURE_BF16
#include <arm_bf16.h>
namespace mlx::core {
typedef __bf16 bfloat16_t;
} // namespace mlx::core
#else
#define ADD_HALF_BINOPS
#include "mlx/types/bf16.h"
namespace mlx::core {
typedef struct _MLX_BFloat16 bfloat16_t;
} // namespace mlx::core
#endif // __ARM_FEATURE_BF16
#ifdef ADD_HALF_BINOPS
namespace mlx::core {
// clang-format off
#define fp16_bf16_binop_helper(__op__, __operator__) \
inline float __operator__(float16_t lhs, bfloat16_t rhs) { \
return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
} \
inline float __operator__(bfloat16_t lhs, float16_t rhs) { \
return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
}
fp16_bf16_binop_helper(+, operator+)
fp16_bf16_binop_helper(-, operator-)
fp16_bf16_binop_helper(*, operator*)
fp16_bf16_binop_helper(/, operator/)
// clang-format on
} // namespace mlx::core
#endif

255
mlx/utils.cpp Normal file
View File

@@ -0,0 +1,255 @@
#include <sstream>
#include <vector>
#include "utils.h"
namespace mlx::core {
Dtype result_type(const std::vector<array>& arrays) {
std::vector<Dtype> dtypes(1, bool_);
for (auto& arr : arrays) {
dtypes.push_back(promote_types(dtypes.back(), arr.dtype()));
}
return dtypes.back();
}
std::vector<int> broadcast_shapes(
const std::vector<int>& s1,
const std::vector<int>& s2) {
// Use the same broadcasting rules as numpy
// https://numpy.org/doc/1.20/user/theory.broadcasting.html
// "The size of the trailing axes for both arrays in an operation must
// either be the same size or one of them must be one."
int ndim1 = s1.size();
int ndim2 = s2.size();
int ndim = std::max(ndim1, ndim2);
int diff = std::abs(ndim1 - ndim2);
const auto& big = ndim1 > ndim2 ? s1 : s2;
const auto& small = ndim1 > ndim2 ? s2 : s1;
std::vector<int> out_shape(ndim);
for (int i = ndim - 1; i >= diff; --i) {
int a = big[i];
int b = small[i - diff];
if (b == a) {
out_shape[i] = a;
} else if (a == 1 || b == 1) {
// 0 if a or b is 0 otherwise max(a, b)
out_shape[i] = a * b;
} else {
std::ostringstream msg;
msg << "Shapes " << s1 << " and " << s2 << " cannot be broadcast.";
throw std::invalid_argument(msg.str());
}
}
for (int i = diff - 1; i >= 0; --i) {
out_shape[i] = big[i];
}
return out_shape;
}
std::ostream& operator<<(std::ostream& os, const Device& d) {
os << "Device(";
switch (d.type) {
case Device::cpu:
os << "cpu";
break;
case Device::gpu:
os << "gpu";
break;
}
os << ", " << d.index << ")";
return os;
}
std::ostream& operator<<(std::ostream& os, const Stream& s) {
os << "Stream(";
os << s.device;
os << ", " << s.index << ")";
return os;
}
std::ostream& operator<<(std::ostream& os, int8_t x) {
os << static_cast<int>(x);
return os;
}
std::ostream& operator<<(std::ostream& os, uint8_t x) {
os << static_cast<uint>(x);
return os;
}
namespace {
inline size_t elem_to_loc(
int elem,
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
size_t loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]);
loc += q_and_r.rem * strides[i];
elem = q_and_r.quot;
}
return loc;
}
template <typename T>
void print_subarray(std::ostream& os, const array& a, size_t index, int dim) {
int num_print = 3;
int n = a.shape(dim);
size_t s = a.strides()[dim];
bool is_last = dim == a.ndim() - 1;
auto prefix = is_last ? "" : std::string(7 + dim, ' ');
auto postfix = is_last ? ", " : ",\n";
os << "[";
for (int i = 0; i < n; ++i) {
os << (i == 0 ? "" : prefix);
if (i == num_print && n > 2 * num_print) {
os << "...";
i = n - num_print - 1;
index += s * (n - 2 * num_print - 1);
} else if (is_last) {
os << a.data<T>()[index];
} else {
print_subarray<T>(os, a, index, dim + 1);
}
os << (i == n - 1 ? "" : postfix);
index += s;
}
os << "]";
}
template <typename T>
void print_array(std::ostream& os, const array& a) {
std::vector<int> indices(a.ndim(), 0);
os << std::boolalpha;
os << "array(";
if (a.ndim() == 0) {
auto data = a.data<T>();
os << data[0];
} else {
print_subarray<T>(os, a, 0, 0);
}
os << ", dtype=" << a.dtype() << ")";
os << std::noboolalpha;
}
} // namespace
std::ostream& operator<<(std::ostream& os, const Dtype& dtype) {
switch (dtype) {
case bool_:
return os << "bool";
case uint8:
return os << "uint8";
case uint16:
return os << "uint16";
case uint32:
return os << "uint32";
case uint64:
return os << "uint64";
case int8:
return os << "int8";
case int16:
return os << "int16";
case int32:
return os << "int32";
case int64:
return os << "int64";
case float16:
return os << "float16";
case float32:
return os << "float32";
case bfloat16:
return os << "bfloat16";
case complex64:
return os << "complex64";
}
return os;
}
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
switch (k) {
case Dtype::Kind::b:
return os << "b";
case Dtype::Kind::i:
return os << "i";
case Dtype::Kind::u:
return os << "u";
case Dtype::Kind::f:
return os << "f";
case Dtype::Kind::c:
return os << "c";
case Dtype::Kind::V:
return os << "V";
}
return os;
}
std::ostream& operator<<(std::ostream& os, array a) {
if (!a.is_evaled()) {
a.eval();
}
switch (a.dtype()) {
case bool_:
print_array<bool>(os, a);
break;
case uint8:
print_array<uint8_t>(os, a);
break;
case uint16:
print_array<uint16_t>(os, a);
break;
case uint32:
print_array<uint32_t>(os, a);
break;
case uint64:
print_array<uint64_t>(os, a);
break;
case int8:
print_array<int8_t>(os, a);
break;
case int16:
print_array<int16_t>(os, a);
break;
case int32:
print_array<int32_t>(os, a);
break;
case int64:
print_array<int64_t>(os, a);
break;
case float16:
print_array<float16_t>(os, a);
break;
case bfloat16:
print_array<bfloat16_t>(os, a);
break;
case float32:
print_array<float>(os, a);
break;
case complex64:
print_array<complex64_t>(os, a);
break;
}
return os;
}
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
} // namespace mlx::core

33
mlx/utils.h Normal file
View File

@@ -0,0 +1,33 @@
#pragma once
#include "array.h"
#include "device.h"
#include "dtype.h"
#include "stream.h"
namespace mlx::core {
/** The type from promoting the arrays' types with one another. */
Dtype result_type(const std::vector<array>& arrays);
std::vector<int> broadcast_shapes(
const std::vector<int>& s1,
const std::vector<int>& s2);
std::ostream& operator<<(std::ostream& os, const Device& d);
std::ostream& operator<<(std::ostream& os, const Stream& s);
std::ostream& operator<<(std::ostream& os, const Dtype& d);
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
std::ostream& operator<<(std::ostream& os, array a);
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v);
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
return os << v.real() << (v.imag() > 0 ? "+" : "") << v.imag() << "j";
}
inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
return os << static_cast<float>(v);
}
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
return os << static_cast<float>(v);
}
} // namespace mlx::core