mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
parent
35b412c099
commit
40c62c1321
@ -420,8 +420,8 @@ element in the output.
|
||||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
constant const int* shape [[buffer(5)]],
|
||||
constant const size_t* x_strides [[buffer(6)]],
|
||||
constant const size_t* y_strides [[buffer(7)]],
|
||||
constant const int64_t* x_strides [[buffer(6)]],
|
||||
constant const int64_t* y_strides [[buffer(7)]],
|
||||
constant const int& ndim [[buffer(8)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
// Convert linear indices to offsets in array
|
||||
@ -438,24 +438,10 @@ each instantiation a unique host name so we can identify it.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] \
|
||||
[[kernel]] void axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
instantiate_kernel("axpby_general_float32", axpby_general, float)
|
||||
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
|
||||
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
|
||||
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)
|
||||
|
||||
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
|
||||
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||
|
@ -12,8 +12,8 @@ template <typename T>
|
||||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
constant const int* shape [[buffer(5)]],
|
||||
constant const size_t* x_strides [[buffer(6)]],
|
||||
constant const size_t* y_strides [[buffer(7)]],
|
||||
constant const int64_t* x_strides [[buffer(6)]],
|
||||
constant const int64_t* y_strides [[buffer(7)]],
|
||||
constant const int& ndim [[buffer(8)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||
@ -34,29 +34,14 @@ template <typename T>
|
||||
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
|
||||
}
|
||||
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
|
||||
axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
|
||||
axpby_contiguous<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
// clang-format off
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
instantiate_kernel("axpby_general_" #type_name, axpby_general, type) \
|
||||
instantiate_kernel( \
|
||||
"axpby_contiguous_" #type_name, axpby_contiguous, type)
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
// clang-format on
|
||||
|
@ -18,7 +18,7 @@ class Primitive;
|
||||
|
||||
using Deleter = std::function<void(allocator::Buffer)>;
|
||||
using Shape = std::vector<int32_t>;
|
||||
using Strides = std::vector<size_t>;
|
||||
using Strides = std::vector<int64_t>;
|
||||
|
||||
class array {
|
||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||
|
@ -13,8 +13,8 @@ template <typename InT, typename OpT>
|
||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
auto axis_size = in.shape()[axis];
|
||||
auto axis_stride = in.strides()[axis];
|
||||
std::vector<size_t> strides = in.strides();
|
||||
std::vector<int> shape = in.shape();
|
||||
Strides strides = in.strides();
|
||||
Shape shape = in.shape();
|
||||
strides.erase(strides.begin() + axis);
|
||||
shape.erase(shape.begin() + axis);
|
||||
for (uint32_t i = 0; i < out.size(); ++i) {
|
||||
|
@ -178,10 +178,10 @@ void binary_op_dims(
|
||||
const T* b,
|
||||
U* out,
|
||||
Op op,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
@ -212,10 +212,10 @@ void binary_op_dispatch_dims(
|
||||
array& out,
|
||||
Op op,
|
||||
int dim,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides) {
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& out_strides) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
@ -258,10 +258,10 @@ void binary_op_dispatch_dims(
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3);
|
||||
ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3);
|
||||
size_t stride = out_strides[dim - 4];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
ContiguousIterator a_it(shape, a_strides, dim - 3);
|
||||
ContiguousIterator b_it(shape, b_strides, dim - 3);
|
||||
auto stride = out_strides[dim - 4];
|
||||
for (int64_t elem = 0; elem < a.size(); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
@ -327,7 +327,7 @@ void binary_op(
|
||||
const auto& strides = new_strides[2];
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto leftmost_rc_dim = [&strides](const std::vector<size_t>& arr_strides) {
|
||||
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
|
||||
}
|
||||
@ -337,7 +337,7 @@ void binary_op(
|
||||
auto b_rc_dim = leftmost_rc_dim(b_strides);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const std::vector<size_t>& arr_strides) {
|
||||
auto leftmost_s_dim = [](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == 0; d--) {
|
||||
}
|
||||
|
@ -16,10 +16,10 @@ void binary_op_dims(
|
||||
U* out_a,
|
||||
U* out_b,
|
||||
Op op,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
@ -96,9 +96,9 @@ void binary_op_dispatch_dims(
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
|
||||
size_t stride = out_strides[ndim - 3];
|
||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
|
@ -49,7 +49,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
std::vector<size_t> strides(out.ndim(), 0);
|
||||
Strides strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
@ -141,7 +141,7 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
std::pair<bool, Strides> Reshape::prepare_reshape(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
// Special case for empty arrays or row contiguous arrays
|
||||
@ -151,8 +151,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
|
||||
// Special case for scalars
|
||||
if (in.ndim() == 0) {
|
||||
std::vector<size_t> out_strides(out.ndim(), 0);
|
||||
return {false, out_strides};
|
||||
return {false, Strides(out.ndim(), 0)};
|
||||
}
|
||||
|
||||
// Firstly let's collapse all the contiguous dimensions of the input
|
||||
@ -160,7 +159,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
|
||||
// If shapes fit exactly in the contiguous dims then no copy is necessary so
|
||||
// let's check.
|
||||
std::vector<size_t> out_strides;
|
||||
Strides out_strides;
|
||||
bool copy_necessary = false;
|
||||
int j = 0;
|
||||
for (int i = 0; i < out.ndim(); i++) {
|
||||
@ -183,7 +182,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
|
||||
void Reshape::shared_buffer_reshape(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
const Strides& out_strides,
|
||||
array& out) {
|
||||
auto flags = in.flags();
|
||||
if (flags.row_contiguous) {
|
||||
@ -249,18 +248,6 @@ void Split::eval(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
}
|
||||
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
move_or_copy(inputs[0], out);
|
||||
@ -268,7 +255,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
std::vector<size_t> out_strides(out.ndim());
|
||||
Strides out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
@ -285,8 +272,8 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||
|
@ -165,7 +165,7 @@ void compiled_allocate_outputs(
|
||||
bool move_buffers /* = false */) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
std::vector<size_t> strides;
|
||||
Strides strides;
|
||||
size_t data_size;
|
||||
array::Flags flags;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
|
@ -746,9 +746,9 @@ void explicit_gemm_conv_1D_cpu(
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape = {N, oH, wH, C};
|
||||
Shape strided_shape = {N, oH, wH, C};
|
||||
|
||||
std::vector<size_t> strided_strides = {
|
||||
Strides strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * wt_strides[0],
|
||||
in_padded.strides()[1],
|
||||
@ -865,9 +865,9 @@ void explicit_gemm_conv_2D_cpu(
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape = {N, oH, oW, wH, wW, C};
|
||||
Shape strided_shape = {N, oH, oW, wH, wW, C};
|
||||
|
||||
std::vector<size_t> strided_strides = {
|
||||
Strides strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * wt_strides[0],
|
||||
in_padded.strides()[2] * wt_strides[1],
|
||||
@ -974,7 +974,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape(oDim.size() + wDim.size() + 2);
|
||||
Shape strided_shape(oDim.size() + wDim.size() + 2);
|
||||
strided_shape.front() = N;
|
||||
for (size_t i = 0; i < oDim.size(); i++) {
|
||||
strided_shape[i + 1] = oDim[i];
|
||||
@ -984,7 +984,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
}
|
||||
strided_shape.back() = C;
|
||||
|
||||
std::vector<size_t> strided_strides(in.shape().size() * 2 - 2);
|
||||
Strides strided_strides(in.shape().size() * 2 - 2);
|
||||
strided_strides[0] = in_padded.strides()[0];
|
||||
for (size_t i = 0; i < wt_strides.size(); i++) {
|
||||
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
|
||||
@ -1000,7 +1000,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {N, C};
|
||||
Shape strided_reshape = {N, C};
|
||||
for (const auto& o : oDim) {
|
||||
strided_reshape[0] *= o;
|
||||
}
|
||||
|
@ -26,13 +26,13 @@ void copy_vector(const array& src, array& dst) {
|
||||
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename StrideT, int D>
|
||||
template <typename SrcT, typename DstT, int D>
|
||||
inline void copy_dims(
|
||||
const SrcT* src,
|
||||
DstT* dst,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& i_strides,
|
||||
const std::vector<StrideT>& o_strides,
|
||||
const Shape& shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
int axis) {
|
||||
auto stride_src = i_strides[axis];
|
||||
auto stride_dst = o_strides[axis];
|
||||
@ -40,7 +40,7 @@ inline void copy_dims(
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
if constexpr (D > 1) {
|
||||
copy_dims<SrcT, DstT, StrideT, D - 1>(
|
||||
copy_dims<SrcT, DstT, D - 1>(
|
||||
src, dst, shape, i_strides, o_strides, axis + 1);
|
||||
} else {
|
||||
*dst = static_cast<DstT>(*src);
|
||||
@ -50,13 +50,13 @@ inline void copy_dims(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename StrideT>
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<StrideT>& i_strides,
|
||||
const std::vector<StrideT>& o_strides,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
if (data_shape.empty()) {
|
||||
@ -65,30 +65,30 @@ void copy_general_general(
|
||||
*dst_ptr = val;
|
||||
return;
|
||||
}
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
data_shape, std::vector<std::vector<StrideT>>{i_strides, o_strides});
|
||||
auto [shape, strides] =
|
||||
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
|
||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||
auto dst_ptr = dst.data<DstT>() + o_offset;
|
||||
int ndim = shape.size();
|
||||
if (ndim == 1) {
|
||||
copy_dims<SrcT, DstT, StrideT, 1>(
|
||||
copy_dims<SrcT, DstT, 1>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
} else if (ndim == 2) {
|
||||
copy_dims<SrcT, DstT, StrideT, 2>(
|
||||
copy_dims<SrcT, DstT, 2>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
} else if (ndim == 3) {
|
||||
copy_dims<SrcT, DstT, StrideT, 3>(
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
}
|
||||
ContiguousIterator<StrideT> in(shape, strides[0], ndim - 3);
|
||||
ContiguousIterator<StrideT> out(shape, strides[1], ndim - 3);
|
||||
StrideT stride = std::accumulate(
|
||||
shape.end() - 3, shape.end(), 1, std::multiplies<StrideT>());
|
||||
for (StrideT elem = 0; elem < src.size(); elem += stride) {
|
||||
copy_dims<SrcT, DstT, StrideT, 3>(
|
||||
ContiguousIterator in(shape, strides[0], ndim - 3);
|
||||
ContiguousIterator out(shape, strides[1], ndim - 3);
|
||||
auto stride = std::accumulate(
|
||||
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
|
||||
for (int64_t elem = 0; elem < src.size(); elem += stride) {
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
src_ptr + in.loc,
|
||||
dst_ptr + out.loc,
|
||||
shape,
|
||||
@ -102,37 +102,37 @@ void copy_general_general(
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general_general(const array& src, array& dst) {
|
||||
copy_general_general<SrcT, DstT, size_t>(
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename StrideT>
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<StrideT>& i_strides,
|
||||
const std::vector<StrideT>&,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides&,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
copy_general_general<SrcT, DstT, StrideT>(
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src,
|
||||
dst,
|
||||
data_shape,
|
||||
i_strides,
|
||||
make_contiguous_strides<StrideT>(data_shape),
|
||||
make_contiguous_strides(data_shape),
|
||||
i_offset,
|
||||
o_offset);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general(const array& src, array& dst) {
|
||||
copy_general_general<SrcT, DstT, size_t>(
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src,
|
||||
dst,
|
||||
src.shape(),
|
||||
src.strides(),
|
||||
make_contiguous_strides<size_t>(src.shape()),
|
||||
make_contiguous_strides(src.shape()),
|
||||
0,
|
||||
0);
|
||||
}
|
||||
@ -282,13 +282,12 @@ void copy(const array& src, array& dst, CopyType ctype) {
|
||||
copy_inplace(src, dst, ctype);
|
||||
}
|
||||
|
||||
template <typename StrideT>
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<StrideT>& i_strides,
|
||||
const std::vector<StrideT>& o_strides,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype) {
|
||||
@ -311,24 +310,4 @@ void copy_inplace(
|
||||
}
|
||||
}
|
||||
|
||||
template void copy_inplace<size_t>(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<size_t>& i_strides,
|
||||
const std::vector<size_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
|
||||
template void copy_inplace<int64_t>(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<int64_t>& i_strides,
|
||||
const std::vector<int64_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -26,13 +26,12 @@ enum class CopyType {
|
||||
void copy(const array& src, array& dst, CopyType ctype);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype);
|
||||
|
||||
template <typename stride_t>
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
|
@ -130,7 +130,7 @@ inline void matmul_common_general(
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
@ -32,7 +32,7 @@ void gather(
|
||||
const std::vector<array>& inds,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& slice_sizes) {
|
||||
const Shape& slice_sizes) {
|
||||
// If the array is row contiguous then we can do a contiguous copy given
|
||||
// two conditions on the slice size:
|
||||
// - Any number of leading ones in the slice sizes are allowed
|
||||
@ -80,11 +80,10 @@ void gather(
|
||||
T* dst_ptr = out.data<T>();
|
||||
size_t out_idx = 0;
|
||||
|
||||
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
|
||||
ContiguousIterator<size_t> src_it;
|
||||
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
|
||||
ContiguousIterator src_it;
|
||||
if (!can_copy && src.ndim() > 0) {
|
||||
src_it = std::move(
|
||||
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
|
||||
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
|
||||
}
|
||||
for (int idx = 0; idx < ind_size; idx++) {
|
||||
size_t src_idx = 0;
|
||||
@ -119,7 +118,7 @@ void dispatch_gather(
|
||||
const std::vector<array>& inds,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& size) {
|
||||
const Shape& size) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
gather<bool, IdxT>(src, inds, out, axes, size);
|
||||
@ -223,16 +222,16 @@ void scatter(
|
||||
auto inds_ndim = updates.ndim() - out.ndim();
|
||||
size_t n_updates = nind ? inds[0].size() : 1;
|
||||
|
||||
std::vector<int> update_shape(
|
||||
Shape update_shape(
|
||||
updates.shape().begin() + inds_ndim, updates.shape().end());
|
||||
size_t update_size = 1;
|
||||
for (auto us : update_shape) {
|
||||
update_size *= us;
|
||||
}
|
||||
|
||||
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
|
||||
ContiguousIterator<size_t> update_it(updates);
|
||||
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
|
||||
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
|
||||
ContiguousIterator update_it(updates);
|
||||
ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
|
||||
|
||||
for (int i = 0; i < n_updates; ++i) {
|
||||
size_t out_offset = 0;
|
||||
|
@ -19,10 +19,10 @@ inline void mask_matrix(
|
||||
int block_size,
|
||||
const int X,
|
||||
const int Y,
|
||||
const size_t X_data_str,
|
||||
const size_t Y_data_str,
|
||||
const size_t X_mask_str,
|
||||
const size_t Y_mask_str,
|
||||
const int64_t X_data_str,
|
||||
const int64_t Y_data_str,
|
||||
const int64_t X_mask_str,
|
||||
const int64_t Y_mask_str,
|
||||
const size_t mask_offset) {
|
||||
int tX = (X + block_size - 1) / block_size;
|
||||
int tY = (Y + block_size - 1) / block_size;
|
||||
@ -84,7 +84,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
int64_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
@ -117,13 +117,13 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
int Y,
|
||||
size_t X_data_str,
|
||||
size_t Y_data_str) {
|
||||
size_t mask_offset = elem_to_loc(
|
||||
auto mask_offset = elem_to_loc(
|
||||
mask.shape(-1) * mask.shape(-2) * batch_idx,
|
||||
mask.shape(),
|
||||
mask.strides());
|
||||
|
||||
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
|
||||
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
|
||||
auto X_mask_str = mask.strides()[mask.ndim() - 2];
|
||||
auto Y_mask_str = mask.strides()[mask.ndim() - 1];
|
||||
|
||||
if (mask.dtype() == bool_) {
|
||||
return mask_matrix(
|
||||
@ -230,7 +230,7 @@ void GatherMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
int64_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
@ -262,13 +262,13 @@ void GatherMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto& lhs_indices = inputs[2];
|
||||
auto& rhs_indices = inputs[3];
|
||||
|
||||
std::vector<int> batch_shape = get_batch_dims(out.shape());
|
||||
auto batch_shape = get_batch_dims(out.shape());
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
|
||||
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
|
||||
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
|
||||
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
|
||||
auto batch_shape_A = get_batch_dims(a.shape());
|
||||
auto batch_strides_A = get_batch_dims(a.strides());
|
||||
auto batch_shape_B = get_batch_dims(b.shape());
|
||||
auto batch_strides_B = get_batch_dims(b.strides());
|
||||
|
||||
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
|
||||
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
|
||||
|
@ -498,14 +498,15 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [copy_needed, data_offset, inp_strides] =
|
||||
prepare_slice(in, start_indices_, strides_);
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
auto copy_needed = std::any_of(
|
||||
strides_.begin(), strides_.end(), [](auto i) { return i < 0; });
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
copy_inplace<int64_t>(
|
||||
Strides ostrides{out.strides().begin(), out.strides().end()};
|
||||
copy_inplace(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ out.shape(),
|
||||
@ -523,7 +524,7 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
Strides ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
|
||||
}
|
||||
}
|
||||
@ -550,11 +551,11 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] = prepare_slice(out);
|
||||
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
|
||||
copy_inplace<int64_t>(
|
||||
Strides upd_strides{upd.strides().begin(), upd.strides().end()};
|
||||
copy_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||
|
@ -54,7 +54,7 @@ void qrf_impl(const array& a, array& q, array& r) {
|
||||
// Copy the input to be column contiguous
|
||||
flags.col_contiguous = num_matrices == 1;
|
||||
flags.row_contiguous = false;
|
||||
std::vector<size_t> strides = in.strides();
|
||||
auto strides = in.strides();
|
||||
strides[in.ndim() - 2] = 1;
|
||||
strides[in.ndim() - 1] = M;
|
||||
in.set_data(
|
||||
|
@ -174,19 +174,19 @@ void reduce_dispatch_min_max(
|
||||
|
||||
void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
const Shape& shape,
|
||||
const Strides& strides) {
|
||||
std::function<void(int, int)> loop_inner;
|
||||
loop_inner = [&](int dim, int offset) {
|
||||
if (dim < shape.size() - 1) {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
auto size = shape[dim];
|
||||
auto stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
loop_inner(dim + 1, offset + i * stride);
|
||||
}
|
||||
} else {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
auto size = shape[dim];
|
||||
auto stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
callback(offset + i * stride);
|
||||
}
|
||||
|
@ -38,13 +38,10 @@ enum ReductionOpType {
|
||||
|
||||
struct ReductionPlan {
|
||||
ReductionOpType type;
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
|
||||
ReductionPlan(
|
||||
ReductionOpType type_,
|
||||
std::vector<int> shape_,
|
||||
std::vector<size_t> strides_)
|
||||
ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_)
|
||||
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
@ -55,10 +52,10 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
||||
// Should this be in utils?
|
||||
void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides);
|
||||
const Shape& shape,
|
||||
const Strides& strides);
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes);
|
||||
|
||||
@ -113,9 +110,6 @@ void reduction_op(
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
|
||||
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
|
||||
int reduction_size = plan.shape[0];
|
||||
const T* x_ptr = x.data<T>();
|
||||
@ -135,7 +129,7 @@ void reduction_op(
|
||||
U* out_ptr = out.data<U>();
|
||||
// Unrolling the following loop (and implementing it in order for
|
||||
// ContiguousReduce) should hold extra performance boost.
|
||||
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
@ -181,7 +175,7 @@ void reduction_op(
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
@ -211,7 +205,7 @@ void reduction_op(
|
||||
if (plan.type == GeneralReduce) {
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
|
@ -4,11 +4,11 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes) {
|
||||
std::vector<int> shape = x.shape();
|
||||
std::vector<size_t> strides = x.strides();
|
||||
auto shape = x.shape();
|
||||
auto strides = x.strides();
|
||||
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int a = axes[i];
|
||||
@ -29,8 +29,8 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// Row contiguous input so the output is row contiguous
|
||||
if (x.flags().row_contiguous) {
|
||||
// Merge consecutive axes
|
||||
std::vector<int> shape = {x.shape(axes[0])};
|
||||
std::vector<size_t> strides = {x.strides()[axes[0]]};
|
||||
Shape shape = {x.shape(axes[0])};
|
||||
Strides strides = {x.strides()[axes[0]]};
|
||||
for (int i = 1; i < axes.size(); i++) {
|
||||
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
|
||||
shape.back() *= x.shape(axes[i]);
|
||||
@ -69,7 +69,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
|
||||
// Sort reduction axes by stride in order to merge them and figure out if we
|
||||
// have a contiguous reduction.
|
||||
std::vector<std::pair<int, size_t>> reductions;
|
||||
std::vector<std::pair<int, int64_t>> reductions;
|
||||
for (auto a : axes) {
|
||||
if (x.shape(a) > 1) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
@ -93,8 +93,8 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
for (auto r : reductions) {
|
||||
shape.push_back(r.first);
|
||||
strides.push_back(r.second);
|
||||
@ -109,15 +109,15 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// Delegate to the general strided reduction op if the axes after
|
||||
// strides.back() are contiguous.
|
||||
if (strides.back() > 1) {
|
||||
int size = 1;
|
||||
int64_t size = 1;
|
||||
bool have_expand = false;
|
||||
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||
if (axes.back() == i) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t stride_i = x.strides()[i];
|
||||
int shape_i = x.shape(i);
|
||||
auto stride_i = x.strides()[i];
|
||||
auto shape_i = x.shape(i);
|
||||
if (stride_i == 0) {
|
||||
if (shape_i == 1) {
|
||||
continue;
|
||||
|
@ -4,24 +4,22 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
std::tuple<int64_t, Strides> prepare_slice(
|
||||
const array& in,
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& strides) {
|
||||
const Shape& start_indices,
|
||||
const Shape& strides) {
|
||||
int64_t data_offset = 0;
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
Strides inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides[i];
|
||||
copy_needed |= strides[i] < 0;
|
||||
}
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
const Strides& out_strides,
|
||||
size_t data_offset,
|
||||
size_t data_size,
|
||||
array& out) {
|
||||
|
@ -6,14 +6,14 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
std::tuple<int64_t, Strides> prepare_slice(
|
||||
const array& in,
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& strides);
|
||||
const Shape& start_indices,
|
||||
const Shape& strides);
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
const Strides& out_strides,
|
||||
size_t data_offset,
|
||||
size_t data_size,
|
||||
array& out);
|
||||
|
@ -25,7 +25,7 @@ struct StridedIterator {
|
||||
// Constructors
|
||||
StridedIterator() = default;
|
||||
|
||||
explicit StridedIterator(T* ptr, size_t stride, difference_type offset = 0)
|
||||
explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)
|
||||
: ptr_(ptr + offset * stride), stride_(stride) {}
|
||||
|
||||
explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
|
||||
@ -99,7 +99,7 @@ struct StridedIterator {
|
||||
}
|
||||
|
||||
private:
|
||||
size_t stride_;
|
||||
int64_t stride_;
|
||||
T* ptr_;
|
||||
};
|
||||
|
||||
@ -120,11 +120,11 @@ void sort(const array& in, array& out, int axis) {
|
||||
auto remaining_strides = out.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
size_t axis_stride = out.strides()[axis];
|
||||
int axis_size = out.shape(axis);
|
||||
auto axis_stride = out.strides()[axis];
|
||||
auto axis_size = out.shape(axis);
|
||||
|
||||
// Perform sorting in place
|
||||
ContiguousIterator<size_t> src_it(
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out.data<T>() + src_it.loc;
|
||||
@ -158,14 +158,14 @@ void argsort(const array& in, array& out, int axis) {
|
||||
auto out_remaining_strides = out.strides();
|
||||
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
|
||||
|
||||
size_t in_stride = in.strides()[axis];
|
||||
size_t out_stride = out.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
auto in_stride = in.strides()[axis];
|
||||
auto out_stride = out.strides()[axis];
|
||||
auto axis_size = in.shape(axis);
|
||||
|
||||
// Perform sorting
|
||||
ContiguousIterator<size_t> in_it(
|
||||
ContiguousIterator in_it(
|
||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
||||
ContiguousIterator<size_t> out_it(
|
||||
ContiguousIterator out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in.data<T>() + in_it.loc;
|
||||
@ -208,13 +208,13 @@ void partition(const array& in, array& out, int axis, int kth) {
|
||||
auto remaining_strides = in.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
size_t axis_stride = in.strides()[axis];
|
||||
auto axis_stride = in.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
kth = kth < 0 ? kth + axis_size : kth;
|
||||
|
||||
// Perform partition in place
|
||||
ContiguousIterator<size_t> src_it(
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out.data<T>() + src_it.loc;
|
||||
@ -249,16 +249,16 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
auto out_remaining_strides = out.strides();
|
||||
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
|
||||
|
||||
size_t in_stride = in.strides()[axis];
|
||||
size_t out_stride = out.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
auto in_stride = in.strides()[axis];
|
||||
auto out_stride = out.strides()[axis];
|
||||
auto axis_size = in.shape(axis);
|
||||
|
||||
kth = kth < 0 ? kth + axis_size : kth;
|
||||
|
||||
// Perform partition
|
||||
ContiguousIterator<size_t> in_it(
|
||||
ContiguousIterator in_it(
|
||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
||||
ContiguousIterator<size_t> out_it(
|
||||
ContiguousIterator out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in.data<T>() + in_it.loc;
|
||||
|
@ -78,11 +78,11 @@ void ternary_op_dims(
|
||||
const T3* c,
|
||||
U* out,
|
||||
Op op,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& c_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& c_strides,
|
||||
const Strides& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
@ -164,10 +164,10 @@ void ternary_op_dispatch_dims(
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
|
||||
ContiguousIterator<size_t> c_it(shape, c_strides, ndim - 2);
|
||||
size_t stride = out_strides[ndim - 3];
|
||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
ContiguousIterator c_it(shape, c_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
ternary_op_dims<T1, T2, T3, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
|
@ -15,7 +15,7 @@ void move_or_copy(const array& in, array& out) {
|
||||
void move_or_copy(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<size_t>& strides,
|
||||
const Strides& strides,
|
||||
array::Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
@ -26,15 +26,13 @@ void move_or_copy(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename StrideT>
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>>
|
||||
collapse_contiguous_dims_impl(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<StrideT>>& strides,
|
||||
StrideT size_cap) {
|
||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const std::vector<Strides>& strides,
|
||||
int64_t size_cap) {
|
||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||
// -1.
|
||||
std::vector<int> to_collapse;
|
||||
Shape to_collapse;
|
||||
if (shape.size() > 0) {
|
||||
if (shape[0] != 1) {
|
||||
to_collapse.push_back(0);
|
||||
@ -43,7 +41,7 @@ collapse_contiguous_dims_impl(
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
bool contiguous = true;
|
||||
size *= shape[i];
|
||||
for (const std::vector<StrideT>& st : strides) {
|
||||
for (const auto& st : strides) {
|
||||
if (st[i] * shape[i] != st[i - 1] || size > size_cap) {
|
||||
contiguous = false;
|
||||
size = shape[i];
|
||||
@ -60,8 +58,8 @@ collapse_contiguous_dims_impl(
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<std::vector<StrideT>> out_strides(strides.size());
|
||||
Shape out_shape;
|
||||
std::vector<Strides> out_strides(strides.size());
|
||||
for (int i = 0;;) {
|
||||
while (i < to_collapse.size() && to_collapse[i] == -1) {
|
||||
++i;
|
||||
@ -76,7 +74,7 @@ collapse_contiguous_dims_impl(
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < strides.size(); j++) {
|
||||
const std::vector<StrideT>& st = strides[j];
|
||||
const auto& st = strides[j];
|
||||
out_strides[j].push_back(st[to_collapse[k - 1]]);
|
||||
}
|
||||
i = k + 1;
|
||||
@ -91,29 +89,12 @@ collapse_contiguous_dims_impl(
|
||||
return std::make_tuple(out_shape, out_strides);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<int64_t>>& strides,
|
||||
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
|
||||
return collapse_contiguous_dims_impl(shape, strides, size_cap);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<size_t>>& strides,
|
||||
size_t size_cap /* = std::numeric_limits<int32>::max() */) {
|
||||
return collapse_contiguous_dims_impl(shape, strides, size_cap);
|
||||
}
|
||||
|
||||
template <typename StrideT>
|
||||
std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& strides,
|
||||
StrideT size_cap) {
|
||||
std::vector<int> collapsed_shape;
|
||||
std::vector<StrideT> collapsed_strides;
|
||||
std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
int64_t size_cap) {
|
||||
Shape collapsed_shape;
|
||||
Strides collapsed_strides;
|
||||
|
||||
if (shape.size() > 0) {
|
||||
collapsed_shape.push_back(shape[0]);
|
||||
@ -123,7 +104,7 @@ std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
|
||||
continue;
|
||||
} else if (
|
||||
strides[i] * shape[i] != collapsed_strides.back() ||
|
||||
collapsed_shape.back() * static_cast<StrideT>(shape[i]) > size_cap) {
|
||||
collapsed_shape.back() * static_cast<int64_t>(shape[i]) > size_cap) {
|
||||
collapsed_shape.push_back(shape[i]);
|
||||
collapsed_strides.push_back(strides[i]);
|
||||
} else {
|
||||
@ -136,25 +117,10 @@ std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
|
||||
return std::make_pair(collapsed_shape, collapsed_strides);
|
||||
}
|
||||
|
||||
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<int64_t>& strides,
|
||||
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
|
||||
return collapse_contiguous_dims_impl<int64_t>(shape, strides, size_cap);
|
||||
}
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides,
|
||||
size_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
|
||||
return collapse_contiguous_dims_impl<size_t>(shape, strides, size_cap);
|
||||
}
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
||||
std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
const array& a,
|
||||
size_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
|
||||
return collapse_contiguous_dims_impl<size_t>(
|
||||
a.shape(), a.strides(), size_cap);
|
||||
int64_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
|
||||
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -8,12 +8,9 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename StrideT>
|
||||
inline StrideT elem_to_loc(
|
||||
int elem,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& strides) {
|
||||
StrideT loc = 0;
|
||||
inline int64_t
|
||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||
int64_t loc = 0;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
auto q_and_r = ldiv(elem, shape[i]);
|
||||
loc += q_and_r.rem * strides[i];
|
||||
@ -22,16 +19,15 @@ inline StrideT elem_to_loc(
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc(int elem, const array& a) {
|
||||
inline int64_t elem_to_loc(int elem, const array& a) {
|
||||
if (a.flags().row_contiguous) {
|
||||
return elem;
|
||||
}
|
||||
return elem_to_loc(elem, a.shape(), a.strides());
|
||||
}
|
||||
|
||||
template <typename StrideT>
|
||||
std::vector<StrideT> make_contiguous_strides(const std::vector<int>& shape) {
|
||||
std::vector<StrideT> strides(shape.size(), 1);
|
||||
inline Strides make_contiguous_strides(const Shape& shape) {
|
||||
Strides strides(shape.size(), 1);
|
||||
for (int i = shape.size() - 1; i > 0; i--) {
|
||||
strides[i - 1] = strides[i] * shape[i];
|
||||
}
|
||||
@ -44,22 +40,15 @@ std::vector<StrideT> make_contiguous_strides(const std::vector<int>& shape) {
|
||||
//
|
||||
// When multiple arrays are passed they should all have the same shape. The
|
||||
// collapsed axes are also the same so one shape is returned.
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<int64_t>>& strides,
|
||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const std::vector<Strides>& strides,
|
||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<size_t>>& strides,
|
||||
size_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(
|
||||
inline std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
const std::vector<array>& xs,
|
||||
size_t size_cap = std::numeric_limits<int32_t>::max()) {
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
std::vector<Strides> strides;
|
||||
for (auto& x : xs) {
|
||||
strides.emplace_back(x.strides());
|
||||
}
|
||||
@ -73,19 +62,14 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
|
||||
}
|
||||
|
||||
// The single array version of the above.
|
||||
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<int64_t>& strides,
|
||||
std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides,
|
||||
size_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
||||
std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
const array& a,
|
||||
size_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
|
||||
template <typename StrideT>
|
||||
struct ContiguousIterator {
|
||||
inline void step() {
|
||||
int dims = shape_.size();
|
||||
@ -102,7 +86,7 @@ struct ContiguousIterator {
|
||||
loc += strides_[i];
|
||||
}
|
||||
|
||||
void seek(StrideT n) {
|
||||
void seek(int64_t n) {
|
||||
loc = 0;
|
||||
for (int i = shape_.size() - 1; i >= 0; --i) {
|
||||
auto q_and_r = ldiv(n, shape_[i]);
|
||||
@ -128,32 +112,29 @@ struct ContiguousIterator {
|
||||
}
|
||||
|
||||
explicit ContiguousIterator(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& strides,
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
int dims)
|
||||
: shape_(shape.begin(), shape.begin() + dims),
|
||||
strides_(strides.begin(), strides.begin() + dims) {
|
||||
if (!shape_.empty()) {
|
||||
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
||||
pos_ = std::vector<int>(shape_.size(), 0);
|
||||
pos_ = Shape(shape_.size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
StrideT loc{0};
|
||||
int64_t loc{0};
|
||||
|
||||
private:
|
||||
std::vector<int> shape_;
|
||||
std::vector<StrideT> strides_;
|
||||
std::vector<int> pos_;
|
||||
Shape shape_;
|
||||
Strides strides_;
|
||||
Shape pos_;
|
||||
};
|
||||
|
||||
template <typename StrideT>
|
||||
inline auto check_contiguity(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& strides) {
|
||||
inline auto check_contiguity(const Shape& shape, const Strides& strides) {
|
||||
size_t no_broadcast_data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
bool is_row_contiguous = true;
|
||||
bool is_col_contiguous = true;
|
||||
|
||||
@ -182,7 +163,7 @@ void move_or_copy(const array& in, array& out);
|
||||
void move_or_copy(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<size_t>& strides,
|
||||
const Strides& strides,
|
||||
array::Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
|
@ -75,8 +75,8 @@ void binary_op_gpu_inplace(
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
return std::make_tuple(shape, strides[0], strides[1], strides[2]);
|
||||
} else {
|
||||
std::vector<size_t> e;
|
||||
return std::make_tuple(std::vector<int>{}, e, e, e);
|
||||
decltype(a.strides()) e{};
|
||||
return std::make_tuple(decltype(a.shape()){}, e, e, e);
|
||||
}
|
||||
};
|
||||
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
|
||||
|
@ -67,7 +67,7 @@ inline void build_kernel(
|
||||
|
||||
if (add_indices) {
|
||||
os += fmt::format(
|
||||
" constant const size_t* in_strides [[buffer({0})]],\n", cnt++);
|
||||
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
|
||||
}
|
||||
|
||||
// Add the output arguments
|
||||
@ -81,7 +81,7 @@ inline void build_kernel(
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os += fmt::format(
|
||||
" constant const size_t* output_strides [[buffer({0})]],\n", cnt++);
|
||||
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
|
||||
os += fmt::format(
|
||||
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
||||
}
|
||||
@ -93,11 +93,11 @@ inline void build_kernel(
|
||||
os += " uint3 pos [[thread_position_in_grid]],\n";
|
||||
os += " uint3 grid [[threads_per_grid]]) {\n";
|
||||
|
||||
std::string idx_type = use_big_index ? "size_t" : "uint";
|
||||
std::string idx_type = use_big_index ? "int64_t" : "uint";
|
||||
if (contiguous && use_big_index) {
|
||||
// This is only used for contiguous kernels which don't have
|
||||
// a third grid dimension
|
||||
os += " size_t index = pos.x + grid.x * size_t(pos.y);\n";
|
||||
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
|
||||
} else if (work_per_thread > 1) {
|
||||
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||
os += fmt::format(
|
||||
@ -144,20 +144,18 @@ inline void build_kernel(
|
||||
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
|
||||
if (ndim == 1) {
|
||||
int offset = i * ndim;
|
||||
os += fmt::format(
|
||||
"elem_to_loc_1<size_t, uint>(pos.x, in_strides[{0}]);\n", offset);
|
||||
os +=
|
||||
fmt::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset);
|
||||
} else if (ndim == 2) {
|
||||
int offset = i * ndim;
|
||||
os += fmt::format(
|
||||
"elem_to_loc_2<size_t, {0}>({{pos.x, pos.y}}, in_strides + {1});\n",
|
||||
"elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\n",
|
||||
idx_type,
|
||||
offset);
|
||||
} else if (ndim == 3) {
|
||||
int offset = i * ndim;
|
||||
os += fmt::format(
|
||||
"elem_to_loc_3<size_t, {0}>(pos, in_strides + {1});\n",
|
||||
idx_type,
|
||||
offset);
|
||||
"elem_to_loc_3<{0}>(pos, in_strides + {1});\n", idx_type, offset);
|
||||
} else if (!dynamic_dims) {
|
||||
int offset = (i + 1) * ndim;
|
||||
os += fmt::format(
|
||||
@ -360,10 +358,10 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
std::vector<std::vector<size_t>> initial_strides;
|
||||
std::vector<Strides> initial_strides;
|
||||
initial_strides.push_back(outputs[0].strides());
|
||||
std::vector<int> shape;
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
Shape shape;
|
||||
std::vector<Strides> strides;
|
||||
if (!contiguous) {
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
// Skip constants.
|
||||
@ -378,7 +376,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
std::vector<size_t> xstrides;
|
||||
Strides xstrides;
|
||||
int j = 0;
|
||||
for (; j < output_shape.size() - x.ndim(); j++) {
|
||||
if (output_shape[j] == 1) {
|
||||
@ -440,7 +438,7 @@ void Compiled::eval_gpu(
|
||||
// Put the inputs in
|
||||
int cnt = 0;
|
||||
int stride_idx = 1; // idx 0 is the output strides
|
||||
std::vector<size_t> in_strides;
|
||||
Strides in_strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
|
@ -64,8 +64,8 @@ void explicit_gemm_conv_ND_gpu(
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
// Reshape weight
|
||||
std::vector<int> wt_reshape{implicit_K, implicit_N};
|
||||
std::vector<size_t> wt_restride{1, static_cast<size_t>(implicit_K)};
|
||||
Shape wt_reshape{implicit_K, implicit_N};
|
||||
Strides wt_restride{1, implicit_K};
|
||||
array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {});
|
||||
auto wt_flags = wt.flags();
|
||||
wt_flags.row_contiguous = false;
|
||||
@ -147,10 +147,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
array wt_view(
|
||||
{wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {});
|
||||
wt_view.copy_shared_buffer(
|
||||
wt,
|
||||
{wt.strides(0), 1, static_cast<size_t>(C_per_group)},
|
||||
wt.flags(),
|
||||
wt.size());
|
||||
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
|
||||
|
||||
// Materialize
|
||||
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
|
||||
|
@ -43,13 +43,12 @@ void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& strides_in_pre,
|
||||
const std::vector<stride_t>& strides_out_pre,
|
||||
const Shape& data_shape,
|
||||
const Strides& strides_in_pre,
|
||||
const Strides& strides_out_pre,
|
||||
int64_t inp_offset,
|
||||
int64_t out_offset,
|
||||
CopyType ctype,
|
||||
@ -68,8 +67,8 @@ void copy_gpu_inplace(
|
||||
/* size_cap = */ INT32_MAX);
|
||||
return std::make_tuple(shape, strides[0], strides[1]);
|
||||
} else {
|
||||
std::vector<stride_t> e;
|
||||
return std::make_tuple(std::vector<int>{}, e, e);
|
||||
Strides e{};
|
||||
return std::make_tuple(Shape{}, e, e);
|
||||
}
|
||||
};
|
||||
auto [shape, strides_in_, strides_out_] = maybe_collapse();
|
||||
@ -124,8 +123,8 @@ void copy_gpu_inplace(
|
||||
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
|
||||
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
|
||||
Strides strides_in{strides_in_.begin(), strides_in_.end()};
|
||||
Strides strides_out{strides_out_.begin(), strides_out_.end()};
|
||||
if (ndim > 3) {
|
||||
compute_encoder.set_vector_bytes(shape, ndim, 2);
|
||||
}
|
||||
@ -180,14 +179,13 @@ void copy_gpu_inplace(
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int64_t>& istride,
|
||||
const Strides& istride,
|
||||
int64_t ioffset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
|
||||
in, out, in.shape(), istride, out.strides(), ioffset, 0, ctype, s);
|
||||
}
|
||||
|
||||
void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
|
@ -8,13 +8,12 @@
|
||||
namespace mlx::core {
|
||||
|
||||
// Generic copy inplace
|
||||
template <typename stride_t>
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype,
|
||||
@ -32,7 +31,7 @@ void copy_gpu_inplace(
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int64_t>& istride,
|
||||
const Strides& istride,
|
||||
int64_t ioffset,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
|
@ -363,7 +363,7 @@ void multi_upload_bluestein_fft(
|
||||
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
|
||||
|
||||
// Broadcast w_q and w_k to the batch size
|
||||
std::vector<size_t> b_strides(in.ndim(), 0);
|
||||
Strides b_strides(in.ndim(), 0);
|
||||
b_strides[axis] = 1;
|
||||
array w_k_broadcast({}, complex64, nullptr, {});
|
||||
array w_q_broadcast({}, complex64, nullptr, {});
|
||||
@ -386,8 +386,8 @@ void multi_upload_bluestein_fft(
|
||||
copies.push_back(slice_temp);
|
||||
copies.push_back(conj_temp);
|
||||
|
||||
std::vector<int> rstarts(in.ndim(), 0);
|
||||
std::vector<int> rstrides(in.ndim(), 1);
|
||||
Shape rstarts(in.ndim(), 0);
|
||||
Shape rstrides(in.ndim(), 1);
|
||||
rstarts[axis] = in.shape(axis) - back_offset;
|
||||
rstrides[axis] = -1;
|
||||
unary_op_gpu({in}, conj_temp, "Conjugate", s);
|
||||
@ -431,19 +431,19 @@ void multi_upload_bluestein_fft(
|
||||
s);
|
||||
|
||||
int offset = plan.bluestein_n - (2 * n - 1);
|
||||
std::vector<int> starts(in.ndim(), 0);
|
||||
std::vector<int> strides(in.ndim(), 1);
|
||||
Shape starts(in.ndim(), 0);
|
||||
Shape strides(in.ndim(), 1);
|
||||
starts[axis] = plan.bluestein_n - offset - n;
|
||||
slice_gpu(pad_temp1, temp, starts, strides, s);
|
||||
|
||||
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
|
||||
|
||||
if (real && !inverse) {
|
||||
std::vector<int> rstarts(in.ndim(), 0);
|
||||
std::vector<int> rstrides(in.ndim(), 1);
|
||||
Shape rstarts(in.ndim(), 0);
|
||||
Shape rstrides(in.ndim(), 1);
|
||||
slice_gpu(temp1, out, rstarts, strides, s);
|
||||
} else if (real && inverse) {
|
||||
std::vector<size_t> b_strides(in.ndim(), 0);
|
||||
Strides b_strides(in.ndim(), 0);
|
||||
auto inv_n = array({1.0f / n}, {1}, float32);
|
||||
array temp_float(out.shape(), out.dtype(), nullptr, {});
|
||||
copies.push_back(temp_float);
|
||||
@ -531,8 +531,8 @@ void fft_op(
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
std::vector<size_t> strides;
|
||||
size_t cur_stride = x.shape(axis);
|
||||
Strides strides;
|
||||
int64_t cur_stride = x.shape(axis);
|
||||
for (int a = 0; a < x.ndim(); a++) {
|
||||
if (a == axis) {
|
||||
strides.push_back(1);
|
||||
@ -777,7 +777,7 @@ void nd_fft_op(
|
||||
// Mirror np.fft.(i)rfftn and perform a real transform
|
||||
// only on the final axis.
|
||||
bool step_real = (real && index == axes.size() - 1);
|
||||
int step_shape = inverse ? out.shape(axis) : in.shape(axis);
|
||||
auto step_shape = inverse ? out.shape(axis) : in.shape(axis);
|
||||
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
|
||||
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
|
||||
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
||||
|
@ -65,7 +65,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx_type_name,
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "size_t" : "uint");
|
||||
large ? "int64_t" : "uint");
|
||||
std::string lib_name = kernel_name;
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
@ -86,7 +86,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx_args,
|
||||
idx_arr,
|
||||
idx_ndim,
|
||||
large ? "size_t" : "uint");
|
||||
large ? "int64_t" : "uint");
|
||||
return kernel_source;
|
||||
});
|
||||
|
||||
@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nidx,
|
||||
upd_contig ? "updc_true" : "updc_false",
|
||||
nwork,
|
||||
large ? "size_t" : "uint");
|
||||
large ? "int64_t" : "uint");
|
||||
std::string lib_name = kernel_name;
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
@ -290,7 +290,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx_arr,
|
||||
upd_contig,
|
||||
nwork,
|
||||
large ? "size_t" : "uint");
|
||||
large ? "int64_t" : "uint");
|
||||
return kernel_source;
|
||||
});
|
||||
|
||||
@ -312,8 +312,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
Shape idx_shapes;
|
||||
Strides idx_strides;
|
||||
// To access .data() use char instead of bool
|
||||
// bool is 1 byte in Metal so this is safe
|
||||
std::vector<char> idx_contigs;
|
||||
@ -332,7 +332,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
int64_t stride_ = 0;
|
||||
compute_encoder.set_bytes(shape_, 3);
|
||||
compute_encoder.set_bytes(stride_, 4);
|
||||
} else {
|
||||
@ -347,7 +347,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
int64_t stride_ = 0;
|
||||
compute_encoder.set_bytes(shape_, 7);
|
||||
compute_encoder.set_bytes(stride_, 8);
|
||||
} else {
|
||||
|
@ -11,13 +11,13 @@ gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device {outm_t}* out_mask [[buffer(20)]],
|
||||
const device {opm_t}* mat_mask [[buffer(21)]],
|
||||
const device {opm_t}* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
const constant int64_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
|
@ -5,12 +5,12 @@ constexpr std::string_view gather_kernels = R"(
|
||||
const device {1}* src [[buffer(0)]],
|
||||
device {1}* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
const constant size_t* src_strides [[buffer(3)]],
|
||||
const constant int64_t* src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int* slice_sizes [[buffer(5)]],
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const constant int* idx_shapes [[buffer(7)]],
|
||||
const constant size_t* idx_strides [[buffer(8)]],
|
||||
const constant int64_t* idx_strides [[buffer(8)]],
|
||||
const constant bool* idx_contigs [[buffer(9)]],
|
||||
const constant int& idx_ndim [[buffer(10)]],
|
||||
{4}
|
||||
@ -38,15 +38,15 @@ constexpr std::string_view scatter_kernels = R"(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* upd_shape [[buffer(3)]],
|
||||
const constant size_t* upd_strides [[buffer(4)]],
|
||||
const constant int64_t* upd_strides [[buffer(4)]],
|
||||
const constant size_t& upd_ndim [[buffer(5)]],
|
||||
const constant size_t& upd_size [[buffer(6)]],
|
||||
const constant int* out_shape [[buffer(7)]],
|
||||
const constant size_t* out_strides [[buffer(8)]],
|
||||
const constant int64_t* out_strides [[buffer(8)]],
|
||||
const constant size_t& out_ndim [[buffer(9)]],
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const constant int* idx_shapes [[buffer(11)]],
|
||||
const constant size_t* idx_strides [[buffer(12)]],
|
||||
const constant int64_t* idx_strides [[buffer(12)]],
|
||||
const constant bool* idx_contigs [[buffer(13)]],
|
||||
const constant int& idx_ndim [[buffer(14)]],
|
||||
const constant size_t& idx_size [[buffer(15)]],
|
||||
|
@ -10,12 +10,12 @@ template [[host_name("{name}")]]
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
||||
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
@ -43,7 +43,7 @@ block_masked_gemm<
|
||||
device {itype}* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const device {outmasktype}* out_mask [[buffer(10)]],
|
||||
const device {opmasktype}* lhs_mask [[buffer(11)]],
|
||||
const device {opmasktype}* rhs_mask [[buffer(12)]],
|
||||
|
@ -75,10 +75,10 @@ template <typename T, typename Op, int N_READS = 4>
|
||||
const device T* in [[buffer(0)]],
|
||||
device uint32_t* out [[buffer(1)]],
|
||||
const constant int* shape [[buffer(2)]],
|
||||
const constant size_t* in_strides [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant int64_t* in_strides [[buffer(3)]],
|
||||
const constant int64_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& ndim [[buffer(5)]],
|
||||
const constant size_t& axis_stride [[buffer(6)]],
|
||||
const constant int64_t& axis_stride [[buffer(6)]],
|
||||
const constant size_t& axis_size [[buffer(7)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
|
@ -43,7 +43,7 @@ template <typename T, typename U, typename Op>
|
||||
device U* c,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
c[offset] = Op()(a[0], b[offset]);
|
||||
}
|
||||
|
||||
@ -54,7 +54,7 @@ template <typename T, typename U, typename Op>
|
||||
device U* c,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
c[offset] = Op()(a[offset], b[0]);
|
||||
}
|
||||
|
||||
@ -65,49 +65,49 @@ template <typename T, typename U, typename Op>
|
||||
device U* c,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
c[offset] = Op()(a[offset], b[offset]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void binary_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
constant const int64_t& a_stride,
|
||||
constant const int64_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_stride);
|
||||
auto a_idx = elem_to_loc_1<IdxT>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<IdxT>(index, b_stride);
|
||||
c[index] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void binary_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
constant const int64_t a_strides[2],
|
||||
constant const int64_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
|
||||
auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void binary_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
constant const int64_t a_strides[3],
|
||||
constant const int64_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
|
||||
auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
@ -117,18 +117,18 @@ template <
|
||||
typename U,
|
||||
typename Op,
|
||||
int N = 1,
|
||||
typename IdxT = size_t>
|
||||
typename IdxT = int64_t>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int64_t* a_strides,
|
||||
constant const int64_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
|
||||
auto idx = elem_to_loc_2_nd<IdxT>(
|
||||
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
||||
auto xshape = shape[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
|
@ -56,7 +56,7 @@ template <typename T, typename U, typename Op>
|
||||
device U* d,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
auto out = Op()(a[0], b[offset]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
@ -70,7 +70,7 @@ template <typename T, typename U, typename Op>
|
||||
device U* d,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
auto out = Op()(a[offset], b[0]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
@ -84,58 +84,58 @@ template <typename T, typename U, typename Op>
|
||||
device U* d,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
auto out = Op()(a[offset], b[offset]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void binary_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
constant const int64_t& a_stride,
|
||||
constant const int64_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_stride);
|
||||
auto a_idx = elem_to_loc_1<IdxT>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<IdxT>(index, b_stride);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void binary_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
constant const int64_t a_strides[2],
|
||||
constant const int64_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
|
||||
auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void binary_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
constant const int64_t a_strides[3],
|
||||
constant const int64_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
|
||||
auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
|
||||
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
@ -147,19 +147,19 @@ template <
|
||||
typename U,
|
||||
typename Op,
|
||||
int N = 1,
|
||||
typename IdxT = size_t>
|
||||
typename IdxT = int64_t>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int64_t* a_strides,
|
||||
constant const int64_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
|
||||
auto idx = elem_to_loc_2_nd<IdxT>(
|
||||
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
||||
auto xshape = shape[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
|
@ -22,7 +22,7 @@ template <typename T, typename U>
|
||||
device U* dst [[buffer(1)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
dst[offset] = static_cast<U>(src[0]);
|
||||
}
|
||||
|
||||
@ -32,7 +32,7 @@ template <typename T, typename U>
|
||||
device U* dst [[buffer(1)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
dst[offset] = static_cast<U>(src[offset]);
|
||||
}
|
||||
|
||||
@ -42,7 +42,7 @@ template <typename T, typename U, typename IdxT = int64_t>
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
|
||||
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
|
||||
dst[index] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
@ -53,7 +53,7 @@ template <typename T, typename U, typename IdxT = int64_t>
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
|
||||
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
|
||||
IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
@ -65,7 +65,7 @@ template <typename T, typename U, typename IdxT = int64_t>
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
|
||||
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
|
||||
IdxT dst_idx =
|
||||
index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
@ -80,7 +80,7 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc<int64_t, IdxT>(
|
||||
auto src_idx = elem_to_loc<IdxT>(
|
||||
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
|
||||
if (N == 1) {
|
||||
IdxT dst_idx =
|
||||
@ -104,8 +104,8 @@ template <typename T, typename U, typename IdxT = int64_t>
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1<int64_t, IdxT>(index, dst_stride);
|
||||
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
@ -116,8 +116,8 @@ template <typename T, typename U, typename IdxT = int64_t>
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2<int64_t, IdxT>(index, dst_strides);
|
||||
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
@ -128,8 +128,8 @@ template <typename T, typename U, typename IdxT = int64_t>
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3<int64_t, IdxT>(index, dst_strides);
|
||||
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
@ -142,7 +142,7 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<int64_t, IdxT>(
|
||||
auto idx = elem_to_loc_2_nd<IdxT>(
|
||||
{N * index.x, index.y, index.z},
|
||||
src_shape,
|
||||
src_strides,
|
||||
|
@ -9,7 +9,7 @@ METAL_FUNC void gather_impl(
|
||||
const device T* src [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
const constant size_t* src_strides [[buffer(3)]],
|
||||
const constant int64_t* src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int* slice_sizes [[buffer(5)]],
|
||||
const constant int* axes [[buffer(6)]],
|
||||
@ -27,7 +27,7 @@ METAL_FUNC void gather_impl(
|
||||
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
|
||||
idx_loc += indices.row_contiguous[i]
|
||||
? index.y
|
||||
: elem_to_loc<size_t, LocT>(
|
||||
: elem_to_loc<LocT>(
|
||||
index.y,
|
||||
&indices.shapes[indices.ndim * i + 1],
|
||||
&indices.strides[indices.ndim * i + 1],
|
||||
@ -39,7 +39,7 @@ METAL_FUNC void gather_impl(
|
||||
}
|
||||
|
||||
auto src_offset =
|
||||
elem_to_loc<size_t, LocT>(index.z, slice_sizes, src_strides, src_ndim);
|
||||
elem_to_loc<LocT>(index.z, slice_sizes, src_strides, src_ndim);
|
||||
|
||||
LocT out_idx = index.z;
|
||||
if (IDX_NDIM == 1) {
|
||||
|
@ -436,9 +436,9 @@ template <
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant int64_t* bias_batch_stride [[buffer(13)]],
|
||||
const constant int& bias_stride [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
@ -486,31 +486,21 @@ template <
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_helper( \
|
||||
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
|
||||
"_tm" #tm "_tn" #tn "_nc" #nc \
|
||||
"_axpby" #axpby)]] [[kernel]] void \
|
||||
gemv<itype, bm, bn, sm, sn, tm, tn, nc, axpby>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]], \
|
||||
const constant int& bias_stride [[buffer(14)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
#define instantiate_gemv_helper( \
|
||||
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
||||
instantiate_kernel( \
|
||||
"gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
||||
"_tn" #tn "_nc" #nc "_axpby" #axpby, \
|
||||
gemv, \
|
||||
itype, \
|
||||
bm, \
|
||||
bn, \
|
||||
sm, \
|
||||
sn, \
|
||||
tm, \
|
||||
tn, \
|
||||
nc, \
|
||||
axpby)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
@ -549,13 +539,13 @@ template <
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* index_batch_strides [[buffer(11)]],
|
||||
const constant int64_t* index_batch_strides [[buffer(11)]],
|
||||
const constant int& vector_batch_ndim [[buffer(12)]],
|
||||
const constant int* vector_batch_shape [[buffer(13)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(14)]],
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]],
|
||||
const constant int* matrix_batch_shape [[buffer(16)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(17)]],
|
||||
const constant uint32_t* vec_indices [[buffer(18)]],
|
||||
const constant uint32_t* mat_indices [[buffer(19)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@ -571,8 +561,8 @@ template <
|
||||
|
||||
// Update batch offsets
|
||||
if (batch_ndim > 1) {
|
||||
const constant size_t* veci_bstrides = index_batch_strides;
|
||||
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
|
||||
const constant auto* veci_bstrides = index_batch_strides;
|
||||
const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
|
||||
@ -619,37 +609,14 @@ template <
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
|
||||
template [[host_name("gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
||||
"_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \
|
||||
gemv_gather<itype, bm, bn, sm, sn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* index_batch_strides [[buffer(11)]], \
|
||||
const constant int& vector_batch_ndim [[buffer(12)]], \
|
||||
const constant int* vector_batch_shape [[buffer(13)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]], \
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]], \
|
||||
const constant int* matrix_batch_shape [[buffer(16)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]], \
|
||||
const constant uint32_t* vec_indices [[buffer(18)]], \
|
||||
const constant uint32_t* mat_indices [[buffer(19)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_bs_blocks(name, itype) \
|
||||
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
|
||||
instantiate_kernel( \
|
||||
"gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
||||
"_sn" #sn "_tm" #tm "_tn" #tn, \
|
||||
gemv_gather, itype, bm, bn, sm, sn, tm, tn)
|
||||
|
||||
#define instantiate_gemv_bs_blocks(name, itype) \
|
||||
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 1, 4) \
|
||||
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 4, 4) \
|
||||
instantiate_gemv_bs_helper(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on
|
||||
@ -684,9 +651,9 @@ template <
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant int64_t* bias_batch_stride [[buffer(13)]],
|
||||
const constant int& bias_stride [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
@ -734,33 +701,14 @@ template <
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_t_helper( \
|
||||
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
|
||||
"_tm" #tm "_tn" #tn "_nc" #nc \
|
||||
"_axpby" #axpby)]] [[kernel]] void \
|
||||
gemv_t<itype, bm, bn, sm, sn, tm, tn, nc, axpby>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]], \
|
||||
const constant int& bias_stride [[buffer(14)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t_helper( \
|
||||
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
||||
instantiate_kernel( \
|
||||
"gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
|
||||
"_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby, \
|
||||
gemv_t, itype, bm, bn, sm, sn, tm, tn, nc, axpby)
|
||||
|
||||
#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \
|
||||
@ -800,13 +748,13 @@ template <
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* index_batch_strides [[buffer(11)]],
|
||||
const constant int64_t* index_batch_strides [[buffer(11)]],
|
||||
const constant int& vector_batch_ndim [[buffer(12)]],
|
||||
const constant int* vector_batch_shape [[buffer(13)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(14)]],
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]],
|
||||
const constant int* matrix_batch_shape [[buffer(16)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(17)]],
|
||||
const constant uint32_t* vec_indices [[buffer(18)]],
|
||||
const constant uint32_t* mat_indices [[buffer(19)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@ -822,8 +770,8 @@ template <
|
||||
|
||||
// Update batch offsets
|
||||
if (batch_ndim > 1) {
|
||||
const constant size_t* veci_bstrides = index_batch_strides;
|
||||
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
|
||||
const constant auto* veci_bstrides = index_batch_strides;
|
||||
const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
|
||||
@ -870,36 +818,14 @@ template <
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
|
||||
template [[host_name("gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
||||
"_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \
|
||||
gemv_t_gather<itype, bm, bn, sm, sn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* index_batch_strides [[buffer(11)]], \
|
||||
const constant int& vector_batch_ndim [[buffer(12)]], \
|
||||
const constant int* vector_batch_shape [[buffer(13)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]], \
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]], \
|
||||
const constant int* matrix_batch_shape [[buffer(16)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]], \
|
||||
const constant uint32_t* vec_indices [[buffer(18)]], \
|
||||
const constant uint32_t* mat_indices [[buffer(19)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t_bs_helper( \
|
||||
nm, itype, bm, bn, sm, sn, tm, tn) \
|
||||
instantiate_kernel( \
|
||||
"gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
||||
"_sn" #sn "_tm" #tm "_tn" #tn, \
|
||||
gemv_t_gather, itype, bm, bn, sm, sn, tm, tn)
|
||||
|
||||
#define instantiate_gemv_t_bs_blocks(name, itype) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 1) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 4) \
|
||||
|
@ -642,13 +642,13 @@ template <
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
const constant int64_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@ -673,8 +673,8 @@ template <
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_mat = mask_batch_strides;
|
||||
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
const constant auto* mask_strides_mat = mask_batch_strides;
|
||||
const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
||||
@ -742,13 +742,13 @@ template <
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
const constant int64_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@ -773,8 +773,8 @@ template <
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_mat = mask_batch_strides;
|
||||
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
const constant auto* mask_strides_mat = mask_batch_strides;
|
||||
const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
||||
|
@ -10,29 +10,11 @@
|
||||
|
||||
#define instantiate_gemv_helper( \
|
||||
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
template [[host_name("gemv_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
|
||||
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
||||
"_tn" #tn "_nc" #nc)]] [[kernel]] void \
|
||||
gemv_masked<itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const device outm_t* out_mask [[buffer(20)]], \
|
||||
const device opm_t* mat_mask [[buffer(21)]], \
|
||||
const device opm_t* vec_mask [[buffer(22)]], \
|
||||
const constant int* mask_strides [[buffer(23)]], \
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
instantiate_kernel( \
|
||||
"gemv_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
|
||||
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
||||
"_tn" #tn "_nc" #nc, \
|
||||
gemv_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc)
|
||||
|
||||
#define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
@ -61,29 +43,11 @@ instantiate_gemv_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
#define instantiate_gemv_t_helper( \
|
||||
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
template [[host_name("gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
|
||||
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
||||
"_tn" #tn "_nc" #nc)]] [[kernel]] void \
|
||||
gemv_t_masked<itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const device outm_t* out_mask [[buffer(20)]], \
|
||||
const device opm_t* mat_mask [[buffer(21)]], \
|
||||
const device opm_t* vec_mask [[buffer(22)]], \
|
||||
const constant int* mask_strides [[buffer(23)]], \
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
instantiate_kernel( \
|
||||
"gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
|
||||
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
||||
"_tn" #tn "_nc" #nc, \
|
||||
gemv_t_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc)
|
||||
|
||||
#define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
|
@ -8,7 +8,7 @@ template <typename IdxT, int NIDX>
|
||||
struct Indices {
|
||||
const array<const device IdxT*, NIDX> buffers;
|
||||
const constant int* shapes;
|
||||
const constant size_t* strides;
|
||||
const constant int64_t* strides;
|
||||
const constant bool* row_contiguous;
|
||||
const int ndim;
|
||||
};
|
||||
|
@ -1219,12 +1219,12 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
int output_stride,
|
||||
const constant int& x_batch_ndims,
|
||||
const constant int* x_shape,
|
||||
const constant size_t* x_strides,
|
||||
const constant int64_t* x_strides,
|
||||
const constant int& w_batch_ndims,
|
||||
const constant int* w_shape,
|
||||
const constant size_t* w_strides,
|
||||
const constant size_t* s_strides,
|
||||
const constant size_t* b_strides,
|
||||
const constant int64_t* w_strides,
|
||||
const constant int64_t* s_strides,
|
||||
const constant int64_t* b_strides,
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
// Set the input/output matrices
|
||||
uint32_t x_idx = tid.z;
|
||||
@ -1260,16 +1260,16 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
int output_stride,
|
||||
const constant int& batch_ndims,
|
||||
const constant int* batch_shape,
|
||||
const constant size_t* lhs_strides,
|
||||
const constant size_t* rhs_strides,
|
||||
const constant int64_t* lhs_strides,
|
||||
const constant int64_t* rhs_strides,
|
||||
const constant int& x_batch_ndims,
|
||||
const constant int* x_shape,
|
||||
const constant size_t* x_strides,
|
||||
const constant int64_t* x_strides,
|
||||
const constant int& w_batch_ndims,
|
||||
const constant int* w_shape,
|
||||
const constant size_t* w_strides,
|
||||
const constant size_t* s_strides,
|
||||
const constant size_t* b_strides,
|
||||
const constant int64_t* w_strides,
|
||||
const constant int64_t* s_strides,
|
||||
const constant int64_t* b_strides,
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
// Set the input/output matrices
|
||||
uint32_t x_idx;
|
||||
@ -1313,12 +1313,12 @@ template <typename T, int group_size, int bits, int D, bool batched>
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
||||
uint quad_lid [[thread_index_in_quadgroup]]) {
|
||||
@ -1364,12 +1364,12 @@ template <typename T, int group_size, int bits, bool batched>
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@ -1415,12 +1415,12 @@ template <typename T, const int group_size, const int bits, bool batched>
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@ -1466,12 +1466,12 @@ template <typename T, const int group_size, const int bits, bool batched>
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@ -1517,12 +1517,12 @@ template <typename T, const int group_size, const int bits, int split_k = 32>
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
const constant int& final_block_size [[buffer(15)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@ -1581,12 +1581,12 @@ template <
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& x_batch_ndims [[buffer(8)]],
|
||||
const constant int* x_shape [[buffer(9)]],
|
||||
const constant size_t* x_strides [[buffer(10)]],
|
||||
const constant int64_t* x_strides [[buffer(10)]],
|
||||
const constant int& w_batch_ndims [[buffer(11)]],
|
||||
const constant int* w_shape [[buffer(12)]],
|
||||
const constant size_t* w_strides [[buffer(13)]],
|
||||
const constant size_t* s_strides [[buffer(14)]],
|
||||
const constant size_t* b_strides [[buffer(15)]],
|
||||
const constant int64_t* w_strides [[buffer(13)]],
|
||||
const constant int64_t* s_strides [[buffer(14)]],
|
||||
const constant int64_t* b_strides [[buffer(15)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@ -1639,12 +1639,12 @@ template <
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& x_batch_ndims [[buffer(8)]],
|
||||
const constant int* x_shape [[buffer(9)]],
|
||||
const constant size_t* x_strides [[buffer(10)]],
|
||||
const constant int64_t* x_strides [[buffer(10)]],
|
||||
const constant int& w_batch_ndims [[buffer(11)]],
|
||||
const constant int* w_shape [[buffer(12)]],
|
||||
const constant size_t* w_strides [[buffer(13)]],
|
||||
const constant size_t* s_strides [[buffer(14)]],
|
||||
const constant size_t* b_strides [[buffer(15)]],
|
||||
const constant int64_t* w_strides [[buffer(13)]],
|
||||
const constant int64_t* s_strides [[buffer(14)]],
|
||||
const constant int64_t* b_strides [[buffer(15)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@ -1691,18 +1691,18 @@ template <typename T, int group_size, int bits>
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
const constant int& batch_ndims [[buffer(15)]],
|
||||
const constant int* batch_shape [[buffer(16)]],
|
||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
||||
const constant size_t* lhs_strides [[buffer(19)]],
|
||||
const constant size_t* rhs_strides [[buffer(20)]],
|
||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@ -1752,18 +1752,18 @@ template <typename T, int group_size, int bits>
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
const constant int& batch_ndims [[buffer(15)]],
|
||||
const constant int* batch_shape [[buffer(16)]],
|
||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
||||
const constant size_t* lhs_strides [[buffer(19)]],
|
||||
const constant size_t* rhs_strides [[buffer(20)]],
|
||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@ -1813,18 +1813,18 @@ template <typename T, int group_size, int bits>
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
const constant int& batch_ndims [[buffer(15)]],
|
||||
const constant int* batch_shape [[buffer(16)]],
|
||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
||||
const constant size_t* lhs_strides [[buffer(19)]],
|
||||
const constant size_t* rhs_strides [[buffer(20)]],
|
||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@ -1882,18 +1882,18 @@ template <
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& x_batch_ndims [[buffer(8)]],
|
||||
const constant int* x_shape [[buffer(9)]],
|
||||
const constant size_t* x_strides [[buffer(10)]],
|
||||
const constant int64_t* x_strides [[buffer(10)]],
|
||||
const constant int& w_batch_ndims [[buffer(11)]],
|
||||
const constant int* w_shape [[buffer(12)]],
|
||||
const constant size_t* w_strides [[buffer(13)]],
|
||||
const constant size_t* s_strides [[buffer(14)]],
|
||||
const constant size_t* b_strides [[buffer(15)]],
|
||||
const constant int64_t* w_strides [[buffer(13)]],
|
||||
const constant int64_t* s_strides [[buffer(14)]],
|
||||
const constant int64_t* b_strides [[buffer(15)]],
|
||||
const constant int& batch_ndims [[buffer(16)]],
|
||||
const constant int* batch_shape [[buffer(17)]],
|
||||
const device uint32_t* lhs_indices [[buffer(18)]],
|
||||
const device uint32_t* rhs_indices [[buffer(19)]],
|
||||
const constant size_t* lhs_strides [[buffer(20)]],
|
||||
const constant size_t* rhs_strides [[buffer(21)]],
|
||||
const constant int64_t* lhs_strides [[buffer(20)]],
|
||||
const constant int64_t* rhs_strides [[buffer(21)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@ -1949,18 +1949,18 @@ template <
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& x_batch_ndims [[buffer(8)]],
|
||||
const constant int* x_shape [[buffer(9)]],
|
||||
const constant size_t* x_strides [[buffer(10)]],
|
||||
const constant int64_t* x_strides [[buffer(10)]],
|
||||
const constant int& w_batch_ndims [[buffer(11)]],
|
||||
const constant int* w_shape [[buffer(12)]],
|
||||
const constant size_t* w_strides [[buffer(13)]],
|
||||
const constant size_t* s_strides [[buffer(14)]],
|
||||
const constant size_t* b_strides [[buffer(15)]],
|
||||
const constant int64_t* w_strides [[buffer(13)]],
|
||||
const constant int64_t* s_strides [[buffer(14)]],
|
||||
const constant int64_t* b_strides [[buffer(15)]],
|
||||
const constant int& batch_ndims [[buffer(16)]],
|
||||
const constant int* batch_shape [[buffer(17)]],
|
||||
const device uint32_t* lhs_indices [[buffer(18)]],
|
||||
const device uint32_t* rhs_indices [[buffer(19)]],
|
||||
const constant size_t* lhs_strides [[buffer(20)]],
|
||||
const constant size_t* rhs_strides [[buffer(21)]],
|
||||
const constant int64_t* lhs_strides [[buffer(20)]],
|
||||
const constant int64_t* rhs_strides [[buffer(21)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
|
@ -71,7 +71,7 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
constant const uint& bytes_per_key,
|
||||
constant const int& ndim,
|
||||
constant const int* key_shape,
|
||||
constant const size_t* key_strides,
|
||||
constant const int64_t* key_strides,
|
||||
uint2 grid_dim [[threads_per_grid]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto kidx = 2 * index.x;
|
||||
|
@ -59,10 +59,10 @@ instantiate_init_min_max(max, Max)
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, size_t, dim) \
|
||||
itype, otype, op, int64_t, dim) \
|
||||
instantiate_kernel("col_reduce_longcolumn_large_" #dim "_reduce_" #name, \
|
||||
col_reduce_longcolumn, \
|
||||
itype, otype, op, size_t, dim)
|
||||
itype, otype, op, int64_t, dim)
|
||||
|
||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
@ -70,7 +70,7 @@ instantiate_init_min_max(max, Max)
|
||||
itype, otype, op, uint, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, size_t, dim, bm, bn)
|
||||
itype, otype, op, int64_t, dim, bm, bn)
|
||||
|
||||
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
@ -78,7 +78,7 @@ instantiate_init_min_max(max, Max)
|
||||
itype, otype, op, uint, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, size_t, dim, bm, bn)
|
||||
itype, otype, op, int64_t, dim, bm, bn)
|
||||
|
||||
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
|
||||
@ -98,7 +98,7 @@ instantiate_init_min_max(max, Max)
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, size_t, dim)
|
||||
itype, otype, op, int64_t, dim)
|
||||
|
||||
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
|
||||
@ -106,7 +106,7 @@ instantiate_init_min_max(max, Max)
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, size_t, dim)
|
||||
itype, otype, op, int64_t, dim)
|
||||
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 1) \
|
||||
@ -125,7 +125,7 @@ instantiate_init_min_max(max, Max)
|
||||
instantiate_col_reduce_general(name##tname, itype, otype, op<otype>)
|
||||
|
||||
#define instantiate_and_or(name, op) \
|
||||
instantiate_reduce_functions(name, bool_, bool, bool, op) \
|
||||
instantiate_reduce_functions(name, bool_, bool, bool, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, bool, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, bool, op) \
|
||||
instantiate_reduce_functions(name, int64, int64_t, bool, op)
|
||||
|
@ -5,12 +5,12 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant int64_t& reduction_stride [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int64_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int64_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
@ -34,7 +34,7 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
|
||||
bool safe = column + n_reads <= reduction_stride;
|
||||
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
@ -100,10 +100,10 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int64_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int64_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
const constant size_t& out_size [[buffer(11)]],
|
||||
@ -116,7 +116,7 @@ template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
|
||||
const device T* row;
|
||||
|
||||
IdxT out_idx = gid.x + gsize.x * IdxT(gid.y);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + lid.x;
|
||||
|
||||
U total = Op::init;
|
||||
@ -164,12 +164,12 @@ template <
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant int64_t& reduction_stride [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int64_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int64_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
@ -197,7 +197,7 @@ template <
|
||||
bool safe = column + n_reads <= reduction_stride;
|
||||
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
@ -303,12 +303,12 @@ template <
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant int64_t& reduction_stride [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int64_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int64_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
const constant size_t& out_size [[buffer(11)]],
|
||||
@ -342,7 +342,7 @@ template <
|
||||
IdxT full_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
IdxT block_idx = full_idx / IdxT(out_size);
|
||||
IdxT out_idx = full_idx % IdxT(out_size);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
|
@ -98,11 +98,11 @@ template <
|
||||
METAL_FUNC void per_thread_row_reduce(
|
||||
thread U totals[N_WRITES],
|
||||
const device T* in,
|
||||
const size_t row_idx,
|
||||
const int64_t row_idx,
|
||||
int blocks,
|
||||
int extra,
|
||||
const constant int* shape,
|
||||
const constant size_t* strides,
|
||||
const constant int64_t* strides,
|
||||
const constant int& ndim,
|
||||
uint lsize_x,
|
||||
uint lid_x) {
|
||||
@ -199,13 +199,13 @@ template <
|
||||
[[kernel]] void row_reduce_small(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& row_size [[buffer(2)]],
|
||||
const constant size_t& non_row_reductions [[buffer(3)]],
|
||||
const constant int64_t& row_size [[buffer(2)]],
|
||||
const constant int64_t& non_row_reductions [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int64_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int64_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
@ -225,7 +225,7 @@ template <
|
||||
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
|
||||
// Simple loop over non_row_reductions and reduce the row in the thread.
|
||||
IdxT out_idx = tid.x + tsize.y * IdxT(tid.y);
|
||||
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
||||
|
||||
for (uint r = 0; r < non_row_reductions; r++) {
|
||||
row = in + loop.location();
|
||||
@ -238,7 +238,7 @@ template <
|
||||
// Collaboratively reduce over non_row_reductions in the simdgroup. Each
|
||||
// thread reduces every 32nd row and then a simple simd reduce.
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
||||
|
||||
loop.next(simd_lane_id, reduce_shape, reduce_strides);
|
||||
|
||||
@ -260,14 +260,14 @@ template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT = size_t,
|
||||
typename IdxT = int64_t,
|
||||
int N_READS = REDUCE_N_READS,
|
||||
int N_WRITES = REDUCE_N_WRITES>
|
||||
[[kernel]] void row_reduce_simple(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant int64_t& out_size [[buffer(3)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
@ -314,13 +314,13 @@ template <
|
||||
[[kernel]] void row_reduce_looped(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& row_size [[buffer(2)]],
|
||||
const constant size_t& non_row_reductions [[buffer(3)]],
|
||||
const constant int64_t& row_size [[buffer(2)]],
|
||||
const constant int64_t& non_row_reductions [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int64_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int64_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
@ -337,8 +337,7 @@ template <
|
||||
|
||||
// lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
|
||||
// needs a small refactor.
|
||||
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim) +
|
||||
lid.x * N_READS;
|
||||
in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim) + lid.x * N_READS;
|
||||
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
const device T* row;
|
||||
|
@ -16,11 +16,11 @@ METAL_FUNC void scatter_impl(
|
||||
const device T* updates,
|
||||
device mlx_atomic<T>* out,
|
||||
const constant int* upd_shape,
|
||||
const constant size_t* upd_strides,
|
||||
const constant int64_t* upd_strides,
|
||||
const constant size_t& upd_ndim,
|
||||
const constant size_t& upd_size,
|
||||
const constant int* out_shape,
|
||||
const constant size_t* out_strides,
|
||||
const constant int64_t* out_strides,
|
||||
const constant size_t& out_ndim,
|
||||
const constant int* axes,
|
||||
const constant size_t& idx_size,
|
||||
@ -31,7 +31,7 @@ METAL_FUNC void scatter_impl(
|
||||
auto ind_idx = gid.y * NWORK;
|
||||
LocT out_offset = 0;
|
||||
if (upd_size > 1) {
|
||||
out_offset = elem_to_loc<size_t, LocT>(
|
||||
out_offset = elem_to_loc<LocT>(
|
||||
gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
}
|
||||
|
||||
@ -40,7 +40,7 @@ METAL_FUNC void scatter_impl(
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = indices.row_contiguous[i]
|
||||
? ind_idx
|
||||
: elem_to_loc<size_t, LocT>(
|
||||
: elem_to_loc<LocT>(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
@ -52,8 +52,7 @@ METAL_FUNC void scatter_impl(
|
||||
}
|
||||
auto upd_idx = ind_idx * static_cast<LocT>(upd_size) + gid.x;
|
||||
if constexpr (!UPD_ROW_CONTIG) {
|
||||
upd_idx =
|
||||
elem_to_loc<size_t, LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);
|
||||
upd_idx = elem_to_loc<LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);
|
||||
}
|
||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||
}
|
||||
|
@ -343,8 +343,8 @@ template <
|
||||
const constant int& out_stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const constant int* nc_shape [[buffer(6)]],
|
||||
const constant size_t* in_nc_strides [[buffer(7)]],
|
||||
const constant size_t* out_nc_strides [[buffer(8)]],
|
||||
const constant int64_t* in_nc_strides [[buffer(7)]],
|
||||
const constant int64_t* out_nc_strides [[buffer(8)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel =
|
||||
@ -486,7 +486,7 @@ template <
|
||||
const constant int& stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const constant int* nc_shape [[buffer(6)]],
|
||||
const constant size_t* nc_strides [[buffer(7)]],
|
||||
const constant int64_t* nc_strides [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel = KernelMultiBlockMergeSort<
|
||||
|
@ -26,10 +26,10 @@ struct AttnParams {
|
||||
int NQ_aligned; ///< Number of full query blocks
|
||||
int NK_aligned; ///< Number of full key/value blocks
|
||||
|
||||
size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
|
||||
size_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
|
||||
size_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
|
||||
size_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
|
||||
int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
|
||||
int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
|
||||
int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
|
||||
int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
|
@ -14,9 +14,9 @@ struct MLXConvParams {
|
||||
const int pad[NDIM]; // Input padding
|
||||
const int kdil[NDIM]; // Kernel dilation
|
||||
const int idil[NDIM]; // Input dilation
|
||||
const size_t in_strides[NDIM + 2]; // In strides
|
||||
const size_t wt_strides[NDIM + 2]; // Wt strides
|
||||
const size_t out_strides[NDIM + 2]; // Out strides
|
||||
const int64_t in_strides[NDIM + 2]; // In strides
|
||||
const int64_t wt_strides[NDIM + 2]; // Wt strides
|
||||
const int64_t out_strides[NDIM + 2]; // Out strides
|
||||
const int groups; // Input channel groups
|
||||
const bool flip;
|
||||
};
|
||||
@ -59,4 +59,4 @@ struct Conv2DGeneralBaseInfo {
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
||||
} // namespace mlx
|
||||
|
@ -38,12 +38,12 @@ template <
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
||||
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
@ -88,9 +88,8 @@ template <
|
||||
uint32_t indx_A, indx_B, indx_C;
|
||||
|
||||
if (has_batch) {
|
||||
const constant size_t* indx_A_bstrides = batch_strides;
|
||||
const constant size_t* indx_B_bstrides =
|
||||
batch_strides + params->batch_ndim;
|
||||
const constant auto* indx_A_bstrides = batch_strides;
|
||||
const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 indx_offsets = elem_to_loc_broadcast(
|
||||
tid.z,
|
||||
@ -102,7 +101,7 @@ template <
|
||||
indx_B = rhs_indices[indx_offsets.y];
|
||||
|
||||
if (use_out_source) {
|
||||
const constant size_t* indx_C_bstrides =
|
||||
const constant auto* indx_C_bstrides =
|
||||
indx_B_bstrides + params->batch_ndim;
|
||||
auto indx_offset_C = elem_to_loc(
|
||||
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
|
||||
@ -120,18 +119,18 @@ template <
|
||||
// Translate indices to offsets
|
||||
int batch_ndim_A = operand_batch_ndim.x;
|
||||
const constant int* batch_shape_A = operand_shape;
|
||||
const constant size_t* batch_strides_A = operand_strides;
|
||||
const constant auto* batch_strides_A = operand_strides;
|
||||
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
|
||||
|
||||
int batch_ndim_B = operand_batch_ndim.y;
|
||||
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
|
||||
const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A;
|
||||
const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A;
|
||||
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
|
||||
|
||||
if (use_out_source) {
|
||||
int batch_ndim_C = operand_batch_ndim.z;
|
||||
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
|
||||
const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B;
|
||||
const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B;
|
||||
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
|
||||
}
|
||||
|
||||
@ -140,8 +139,8 @@ template <
|
||||
// Handle regular batch
|
||||
else {
|
||||
if (has_batch) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
const constant auto* A_bstrides = batch_strides;
|
||||
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
@ -150,7 +149,7 @@ template <
|
||||
B += batch_offsets.y;
|
||||
|
||||
if (use_out_source) {
|
||||
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
||||
}
|
||||
} else {
|
||||
|
@ -7,26 +7,10 @@
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h"
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
template [[host_name("steel_gemm_fused_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
|
||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
const device itype *C [[buffer(2), function_constant(use_out_source)]], \
|
||||
device itype *D [[buffer(3)]], \
|
||||
const constant GEMMParams* params [[buffer(4)]], \
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], \
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], \
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], \
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], \
|
||||
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], \
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
instantiate_kernel( \
|
||||
"steel_gemm_fused_" #tname "_" #iname "_" #oname \
|
||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||
gemm, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float)
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
|
@ -56,7 +56,7 @@ block_masked_gemm(
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const device out_mask_t* out_mask [[buffer(10)]],
|
||||
const device op_mask_t* lhs_mask [[buffer(11)]],
|
||||
const device op_mask_t* rhs_mask [[buffer(12)]],
|
||||
@ -104,7 +104,7 @@ block_masked_gemm(
|
||||
return;
|
||||
}
|
||||
|
||||
const constant size_t* mask_batch_strides =
|
||||
const constant auto* mask_batch_strides =
|
||||
batch_strides + 2 * params->batch_ndim;
|
||||
|
||||
if (params->batch_ndim > 1) {
|
||||
@ -116,8 +116,8 @@ block_masked_gemm(
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_lhs = mask_batch_strides;
|
||||
const constant size_t* mask_strides_rhs =
|
||||
const constant auto* mask_strides_lhs = mask_batch_strides;
|
||||
const constant auto* mask_strides_rhs =
|
||||
mask_strides_lhs + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
@ -144,8 +144,8 @@ block_masked_gemm(
|
||||
|
||||
// Adjust for batch
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
const constant auto* A_bstrides = batch_strides;
|
||||
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
@ -442,7 +442,7 @@ block_masked_gemm(
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const device bool* out_mask [[buffer(10)]],
|
||||
const device bool* lhs_mask [[buffer(11)]],
|
||||
const device bool* rhs_mask [[buffer(12)]],
|
||||
@ -476,15 +476,15 @@ block_masked_gemm(
|
||||
}
|
||||
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* mask_batch_strides =
|
||||
const constant auto* mask_batch_strides =
|
||||
batch_strides + 2 * params->batch_ndim;
|
||||
out_mask +=
|
||||
elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_lhs =
|
||||
const constant auto* mask_strides_lhs =
|
||||
mask_batch_strides + params->batch_ndim;
|
||||
const constant size_t* mask_strides_rhs =
|
||||
const constant auto* mask_strides_rhs =
|
||||
mask_strides_lhs + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
@ -507,8 +507,8 @@ block_masked_gemm(
|
||||
|
||||
// Adjust for batch
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
const constant auto* A_bstrides = batch_strides;
|
||||
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
@ -5,58 +5,45 @@
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
|
||||
|
||||
#define instantiate_gemm( \
|
||||
outmaskname, \
|
||||
outmasktype, \
|
||||
opmaskname, \
|
||||
opmasktype, \
|
||||
tname, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
iname, \
|
||||
itype, \
|
||||
oname, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
aname, \
|
||||
mn_aligned, \
|
||||
kname, \
|
||||
k_aligned) \
|
||||
template [[host_name("steel_gemm_block_outmask_" #outmaskname \
|
||||
"_opmask_" #opmaskname "_" #tname "_" #iname "_" #oname \
|
||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
|
||||
block_masked_gemm< \
|
||||
itype, \
|
||||
outmasktype, \
|
||||
opmasktype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
mn_aligned, \
|
||||
k_aligned>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device itype* D [[buffer(3)]], \
|
||||
const constant GEMMParams* params [[buffer(4)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
const device outmasktype* out_mask [[buffer(10)]], \
|
||||
const device opmasktype* lhs_mask [[buffer(11)]], \
|
||||
const device opmasktype* rhs_mask [[buffer(12)]], \
|
||||
const constant int* mask_strides [[buffer(13)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
#define instantiate_gemm( \
|
||||
outmaskname, \
|
||||
outmasktype, \
|
||||
opmaskname, \
|
||||
opmasktype, \
|
||||
tname, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
iname, \
|
||||
itype, \
|
||||
oname, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
aname, \
|
||||
mn_aligned, \
|
||||
kname, \
|
||||
k_aligned) \
|
||||
instantiate_kernel( \
|
||||
"steel_gemm_block_outmask_" #outmaskname \
|
||||
"_opmask_" #opmaskname "_" #tname "_" #iname "_" #oname \
|
||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||
"_MN_" #aname "_K_" #kname, \
|
||||
block_masked_gemm, \
|
||||
itype, \
|
||||
outmasktype, \
|
||||
opmasktype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
mn_aligned, \
|
||||
k_aligned)
|
||||
|
||||
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
instantiate_gemm(bool_, bool, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
|
@ -5,46 +5,39 @@
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h"
|
||||
|
||||
#define instantiate_gemm( \
|
||||
tname, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
iname, \
|
||||
itype, \
|
||||
oname, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
aname, \
|
||||
mn_aligned, \
|
||||
kname, \
|
||||
k_aligned) \
|
||||
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname \
|
||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
|
||||
gemm_splitk< \
|
||||
itype, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
mn_aligned, \
|
||||
k_aligned>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device otype* C [[buffer(2)]], \
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
#define instantiate_gemm( \
|
||||
tname, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
iname, \
|
||||
itype, \
|
||||
oname, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
aname, \
|
||||
mn_aligned, \
|
||||
kname, \
|
||||
k_aligned) \
|
||||
instantiate_kernel( \
|
||||
"steel_gemm_splitk_" #tname "_" #iname "_" #oname \
|
||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||
"_MN_" #aname "_K_" #kname, \
|
||||
gemm_splitk, \
|
||||
itype, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
mn_aligned, \
|
||||
k_aligned)
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
@ -68,30 +61,13 @@ instantiate_gemm_shapes_helper(float16, half, float32, float);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
|
||||
#define instantiate_accum(oname, otype, aname, atype) \
|
||||
template [[host_name("steel_gemm_splitk_accum_" #oname \
|
||||
"_" #aname)]] [[kernel]] void \
|
||||
gemm_splitk_accum<atype, otype>( \
|
||||
const device atype* C_split [[buffer(0)]], \
|
||||
device otype* D [[buffer(1)]], \
|
||||
const constant int& k_partitions [[buffer(2)]], \
|
||||
const constant int& partition_stride [[buffer(3)]], \
|
||||
const constant int& ldd [[buffer(4)]], \
|
||||
uint2 gid [[thread_position_in_grid]]); \
|
||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname \
|
||||
"_axbpy")]] [[kernel]] void \
|
||||
gemm_splitk_accum_axpby<atype, otype>( \
|
||||
const device atype* C_split [[buffer(0)]], \
|
||||
device otype* D [[buffer(1)]], \
|
||||
const constant int& k_partitions [[buffer(2)]], \
|
||||
const constant int& partition_stride [[buffer(3)]], \
|
||||
const constant int& ldd [[buffer(4)]], \
|
||||
const device otype* C [[buffer(5)]], \
|
||||
const constant int& ldc [[buffer(6)]], \
|
||||
const constant int& fdc [[buffer(7)]], \
|
||||
const constant float& alpha [[buffer(8)]], \
|
||||
const constant float& beta [[buffer(9)]], \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
#define instantiate_accum(oname, otype, aname, atype) \
|
||||
instantiate_kernel( \
|
||||
"steel_gemm_splitk_accum_" #oname "_" #aname, \
|
||||
gemm_splitk_accum, atype, otype) \
|
||||
instantiate_kernel( \
|
||||
"steel_gemm_splitk_accum_" #oname "_" #aname "_axbpy", \
|
||||
gemm_splitk_accum_axpby, atype, otype) \
|
||||
|
||||
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
||||
instantiate_accum(float16, half, float32, float);
|
||||
|
@ -21,9 +21,9 @@ struct GEMMParams {
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
|
||||
const size_t batch_stride_a;
|
||||
const size_t batch_stride_b;
|
||||
const size_t batch_stride_d;
|
||||
const int64_t batch_stride_a;
|
||||
const int64_t batch_stride_b;
|
||||
const int64_t batch_stride_d;
|
||||
|
||||
const int swizzle_log;
|
||||
const int gemm_k_iterations_aligned;
|
||||
@ -54,7 +54,7 @@ struct GEMMAddMMParams {
|
||||
const int ldc;
|
||||
const int fdc;
|
||||
|
||||
const size_t batch_stride_c;
|
||||
const int64_t batch_stride_c;
|
||||
|
||||
const float alpha;
|
||||
const float beta;
|
||||
|
@ -7,8 +7,8 @@
|
||||
METAL_FUNC ulong2 elem_to_loc_broadcast(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int64_t* a_strides,
|
||||
constant const int64_t* b_strides,
|
||||
int ndim) {
|
||||
ulong loc_a{0};
|
||||
ulong loc_b{0};
|
||||
@ -24,9 +24,9 @@ METAL_FUNC ulong2 elem_to_loc_broadcast(
|
||||
METAL_FUNC ulong3 elem_to_loc_broadcast(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
constant const int64_t* a_strides,
|
||||
constant const int64_t* b_strides,
|
||||
constant const int64_t* c_strides,
|
||||
int ndim) {
|
||||
ulong loc_a{0};
|
||||
ulong loc_b{0};
|
||||
|
@ -18,72 +18,72 @@ template <typename T, typename Op>
|
||||
device T* d,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
d[offset] = Op()(a[offset], b[offset], c[offset]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void ternary_g_nd1(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const size_t& a_strides,
|
||||
constant const size_t& b_strides,
|
||||
constant const size_t& c_strides,
|
||||
constant const int64_t& a_strides,
|
||||
constant const int64_t& b_strides,
|
||||
constant const int64_t& c_strides,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_1<size_t, IdxT>(index, c_strides);
|
||||
auto a_idx = elem_to_loc_1<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_1<IdxT>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_1<IdxT>(index, c_strides);
|
||||
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void ternary_g_nd2(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
constant const size_t c_strides[2],
|
||||
constant const int64_t a_strides[2],
|
||||
constant const int64_t b_strides[2],
|
||||
constant const int64_t c_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_2<size_t, IdxT>(index, c_strides);
|
||||
auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_2<IdxT>(index, c_strides);
|
||||
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op, typename IdxT = size_t>
|
||||
template <typename T, typename Op, typename IdxT = int64_t>
|
||||
[[kernel]] void ternary_g_nd3(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
constant const size_t c_strides[3],
|
||||
constant const int64_t a_strides[3],
|
||||
constant const int64_t b_strides[3],
|
||||
constant const int64_t c_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_3<size_t, IdxT>(index, c_strides);
|
||||
auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_3<IdxT>(index, c_strides);
|
||||
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op, int N = 1, typename IdxT = size_t>
|
||||
template <typename T, typename Op, int N = 1, typename IdxT = int64_t>
|
||||
[[kernel]] void ternary_g(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
constant const int64_t* a_strides,
|
||||
constant const int64_t* b_strides,
|
||||
constant const int64_t* c_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
|
@ -14,7 +14,7 @@ template <typename T, typename U, typename Op>
|
||||
device U* out,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
||||
out[offset] = Op()(in[offset]);
|
||||
}
|
||||
|
||||
@ -23,16 +23,16 @@ template <
|
||||
typename U,
|
||||
typename Op,
|
||||
int N = 1,
|
||||
typename IdxT = size_t>
|
||||
typename IdxT = int64_t>
|
||||
[[kernel]] void unary_g(
|
||||
device const T* in,
|
||||
device U* out,
|
||||
constant const int* in_shape,
|
||||
constant const size_t* in_strides,
|
||||
constant const int64_t* in_strides,
|
||||
device const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc<size_t, IdxT>(
|
||||
auto idx = elem_to_loc<IdxT>(
|
||||
{N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
|
||||
auto xshape = in_shape[ndim - 1];
|
||||
IdxT xstride = in_strides[ndim - 1];
|
||||
|
@ -89,11 +89,11 @@ struct Limits<complex64_t> {
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Single Array with generic dims
|
||||
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC IdxT elem_to_loc(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
constant const StrideT* strides,
|
||||
constant const int64_t* strides,
|
||||
int ndim) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
@ -103,11 +103,11 @@ METAL_FUNC IdxT elem_to_loc(
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC IdxT elem_to_loc(
|
||||
StrideT elem,
|
||||
int64_t elem,
|
||||
constant const int* shape,
|
||||
constant const StrideT* strides,
|
||||
constant const int64_t* strides,
|
||||
int ndim) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
@ -118,11 +118,11 @@ METAL_FUNC IdxT elem_to_loc(
|
||||
}
|
||||
|
||||
// Non templated version to handle arbitrary dims
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC IdxT elem_to_loc(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const StrideT* strides,
|
||||
constant const int64_t* strides,
|
||||
int ndim) {
|
||||
IdxT loc =
|
||||
elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
|
||||
@ -136,18 +136,18 @@ METAL_FUNC IdxT elem_to_loc(
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Single Array with fixed N dims
|
||||
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const StrideT& stride) {
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t& stride) {
|
||||
return elem * IdxT(stride);
|
||||
}
|
||||
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const StrideT strides[2]) {
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2]) {
|
||||
return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
|
||||
}
|
||||
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) {
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3]) {
|
||||
return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
|
||||
elem.z * IdxT(strides[0]);
|
||||
}
|
||||
@ -155,12 +155,12 @@ METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) {
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Multiple Arrays with generic dims
|
||||
|
||||
template <typename StrideT, typename IdxT = StrideT>
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const StrideT* a_strides,
|
||||
constant const StrideT* b_strides,
|
||||
constant const int64_t* a_strides,
|
||||
constant const int64_t* b_strides,
|
||||
int ndim) {
|
||||
vec<IdxT, 2> loc = {
|
||||
IdxT(
|
||||
@ -178,13 +178,13 @@ METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename IdxT = size_t>
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
constant const int64_t* a_strides,
|
||||
constant const int64_t* b_strides,
|
||||
constant const int64_t* c_strides,
|
||||
int ndim) {
|
||||
vec<IdxT, 3> loc = {
|
||||
elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]),
|
||||
@ -213,7 +213,7 @@ struct LoopedElemToLoc {
|
||||
|
||||
LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
|
||||
|
||||
void next(const constant int* shape, const constant size_t* strides) {
|
||||
void next(const constant int* shape, const constant int64_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
@ -226,7 +226,7 @@ struct LoopedElemToLoc {
|
||||
}
|
||||
}
|
||||
|
||||
void next(int n, const constant int* shape, const constant size_t* strides) {
|
||||
void next(int n, const constant int* shape, const constant int64_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
@ -262,19 +262,19 @@ struct LoopedElemToLoc<1, OffsetT, true> {
|
||||
|
||||
LoopedElemToLoc(int dim) : dim(dim) {}
|
||||
|
||||
void next(const constant int* shape, const constant size_t* strides) {
|
||||
void next(const constant int* shape, const constant int64_t* strides) {
|
||||
index++;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
|
||||
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
void next(int n, const constant int* shape, const constant size_t* strides) {
|
||||
void next(int n, const constant int* shape, const constant int64_t* strides) {
|
||||
index += n;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
|
||||
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset = index * OffsetT(strides[0]);
|
||||
}
|
||||
@ -291,11 +291,11 @@ struct LoopedElemToLoc<1, OffsetT, false> {
|
||||
|
||||
LoopedElemToLoc(int) {}
|
||||
|
||||
void next(const constant int*, const constant size_t* strides) {
|
||||
void next(const constant int*, const constant int64_t* strides) {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
void next(int n, const constant int*, const constant size_t* strides) {
|
||||
void next(int n, const constant int*, const constant int64_t* strides) {
|
||||
offset += n * OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
|
@ -21,8 +21,8 @@ namespace {
|
||||
|
||||
inline auto collapse_batches(const array& a, const array& b) {
|
||||
// Get and check the shape for the batched dims
|
||||
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
if (A_bshape != B_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||
@ -30,8 +30,8 @@ inline auto collapse_batches(const array& a, const array& b) {
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] =
|
||||
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
||||
@ -50,9 +50,9 @@ inline auto collapse_batches(const array& a, const array& b) {
|
||||
|
||||
inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
// Get and check the shape for the batched dims
|
||||
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
std::vector<int> C_bshape{c.shape().begin(), c.shape().end() - 2};
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
|
||||
if (A_bshape != B_bshape || A_bshape != C_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||
@ -60,9 +60,9 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
std::vector<size_t> C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
||||
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
||||
@ -82,6 +82,25 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
||||
}
|
||||
|
||||
std::tuple<bool, int64_t, array> check_transpose(
|
||||
std::vector<array>& copies,
|
||||
const Stream& s,
|
||||
const array& arr,
|
||||
bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -180,11 +199,11 @@ void steel_matmul_regular(
|
||||
int ldd,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<int> batch_shape,
|
||||
std::vector<size_t> batch_strides,
|
||||
size_t A_batch_stride,
|
||||
size_t B_batch_stride,
|
||||
size_t matrix_stride_out,
|
||||
Shape batch_shape,
|
||||
Strides batch_strides,
|
||||
int64_t A_batch_stride,
|
||||
int64_t B_batch_stride,
|
||||
int64_t matrix_stride_out,
|
||||
std::vector<array>& copies) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
@ -268,9 +287,9 @@ void steel_matmul_regular(
|
||||
/* const int ldd = */ ldd,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const size_t batch_stride_a = */ A_batch_stride,
|
||||
/* const size_t batch_stride_b = */ B_batch_stride,
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int64_t batch_stride_a = */ A_batch_stride,
|
||||
/* const int64_t batch_stride_b = */ B_batch_stride,
|
||||
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
@ -314,9 +333,9 @@ void steel_matmul(
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
std::vector<int> batch_shape /* = {} */,
|
||||
std::vector<size_t> A_batch_stride /* = {} */,
|
||||
std::vector<size_t> B_batch_stride /* = {} */) {
|
||||
Shape batch_shape /* = {} */,
|
||||
Strides A_batch_stride /* = {} */,
|
||||
Strides B_batch_stride /* = {} */) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
if (batch_shape.empty()) {
|
||||
@ -447,7 +466,7 @@ void steel_matmul(
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Regular kernel dispatch
|
||||
std::vector<size_t> batch_strides = A_batch_stride;
|
||||
auto batch_strides = A_batch_stride;
|
||||
batch_strides.insert(
|
||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
||||
|
||||
@ -505,24 +524,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [a_transposed, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [b_transposed, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
auto [a_transposed, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
||||
auto [b_transposed, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
@ -662,9 +665,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* bool transpose_a = */ a_transposed,
|
||||
/* bool transpose_b = */ b_transposed,
|
||||
/* std::vector<array>& = */ copies,
|
||||
/* std::vector<int> batch_shape = */ batch_shape,
|
||||
/* std::vector<size_t> A_batch_stride = */ A_batch_stride,
|
||||
/* std::vector<size_t> B_batch_stride = */ B_batch_stride);
|
||||
/* Shape batch_shape = */ batch_shape,
|
||||
/* Strides A_batch_stride = */ A_batch_stride,
|
||||
/* Strides B_batch_stride = */ B_batch_stride);
|
||||
}
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -691,24 +694,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
||||
|
||||
array c = c_pre;
|
||||
int ldc = c.strides()[c.ndim() - 2];
|
||||
@ -723,7 +710,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] =
|
||||
collapse_batches(a, b, c);
|
||||
|
||||
size_t matrix_stride_out = size_t(M) * size_t(N);
|
||||
int64_t matrix_stride_out = M * static_cast<int64_t>(N);
|
||||
auto batch_size_out = out.size() / (matrix_stride_out);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
@ -1044,9 +1031,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const size_t batch_stride_a = */ A_batch_stride.back(),
|
||||
/* const size_t batch_stride_b = */ B_batch_stride.back(),
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int64_t batch_stride_a = */ A_batch_stride.back(),
|
||||
/* const int64_t batch_stride_b = */ B_batch_stride.back(),
|
||||
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
@ -1054,7 +1041,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
GEMMAddMMParams params{
|
||||
/* const int ldc = */ ldc,
|
||||
/* const int fdc = */ fdc,
|
||||
/* const size_t batch_stride_c = */ C_batch_stride.back(),
|
||||
/* const int64_t batch_stride_c = */ C_batch_stride.back(),
|
||||
/* const float alpha = */ alpha_,
|
||||
/* const float beta = */ beta_};
|
||||
|
||||
@ -1065,7 +1052,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
|
||||
std::vector<size_t> batch_strides = A_batch_stride;
|
||||
Strides batch_strides = A_batch_stride;
|
||||
batch_strides.insert(
|
||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
||||
batch_strides.insert(
|
||||
@ -1120,24 +1107,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
@ -1156,20 +1127,20 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return decltype(v){v.begin(), v.end() - 2};
|
||||
};
|
||||
|
||||
std::vector<int> batch_shape{1};
|
||||
std::vector<size_t> A_batch_stride{0};
|
||||
std::vector<size_t> B_batch_stride{0};
|
||||
std::vector<size_t> outmask_bstride{0};
|
||||
std::vector<size_t> Amask_bstride{0};
|
||||
std::vector<size_t> Bmask_bstride{0};
|
||||
size_t A_batch_str = 0;
|
||||
size_t B_batch_str = 0;
|
||||
Shape batch_shape{1};
|
||||
Strides A_batch_stride{0};
|
||||
Strides B_batch_stride{0};
|
||||
Strides outmask_bstride{0};
|
||||
Strides Amask_bstride{0};
|
||||
Strides Bmask_bstride{0};
|
||||
int64_t A_batch_str = 0;
|
||||
int64_t B_batch_str = 0;
|
||||
|
||||
std::vector<size_t> batch_strides;
|
||||
Strides batch_strides;
|
||||
|
||||
if (out.ndim() > 2) {
|
||||
std::vector<int> bshape{out.shape().begin(), out.shape().end() - 2};
|
||||
std::vector<std::vector<size_t>> bstrides;
|
||||
Shape bshape{out.shape().begin(), out.shape().end() - 2};
|
||||
std::vector<Strides> bstrides;
|
||||
|
||||
for (auto& arr : inputs) {
|
||||
bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2);
|
||||
@ -1196,10 +1167,10 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
} else {
|
||||
batch_strides = std::vector<size_t>(inputs.size(), 0);
|
||||
batch_strides = Strides(inputs.size(), 0);
|
||||
}
|
||||
|
||||
size_t matrix_stride_out = size_t(M) * N;
|
||||
int64_t matrix_stride_out = static_cast<int64_t>(M) * N;
|
||||
size_t batch_size_out = out.size() / (matrix_stride_out);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
@ -1306,7 +1277,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Get mask params
|
||||
std::vector<int> mask_strides;
|
||||
std::vector<size_t> mask_batch_strides;
|
||||
Strides mask_batch_strides;
|
||||
if (has_out_mask) {
|
||||
auto& out_mask = inputs[2];
|
||||
|
||||
@ -1436,9 +1407,9 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const size_t batch_stride_a = */ A_batch_str,
|
||||
/* const size_t batch_stride_b = */ B_batch_str,
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int64_t batch_stride_a = */ A_batch_str,
|
||||
/* const int64_t batch_stride_b = */ B_batch_str,
|
||||
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
@ -1524,24 +1495,8 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
@ -1556,20 +1511,20 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& lhs_indices = inputs[2];
|
||||
auto& rhs_indices = inputs[3];
|
||||
|
||||
std::vector<int> batch_shape = get_batch_dims(out.shape());
|
||||
std::vector<size_t> batch_strides;
|
||||
Shape batch_shape = get_batch_dims(out.shape());
|
||||
Strides batch_strides;
|
||||
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
lhs_indices.strides().begin(),
|
||||
lhs_indices.strides().end());
|
||||
size_t lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
rhs_indices.strides().begin(),
|
||||
rhs_indices.strides().end());
|
||||
size_t rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
@ -1582,10 +1537,10 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int batch_ndim_B = b.ndim() - 2;
|
||||
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B};
|
||||
|
||||
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
|
||||
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
|
||||
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
|
||||
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
|
||||
Shape batch_shape_A = get_batch_dims(a.shape());
|
||||
Strides batch_strides_A = get_batch_dims(a.strides());
|
||||
Shape batch_shape_B = get_batch_dims(b.shape());
|
||||
Strides batch_strides_B = get_batch_dims(b.strides());
|
||||
|
||||
if (batch_ndim_A == 0) {
|
||||
batch_shape_A = {1};
|
||||
@ -1597,7 +1552,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
batch_strides_B = {0};
|
||||
}
|
||||
|
||||
size_t matrix_stride_out = size_t(M) * N;
|
||||
auto matrix_stride_out = static_cast<int64_t>(M) * N;
|
||||
auto batch_size_out = out.size() / matrix_stride_out;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
@ -1801,9 +1756,9 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const size_t batch_stride_a = */ lhs_indices_str,
|
||||
/* const size_t batch_stride_b = */ rhs_indices_str,
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int64_t batch_stride_a = */ lhs_indices_str,
|
||||
/* const int64_t batch_stride_b = */ rhs_indices_str,
|
||||
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ batch_ndim};
|
||||
|
@ -21,11 +21,11 @@ void steel_matmul_regular(
|
||||
int ldd,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<int> batch_shape,
|
||||
std::vector<size_t> batch_strides,
|
||||
size_t A_batch_stride,
|
||||
size_t B_batch_stride,
|
||||
size_t matrix_stride_out,
|
||||
Shape batch_shape,
|
||||
Strides batch_strides,
|
||||
int64_t A_batch_stride,
|
||||
int64_t B_batch_stride,
|
||||
int64_t matrix_stride_out,
|
||||
std::vector<array>& copies);
|
||||
|
||||
void steel_matmul(
|
||||
@ -43,8 +43,8 @@ void steel_matmul(
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
std::vector<int> batch_shape = {},
|
||||
std::vector<size_t> A_batch_stride = {},
|
||||
std::vector<size_t> B_batch_stride = {});
|
||||
Shape batch_shape = {},
|
||||
Strides A_batch_stride = {},
|
||||
Strides B_batch_stride = {});
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
@ -101,10 +102,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Prepare the shapes, strides and axis arguments.
|
||||
std::vector<size_t> in_strides = in.strides();
|
||||
std::vector<int> shape = in.shape();
|
||||
std::vector<size_t> out_strides = out.strides();
|
||||
size_t axis_stride = in_strides[axis_];
|
||||
auto in_strides = in.strides();
|
||||
auto shape = in.shape();
|
||||
auto out_strides = out.strides();
|
||||
auto axis_stride = in_strides[axis_];
|
||||
size_t axis_size = shape[axis_];
|
||||
if (out_strides.size() == in_strides.size()) {
|
||||
out_strides.erase(out_strides.begin() + axis_);
|
||||
@ -136,7 +137,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (ndim == 0) {
|
||||
// Pass place holders so metal doesn't complain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
int64_t stride_ = 0;
|
||||
compute_encoder.set_bytes(shape_, 2);
|
||||
compute_encoder.set_bytes(stride_, 3);
|
||||
compute_encoder.set_bytes(stride_, 4);
|
||||
@ -311,13 +312,12 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto out_strides = make_contiguous_strides<size_t>(in.shape());
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
out_strides,
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
@ -366,16 +366,15 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] = prepare_slice(out);
|
||||
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
|
||||
copy_gpu_inplace<int64_t>(
|
||||
copy_gpu_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ upd_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ out_strides,
|
||||
/* const Shape& data_shape = */ upd.shape(),
|
||||
/* const Strides& i_strides = */ upd.strides(),
|
||||
/* const Strides& o_strides = */ out_strides,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ data_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
|
@ -18,13 +18,13 @@ namespace {
|
||||
|
||||
struct RowReduceArgs {
|
||||
// Input shape and strides not including the reduction axes
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
int ndim;
|
||||
|
||||
// Input shape and strides for the reduction axes
|
||||
std::vector<int> reduce_shape;
|
||||
std::vector<size_t> reduce_strides;
|
||||
Shape reduce_shape;
|
||||
Strides reduce_strides;
|
||||
int reduce_ndim;
|
||||
|
||||
// The number of rows we are reducing. Namely prod(reduce_shape).
|
||||
@ -88,13 +88,13 @@ struct RowReduceArgs {
|
||||
|
||||
struct ColReduceArgs {
|
||||
// Input shape and strides not including the reduction axes
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
int ndim;
|
||||
|
||||
// Input shape and strides for the reduction axes
|
||||
std::vector<int> reduce_shape;
|
||||
std::vector<size_t> reduce_strides;
|
||||
Shape reduce_shape;
|
||||
Strides reduce_strides;
|
||||
int reduce_ndim;
|
||||
|
||||
// The number of column reductions we are doing. Namely prod(reduce_shape).
|
||||
@ -102,7 +102,7 @@ struct ColReduceArgs {
|
||||
|
||||
// The size of the contiguous column reduction.
|
||||
size_t reduction_size;
|
||||
size_t reduction_stride;
|
||||
int64_t reduction_stride;
|
||||
|
||||
ColReduceArgs(
|
||||
const array& in,
|
||||
@ -126,7 +126,7 @@ struct ColReduceArgs {
|
||||
// yet we may have removed the appropriate amount of elements. It is safe
|
||||
// to compute the stride by multiplying shapes (while < reduction_stride)
|
||||
// because it is a contiguous section.
|
||||
size_t stride_back = 1;
|
||||
int64_t stride_back = 1;
|
||||
std::tie(shape, strides) = shapes_without_reduction_axes(in, axes);
|
||||
while (!shape.empty() && stride_back < reduction_stride) {
|
||||
stride_back *= shape.back();
|
||||
@ -683,7 +683,7 @@ void strided_reduce_longcolumn(
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
large ? "int64_t" : "uint",
|
||||
n);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@ -718,7 +718,7 @@ void strided_reduce_longcolumn(
|
||||
op_name,
|
||||
intermediate.dtype(),
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
large ? "int64_t" : "uint",
|
||||
1,
|
||||
32,
|
||||
32);
|
||||
@ -782,7 +782,7 @@ void strided_reduce_looped(
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
large ? "int64_t" : "uint",
|
||||
n,
|
||||
BM,
|
||||
BN);
|
||||
@ -859,7 +859,7 @@ void strided_reduce_2pass(
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
large ? "int64_t" : "uint",
|
||||
n,
|
||||
BM,
|
||||
BN);
|
||||
@ -894,7 +894,7 @@ void strided_reduce_2pass(
|
||||
op_name,
|
||||
intermediate.dtype(),
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
large ? "int64_t" : "uint",
|
||||
1,
|
||||
32,
|
||||
32);
|
||||
|
@ -50,17 +50,17 @@ void sdpa_full_self_attention_metal(
|
||||
|
||||
std::ostringstream kname;
|
||||
// clang-format off
|
||||
kname << "steel_attention_"
|
||||
<< type_to_name(q)
|
||||
<< "_bq" << bq
|
||||
kname << "steel_attention_"
|
||||
<< type_to_name(q)
|
||||
<< "_bq" << bq
|
||||
<< "_bk" << bk
|
||||
<< "_bd" << bd
|
||||
<< "_bd" << bd
|
||||
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
// clang-format off
|
||||
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
||||
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
@ -92,10 +92,10 @@ void sdpa_full_self_attention_metal(
|
||||
/* int NQ_aligned = */ NQ_aligned,
|
||||
/* int NK_aligned = */ NK_aligned,
|
||||
|
||||
/* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||
/* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||
/* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||
/* size_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
|
||||
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
|
||||
|
||||
compute_encoder.set_input_array(q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
@ -175,13 +175,13 @@ void sdpa_vector_2pass(
|
||||
int N = k.shape(2);
|
||||
int blocks = 32;
|
||||
int B = q.shape(0) * q.shape(1);
|
||||
size_t k_stride = k.strides()[1];
|
||||
size_t v_stride = v.strides()[1];
|
||||
auto k_stride = k.strides()[1];
|
||||
auto v_stride = v.strides()[1];
|
||||
MTL::Size group_dims(8 * 32, 1, 1);
|
||||
MTL::Size grid_dims(1, B, blocks);
|
||||
|
||||
// Allocate the intermediates
|
||||
std::vector<int> intermediate_shape;
|
||||
Shape intermediate_shape;
|
||||
intermediate_shape.reserve(out.ndim() + 1);
|
||||
intermediate_shape.insert(
|
||||
intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1);
|
||||
@ -324,10 +324,10 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
||||
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
||||
|
||||
size_t str_oD = 1;
|
||||
size_t str_oH = o.shape(3);
|
||||
size_t str_oL = o.shape(1) * str_oH;
|
||||
size_t str_oB = o.shape(2) * str_oL;
|
||||
int64_t str_oD = 1;
|
||||
int64_t str_oH = o.shape(3);
|
||||
int64_t str_oL = o.shape(1) * str_oH;
|
||||
int64_t str_oB = o.shape(2) * str_oL;
|
||||
size_t data_size = o.shape(0) * str_oB;
|
||||
|
||||
array::Flags flags{
|
||||
|
@ -11,29 +11,28 @@ namespace mlx::core {
|
||||
void slice_gpu(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& strides,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides,
|
||||
const Stream& s) {
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [copy_needed, data_offset, inp_strides] =
|
||||
prepare_slice(in, start_indices, strides);
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
||||
auto copy_needed =
|
||||
std::any_of(strides.begin(), strides.end(), [](auto i) { return i < 0; });
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
copy_gpu_inplace(
|
||||
/* const array& in = */ in,
|
||||
/* array& out = */ out,
|
||||
/* const std::vector<int>& data_shape = */ out.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ inp_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ ostrides,
|
||||
/* const std::vector<stride_t>& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ data_offset,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::General,
|
||||
/* const Stream& s = */ s);
|
||||
} else {
|
||||
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < strides.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
@ -42,7 +41,7 @@ void slice_gpu(
|
||||
}
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
|
||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -9,8 +9,8 @@ namespace mlx::core {
|
||||
void slice_gpu(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& strides,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides,
|
||||
const Stream& s);
|
||||
|
||||
void concatenate_gpu(
|
||||
|
@ -24,13 +24,13 @@ void single_block_sort(
|
||||
// Prepare shapes
|
||||
int n_rows = in.size() / in.shape(axis);
|
||||
|
||||
std::vector<size_t> in_nc_str = in.strides();
|
||||
auto in_nc_str = in.strides();
|
||||
in_nc_str.erase(in_nc_str.begin() + axis);
|
||||
|
||||
std::vector<size_t> out_nc_str = out.strides();
|
||||
auto out_nc_str = out.strides();
|
||||
out_nc_str.erase(out_nc_str.begin() + axis);
|
||||
|
||||
std::vector<int> nc_shape = in.shape();
|
||||
auto nc_shape = in.shape();
|
||||
nc_shape.erase(nc_shape.begin() + axis);
|
||||
|
||||
int nc_dim = nc_shape.size();
|
||||
@ -106,10 +106,10 @@ void multi_block_sort(
|
||||
// Prepare shapes
|
||||
int n_rows = in.size() / in.shape(axis);
|
||||
|
||||
std::vector<size_t> nc_str = in.strides();
|
||||
auto nc_str = in.strides();
|
||||
nc_str.erase(nc_str.begin() + axis);
|
||||
|
||||
std::vector<int> nc_shape = in.shape();
|
||||
auto nc_shape = in.shape();
|
||||
nc_shape.erase(nc_shape.begin() + axis);
|
||||
|
||||
int nc_dim = nc_shape.size();
|
||||
|
@ -30,8 +30,8 @@ void ternary_op_gpu_inplace(
|
||||
return std::make_tuple(
|
||||
shape, strides[0], strides[1], strides[2], strides[3]);
|
||||
} else {
|
||||
std::vector<size_t> e;
|
||||
return std::make_tuple(std::vector<int>{}, e, e, e, e);
|
||||
Strides e;
|
||||
return std::make_tuple(Shape{}, e, e, e, e);
|
||||
}
|
||||
};
|
||||
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
|
||||
|
@ -30,7 +30,7 @@ void unary_op_gpu_inplace(
|
||||
if (!contig) {
|
||||
return collapse_contiguous_dims(in);
|
||||
} else {
|
||||
return std::make_pair(std::vector<int>{}, std::vector<size_t>{});
|
||||
return std::make_pair(Shape{}, Strides{});
|
||||
}
|
||||
};
|
||||
auto [shape, strides] = maybe_collapse();
|
||||
|
@ -87,9 +87,7 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
||||
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
||||
}
|
||||
|
||||
MTL::Size get_2d_grid_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides) {
|
||||
// Dims with strides of 0 are ignored as they
|
||||
// correspond to broadcasted dimensions
|
||||
size_t grid_x = 1;
|
||||
@ -114,10 +112,8 @@ MTL::Size get_2d_grid_dims(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
MTL::Size get_2d_grid_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides,
|
||||
size_t divisor) {
|
||||
MTL::Size
|
||||
get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor) {
|
||||
// Compute the 2d grid dimensions such that the total size of the grid is
|
||||
// divided by divisor.
|
||||
size_t grid_x = 1;
|
||||
|
@ -22,17 +22,13 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
||||
// - shape and strides correspond to a contiguous (no holes) but
|
||||
// possibly broadcasted array
|
||||
MTL::Size get_2d_grid_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides);
|
||||
MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides);
|
||||
|
||||
// Same as above but we do an implicit division with divisor.
|
||||
// Basically, equivalent to factorizing
|
||||
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
|
||||
MTL::Size get_2d_grid_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides,
|
||||
size_t divisor);
|
||||
MTL::Size
|
||||
get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor);
|
||||
|
||||
inline NS::String* make_string(std::ostringstream& os) {
|
||||
std::string string = os.str();
|
||||
|
@ -381,7 +381,7 @@ array batch_tensordot(
|
||||
size2 *= x.shape(s);
|
||||
}
|
||||
|
||||
std::vector<int> shape;
|
||||
Shape shape;
|
||||
for (auto ax : i) {
|
||||
shape.push_back(x.shape(ax));
|
||||
}
|
||||
@ -391,7 +391,7 @@ array batch_tensordot(
|
||||
return reshape(transpose(x, reorder, s), std::move(shape), s);
|
||||
};
|
||||
|
||||
std::vector<int> out_shape;
|
||||
Shape out_shape;
|
||||
for (auto ax : a_batch) {
|
||||
out_shape.push_back(a.shape(ax));
|
||||
}
|
||||
@ -455,7 +455,7 @@ array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) {
|
||||
axes.push_back(i);
|
||||
}
|
||||
}
|
||||
std::vector<int> idx_shape(n_expand--, 1);
|
||||
Shape idx_shape(n_expand--, 1);
|
||||
idx_shape[0] = in.shape(axes.back());
|
||||
auto idx = reshape(arange(in.shape(axes.back()), s), idx_shape, s);
|
||||
for (int i = 0; i < v; ++i) {
|
||||
|
@ -1014,7 +1014,7 @@ std::string write_signature(
|
||||
}
|
||||
if (shape_infos[i].strides) {
|
||||
kernel_source +=
|
||||
(" const constant size_t* " + name + "_strides [[buffer(" +
|
||||
(" const constant int64_t* " + name + "_strides [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
@ -1144,7 +1144,7 @@ MetalKernelFunction metal_kernel(
|
||||
shape_infos = std::move(shape_infos),
|
||||
attributes = std::move(attributes)](
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<std::vector<int>>& output_shapes,
|
||||
const std::vector<Shape>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
|
19
mlx/fft.cpp
19
mlx/fft.cpp
@ -12,7 +12,7 @@ namespace mlx::core::fft {
|
||||
|
||||
array fft_impl(
|
||||
const array& a,
|
||||
std::vector<int> n,
|
||||
Shape n,
|
||||
const std::vector<int>& axes,
|
||||
bool real,
|
||||
bool inverse,
|
||||
@ -59,7 +59,7 @@ array fft_impl(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
std::vector<int> in_shape = a.shape();
|
||||
auto in_shape = a.shape();
|
||||
for (int i = 0; i < valid_axes.size(); ++i) {
|
||||
in_shape[valid_axes[i]] = n[i];
|
||||
}
|
||||
@ -76,13 +76,12 @@ array fft_impl(
|
||||
|
||||
auto in = a;
|
||||
if (any_less) {
|
||||
in = slice(in, std::vector<int>(in.ndim(), 0), in_shape, s);
|
||||
in = slice(in, Shape(in.ndim(), 0), in_shape, s);
|
||||
}
|
||||
if (any_greater) {
|
||||
// Pad with zeros
|
||||
auto tmp = zeros(in_shape, a.dtype(), s);
|
||||
std::vector<int> starts(in.ndim(), 0);
|
||||
in = slice_update(tmp, in, starts, in.shape());
|
||||
in = slice_update(tmp, in, Shape(in.ndim(), 0), in.shape());
|
||||
}
|
||||
|
||||
auto out_shape = in_shape;
|
||||
@ -106,7 +105,7 @@ array fft_impl(
|
||||
bool real,
|
||||
bool inverse,
|
||||
StreamOrDevice s) {
|
||||
std::vector<int> n;
|
||||
Shape n;
|
||||
for (auto ax : axes) {
|
||||
n.push_back(a.shape(ax));
|
||||
}
|
||||
@ -124,7 +123,7 @@ array fft_impl(const array& a, bool real, bool inverse, StreamOrDevice s) {
|
||||
|
||||
array fftn(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const Shape& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fft_impl(a, n, axes, false, false, s);
|
||||
@ -141,7 +140,7 @@ array fftn(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array ifftn(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const Shape& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fft_impl(a, n, axes, false, true, s);
|
||||
@ -158,7 +157,7 @@ array ifftn(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array rfftn(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const Shape& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fft_impl(a, n, axes, true, false, s);
|
||||
@ -175,7 +174,7 @@ array rfftn(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array irfftn(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const Shape& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fft_impl(a, n, axes, true, true, s);
|
||||
|
16
mlx/fft.h
16
mlx/fft.h
@ -13,7 +13,7 @@ namespace mlx::core::fft {
|
||||
/** Compute the n-dimensional Fourier Transform. */
|
||||
array fftn(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const Shape& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
array fftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});
|
||||
@ -22,7 +22,7 @@ array fftn(const array& a, StreamOrDevice s = {});
|
||||
/** Compute the n-dimensional inverse Fourier Transform. */
|
||||
array ifftn(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const Shape& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
array ifftn(
|
||||
@ -50,7 +50,7 @@ inline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
||||
/** Compute the two-dimensional Fourier Transform. */
|
||||
inline array fft2(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const Shape& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {}) {
|
||||
return fftn(a, n, axes, s);
|
||||
@ -65,7 +65,7 @@ inline array fft2(
|
||||
/** Compute the two-dimensional inverse Fourier Transform. */
|
||||
inline array ifft2(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const Shape& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {}) {
|
||||
return ifftn(a, n, axes, s);
|
||||
@ -80,7 +80,7 @@ inline array ifft2(
|
||||
/** Compute the n-dimensional Fourier Transform on a real input. */
|
||||
array rfftn(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const Shape& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
array rfftn(
|
||||
@ -92,7 +92,7 @@ array rfftn(const array& a, StreamOrDevice s = {});
|
||||
/** Compute the n-dimensional inverse of `rfftn`. */
|
||||
array irfftn(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const Shape& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
array irfftn(
|
||||
@ -119,7 +119,7 @@ inline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
||||
/** Compute the two-dimensional Fourier Transform on a real input. */
|
||||
inline array rfft2(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const Shape& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {}) {
|
||||
return rfftn(a, n, axes, s);
|
||||
@ -134,7 +134,7 @@ inline array rfft2(
|
||||
/** Compute the two-dimensional inverse of `rfft2`. */
|
||||
inline array irfft2(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const Shape& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {}) {
|
||||
return irfftn(a, n, axes, s);
|
||||
|
@ -138,7 +138,7 @@ SafetensorsLoad load_safetensors(
|
||||
continue;
|
||||
}
|
||||
const std::string& dtype = item.value().at("dtype");
|
||||
const std::vector<int>& shape = item.value().at("shape");
|
||||
const Shape& shape = item.value().at("shape");
|
||||
const std::vector<size_t>& data_offsets = item.value().at("data_offsets");
|
||||
Dtype type = dtype_from_safetensor_str(dtype);
|
||||
auto loaded_array = array(
|
||||
|
22
mlx/ops.cpp
22
mlx/ops.cpp
@ -856,14 +856,7 @@ array concatenate(
|
||||
"[concatenate] No arrays provided for concatenation");
|
||||
}
|
||||
|
||||
// Normalize the given axis
|
||||
auto ax = axis < 0 ? axis + arrays[0].ndim() : axis;
|
||||
if (ax < 0 || ax >= arrays[0].ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[concatenate] Invalid axis (" << axis << ") passed to concatenate"
|
||||
<< " for array with shape " << arrays[0].shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto ax = normalize_axis_index(axis, arrays[0].ndim(), "[concatenate] ");
|
||||
|
||||
auto throw_invalid_shapes = [&]() {
|
||||
std::ostringstream msg;
|
||||
@ -925,12 +918,15 @@ array stack(
|
||||
int axis,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (arrays.empty()) {
|
||||
throw std::invalid_argument("No arrays provided for stacking");
|
||||
throw std::invalid_argument("[stack] No arrays provided for stacking");
|
||||
}
|
||||
if (!is_same_shape(arrays)) {
|
||||
throw std::invalid_argument("All arrays must have the same shape");
|
||||
if (!std::all_of(arrays.begin(), arrays.end(), [&](const auto& a) {
|
||||
return arrays[0].shape() == a.shape();
|
||||
})) {
|
||||
throw std::invalid_argument("[stack] All arrays must have the same shape");
|
||||
}
|
||||
int normalized_axis = normalize_axis(axis, arrays[0].ndim() + 1);
|
||||
auto normalized_axis =
|
||||
normalize_axis_index(axis, arrays[0].ndim() + 1, "[stack] ");
|
||||
std::vector<array> new_arrays;
|
||||
new_arrays.reserve(arrays.size());
|
||||
for (auto& a : arrays) {
|
||||
@ -945,7 +941,7 @@ array stack(const std::vector<array>& arrays, StreamOrDevice s /* = {} */) {
|
||||
|
||||
/** array repeat with axis */
|
||||
array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {
|
||||
axis = normalize_axis(axis, arr.ndim());
|
||||
axis = normalize_axis_index(axis, arr.ndim(), "[repeat] ");
|
||||
|
||||
if (repeats < 0) {
|
||||
throw std::invalid_argument(
|
||||
|
@ -144,8 +144,7 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> Primitive::output_shapes(
|
||||
const std::vector<array>&) {
|
||||
std::vector<Shape> Primitive::output_shapes(const std::vector<array>&) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Primitive::output_shapes] ";
|
||||
this->print(msg);
|
||||
@ -969,7 +968,7 @@ array conv_weight_backward_patches(
|
||||
}
|
||||
|
||||
// padded strides (contiguous)
|
||||
std::vector<size_t> in_padded_strides(in.ndim(), 1);
|
||||
Strides in_padded_strides(in.ndim(), 1);
|
||||
for (int i = in.ndim() - 2; i >= 0; --i) {
|
||||
in_padded_strides[i] = in_padded_strides[i + 1] * in_padded_shape[i + 1];
|
||||
}
|
||||
@ -984,14 +983,13 @@ array conv_weight_backward_patches(
|
||||
|
||||
// patches are shaped as
|
||||
// (batch_dim, out_spatial_dims, weight_spatial_dims, in_channels)
|
||||
std::vector<int> patches_shape{
|
||||
cotan.shape().begin(), cotan.shape().end() - 1};
|
||||
Shape patches_shape{cotan.shape().begin(), cotan.shape().end() - 1};
|
||||
patches_shape.insert(
|
||||
patches_shape.end(), wt.shape().begin() + 1, wt.shape().end());
|
||||
|
||||
// Resolve patch strides
|
||||
int n_spatial_dim = in.ndim() - 2;
|
||||
std::vector<size_t> patches_strides(patches_shape.size(), 1);
|
||||
Strides patches_strides(patches_shape.size(), 1);
|
||||
patches_strides[0] = in_padded_strides[0];
|
||||
for (int i = 1; i < n_spatial_dim + 1; i++) {
|
||||
patches_strides[i] = in_padded_strides[i] * kernel_strides[i - 1];
|
||||
@ -1095,8 +1093,8 @@ std::vector<array> Convolution::vjp(
|
||||
|
||||
// Handle negative padding
|
||||
if (has_neg_padding) {
|
||||
std::vector<int> starts(grad.ndim(), 0);
|
||||
std::vector<int> stops = grad.shape();
|
||||
Shape starts(grad.ndim(), 0);
|
||||
auto stops = grad.shape();
|
||||
|
||||
for (int i = 0; i < grad.ndim() - 2; i++) {
|
||||
if (padding_lo[i] < 0) {
|
||||
|
@ -1917,8 +1917,6 @@ class SliceUpdate : public UnaryPrimitive {
|
||||
std::vector<int> strides_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> prepare_slice(const array& in);
|
||||
};
|
||||
|
||||
class Softmax : public UnaryPrimitive {
|
||||
|
@ -34,7 +34,7 @@ array key(uint64_t seed) {
|
||||
}
|
||||
|
||||
array bits(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
int width /* 4 */,
|
||||
const std::optional<array>& key_ /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
@ -45,7 +45,7 @@ array bits(
|
||||
<< ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (key.shape() != std::vector<int>{2}) {
|
||||
if (key.shape() != Shape{2}) {
|
||||
std::ostringstream msg;
|
||||
msg << "[bits] Expected key shape (2) but received " << key.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
@ -118,7 +118,7 @@ array above_minus_one_with_default(Dtype dtype) {
|
||||
array uniform(
|
||||
const array& low,
|
||||
const array& high,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype /* = float32 */,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
@ -168,7 +168,7 @@ array uniform(
|
||||
}
|
||||
|
||||
array uniform(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
@ -177,7 +177,7 @@ array uniform(
|
||||
}
|
||||
|
||||
array normal(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const float loc /* = 0.0 */,
|
||||
const float scale /* = 1.0 */,
|
||||
@ -201,7 +201,7 @@ array normal(
|
||||
array multivariate_normal(
|
||||
const array& mean,
|
||||
const array& cov,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const std::optional<array>& key /* = nullopt */,
|
||||
StreamOrDevice s) {
|
||||
@ -234,12 +234,9 @@ array multivariate_normal(
|
||||
}
|
||||
|
||||
// Compute output shape
|
||||
std::vector<int> truncated_output_shape;
|
||||
|
||||
auto truncated_mean_shape =
|
||||
std::vector<int>(mean.shape().begin(), mean.shape().end() - 1);
|
||||
auto truncated_cov_shape =
|
||||
std::vector<int>(cov.shape().begin(), cov.shape().end() - 2);
|
||||
Shape(mean.shape().begin(), mean.shape().end() - 1);
|
||||
auto truncated_cov_shape = Shape(cov.shape().begin(), cov.shape().end() - 2);
|
||||
auto output_shape =
|
||||
broadcast_shapes(truncated_cov_shape, truncated_mean_shape);
|
||||
output_shape = broadcast_shapes(output_shape, shape);
|
||||
@ -269,7 +266,7 @@ array multivariate_normal(
|
||||
array randint(
|
||||
const array& low,
|
||||
const array& high,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype /* = int32 */,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
@ -283,7 +280,7 @@ array randint(
|
||||
|
||||
array bernoulli(
|
||||
const array& p,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!issubdtype(p.dtype(), floating)) {
|
||||
@ -322,7 +319,7 @@ array bernoulli(
|
||||
array truncated_normal(
|
||||
const array& lower,
|
||||
const array& upper,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype /* = float32 */,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
@ -357,7 +354,7 @@ array truncated_normal(
|
||||
}
|
||||
|
||||
array gumbel(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype /* = float32 */,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
@ -380,7 +377,7 @@ int get_valid_axis(int axis, int ndim) {
|
||||
array categorical_impl(
|
||||
const array& logits,
|
||||
int axis,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s) {
|
||||
auto gumbel_shape = shape;
|
||||
@ -393,7 +390,7 @@ array categorical_impl(
|
||||
array categorical(
|
||||
const array& logits,
|
||||
int axis,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Validate and normalize axis
|
||||
@ -439,7 +436,7 @@ array categorical(
|
||||
}
|
||||
|
||||
array laplace(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const float loc /* = 0.0 */,
|
||||
const float scale /* = 1.0 */,
|
||||
|
44
mlx/random.h
44
mlx/random.h
@ -42,12 +42,12 @@ void seed(uint64_t seed);
|
||||
|
||||
/** Generate an array with type uint32 filled with random bits. */
|
||||
array bits(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
int width,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
inline array bits(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return bits(shape, 4, key, s);
|
||||
@ -63,7 +63,7 @@ array split(const array& key, int num, StreamOrDevice s = {});
|
||||
array uniform(
|
||||
const array& low,
|
||||
const array& high,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype = float32,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
@ -72,7 +72,7 @@ template <typename T, typename U>
|
||||
array uniform(
|
||||
T low,
|
||||
U high,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype = float32,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
@ -81,12 +81,12 @@ array uniform(
|
||||
|
||||
/** Generate uniform random numbers between 0 and 1. */
|
||||
array uniform(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
inline array uniform(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return uniform(shape, float32, key);
|
||||
@ -94,14 +94,14 @@ inline array uniform(
|
||||
|
||||
/** Generate samples from the standard normal distribution. */
|
||||
array normal(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
inline array normal(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
@ -109,14 +109,14 @@ inline array normal(
|
||||
return normal(shape, float32, loc, scale, key, s);
|
||||
}
|
||||
inline array normal(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const Dtype dtype,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, dtype, 0.0, 1.0, key, s);
|
||||
}
|
||||
inline array normal(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, float32, 0.0, 1.0, key, s);
|
||||
@ -126,7 +126,7 @@ inline array normal(
|
||||
array multivariate_normal(
|
||||
const array& mean,
|
||||
const array& cov,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
@ -135,7 +135,7 @@ array multivariate_normal(
|
||||
array randint(
|
||||
const array& low,
|
||||
const array& high,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype = int32,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
@ -144,7 +144,7 @@ template <typename T, typename U>
|
||||
array randint(
|
||||
T low,
|
||||
U high,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype = int32,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
@ -154,7 +154,7 @@ array randint(
|
||||
/** Generate binary variables with probability to be true equal to p */
|
||||
array bernoulli(
|
||||
const array& p,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
array bernoulli(
|
||||
@ -173,7 +173,7 @@ array bernoulli(
|
||||
template <typename T>
|
||||
array bernoulli(
|
||||
T p,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return bernoulli(array(p), shape, key, s);
|
||||
@ -186,7 +186,7 @@ array bernoulli(
|
||||
array truncated_normal(
|
||||
const array& lower,
|
||||
const array& upper,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype = float32,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
@ -199,7 +199,7 @@ array truncated_normal(
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array gumbel(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype = float32,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
@ -207,7 +207,7 @@ array gumbel(
|
||||
array categorical(
|
||||
const array& logits,
|
||||
int axis,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
@ -226,14 +226,14 @@ array categorical(
|
||||
|
||||
/** Generate samples from the laplace distribution. */
|
||||
array laplace(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
inline array laplace(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
@ -241,14 +241,14 @@ inline array laplace(
|
||||
return laplace(shape, float32, loc, scale, key, s);
|
||||
}
|
||||
inline array laplace(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const Dtype dtype,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return laplace(shape, dtype, 0.0, 1.0, key, s);
|
||||
}
|
||||
inline array laplace(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return laplace(shape, float32, 0.0, 1.0, key, s);
|
||||
|
@ -681,7 +681,7 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
||||
std::vector<array> s_inputs;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
if (in_axes[i] != -1) {
|
||||
std::vector<int> shape = inputs[i].shape();
|
||||
auto shape = inputs[i].shape();
|
||||
shape.erase(shape.begin() + in_axes[i]);
|
||||
array in(shape, inputs[i].dtype(), nullptr, {});
|
||||
s_inputs.push_back(in);
|
||||
@ -924,7 +924,7 @@ std::function<std::vector<array>(const std::vector<array>&)> custom_function(
|
||||
: default_stream(default_device());
|
||||
|
||||
// Make the output info
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Shape> shapes;
|
||||
std::vector<Dtype> dtypes;
|
||||
for (const auto& out : outputs) {
|
||||
shapes.emplace_back(out.shape());
|
||||
|
@ -98,29 +98,17 @@ Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
bool is_same_shape(const std::vector<array>& arrays) {
|
||||
if (arrays.empty()) {
|
||||
return true;
|
||||
}
|
||||
return std::all_of(arrays.begin() + 1, arrays.end(), [&](const array& a) {
|
||||
return (a.shape() == arrays[0].shape());
|
||||
});
|
||||
}
|
||||
|
||||
int normalize_axis(int axis, int ndim) {
|
||||
if (ndim <= 0) {
|
||||
throw std::invalid_argument("Number of dimensions must be positive.");
|
||||
}
|
||||
int normalize_axis_index(
|
||||
int axis,
|
||||
int ndim,
|
||||
const std::string& msg_prefix /* = "" */) {
|
||||
if (axis < -ndim || axis >= ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "Axis " << axis << " is out of bounds for array with " << ndim
|
||||
<< " dimensions.";
|
||||
msg << msg_prefix << "Axis " << axis << " is out of bounds for array with "
|
||||
<< ndim << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (axis < 0) {
|
||||
axis += ndim;
|
||||
}
|
||||
return axis;
|
||||
return axis < 0 ? axis + ndim : axis;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Device& d) {
|
||||
@ -323,15 +311,6 @@ std::ostream& operator<<(std::ostream& os, const Strides& v) {
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
|
||||
os << "(";
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||
}
|
||||
os << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
namespace env {
|
||||
|
||||
int get_var(const char* name, int default_value) {
|
||||
|
26
mlx/utils.h
26
mlx/utils.h
@ -64,30 +64,13 @@ Dtype result_type(const std::vector<array>& arrays);
|
||||
|
||||
Shape broadcast_shapes(const Shape& s1, const Shape& s2);
|
||||
|
||||
bool is_same_shape(const std::vector<array>& arrays);
|
||||
|
||||
/** Returns the shape dimension if it's within allowed range. */
|
||||
template <typename T>
|
||||
int check_shape_dim(const T dim) {
|
||||
constexpr bool is_signed = std::numeric_limits<T>::is_signed;
|
||||
using U = std::conditional_t<is_signed, int64_t, size_t>;
|
||||
constexpr U min = static_cast<U>(std::numeric_limits<int>::min());
|
||||
constexpr U max = static_cast<U>(std::numeric_limits<int>::max());
|
||||
|
||||
if ((is_signed && dim < min) || dim > max) {
|
||||
throw std::invalid_argument(
|
||||
"Shape dimension falls outside supported `int` range.");
|
||||
}
|
||||
|
||||
return static_cast<int>(dim);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the axis normalized to be in the range [0, ndim).
|
||||
* Based on numpy's normalize_axis_index. See
|
||||
* https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
|
||||
*/
|
||||
int normalize_axis(int axis, int ndim);
|
||||
int normalize_axis_index(
|
||||
int axis,
|
||||
int ndim,
|
||||
const std::string& msg_prefix = "");
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Device& d);
|
||||
std::ostream& operator<<(std::ostream& os, const Stream& s);
|
||||
@ -96,7 +79,6 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
|
||||
std::ostream& operator<<(std::ostream& os, array a);
|
||||
std::ostream& operator<<(std::ostream& os, const Shape& v);
|
||||
std::ostream& operator<<(std::ostream& os, const Strides& v);
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
|
||||
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
|
||||
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
|
||||
}
|
||||
|
@ -27,10 +27,18 @@ struct ndarray_traits<float16_t> {
|
||||
static constexpr dlpack::dtype bfloat16{4, 16, 1};
|
||||
}; // namespace nanobind
|
||||
|
||||
int check_shape_dim(int64_t dim) {
|
||||
if (dim > std::numeric_limits<int>::max()) {
|
||||
throw std::invalid_argument(
|
||||
"Shape dimension falls outside supported `int` range.");
|
||||
}
|
||||
return static_cast<int>(dim);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
array nd_array_to_mlx_contiguous(
|
||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype) {
|
||||
// Make a copy of the numpy buffer
|
||||
// Get buffer ptr pass to array constructor
|
||||
@ -42,7 +50,7 @@ array nd_array_to_mlx(
|
||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
||||
std::optional<Dtype> dtype) {
|
||||
// Compute the shape and size
|
||||
std::vector<int> shape;
|
||||
Shape shape;
|
||||
for (int i = 0; i < nd_array.ndim(); i++) {
|
||||
shape.push_back(check_shape_dim(nd_array.shape(i)));
|
||||
}
|
||||
@ -108,13 +116,12 @@ nb::ndarray<NDParams...> mlx_to_nd_array_impl(
|
||||
a.eval();
|
||||
}
|
||||
std::vector<size_t> shape(a.shape().begin(), a.shape().end());
|
||||
std::vector<int64_t> strides(a.strides().begin(), a.strides().end());
|
||||
return nb::ndarray<NDParams...>(
|
||||
a.data<T>(),
|
||||
a.ndim(),
|
||||
shape.data(),
|
||||
/* owner= */ nb::none(),
|
||||
strides.data(),
|
||||
a.strides().data(),
|
||||
t.value_or(nb::dtype<T>()));
|
||||
}
|
||||
|
||||
@ -272,7 +279,7 @@ void fill_vector(T list, std::vector<U>& vals) {
|
||||
template <typename T>
|
||||
PyScalarT validate_shape(
|
||||
T list,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
int idx,
|
||||
bool& all_python_primitive_elements) {
|
||||
if (idx >= shape.size()) {
|
||||
@ -340,7 +347,7 @@ PyScalarT validate_shape(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void get_shape(T list, std::vector<int>& shape) {
|
||||
void get_shape(T list, Shape& shape) {
|
||||
shape.push_back(check_shape_dim(nb::len(list)));
|
||||
if (shape.back() > 0) {
|
||||
auto l = list.begin();
|
||||
@ -351,7 +358,7 @@ void get_shape(T list, std::vector<int>& shape) {
|
||||
} else if (nb::isinstance<array>(*l)) {
|
||||
auto arr = nb::cast<array>(*l);
|
||||
for (int i = 0; i < arr.ndim(); i++) {
|
||||
shape.push_back(check_shape_dim(arr.shape(i)));
|
||||
shape.push_back(arr.shape(i));
|
||||
}
|
||||
return;
|
||||
}
|
||||
@ -363,7 +370,7 @@ array array_from_list_impl(
|
||||
T pl,
|
||||
const PyScalarT& inferred_type,
|
||||
std::optional<Dtype> specified_type,
|
||||
const std::vector<int>& shape) {
|
||||
const Shape& shape) {
|
||||
// Make the array
|
||||
switch (inferred_type) {
|
||||
case pybool: {
|
||||
@ -420,7 +427,7 @@ array array_from_list_impl(
|
||||
template <typename T>
|
||||
array array_from_list_impl(T pl, std::optional<Dtype> dtype) {
|
||||
// Compute the shape
|
||||
std::vector<int> shape;
|
||||
Shape shape;
|
||||
get_shape(pl, shape);
|
||||
|
||||
// Validate the shape and type
|
||||
|
@ -2953,16 +2953,16 @@ void init_ops(nb::module_& m) {
|
||||
m.def(
|
||||
"as_strided",
|
||||
[](const array& a,
|
||||
std::optional<std::vector<int>> shape,
|
||||
std::optional<std::vector<size_t>> strides,
|
||||
std::optional<Shape> shape,
|
||||
std::optional<Strides> strides,
|
||||
size_t offset,
|
||||
StreamOrDevice s) {
|
||||
std::vector<int> a_shape = (shape) ? *shape : a.shape();
|
||||
std::vector<size_t> a_strides;
|
||||
auto a_shape = (shape) ? *shape : a.shape();
|
||||
Strides a_strides;
|
||||
if (strides) {
|
||||
a_strides = *strides;
|
||||
} else {
|
||||
a_strides = std::vector<size_t>(a_shape.size(), 1);
|
||||
a_strides = Strides(a_shape.size(), 1);
|
||||
for (int i = a_shape.size() - 1; i > 0; i--) {
|
||||
a_strides[i - 1] = a_shape[i] * a_strides[i];
|
||||
}
|
||||
|
@ -11,7 +11,7 @@ void test_arg_reduce_small(
|
||||
Device d,
|
||||
const array& x,
|
||||
ArgReduce::ReduceType r,
|
||||
std::vector<int> out_shape,
|
||||
Shape out_shape,
|
||||
int axis,
|
||||
std::vector<int> expected_output) {
|
||||
auto s = default_stream(d);
|
||||
@ -27,7 +27,7 @@ void test_arg_reduce_small(
|
||||
void test_arg_reduce_against_cpu(
|
||||
const array& x,
|
||||
ArgReduce::ReduceType r,
|
||||
std::vector<int> out_shape,
|
||||
Shape out_shape,
|
||||
int axis) {
|
||||
auto y1 = array(
|
||||
out_shape,
|
||||
@ -125,7 +125,7 @@ TEST_CASE("test arg reduce against cpu") {
|
||||
void test_arg_reduce_small_bool(
|
||||
Device d,
|
||||
ArgReduce::ReduceType r,
|
||||
std::vector<int> out_shape,
|
||||
Shape out_shape,
|
||||
int axis,
|
||||
std::vector<int> expected_output) {
|
||||
auto s = default_stream(d);
|
||||
|
@ -13,10 +13,10 @@ TEST_CASE("test array basics") {
|
||||
array x(1.0);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.ndim(), 0);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{});
|
||||
CHECK_EQ(x.shape(), Shape{});
|
||||
CHECK_THROWS_AS(x.shape(0), std::out_of_range);
|
||||
CHECK_THROWS_AS(x.shape(-1), std::out_of_range);
|
||||
CHECK_EQ(x.strides(), std::vector<size_t>{});
|
||||
CHECK_EQ(x.strides(), Strides{});
|
||||
CHECK_EQ(x.itemsize(), sizeof(float));
|
||||
CHECK_EQ(x.nbytes(), sizeof(float));
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
@ -39,12 +39,12 @@ TEST_CASE("test array basics") {
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.ndim(), 1);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{1});
|
||||
CHECK_EQ(x.shape(), Shape{1});
|
||||
CHECK_EQ(x.shape(0), 1);
|
||||
CHECK_EQ(x.shape(-1), 1);
|
||||
CHECK_THROWS_AS(x.shape(1), std::out_of_range);
|
||||
CHECK_THROWS_AS(x.shape(-2), std::out_of_range);
|
||||
CHECK_EQ(x.strides(), std::vector<size_t>{1});
|
||||
CHECK_EQ(x.strides(), Strides{1});
|
||||
CHECK_EQ(x.item<float>(), 1.0);
|
||||
|
||||
// Check empty array
|
||||
@ -57,7 +57,7 @@ TEST_CASE("test array basics") {
|
||||
|
||||
x = array({1.0, 1.0});
|
||||
CHECK_EQ(x.size(), 2);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{2});
|
||||
CHECK_EQ(x.shape(), Shape{2});
|
||||
CHECK_EQ(x.itemsize(), sizeof(float));
|
||||
CHECK_EQ(x.nbytes(), x.itemsize() * x.size());
|
||||
|
||||
@ -65,9 +65,9 @@ TEST_CASE("test array basics") {
|
||||
CHECK_THROWS_AS(x.item<float>(), std::invalid_argument);
|
||||
|
||||
x = array({1.0, 1.0, 1.0}, {1, 3});
|
||||
CHECK(x.size() == 3);
|
||||
CHECK(x.shape() == std::vector<int>{1, 3});
|
||||
CHECK(x.strides() == std::vector<size_t>{3, 1});
|
||||
CHECK_EQ(x.size(), 3);
|
||||
CHECK_EQ(x.shape(), Shape{1, 3});
|
||||
CHECK_EQ(x.strides(), Strides{3, 1});
|
||||
|
||||
// Test wrong size/shapes throw:
|
||||
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {4}), std::invalid_argument);
|
||||
@ -472,7 +472,7 @@ TEST_CASE("test array metadata") {
|
||||
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
|
||||
y = slice(x, {0, 0}, {1, 2}, {2, 3});
|
||||
eval(y);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{1, 1});
|
||||
CHECK_EQ(y.shape(), Shape{1, 1});
|
||||
CHECK_EQ(y.data_size(), 1);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
@ -481,7 +481,7 @@ TEST_CASE("test array metadata") {
|
||||
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {1, 4});
|
||||
y = slice(x, {0, 0}, {1, 4}, {1, 2});
|
||||
eval(y);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{1, 2});
|
||||
CHECK_EQ(y.shape(), Shape{1, 2});
|
||||
CHECK_EQ(y.flags().contiguous, false);
|
||||
CHECK_EQ(y.flags().row_contiguous, false);
|
||||
CHECK_EQ(y.flags().col_contiguous, false);
|
||||
@ -489,7 +489,7 @@ TEST_CASE("test array metadata") {
|
||||
x = broadcast_to(array(1.0f), {4, 10});
|
||||
y = slice(x, {0, 0}, {4, 10}, {2, 2});
|
||||
eval(y);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{2, 5});
|
||||
CHECK_EQ(y.shape(), Shape{2, 5});
|
||||
CHECK_EQ(y.data_size(), 1);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, false);
|
||||
@ -566,8 +566,8 @@ TEST_CASE("test array iteration") {
|
||||
}
|
||||
|
||||
TEST_CASE("test array shared buffer") {
|
||||
std::vector<int> shape = {2, 2};
|
||||
int n_elem = shape[0] * shape[1];
|
||||
Shape shape = {2, 2};
|
||||
auto n_elem = shape[0] * shape[1];
|
||||
|
||||
allocator::Buffer buf_b = allocator::malloc(n_elem * sizeof(float));
|
||||
void* buf_b_ptr = buf_b.raw_ptr();
|
||||
|
@ -617,7 +617,7 @@ TEST_CASE("test op vjps") {
|
||||
axes = {0};
|
||||
out = vjp(fun, array({}), array(3.0f)).second;
|
||||
CHECK_EQ(out.size(), 0);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{0});
|
||||
CHECK_EQ(out.shape(), Shape{0});
|
||||
|
||||
axes = {0};
|
||||
out = vjp(fun, ones({2, 2, 2}), array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}))
|
||||
@ -725,9 +725,9 @@ TEST_CASE("test gather and take grads") {
|
||||
}
|
||||
|
||||
TEST_CASE("test slice grads") {
|
||||
std::vector<int> start = {5, 0, 0};
|
||||
std::vector<int> stop = {7, 2, 4};
|
||||
std::vector<int> strides = {1, 1, 1};
|
||||
Shape start = {5, 0, 0};
|
||||
Shape stop = {7, 2, 4};
|
||||
Shape strides = {1, 1, 1};
|
||||
|
||||
auto fn = [&start, &stop, &strides](array input) {
|
||||
return slice(input, start, stop, strides);
|
||||
@ -982,8 +982,8 @@ TEST_CASE("test comparison grads") {
|
||||
|
||||
TEST_CASE("test as_strided grads") {
|
||||
auto x = ones({11});
|
||||
std::vector<int> shape = {5, 5};
|
||||
std::vector<size_t> strides = {1, 1};
|
||||
Shape shape = {5, 5};
|
||||
Strides strides = {1, 1};
|
||||
size_t offset = 0;
|
||||
|
||||
auto fun = [&shape, &strides, &offset](array x) {
|
||||
|
@ -16,7 +16,7 @@ TEST_CASE("test matmul") {
|
||||
a = array({1.0});
|
||||
b = array({1.0});
|
||||
auto out = matmul(a, b);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{});
|
||||
CHECK_EQ(out.shape(), Shape{});
|
||||
CHECK_EQ(out.size(), 1);
|
||||
CHECK_EQ(out.dtype(), float32);
|
||||
CHECK_EQ(out.item<float>(), 1.0f);
|
||||
|
@ -208,14 +208,14 @@ TEST_CASE("test full") {
|
||||
// Check zeros and ones
|
||||
{
|
||||
auto x = zeros({2, 2}, float32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{2, 2});
|
||||
CHECK_EQ(x.shape(), Shape{2, 2});
|
||||
CHECK_EQ(x.ndim(), 2);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
auto y = array({0.0, 0.0, 0.0, 0.0}, {2, 2});
|
||||
CHECK(array_equal(x, y).item<bool>());
|
||||
|
||||
x = ones({2, 2}, float32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{2, 2});
|
||||
CHECK_EQ(x.shape(), Shape{2, 2});
|
||||
CHECK_EQ(x.ndim(), 2);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
y = array({1.0, 1.0, 1.0, 1.0}, {2, 2});
|
||||
@ -235,11 +235,11 @@ TEST_CASE("test full") {
|
||||
// Works for empty shape and empty array
|
||||
{
|
||||
array x = ones({}, int32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{});
|
||||
CHECK_EQ(x.shape(), Shape{});
|
||||
CHECK_EQ(x.item<int>(), 1);
|
||||
|
||||
x = full({0}, array({}));
|
||||
CHECK_EQ(x.shape(), std::vector<int>{0});
|
||||
CHECK_EQ(x.shape(), Shape{0});
|
||||
CHECK_EQ(x.size(), 0);
|
||||
|
||||
CHECK_THROWS_AS(full({}, array({})), std::invalid_argument);
|
||||
|
@ -162,35 +162,35 @@ TEST_CASE("test fftn") {
|
||||
|
||||
x = reshape(arange(20, float32), {5, 4});
|
||||
y = fft::rfftn(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||
CHECK_EQ(y.shape(), Shape{5, 3});
|
||||
y = fft::rfftn(x, {1, 0});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{3, 4});
|
||||
CHECK_EQ(y.shape(), Shape{3, 4});
|
||||
|
||||
x = reshape(arange(20, float32), {5, 4});
|
||||
y = fft::irfftn(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 6});
|
||||
CHECK_EQ(y.shape(), Shape{5, 6});
|
||||
y = fft::irfftn(x, {1, 0});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{8, 4});
|
||||
CHECK_EQ(y.shape(), Shape{8, 4});
|
||||
}
|
||||
|
||||
// Check the types of real ffts
|
||||
{
|
||||
x = zeros({5, 5}, float32);
|
||||
auto y = fft::rfft2(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||
CHECK_EQ(y.shape(), Shape{5, 3});
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
|
||||
y = fft::rfftn(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||
CHECK_EQ(y.shape(), Shape{5, 3});
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
|
||||
x = zeros({5, 5}, complex64);
|
||||
y = fft::irfft2(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 8});
|
||||
CHECK_EQ(y.shape(), Shape{5, 8});
|
||||
CHECK_EQ(y.dtype(), float32);
|
||||
|
||||
y = fft::irfftn(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 8});
|
||||
CHECK_EQ(y.shape(), Shape{5, 8});
|
||||
CHECK_EQ(y.dtype(), float32);
|
||||
}
|
||||
}
|
||||
@ -199,25 +199,25 @@ TEST_CASE("test fft with provided shape") {
|
||||
auto x = ones({5, 5});
|
||||
|
||||
auto y = fft::fft(x, 7, 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{7, 5});
|
||||
CHECK_EQ(y.shape(), Shape{7, 5});
|
||||
|
||||
y = fft::fft(x, 3, 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{3, 5});
|
||||
CHECK_EQ(y.shape(), Shape{3, 5});
|
||||
|
||||
y = fft::fft(x, 7, 1);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 7});
|
||||
CHECK_EQ(y.shape(), Shape{5, 7});
|
||||
|
||||
y = fft::fft(x, 3, 1);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||
CHECK_EQ(y.shape(), Shape{5, 3});
|
||||
|
||||
y = fft::rfft(x, 7, 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{4, 5});
|
||||
CHECK_EQ(y.shape(), Shape{4, 5});
|
||||
|
||||
y = fft::rfft(x, 3, 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{2, 5});
|
||||
CHECK_EQ(y.shape(), Shape{2, 5});
|
||||
|
||||
y = fft::rfft(x, 3, 1);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 2});
|
||||
CHECK_EQ(y.shape(), Shape{5, 2});
|
||||
}
|
||||
|
||||
TEST_CASE("test fft vmap") {
|
||||
@ -288,23 +288,23 @@ TEST_CASE("test fft grads") {
|
||||
astype(zeros({5, 5}), complex64),
|
||||
astype(zeros({5, 5}), complex64))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
|
||||
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
|
||||
|
||||
vjp_out = vjp([](array x) { return fft::ifftn(x); },
|
||||
astype(zeros({5, 5}), complex64),
|
||||
astype(zeros({5, 5}), complex64))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
|
||||
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
|
||||
|
||||
vjp_out = vjp([](array x) { return fft::rfftn(x); },
|
||||
zeros({5, 9}),
|
||||
astype(zeros({5, 5}), complex64))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 9});
|
||||
CHECK_EQ(vjp_out.shape(), Shape{5, 9});
|
||||
|
||||
vjp_out = vjp([](array x) { return fft::irfftn(x); },
|
||||
astype(zeros({5, 5}), complex64),
|
||||
zeros({5, 8}))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
|
||||
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
|
||||
}
|
||||
|
@ -129,18 +129,10 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
||||
CHECK_EQ(
|
||||
norm(x, -1.0, std::vector<int>{1, 0}).item<float>(),
|
||||
doctest::Approx(3.0));
|
||||
CHECK_EQ(
|
||||
norm(x, 1.0, std::vector<int>{0, 1}, true).shape(),
|
||||
std::vector<int>{1, 1});
|
||||
CHECK_EQ(
|
||||
norm(x, 1.0, std::vector<int>{1, 0}, true).shape(),
|
||||
std::vector<int>{1, 1});
|
||||
CHECK_EQ(
|
||||
norm(x, -1.0, std::vector<int>{0, 1}, true).shape(),
|
||||
std::vector<int>{1, 1});
|
||||
CHECK_EQ(
|
||||
norm(x, -1.0, std::vector<int>{1, 0}, true).shape(),
|
||||
std::vector<int>{1, 1});
|
||||
CHECK_EQ(norm(x, 1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});
|
||||
CHECK_EQ(norm(x, 1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});
|
||||
CHECK_EQ(norm(x, -1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});
|
||||
CHECK_EQ(norm(x, -1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});
|
||||
|
||||
CHECK_EQ(
|
||||
norm(x, -1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
||||
@ -286,9 +278,9 @@ TEST_CASE("test SVD factorization") {
|
||||
const auto& S = outs[1];
|
||||
const auto& Vt = outs[2];
|
||||
|
||||
CHECK_EQ(U.shape(), std::vector<int>{5, 5});
|
||||
CHECK_EQ(S.shape(), std::vector<int>{4});
|
||||
CHECK_EQ(Vt.shape(), std::vector<int>{4, 4});
|
||||
CHECK_EQ(U.shape(), Shape{5, 5});
|
||||
CHECK_EQ(S.shape(), Shape{4});
|
||||
CHECK_EQ(Vt.shape(), Shape{4, 4});
|
||||
|
||||
const auto U_slice = slice(U, {0, 0}, {U.shape(0), S.shape(0)});
|
||||
|
||||
|
@ -32,11 +32,11 @@ TEST_CASE("test save_safetensors") {
|
||||
CHECK_EQ(dict.count("test2"), 1);
|
||||
array test = dict.at("test");
|
||||
CHECK_EQ(test.dtype(), float32);
|
||||
CHECK_EQ(test.shape(), std::vector<int>({4}));
|
||||
CHECK_EQ(test.shape(), Shape{4});
|
||||
CHECK(array_equal(test, array({1.0, 2.0, 3.0, 4.0})).item<bool>());
|
||||
array test2 = dict.at("test2");
|
||||
CHECK_EQ(test2.dtype(), float32);
|
||||
CHECK_EQ(test2.shape(), std::vector<int>({2, 2}));
|
||||
CHECK_EQ(test2.shape(), Shape{2, 2});
|
||||
CHECK(array_equal(test2, ones({2, 2})).item<bool>());
|
||||
}
|
||||
|
||||
|
@ -15,13 +15,13 @@ using namespace mlx::core;
|
||||
TEST_CASE("test copy") {
|
||||
array x(1.0);
|
||||
auto y = copy(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{});
|
||||
CHECK_EQ(y.shape(), Shape{});
|
||||
CHECK_NE(y.id(), x.id());
|
||||
CHECK_EQ(y.item<float>(), 1.0f);
|
||||
|
||||
x = array({1, 2}, {2, 1});
|
||||
y = copy(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{2, 1});
|
||||
CHECK_EQ(y.shape(), Shape{2, 1});
|
||||
CHECK_EQ(y.dtype(), int32);
|
||||
CHECK_NE(y.id(), x.id());
|
||||
CHECK(array_equal(y, x).item<bool>());
|
||||
@ -29,37 +29,37 @@ TEST_CASE("test copy") {
|
||||
|
||||
TEST_CASE("test reshape") {
|
||||
array x(1.0);
|
||||
CHECK_EQ(reshape(x, {}).shape(), std::vector<int>{});
|
||||
CHECK_EQ(reshape(x, {}).shape(), Shape{});
|
||||
CHECK_THROWS_AS(reshape(x, {2}), std::invalid_argument);
|
||||
auto y = reshape(x, {1, 1, 1});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{1, 1, 1});
|
||||
CHECK_EQ(y.shape(), Shape{1, 1, 1});
|
||||
y = reshape(x, {-1, 1, 1});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{1, 1, 1});
|
||||
CHECK_EQ(y.shape(), Shape{1, 1, 1});
|
||||
y = reshape(x, {1, 1, -1});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{1, 1, 1});
|
||||
CHECK_EQ(y.shape(), Shape{1, 1, 1});
|
||||
CHECK_THROWS_AS(reshape(x, {1, -1, -1}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(reshape(x, {2, -1}), std::invalid_argument);
|
||||
|
||||
x = zeros({2, 2, 2});
|
||||
y = reshape(x, {8});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{8});
|
||||
CHECK_EQ(y.shape(), Shape{8});
|
||||
CHECK_THROWS_AS(reshape(x, {7}), std::invalid_argument);
|
||||
y = reshape(x, {-1});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{8});
|
||||
CHECK_EQ(y.shape(), Shape{8});
|
||||
y = reshape(x, {-1, 2});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{4, 2});
|
||||
CHECK_EQ(y.shape(), Shape{4, 2});
|
||||
CHECK_THROWS_AS(reshape(x, {-1, 7}), std::invalid_argument);
|
||||
|
||||
// Works with empty array
|
||||
x = array({});
|
||||
y = reshape(x, {0, 0, 0});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{0, 0, 0});
|
||||
CHECK_EQ(y.shape(), Shape{0, 0, 0});
|
||||
y.eval();
|
||||
CHECK_EQ(y.size(), 0);
|
||||
CHECK_THROWS_AS(reshape(x, {}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(reshape(x, {1}), std::invalid_argument);
|
||||
y = reshape(x, {1, 5, 0});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{1, 5, 0});
|
||||
CHECK_EQ(y.shape(), Shape{1, 5, 0});
|
||||
|
||||
// Check that reshaping a transposed array doesn't result in a copy
|
||||
x = reshape(arange(64), {2, 4, 8});
|
||||
@ -138,15 +138,15 @@ TEST_CASE("test reshape") {
|
||||
|
||||
TEST_CASE("test flatten") {
|
||||
array x = zeros({2, 3, 4});
|
||||
CHECK_EQ(flatten(x).shape(), std::vector<int>({2 * 3 * 4}));
|
||||
CHECK_EQ(flatten(x).shape(), Shape({2 * 3 * 4}));
|
||||
|
||||
CHECK_EQ(flatten(x, 1, 1).shape(), std::vector<int>({2, 3, 4}));
|
||||
CHECK_EQ(flatten(x, 1, 2).shape(), std::vector<int>({2, 3 * 4}));
|
||||
CHECK_EQ(flatten(x, 1, 3).shape(), std::vector<int>({2, 3 * 4}));
|
||||
CHECK_EQ(flatten(x, 1, -1).shape(), std::vector<int>({2, 3 * 4}));
|
||||
CHECK_EQ(flatten(x, -2, -1).shape(), std::vector<int>({2, 3 * 4}));
|
||||
CHECK_EQ(flatten(x, -3, -1).shape(), std::vector<int>({2 * 3 * 4}));
|
||||
CHECK_EQ(flatten(x, -4, -1).shape(), std::vector<int>({2 * 3 * 4}));
|
||||
CHECK_EQ(flatten(x, 1, 1).shape(), Shape({2, 3, 4}));
|
||||
CHECK_EQ(flatten(x, 1, 2).shape(), Shape({2, 3 * 4}));
|
||||
CHECK_EQ(flatten(x, 1, 3).shape(), Shape({2, 3 * 4}));
|
||||
CHECK_EQ(flatten(x, 1, -1).shape(), Shape({2, 3 * 4}));
|
||||
CHECK_EQ(flatten(x, -2, -1).shape(), Shape({2, 3 * 4}));
|
||||
CHECK_EQ(flatten(x, -3, -1).shape(), Shape({2 * 3 * 4}));
|
||||
CHECK_EQ(flatten(x, -4, -1).shape(), Shape({2 * 3 * 4}));
|
||||
|
||||
// Check start > end throws
|
||||
CHECK_THROWS(flatten(x, 2, 1));
|
||||
@ -159,17 +159,17 @@ TEST_CASE("test flatten") {
|
||||
|
||||
// Check scalar flattens to 1D
|
||||
x = array(1);
|
||||
CHECK_EQ(flatten(x, -3, -1).shape(), std::vector<int>({1}));
|
||||
CHECK_EQ(flatten(x, 0, 0).shape(), std::vector<int>({1}));
|
||||
CHECK_EQ(flatten(x, -3, -1).shape(), Shape({1}));
|
||||
CHECK_EQ(flatten(x, 0, 0).shape(), Shape({1}));
|
||||
}
|
||||
|
||||
TEST_CASE("test squeeze and expand") {
|
||||
array x = zeros({2, 1, 2, 1, 2, 1});
|
||||
CHECK_EQ(squeeze(x).shape(), std::vector<int>{2, 2, 2});
|
||||
CHECK_EQ(squeeze(x, {1, 3, 5}).shape(), std::vector<int>{2, 2, 2});
|
||||
CHECK_EQ(squeeze(x, {-1, -3, -5}).shape(), std::vector<int>{2, 2, 2});
|
||||
CHECK_EQ(squeeze(x, 1).shape(), std::vector<int>{2, 2, 1, 2, 1});
|
||||
CHECK_EQ(squeeze(x, -1).shape(), std::vector<int>{2, 1, 2, 1, 2});
|
||||
CHECK_EQ(squeeze(x).shape(), Shape{2, 2, 2});
|
||||
CHECK_EQ(squeeze(x, {1, 3, 5}).shape(), Shape{2, 2, 2});
|
||||
CHECK_EQ(squeeze(x, {-1, -3, -5}).shape(), Shape{2, 2, 2});
|
||||
CHECK_EQ(squeeze(x, 1).shape(), Shape{2, 2, 1, 2, 1});
|
||||
CHECK_EQ(squeeze(x, -1).shape(), Shape{2, 1, 2, 1, 2});
|
||||
|
||||
CHECK_THROWS(squeeze(x, 0));
|
||||
CHECK_THROWS(squeeze(x, 2));
|
||||
@ -177,13 +177,13 @@ TEST_CASE("test squeeze and expand") {
|
||||
CHECK_THROWS(squeeze(x, {1, 3, -3}));
|
||||
|
||||
x = zeros({2, 2});
|
||||
CHECK_EQ(expand_dims(x, 0).shape(), std::vector<int>{1, 2, 2});
|
||||
CHECK_EQ(expand_dims(x, -1).shape(), std::vector<int>{2, 2, 1});
|
||||
CHECK_EQ(expand_dims(x, 1).shape(), std::vector<int>{2, 1, 2});
|
||||
CHECK_EQ(expand_dims(x, {0, 1, 2}).shape(), std::vector<int>{1, 1, 1, 2, 2});
|
||||
CHECK_EQ(expand_dims(x, 0).shape(), Shape{1, 2, 2});
|
||||
CHECK_EQ(expand_dims(x, -1).shape(), Shape{2, 2, 1});
|
||||
CHECK_EQ(expand_dims(x, 1).shape(), Shape{2, 1, 2});
|
||||
CHECK_EQ(expand_dims(x, {0, 1, 2}).shape(), Shape{1, 1, 1, 2, 2});
|
||||
CHECK_EQ(
|
||||
expand_dims(x, {0, 1, 2, 5, 6, 7}).shape(),
|
||||
std::vector<int>{1, 1, 1, 2, 2, 1, 1, 1});
|
||||
Shape{1, 1, 1, 2, 2, 1, 1, 1});
|
||||
|
||||
CHECK_THROWS(expand_dims(x, 3));
|
||||
CHECK_THROWS(expand_dims(x, -4));
|
||||
@ -210,7 +210,7 @@ TEST_CASE("test slice") {
|
||||
|
||||
out = slice(x, {1}, {0});
|
||||
eval(out);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{0});
|
||||
CHECK_EQ(out.shape(), Shape{0});
|
||||
|
||||
out = slice(x, {0}, {1}, {1});
|
||||
CHECK_EQ(out.item<int>(), 3);
|
||||
@ -353,7 +353,7 @@ TEST_CASE("test split") {
|
||||
out = split(x, 3, -1);
|
||||
CHECK_EQ(out.size(), 3);
|
||||
for (auto i = 0; i < 3; ++i) {
|
||||
CHECK_EQ(out[i].shape(), std::vector<int>{1});
|
||||
CHECK_EQ(out[i].shape(), Shape{1});
|
||||
CHECK_EQ(out[i].dtype(), int32);
|
||||
CHECK_EQ(out[i].item<int>(), i);
|
||||
}
|
||||
@ -370,13 +370,13 @@ TEST_CASE("test split") {
|
||||
x = zeros({8, 12});
|
||||
out = split(x, 2);
|
||||
CHECK_EQ(out.size(), 2);
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{4, 12});
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{4, 12});
|
||||
CHECK_EQ(out[0].shape(), Shape{4, 12});
|
||||
CHECK_EQ(out[1].shape(), Shape{4, 12});
|
||||
out = split(x, 3, 1);
|
||||
CHECK_EQ(out.size(), 3);
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{8, 4});
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{8, 4});
|
||||
CHECK_EQ(out[2].shape(), std::vector<int>{8, 4});
|
||||
CHECK_EQ(out[0].shape(), Shape{8, 4});
|
||||
CHECK_EQ(out[1].shape(), Shape{8, 4});
|
||||
CHECK_EQ(out[2].shape(), Shape{8, 4});
|
||||
|
||||
out = split(x, std::vector<int>{});
|
||||
CHECK_EQ(out.size(), 1);
|
||||
@ -384,25 +384,25 @@ TEST_CASE("test split") {
|
||||
|
||||
out = split(x, {3, 7});
|
||||
CHECK_EQ(out.size(), 3);
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{3, 12});
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{4, 12});
|
||||
CHECK_EQ(out[2].shape(), std::vector<int>{1, 12});
|
||||
CHECK_EQ(out[0].shape(), Shape{3, 12});
|
||||
CHECK_EQ(out[1].shape(), Shape{4, 12});
|
||||
CHECK_EQ(out[2].shape(), Shape{1, 12});
|
||||
|
||||
out = split(x, std::vector<int>{20});
|
||||
CHECK_EQ(out.size(), 2);
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{8, 12});
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{0, 12});
|
||||
CHECK_EQ(out[0].shape(), Shape{8, 12});
|
||||
CHECK_EQ(out[1].shape(), Shape{0, 12});
|
||||
|
||||
// Negative indices
|
||||
out = split(x, std::vector<int>{-5});
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{3, 12});
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{5, 12});
|
||||
CHECK_EQ(out[0].shape(), Shape{3, 12});
|
||||
CHECK_EQ(out[1].shape(), Shape{5, 12});
|
||||
|
||||
// Different axis
|
||||
out = split(x, std::vector<int>{2, 8}, 1);
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{8, 2});
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{8, 6});
|
||||
CHECK_EQ(out[2].shape(), std::vector<int>{8, 4});
|
||||
CHECK_EQ(out[0].shape(), Shape{8, 2});
|
||||
CHECK_EQ(out[1].shape(), Shape{8, 6});
|
||||
CHECK_EQ(out[2].shape(), Shape{8, 4});
|
||||
|
||||
// Out of order indices
|
||||
x = arange(5);
|
||||
@ -420,18 +420,18 @@ TEST_CASE("test swap and move axes") {
|
||||
|
||||
a = zeros({2});
|
||||
CHECK_THROWS(swapaxes(a, 0, 1));
|
||||
CHECK_EQ(swapaxes(a, 0, 0).shape(), std::vector<int>{2});
|
||||
CHECK_EQ(swapaxes(a, -1, -1).shape(), std::vector<int>{2});
|
||||
CHECK_EQ(swapaxes(a, 0, 0).shape(), Shape{2});
|
||||
CHECK_EQ(swapaxes(a, -1, -1).shape(), Shape{2});
|
||||
|
||||
a = zeros({2, 3, 4});
|
||||
CHECK_THROWS(swapaxes(a, 0, -4));
|
||||
CHECK_THROWS(swapaxes(a, 0, 3));
|
||||
CHECK_THROWS(swapaxes(a, 3, 0));
|
||||
CHECK_THROWS(swapaxes(a, -4, 0));
|
||||
CHECK_EQ(swapaxes(a, 0, 2).shape(), std::vector<int>{4, 3, 2});
|
||||
CHECK_EQ(swapaxes(a, 0, 1).shape(), std::vector<int>{3, 2, 4});
|
||||
CHECK_EQ(swapaxes(a, 0, -1).shape(), std::vector<int>{4, 3, 2});
|
||||
CHECK_EQ(swapaxes(a, -2, 2).shape(), std::vector<int>{2, 4, 3});
|
||||
CHECK_EQ(swapaxes(a, 0, 2).shape(), Shape{4, 3, 2});
|
||||
CHECK_EQ(swapaxes(a, 0, 1).shape(), Shape{3, 2, 4});
|
||||
CHECK_EQ(swapaxes(a, 0, -1).shape(), Shape{4, 3, 2});
|
||||
CHECK_EQ(swapaxes(a, -2, 2).shape(), Shape{2, 4, 3});
|
||||
|
||||
// Test moveaxis
|
||||
a = array(0.0);
|
||||
@ -439,36 +439,36 @@ TEST_CASE("test swap and move axes") {
|
||||
|
||||
a = zeros({2});
|
||||
CHECK_THROWS(moveaxis(a, 0, 1));
|
||||
CHECK_EQ(moveaxis(a, 0, 0).shape(), std::vector<int>{2});
|
||||
CHECK_EQ(moveaxis(a, -1, -1).shape(), std::vector<int>{2});
|
||||
CHECK_EQ(moveaxis(a, 0, 0).shape(), Shape{2});
|
||||
CHECK_EQ(moveaxis(a, -1, -1).shape(), Shape{2});
|
||||
|
||||
a = zeros({2, 3, 4});
|
||||
CHECK_THROWS(moveaxis(a, 0, -4));
|
||||
CHECK_THROWS(moveaxis(a, 0, 3));
|
||||
CHECK_THROWS(moveaxis(a, 3, 0));
|
||||
CHECK_THROWS(moveaxis(a, -4, 0));
|
||||
CHECK_EQ(moveaxis(a, 0, 2).shape(), std::vector<int>{3, 4, 2});
|
||||
CHECK_EQ(moveaxis(a, 0, 1).shape(), std::vector<int>{3, 2, 4});
|
||||
CHECK_EQ(moveaxis(a, 0, -1).shape(), std::vector<int>{3, 4, 2});
|
||||
CHECK_EQ(moveaxis(a, -2, 2).shape(), std::vector<int>{2, 4, 3});
|
||||
CHECK_EQ(moveaxis(a, 0, 2).shape(), Shape{3, 4, 2});
|
||||
CHECK_EQ(moveaxis(a, 0, 1).shape(), Shape{3, 2, 4});
|
||||
CHECK_EQ(moveaxis(a, 0, -1).shape(), Shape{3, 4, 2});
|
||||
CHECK_EQ(moveaxis(a, -2, 2).shape(), Shape{2, 4, 3});
|
||||
}
|
||||
|
||||
TEST_CASE("test transpose") {
|
||||
array x(1);
|
||||
auto y = transpose(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{});
|
||||
CHECK_EQ(y.shape(), Shape{});
|
||||
CHECK_EQ(y.item<int>(), 1);
|
||||
CHECK_THROWS_AS(transpose(x, {0}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(transpose(x, {1}), std::invalid_argument);
|
||||
|
||||
x = array({1}, {1});
|
||||
y = transpose(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{1});
|
||||
CHECK_EQ(y.shape(), Shape{1});
|
||||
CHECK_EQ(y.item<int>(), 1);
|
||||
|
||||
// Negative indices
|
||||
y = transpose(x, {-1});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{1});
|
||||
CHECK_EQ(y.shape(), Shape{1});
|
||||
CHECK_EQ(y.item<int>(), 1);
|
||||
|
||||
CHECK_THROWS_AS(transpose(x, {1}), std::invalid_argument);
|
||||
@ -477,24 +477,24 @@ TEST_CASE("test transpose") {
|
||||
// Works with empty array
|
||||
x = array({});
|
||||
y = transpose(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{0});
|
||||
CHECK_EQ(y.shape(), Shape{0});
|
||||
y.eval();
|
||||
CHECK_EQ(y.size(), 0);
|
||||
|
||||
x = array({1, 2, 3, 4, 5, 6}, {2, 3});
|
||||
y = transpose(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{3, 2});
|
||||
CHECK_EQ(y.shape(), Shape{3, 2});
|
||||
y = transpose(x, {-1, 0});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{3, 2});
|
||||
CHECK_EQ(y.shape(), Shape{3, 2});
|
||||
y = transpose(x, {-1, -2});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{3, 2});
|
||||
CHECK_EQ(y.shape(), Shape{3, 2});
|
||||
y.eval();
|
||||
CHECK(array_equal(y, array({1, 4, 2, 5, 3, 6}, {3, 2})).item<bool>());
|
||||
y = transpose(x, {0, 1});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{2, 3});
|
||||
CHECK_EQ(y.shape(), Shape{2, 3});
|
||||
CHECK(array_equal(y, x).item<bool>());
|
||||
y = transpose(x, {0, -1});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{2, 3});
|
||||
CHECK_EQ(y.shape(), Shape{2, 3});
|
||||
CHECK(array_equal(y, x).item<bool>());
|
||||
|
||||
CHECK_THROWS_AS(transpose(x, {}), std::invalid_argument);
|
||||
@ -505,19 +505,19 @@ TEST_CASE("test transpose") {
|
||||
|
||||
x = array({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 3, 2});
|
||||
y = transpose(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{2, 3, 2});
|
||||
CHECK_EQ(y.shape(), Shape{2, 3, 2});
|
||||
auto expected = array({1, 7, 3, 9, 5, 11, 2, 8, 4, 10, 6, 12}, {2, 3, 2});
|
||||
CHECK(array_equal(y, expected).item<bool>());
|
||||
|
||||
y = transpose(x, {0, 1, 2});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{2, 3, 2});
|
||||
CHECK_EQ(y.shape(), Shape{2, 3, 2});
|
||||
CHECK(array_equal(y, x).item<bool>());
|
||||
y = transpose(x, {1, 0, 2});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{3, 2, 2});
|
||||
CHECK_EQ(y.shape(), Shape{3, 2, 2});
|
||||
expected = array({1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12}, {3, 2, 2});
|
||||
CHECK(array_equal(y, expected).item<bool>());
|
||||
y = transpose(x, {0, 2, 1});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{2, 2, 3});
|
||||
CHECK_EQ(y.shape(), Shape{2, 2, 3});
|
||||
expected = array({1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12}, {2, 2, 3});
|
||||
CHECK(array_equal(y, expected).item<bool>());
|
||||
|
||||
@ -542,7 +542,7 @@ TEST_CASE("test comparison ops") {
|
||||
array y({});
|
||||
auto z = x == y;
|
||||
CHECK_EQ(z.dtype(), bool_);
|
||||
CHECK_EQ(z.shape(), std::vector<int>{0});
|
||||
CHECK_EQ(z.shape(), Shape{0});
|
||||
}
|
||||
|
||||
// Basic cases
|
||||
@ -631,7 +631,7 @@ TEST_CASE("test comparison ops") {
|
||||
auto y = zeros({2, 1});
|
||||
auto z = equal(x, y);
|
||||
CHECK_EQ(z.dtype(), bool_);
|
||||
CHECK_EQ(z.shape(), std::vector<int>{2, 2});
|
||||
CHECK_EQ(z.shape(), Shape{2, 2});
|
||||
auto expected = array({true, true, true, true}, {2, 2});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
|
||||
@ -639,7 +639,7 @@ TEST_CASE("test comparison ops") {
|
||||
y = array({1.0, 2.0}, {2, 1});
|
||||
z = equal(x, y);
|
||||
CHECK_EQ(z.dtype(), bool_);
|
||||
CHECK_EQ(z.shape(), std::vector<int>{2, 2});
|
||||
CHECK_EQ(z.shape(), Shape{2, 2});
|
||||
expected = array({true, false, false, true}, {2, 2});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
|
||||
@ -769,15 +769,15 @@ TEST_CASE("test reduction ops") {
|
||||
CHECK_THROWS_AS(sum(x, 0), std::out_of_range);
|
||||
CHECK_THROWS_AS(sum(x, -1), std::out_of_range);
|
||||
out = sum(x, std::vector<int>{});
|
||||
CHECK_EQ(out.shape(), std::vector<int>{});
|
||||
CHECK_EQ(out.shape(), Shape{});
|
||||
CHECK_EQ(out.size(), 1);
|
||||
|
||||
x = array({});
|
||||
out = sum(x);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{});
|
||||
CHECK_EQ(out.shape(), Shape{});
|
||||
CHECK_EQ(out.size(), 1);
|
||||
out = sum(x, true);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1});
|
||||
CHECK_EQ(out.shape(), Shape{1});
|
||||
out = sum(x, std::vector<int>{});
|
||||
CHECK_EQ(out.shape(), x.shape());
|
||||
|
||||
@ -788,7 +788,7 @@ TEST_CASE("test reduction ops") {
|
||||
CHECK_EQ(out.ndim(), 0);
|
||||
out = sum(x, -1, true);
|
||||
CHECK_EQ(out.ndim(), 1);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1});
|
||||
CHECK_EQ(out.shape(), Shape{1});
|
||||
|
||||
CHECK_THROWS_AS(sum(x, 1), std::out_of_range);
|
||||
CHECK_THROWS_AS(sum(x, -2), std::out_of_range);
|
||||
@ -797,21 +797,21 @@ TEST_CASE("test reduction ops") {
|
||||
|
||||
x = zeros({2, 3, 4});
|
||||
out = sum(x, {0, 2});
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3});
|
||||
CHECK_EQ(out.shape(), Shape{3});
|
||||
out = sum(x, std::vector<int>{});
|
||||
CHECK_EQ(out.shape(), x.shape());
|
||||
|
||||
out = sum(x, {0, -1});
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3});
|
||||
CHECK_EQ(out.shape(), Shape{3});
|
||||
|
||||
out = sum(x, {0, -1}, true);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1, 3, 1});
|
||||
CHECK_EQ(out.shape(), Shape{1, 3, 1});
|
||||
|
||||
out = sum(x, true);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1, 1, 1});
|
||||
CHECK_EQ(out.shape(), Shape{1, 1, 1});
|
||||
|
||||
out = sum(x);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{});
|
||||
CHECK_EQ(out.shape(), Shape{});
|
||||
|
||||
CHECK_THROWS_AS(sum(x, 3), std::out_of_range);
|
||||
CHECK_THROWS_AS(sum(x, -4), std::out_of_range);
|
||||
@ -986,7 +986,7 @@ TEST_CASE("test reduction ops") {
|
||||
std::vector<float> nums = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
x = array(nums.data(), {2, 2});
|
||||
auto y = logsumexp(x, {0, 1}, true);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{1, 1});
|
||||
CHECK_EQ(y.shape(), Shape{1, 1});
|
||||
auto result = std::log(
|
||||
std::exp(nums[0]) + std::exp(nums[1]) + std::exp(nums[2]) +
|
||||
std::exp(nums[3]));
|
||||
@ -1594,7 +1594,7 @@ TEST_CASE("test arithmetic binary ops") {
|
||||
x = array({1.0, 2.0, 3.0}, {1, 3});
|
||||
y = array({1.0, 2.0, 3.0}, {1, 3});
|
||||
z = add(x, y);
|
||||
CHECK_EQ(z.shape(), std::vector<int>{1, 3});
|
||||
CHECK_EQ(z.shape(), Shape{1, 3});
|
||||
auto eq = array_equal(z, array({2.0, 4.0, 6.0}, {1, 3}));
|
||||
CHECK(eq.item<bool>());
|
||||
|
||||
@ -1626,13 +1626,13 @@ TEST_CASE("test arithmetic binary ops") {
|
||||
x = array({1.0, 2.0}, {1, 2});
|
||||
y = array({1.0, 2.0}, {2, 1});
|
||||
z = add(x, y);
|
||||
CHECK_EQ(z.shape(), std::vector<int>{2, 2});
|
||||
CHECK_EQ(z.shape(), Shape{2, 2});
|
||||
eq = array_equal(z, array({2.0, 3.0, 3.0, 4.0}, {2, 2}));
|
||||
CHECK(eq.item<bool>());
|
||||
|
||||
x = ones({3, 2, 1});
|
||||
z = x + 2.0;
|
||||
CHECK_EQ(z.shape(), std::vector<int>{3, 2, 1});
|
||||
CHECK_EQ(z.shape(), Shape{3, 2, 1});
|
||||
eq = array_equal(z, array({3.0, 3.0, 3.0, 3.0, 3.0, 3.0}, {3, 2, 1}));
|
||||
CHECK(eq.item<bool>());
|
||||
|
||||
@ -1642,7 +1642,7 @@ TEST_CASE("test arithmetic binary ops") {
|
||||
z = x + y;
|
||||
z.eval();
|
||||
CHECK_EQ(z.size(), 0);
|
||||
CHECK_EQ(z.shape(), std::vector<int>{0});
|
||||
CHECK_EQ(z.shape(), Shape{0});
|
||||
|
||||
// Check subtraction
|
||||
x = array({3, 2, 1});
|
||||
@ -1725,46 +1725,46 @@ TEST_CASE("test arithmetic binary ops") {
|
||||
|
||||
TEST_CASE("test broadcast") {
|
||||
auto s = broadcast_shapes({1}, {1, 2});
|
||||
CHECK_EQ(s, std::vector<int>{1, 2});
|
||||
CHECK_EQ(s, Shape{1, 2});
|
||||
|
||||
s = broadcast_shapes({1, 2}, {1});
|
||||
CHECK_EQ(s, std::vector<int>{1, 2});
|
||||
CHECK_EQ(s, Shape{1, 2});
|
||||
|
||||
s = broadcast_shapes({2, 2}, {});
|
||||
CHECK_EQ(s, std::vector<int>{2, 2});
|
||||
CHECK_EQ(s, Shape{2, 2});
|
||||
|
||||
s = broadcast_shapes({}, {1, 1});
|
||||
CHECK_EQ(s, std::vector<int>{1, 1});
|
||||
CHECK_EQ(s, Shape{1, 1});
|
||||
|
||||
s = broadcast_shapes({1, 2, 1}, {2});
|
||||
CHECK_EQ(s, std::vector<int>{1, 2, 2});
|
||||
CHECK_EQ(s, Shape{1, 2, 2});
|
||||
|
||||
s = broadcast_shapes({2}, {1, 2, 1});
|
||||
CHECK_EQ(s, std::vector<int>{1, 2, 2});
|
||||
CHECK_EQ(s, Shape{1, 2, 2});
|
||||
|
||||
s = broadcast_shapes({2, 2, 2}, {1, 2, 1});
|
||||
CHECK_EQ(s, std::vector<int>{2, 2, 2});
|
||||
CHECK_EQ(s, Shape{2, 2, 2});
|
||||
|
||||
s = broadcast_shapes({2, 2, 2, 1}, {1, 2, 1});
|
||||
CHECK_EQ(s, std::vector<int>{2, 2, 2, 1});
|
||||
CHECK_EQ(s, Shape{2, 2, 2, 1});
|
||||
|
||||
s = broadcast_shapes({0}, {0, 0});
|
||||
CHECK_EQ(s, std::vector<int>{0, 0});
|
||||
CHECK_EQ(s, Shape{0, 0});
|
||||
|
||||
CHECK_EQ(broadcast_shapes({}, {0}), std::vector<int>{0});
|
||||
CHECK_EQ(broadcast_shapes({}, {0}), Shape{0});
|
||||
|
||||
s = broadcast_shapes({5, 0}, {0, 5, 0});
|
||||
CHECK_EQ(s, std::vector<int>{0, 5, 0});
|
||||
CHECK_EQ(s, Shape{0, 5, 0});
|
||||
|
||||
CHECK_EQ(broadcast_shapes({}, {0}), std::vector<int>{0});
|
||||
CHECK_EQ(broadcast_shapes({1}, {0}), std::vector<int>{0});
|
||||
CHECK_EQ(broadcast_shapes({1}, {0}), std::vector<int>{0});
|
||||
CHECK_EQ(broadcast_shapes({1}, {0, 0}), std::vector<int>{0, 0});
|
||||
CHECK_EQ(broadcast_shapes({1, 1}, {0}), std::vector<int>{1, 0});
|
||||
CHECK_EQ(broadcast_shapes({1, 1}, {0, 0}), std::vector<int>{0, 0});
|
||||
CHECK_EQ(broadcast_shapes({2, 1}, {1, 0}), std::vector<int>{2, 0});
|
||||
CHECK_EQ(broadcast_shapes({2, 1}, {2, 0}), std::vector<int>{2, 0});
|
||||
CHECK_EQ(broadcast_shapes({2, 1}, {1, 2, 0}), std::vector<int>{1, 2, 0});
|
||||
CHECK_EQ(broadcast_shapes({}, {0}), Shape{0});
|
||||
CHECK_EQ(broadcast_shapes({1}, {0}), Shape{0});
|
||||
CHECK_EQ(broadcast_shapes({1}, {0}), Shape{0});
|
||||
CHECK_EQ(broadcast_shapes({1}, {0, 0}), Shape{0, 0});
|
||||
CHECK_EQ(broadcast_shapes({1, 1}, {0}), Shape{1, 0});
|
||||
CHECK_EQ(broadcast_shapes({1, 1}, {0, 0}), Shape{0, 0});
|
||||
CHECK_EQ(broadcast_shapes({2, 1}, {1, 0}), Shape{2, 0});
|
||||
CHECK_EQ(broadcast_shapes({2, 1}, {2, 0}), Shape{2, 0});
|
||||
CHECK_EQ(broadcast_shapes({2, 1}, {1, 2, 0}), Shape{1, 2, 0});
|
||||
CHECK_THROWS_AS(broadcast_shapes({2}, {0}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(broadcast_shapes({2, 1}, {0, 0}), std::invalid_argument);
|
||||
|
||||
@ -1778,19 +1778,19 @@ TEST_CASE("test broadcast") {
|
||||
CHECK_EQ(broadcast_to(x, {1, 1}).item<float>(), 2.3f);
|
||||
|
||||
x = broadcast_to(x, {5, 1});
|
||||
CHECK_EQ(x.shape(), std::vector<int>{5, 1});
|
||||
CHECK_EQ(x.shape(), Shape{5, 1});
|
||||
x.eval();
|
||||
CHECK_EQ(x.strides(), std::vector<size_t>{0, 0});
|
||||
CHECK_EQ(x.strides(), Strides{0, 0});
|
||||
|
||||
CHECK_THROWS_AS(broadcast_to(x, {1, 5}), std::invalid_argument);
|
||||
x = broadcast_to(x, {5, 5});
|
||||
CHECK_EQ(x.shape(), std::vector<int>{5, 5});
|
||||
CHECK_EQ(x.shape(), Shape{5, 5});
|
||||
|
||||
x = zeros({2, 1, 2});
|
||||
x = broadcast_to(x, {4, 2, 1, 2});
|
||||
CHECK_EQ(x.shape(), std::vector<int>{4, 2, 1, 2});
|
||||
CHECK_EQ(x.shape(), Shape{4, 2, 1, 2});
|
||||
x.eval();
|
||||
CHECK_EQ(x.strides(), std::vector<size_t>{0, 2, 0, 1});
|
||||
CHECK_EQ(x.strides(), Strides{0, 2, 0, 1});
|
||||
|
||||
// Broadcast on empty arrays works as expected
|
||||
x = array({});
|
||||
@ -1801,29 +1801,29 @@ TEST_CASE("test broadcast") {
|
||||
auto y = broadcast_to(x, {0});
|
||||
eval(y);
|
||||
CHECK_EQ(y.size(), 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{0});
|
||||
CHECK_EQ(y.shape(), Shape{0});
|
||||
|
||||
x = array({1, 2}, {2, 1});
|
||||
y = broadcast_to(x, {2, 0});
|
||||
eval(y);
|
||||
CHECK_EQ(y.size(), 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{2, 0});
|
||||
CHECK_EQ(y.shape(), Shape{2, 0});
|
||||
|
||||
// Check repeat application works
|
||||
x = zeros({2});
|
||||
x = broadcast_to(broadcast_to(x, {2, 2}), {2, 2});
|
||||
CHECK_EQ(x.shape(), std::vector<int>{2, 2});
|
||||
CHECK_EQ(x.shape(), Shape{2, 2});
|
||||
x.eval();
|
||||
CHECK_EQ(x.strides(), std::vector<size_t>{0, 1});
|
||||
CHECK_EQ(x.strides(), Strides{0, 1});
|
||||
x = broadcast_to(broadcast_to(x, {2, 2}), {2, 2, 2});
|
||||
CHECK_EQ(x.shape(), std::vector<int>{2, 2, 2});
|
||||
CHECK_EQ(x.shape(), Shape{2, 2, 2});
|
||||
x.eval();
|
||||
CHECK_EQ(x.strides(), std::vector<size_t>{0, 0, 1});
|
||||
CHECK_EQ(x.strides(), Strides{0, 0, 1});
|
||||
|
||||
// Broadcast on transposed array works
|
||||
x = array({0, 1, 2, 3, 4, 5}, {2, 3});
|
||||
x = broadcast_to(transpose(x), {2, 3, 2});
|
||||
CHECK_EQ(x.shape(), std::vector<int>{2, 3, 2});
|
||||
CHECK_EQ(x.shape(), Shape{2, 3, 2});
|
||||
y = broadcast_to(array({0, 3, 1, 4, 2, 5}, {3, 2}), {2, 3, 2});
|
||||
CHECK(array_equal(x, y).item<bool>());
|
||||
|
||||
@ -1867,16 +1867,16 @@ TEST_CASE("test gather") {
|
||||
auto x = arange(20);
|
||||
auto y = arange(10);
|
||||
auto out = gather(x, y, 0, {1});
|
||||
CHECK_EQ(out.shape(), std::vector<int>{10, 1});
|
||||
CHECK_EQ(out.shape(), Shape{10, 1});
|
||||
CHECK(array_equal(reshape(out, {-1}), y).item<bool>());
|
||||
|
||||
out = gather(x, array({15}, uint32), 0, {1});
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1, 1});
|
||||
CHECK_EQ(out.shape(), Shape{1, 1});
|
||||
CHECK_EQ(out.item<int32_t>(), 15);
|
||||
|
||||
// No index gather works
|
||||
out = gather(x, {}, std::vector<int>{}, {10});
|
||||
CHECK_EQ(out.shape(), std::vector<int>{10});
|
||||
CHECK_EQ(out.shape(), Shape{10});
|
||||
CHECK(array_equal(out, arange(10)).item<bool>());
|
||||
|
||||
// Basic test of correctness with 2D input
|
||||
@ -1884,13 +1884,13 @@ TEST_CASE("test gather") {
|
||||
x = reshape(x, {4, 32});
|
||||
y = array({0, 1}, uint32);
|
||||
out = gather(x, y, 0, {1, 32});
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 1, 32});
|
||||
CHECK_EQ(out.shape(), Shape{2, 1, 32});
|
||||
CHECK(array_equal(reshape(out, {64}), arange(64)).item<bool>());
|
||||
|
||||
x = reshape(x, {64, 2});
|
||||
y = array({0}, uint32);
|
||||
out = gather(x, y, 0, {64, 1});
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1, 64, 1});
|
||||
CHECK_EQ(out.shape(), Shape{1, 64, 1});
|
||||
CHECK(array_equal(out, reshape(arange(0, 128, 2), {1, 64, 1})).item<bool>());
|
||||
|
||||
// Basic test of correctness with 3D input
|
||||
@ -1898,7 +1898,7 @@ TEST_CASE("test gather") {
|
||||
x = reshape(x, {8, 4, 8});
|
||||
y = array({0}, uint32);
|
||||
out = gather(x, y, 0, {8, 1, 1});
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1, 8, 1, 1});
|
||||
CHECK_EQ(out.shape(), Shape{1, 8, 1, 1});
|
||||
CHECK(
|
||||
array_equal(out, reshape(arange(0, 256, 32), {1, 8, 1, 1})).item<bool>());
|
||||
|
||||
@ -1913,10 +1913,10 @@ TEST_CASE("test take") {
|
||||
// Empty takes
|
||||
auto empty = astype(array({}), int32);
|
||||
auto z = take(array({1}), empty);
|
||||
CHECK_EQ(z.shape(), std::vector<int>{0});
|
||||
CHECK_EQ(z.shape(), Shape{0});
|
||||
empty = reshape(empty, {1, 0, 1});
|
||||
z = take(array({1}), empty);
|
||||
CHECK_EQ(z.shape(), std::vector<int>{1, 0, 1});
|
||||
CHECK_EQ(z.shape(), Shape{1, 0, 1});
|
||||
|
||||
CHECK_THROWS(take(array({}), array(1)));
|
||||
|
||||
@ -1926,7 +1926,7 @@ TEST_CASE("test take") {
|
||||
// Take a single row
|
||||
auto x = reshape(arange(256), {8, 4, 8});
|
||||
z = take(x, array({0}, uint32), 0);
|
||||
CHECK_EQ(z.shape(), std::vector<int>{1, 4, 8});
|
||||
CHECK_EQ(z.shape(), Shape{1, 4, 8});
|
||||
z = reshape(z, {32});
|
||||
CHECK(array_equal(z, arange(32)).item<bool>());
|
||||
|
||||
@ -2017,12 +2017,12 @@ TEST_CASE("test take along axis") {
|
||||
|
||||
out = take_along_axis(a, reshape(array({1}), {1, 1}), 0);
|
||||
eval(out); // Make sure it runs
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1, 0});
|
||||
CHECK_EQ(out.shape(), Shape{1, 0});
|
||||
|
||||
auto inds = reshape(astype(array({}), int32), {1, 0});
|
||||
out = take_along_axis(a, inds, 0);
|
||||
eval(out); // Make sure it runs
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1, 0});
|
||||
CHECK_EQ(out.shape(), Shape{1, 0});
|
||||
|
||||
a = array({1, 2, 3, 4}, {2, 2});
|
||||
inds = array({0, 1}, {1, 2});
|
||||
@ -2084,7 +2084,7 @@ TEST_CASE("test put along axis") {
|
||||
auto inds = reshape(astype(array({}), int32), {1, 0});
|
||||
out = take_along_axis(a, inds, 0);
|
||||
eval(out); // Make sure it runs
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1, 0});
|
||||
CHECK_EQ(out.shape(), Shape{1, 0});
|
||||
|
||||
a = array({1, 2, 3, 4}, {2, 2});
|
||||
inds = array({0, 1}, {1, 2});
|
||||
@ -2506,9 +2506,9 @@ TEST_CASE("test scan op") {
|
||||
|
||||
TEST_CASE("test pad") {
|
||||
auto x = zeros({1, 2, 3});
|
||||
CHECK_EQ(pad(x, 1).shape(), std::vector<int>{3, 4, 5});
|
||||
CHECK_EQ(pad(x, {0, 1}).shape(), std::vector<int>{2, 3, 4});
|
||||
CHECK_EQ(pad(x, {{1, 1}, {1, 2}, {3, 1}}).shape(), std::vector<int>{3, 5, 7});
|
||||
CHECK_EQ(pad(x, 1).shape(), Shape{3, 4, 5});
|
||||
CHECK_EQ(pad(x, {0, 1}).shape(), Shape{2, 3, 4});
|
||||
CHECK_EQ(pad(x, {{1, 1}, {1, 2}, {3, 1}}).shape(), Shape{3, 5, 7});
|
||||
|
||||
x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||
auto padded_x = pad(x, 1);
|
||||
@ -2647,20 +2647,20 @@ TEST_CASE("test where") {
|
||||
|
||||
TEST_CASE("test stack") {
|
||||
auto x = array({});
|
||||
CHECK_EQ(stack({x}, 0).shape(), std::vector<int>{1, 0});
|
||||
CHECK_EQ(stack({x}, 1).shape(), std::vector<int>{0, 1});
|
||||
CHECK_EQ(stack({x}, 0).shape(), Shape{1, 0});
|
||||
CHECK_EQ(stack({x}, 1).shape(), Shape{0, 1});
|
||||
|
||||
x = array({1, 2, 3}, {3});
|
||||
CHECK_EQ(stack({x}, 0).shape(), std::vector<int>{1, 3});
|
||||
CHECK_EQ(stack({x}, 1).shape(), std::vector<int>{3, 1});
|
||||
CHECK_EQ(stack({x}, 0).shape(), Shape{1, 3});
|
||||
CHECK_EQ(stack({x}, 1).shape(), Shape{3, 1});
|
||||
|
||||
auto y = array({4, 5, 6}, {3});
|
||||
auto z = std::vector<array>{x, y};
|
||||
CHECK_EQ(stack(z).shape(), std::vector<int>{2, 3});
|
||||
CHECK_EQ(stack(z, 0).shape(), std::vector<int>{2, 3});
|
||||
CHECK_EQ(stack(z, 1).shape(), std::vector<int>{3, 2});
|
||||
CHECK_EQ(stack(z, -1).shape(), std::vector<int>{3, 2});
|
||||
CHECK_EQ(stack(z, -2).shape(), std::vector<int>{2, 3});
|
||||
CHECK_EQ(stack(z).shape(), Shape{2, 3});
|
||||
CHECK_EQ(stack(z, 0).shape(), Shape{2, 3});
|
||||
CHECK_EQ(stack(z, 1).shape(), Shape{3, 2});
|
||||
CHECK_EQ(stack(z, -1).shape(), Shape{3, 2});
|
||||
CHECK_EQ(stack(z, -2).shape(), Shape{2, 3});
|
||||
|
||||
CHECK_THROWS_MESSAGE(stack({}, 0), "No arrays provided for stacking");
|
||||
|
||||
@ -2676,20 +2676,20 @@ TEST_CASE("test stack") {
|
||||
|
||||
TEST_CASE("test eye") {
|
||||
auto eye_3 = eye(3);
|
||||
CHECK_EQ(eye_3.shape(), std::vector<int>{3, 3});
|
||||
CHECK_EQ(eye_3.shape(), Shape{3, 3});
|
||||
auto expected_eye_3 =
|
||||
array({1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, {3, 3});
|
||||
CHECK(array_equal(eye_3, expected_eye_3).item<bool>());
|
||||
|
||||
auto eye_3x2 = eye(3, 2);
|
||||
CHECK_EQ(eye_3x2.shape(), std::vector<int>{3, 2});
|
||||
CHECK_EQ(eye_3x2.shape(), Shape{3, 2});
|
||||
auto expected_eye_3x2 = array({1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}, {3, 2});
|
||||
CHECK(array_equal(eye_3x2, expected_eye_3x2).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test tri") {
|
||||
auto _tri = tri(4, 4, 0, float32);
|
||||
CHECK_EQ(_tri.shape(), std::vector<int>{4, 4});
|
||||
CHECK_EQ(_tri.shape(), Shape{4, 4});
|
||||
auto expected_tri = array(
|
||||
{1.0f,
|
||||
0.0f,
|
||||
@ -2712,8 +2712,8 @@ TEST_CASE("test tri") {
|
||||
}
|
||||
|
||||
TEST_CASE("test tril") {
|
||||
auto _tril = tril(full(std::vector<int>{4, 4}, 2.0f, float32), 0);
|
||||
CHECK_EQ(_tril.shape(), std::vector<int>{4, 4});
|
||||
auto _tril = tril(full({4, 4}, 2.0f, float32), 0);
|
||||
CHECK_EQ(_tril.shape(), Shape{4, 4});
|
||||
auto expected_tri = array(
|
||||
{2.0f,
|
||||
0.0f,
|
||||
@ -2736,8 +2736,8 @@ TEST_CASE("test tril") {
|
||||
}
|
||||
|
||||
TEST_CASE("test triu") {
|
||||
auto _triu = triu(full(std::vector<int>{4, 4}, 2.0f, float32), 0);
|
||||
CHECK_EQ(_triu.shape(), std::vector<int>{4, 4});
|
||||
auto _triu = triu(full({4, 4}, 2.0f, float32), 0);
|
||||
CHECK_EQ(_triu.shape(), Shape{4, 4});
|
||||
auto expected_tri = array(
|
||||
{2.0f,
|
||||
2.0f,
|
||||
@ -2761,7 +2761,7 @@ TEST_CASE("test triu") {
|
||||
|
||||
TEST_CASE("test identity") {
|
||||
auto id_4 = identity(4);
|
||||
CHECK_EQ(id_4.shape(), std::vector<int>{4, 4});
|
||||
CHECK_EQ(id_4.shape(), Shape{4, 4});
|
||||
auto expected_id_4 = array(
|
||||
{1.0f,
|
||||
0.0f,
|
||||
@ -2785,7 +2785,7 @@ TEST_CASE("test identity") {
|
||||
|
||||
TEST_CASE("test eye with positive k offset") {
|
||||
auto eye_3_k1 = eye(3, 4, 1);
|
||||
CHECK_EQ(eye_3_k1.shape(), std::vector<int>{3, 4});
|
||||
CHECK_EQ(eye_3_k1.shape(), Shape{3, 4});
|
||||
auto expected_eye_3_k1 = array(
|
||||
{0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f},
|
||||
{3, 4});
|
||||
@ -2794,7 +2794,7 @@ TEST_CASE("test eye with positive k offset") {
|
||||
|
||||
TEST_CASE("test eye with negative k offset") {
|
||||
auto eye_4_k_minus1 = eye(4, 3, -1);
|
||||
CHECK_EQ(eye_4_k_minus1.shape(), std::vector<int>{4, 3});
|
||||
CHECK_EQ(eye_4_k_minus1.shape(), Shape{4, 3});
|
||||
auto expected_eye_4_k_minus1 = array(
|
||||
{0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f},
|
||||
{4, 3});
|
||||
@ -2844,9 +2844,9 @@ TEST_CASE("test quantize dequantize") {
|
||||
for (int i = 2; i <= 8; i *= 2) {
|
||||
int el_per_int = 32 / i;
|
||||
auto [x_q, scales, biases] = quantize(x, 128, i);
|
||||
CHECK_EQ(x_q.shape(), std::vector<int>{128, 512 / el_per_int});
|
||||
CHECK_EQ(scales.shape(), std::vector<int>{128, 4});
|
||||
CHECK_EQ(biases.shape(), std::vector<int>{128, 4});
|
||||
CHECK_EQ(x_q.shape(), Shape{128, 512 / el_per_int});
|
||||
CHECK_EQ(scales.shape(), Shape{128, 4});
|
||||
CHECK_EQ(biases.shape(), Shape{128, 4});
|
||||
|
||||
auto x_hat = dequantize(x_q, scales, biases, 128, i);
|
||||
auto max_diff = max(abs(x - x_hat)).item<float>();
|
||||
@ -3081,7 +3081,7 @@ TEST_CASE("test diagonal") {
|
||||
|
||||
out = diagonal(x, -5, 0, 1);
|
||||
eval(out);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{0});
|
||||
CHECK_EQ(out.shape(), Shape{0});
|
||||
|
||||
x = array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 2, 2});
|
||||
out = diagonal(x, 1, 0, 1);
|
||||
@ -3337,17 +3337,17 @@ TEST_CASE("test atleast_1d") {
|
||||
auto x = array(1);
|
||||
auto out = atleast_1d(x);
|
||||
CHECK_EQ(out.ndim(), 1);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1});
|
||||
CHECK_EQ(out.shape(), Shape{1});
|
||||
|
||||
x = array({1, 2, 3}, {3});
|
||||
out = atleast_1d(x);
|
||||
CHECK_EQ(out.ndim(), 1);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3});
|
||||
CHECK_EQ(out.shape(), Shape{3});
|
||||
|
||||
x = array({1, 2, 3}, {3, 1});
|
||||
out = atleast_1d(x);
|
||||
CHECK_EQ(out.ndim(), 2);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
|
||||
CHECK_EQ(out.shape(), Shape{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_1d vector") {
|
||||
@ -3356,28 +3356,28 @@ TEST_CASE("test atleast_1d vector") {
|
||||
auto out = atleast_1d(x);
|
||||
CHECK_EQ(out.size(), 3);
|
||||
CHECK_EQ(out[0].ndim(), 1);
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{1});
|
||||
CHECK_EQ(out[0].shape(), Shape{1});
|
||||
CHECK_EQ(out[1].ndim(), 1);
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{3});
|
||||
CHECK_EQ(out[1].shape(), Shape{3});
|
||||
CHECK_EQ(out[2].ndim(), 2);
|
||||
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1});
|
||||
CHECK_EQ(out[2].shape(), Shape{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_2d") {
|
||||
auto x = array(1);
|
||||
auto out = atleast_2d(x);
|
||||
CHECK_EQ(out.ndim(), 2);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1, 1});
|
||||
CHECK_EQ(out.shape(), Shape{1, 1});
|
||||
|
||||
x = array({1, 2, 3}, {3});
|
||||
out = atleast_2d(x);
|
||||
CHECK_EQ(out.ndim(), 2);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1, 3});
|
||||
CHECK_EQ(out.shape(), Shape{1, 3});
|
||||
|
||||
x = array({1, 2, 3}, {3, 1});
|
||||
out = atleast_2d(x);
|
||||
CHECK_EQ(out.ndim(), 2);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
|
||||
CHECK_EQ(out.shape(), Shape{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_2d vector") {
|
||||
@ -3386,28 +3386,28 @@ TEST_CASE("test atleast_2d vector") {
|
||||
auto out = atleast_2d(x);
|
||||
CHECK_EQ(out.size(), 3);
|
||||
CHECK_EQ(out[0].ndim(), 2);
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{1, 1});
|
||||
CHECK_EQ(out[0].shape(), Shape{1, 1});
|
||||
CHECK_EQ(out[1].ndim(), 2);
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{1, 3});
|
||||
CHECK_EQ(out[1].shape(), Shape{1, 3});
|
||||
CHECK_EQ(out[2].ndim(), 2);
|
||||
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1});
|
||||
CHECK_EQ(out[2].shape(), Shape{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_3d") {
|
||||
auto x = array(1);
|
||||
auto out = atleast_3d(x);
|
||||
CHECK_EQ(out.ndim(), 3);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1, 1, 1});
|
||||
CHECK_EQ(out.shape(), Shape{1, 1, 1});
|
||||
|
||||
x = array({1, 2, 3}, {3});
|
||||
out = atleast_3d(x);
|
||||
CHECK_EQ(out.ndim(), 3);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1, 3, 1});
|
||||
CHECK_EQ(out.shape(), Shape{1, 3, 1});
|
||||
|
||||
x = array({1, 2, 3}, {3, 1});
|
||||
out = atleast_3d(x);
|
||||
CHECK_EQ(out.ndim(), 3);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3, 1, 1});
|
||||
CHECK_EQ(out.shape(), Shape{3, 1, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_3d vector") {
|
||||
@ -3416,11 +3416,11 @@ TEST_CASE("test atleast_3d vector") {
|
||||
auto out = atleast_3d(x);
|
||||
CHECK_EQ(out.size(), 3);
|
||||
CHECK_EQ(out[0].ndim(), 3);
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{1, 1, 1});
|
||||
CHECK_EQ(out[0].shape(), Shape{1, 1, 1});
|
||||
CHECK_EQ(out[1].ndim(), 3);
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{1, 3, 1});
|
||||
CHECK_EQ(out[1].shape(), Shape{1, 3, 1});
|
||||
CHECK_EQ(out[2].ndim(), 3);
|
||||
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1, 1});
|
||||
CHECK_EQ(out[2].shape(), Shape{3, 1, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test topk") {
|
||||
|
@ -141,7 +141,7 @@ TEST_CASE("test random bits") {
|
||||
|
||||
{
|
||||
auto key = array({0u, 0u, 1u, 1u}, {2, 2});
|
||||
auto shape = std::vector<int>{3};
|
||||
auto shape = Shape{3};
|
||||
auto fn = [&shape](array k) { return random::bits(shape, k); };
|
||||
|
||||
auto expected = array(
|
||||
@ -264,7 +264,7 @@ TEST_CASE("test random uniform") {
|
||||
|
||||
// Check broadcasting
|
||||
x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3});
|
||||
CHECK_EQ(x.shape(), std::vector<int>{3, 3});
|
||||
CHECK_EQ(x.shape(), Shape{3, 3});
|
||||
CHECK_THROWS_AS(
|
||||
random::uniform(zeros({3, 3}), 1.0, {1, 3}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(
|
||||
@ -332,11 +332,11 @@ TEST_CASE("test random uniform") {
|
||||
return random::uniform(low, 1, {3}, float32, k);
|
||||
};
|
||||
auto out = vmap(fun, -1)(key, zeros({2, 3}));
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3});
|
||||
CHECK_EQ(out.shape(), Shape{2, 3});
|
||||
|
||||
key = zeros({2, 2}, uint32);
|
||||
out = vmap(fun)(key, zeros({2, 3}));
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3});
|
||||
CHECK_EQ(out.shape(), Shape{2, 3});
|
||||
}
|
||||
|
||||
// Check bounds are respected
|
||||
@ -425,7 +425,7 @@ TEST_CASE("test random multivariate_normal") {
|
||||
auto mean = zeros({3});
|
||||
auto cov = eye(3);
|
||||
auto x = random::multivariate_normal(mean, cov, {1000}, float32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>({1000, 3}));
|
||||
CHECK_EQ(x.shape(), Shape{1000, 3});
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
}
|
||||
|
||||
@ -435,7 +435,7 @@ TEST_CASE("test random multivariate_normal") {
|
||||
auto cov = array({1., -1, -.1, 1.});
|
||||
cov = reshape(cov, {2, 2});
|
||||
auto x = random::multivariate_normal(mean, cov, {1}, float32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>({1, 2}));
|
||||
CHECK_EQ(x.shape(), Shape{1, 2});
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
}
|
||||
|
||||
@ -457,7 +457,7 @@ TEST_CASE("test random multivariate_normal") {
|
||||
auto mean = zeros({3});
|
||||
auto cov = zeros({1, 2, 3, 3});
|
||||
auto x = random::multivariate_normal(mean, cov, {1000, 2}, float32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>({1000, 2, 3}));
|
||||
CHECK_EQ(x.shape(), Shape{1000, 2, 3});
|
||||
}
|
||||
{
|
||||
auto mean = zeros({3});
|
||||
@ -537,7 +537,7 @@ TEST_CASE("test random bernoulli") {
|
||||
|
||||
// Return array with correct shape
|
||||
x = random::bernoulli(0.5, {3, 3});
|
||||
CHECK_EQ(x.shape(), std::vector<int>({3, 3}));
|
||||
CHECK_EQ(x.shape(), Shape{3, 3});
|
||||
|
||||
// Try with p = {}
|
||||
x = random::bernoulli(array({}));
|
||||
@ -547,7 +547,7 @@ TEST_CASE("test random bernoulli") {
|
||||
auto p = array({0.1, 0.2, 0.3});
|
||||
p = reshape(p, {1, 3});
|
||||
x = random::bernoulli(p, {4, 3});
|
||||
CHECK_EQ(x.shape(), std::vector<int>({4, 3}));
|
||||
CHECK_EQ(x.shape(), Shape{4, 3});
|
||||
|
||||
CHECK_THROWS_AS(random::bernoulli(array({}), {3, 3}), std::invalid_argument);
|
||||
|
||||
@ -572,7 +572,7 @@ TEST_CASE("Test truncated normal") {
|
||||
|
||||
// Requested shape
|
||||
x = random::truncated_normal(array(-2.0), array(2.0), {3, 4});
|
||||
CHECK_EQ(x.shape(), std::vector<int>({3, 4}));
|
||||
CHECK_EQ(x.shape(), Shape{3, 4});
|
||||
|
||||
// Empty array
|
||||
x = random::truncated_normal(array({}), array({}));
|
||||
@ -584,7 +584,7 @@ TEST_CASE("Test truncated normal") {
|
||||
x = random::truncated_normal(lower, higher);
|
||||
|
||||
// All in bounds
|
||||
CHECK_EQ(x.shape(), std::vector<int>({3, 2}));
|
||||
CHECK_EQ(x.shape(), Shape{3, 2});
|
||||
CHECK((all(x <= higher).item<bool>() && all(lower <= x).item<bool>()));
|
||||
|
||||
// high < low => all equal to low
|
||||
@ -615,17 +615,17 @@ TEST_CASE("test categorical") {
|
||||
CHECK_THROWS(categorical(logits, 1, std::vector<int>{11}));
|
||||
CHECK_THROWS(categorical(logits, 1, {10, 1}));
|
||||
|
||||
CHECK_EQ(categorical(logits, -1).shape(), std::vector<int>{10});
|
||||
CHECK_EQ(categorical(logits, 0).shape(), std::vector<int>{20});
|
||||
CHECK_EQ(categorical(logits, 1).shape(), std::vector<int>{10});
|
||||
CHECK_EQ(categorical(logits, -1).shape(), Shape{10});
|
||||
CHECK_EQ(categorical(logits, 0).shape(), Shape{20});
|
||||
CHECK_EQ(categorical(logits, 1).shape(), Shape{10});
|
||||
|
||||
auto out = categorical(logits);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{10});
|
||||
CHECK_EQ(out.shape(), Shape{10});
|
||||
CHECK_EQ(out.dtype(), uint32);
|
||||
CHECK(max(out).item<uint32_t>() < 20);
|
||||
|
||||
out = categorical(logits, 0, {5, 20});
|
||||
CHECK_EQ(out.shape(), std::vector<int>{5, 20});
|
||||
CHECK_EQ(out.shape(), Shape{5, 20});
|
||||
CHECK(max(out).item<uint32_t>() < 10);
|
||||
|
||||
float inf = std::numeric_limits<float>::infinity();
|
||||
@ -636,9 +636,9 @@ TEST_CASE("test categorical") {
|
||||
CHECK_EQ(categorical(logits).item<uint32_t>(), 1);
|
||||
|
||||
logits = zeros({5, 4, 3});
|
||||
CHECK_EQ(categorical(logits, -1, 7).shape(), std::vector<int>{5, 4, 7});
|
||||
CHECK_EQ(categorical(logits, -2, 7).shape(), std::vector<int>{5, 3, 7});
|
||||
CHECK_EQ(categorical(logits, -3, 7).shape(), std::vector<int>{4, 3, 7});
|
||||
CHECK_EQ(categorical(logits, -1, 7).shape(), Shape{5, 4, 7});
|
||||
CHECK_EQ(categorical(logits, -2, 7).shape(), Shape{5, 3, 7});
|
||||
CHECK_EQ(categorical(logits, -3, 7).shape(), Shape{4, 3, 7});
|
||||
}
|
||||
|
||||
TEST_CASE("test laplace") {
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user