mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
feat: metal formatting and pre-commit bump (#1038)
* feat: metal formatting and pre-commit bump * add guards * update * more guards * more guards * smakk fix * Refactor instantiation of ternary types in ternary.metal * fix scan.metal
This commit is contained in:
parent
8db7161c94
commit
a30e7ed2da
@ -1,11 +1,11 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v18.1.3
|
rev: v18.1.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: clang-format
|
- id: clang-format
|
||||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||||
rev: 24.3.0
|
rev: 24.4.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
|
@ -33,7 +33,7 @@ array axpby(
|
|||||||
class Axpby : public Primitive {
|
class Axpby : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit Axpby(Stream stream, float alpha, float beta)
|
explicit Axpby(Stream stream, float alpha, float beta)
|
||||||
: Primitive(stream), alpha_(alpha), beta_(beta){};
|
: Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||||
|
@ -36,8 +36,8 @@ template <typename T>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_axpby(type_name, type) \
|
#define instantiate_axpby(type_name, type) \
|
||||||
template [[host_name("axpby_general_" #type_name)]] \
|
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
|
||||||
[[kernel]] void axpby_general<type>( \
|
axpby_general<type>( \
|
||||||
device const type* x [[buffer(0)]], \
|
device const type* x [[buffer(0)]], \
|
||||||
device const type* y [[buffer(1)]], \
|
device const type* y [[buffer(1)]], \
|
||||||
device type* out [[buffer(2)]], \
|
device type* out [[buffer(2)]], \
|
||||||
@ -48,8 +48,8 @@ template <typename T>
|
|||||||
constant const size_t* y_strides [[buffer(7)]], \
|
constant const size_t* y_strides [[buffer(7)]], \
|
||||||
constant const int& ndim [[buffer(8)]], \
|
constant const int& ndim [[buffer(8)]], \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name("axpby_contiguous_" #type_name)]] \
|
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
|
||||||
[[kernel]] void axpby_contiguous<type>( \
|
axpby_contiguous<type>( \
|
||||||
device const type* x [[buffer(0)]], \
|
device const type* x [[buffer(0)]], \
|
||||||
device const type* y [[buffer(1)]], \
|
device const type* y [[buffer(1)]], \
|
||||||
device type* out [[buffer(2)]], \
|
device type* out [[buffer(2)]], \
|
||||||
|
@ -14,7 +14,7 @@ class Buffer {
|
|||||||
void* ptr_;
|
void* ptr_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Buffer(void* ptr) : ptr_(ptr){};
|
Buffer(void* ptr) : ptr_(ptr) {};
|
||||||
|
|
||||||
// Get the raw data pointer from the buffer
|
// Get the raw data pointer from the buffer
|
||||||
void* raw_ptr();
|
void* raw_ptr();
|
||||||
|
@ -209,7 +209,7 @@ class array {
|
|||||||
allocator::Buffer buffer;
|
allocator::Buffer buffer;
|
||||||
deleter_t d;
|
deleter_t d;
|
||||||
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
||||||
: buffer(buffer), d(d){};
|
: buffer(buffer), d(d) {};
|
||||||
// Not copyable
|
// Not copyable
|
||||||
Data(const Data& d) = delete;
|
Data(const Data& d) = delete;
|
||||||
Data& operator=(const Data& d) = delete;
|
Data& operator=(const Data& d) = delete;
|
||||||
|
@ -38,7 +38,7 @@ using MTLFCList =
|
|||||||
|
|
||||||
struct CommandEncoder {
|
struct CommandEncoder {
|
||||||
CommandEncoder(MTL::ComputeCommandEncoder* enc)
|
CommandEncoder(MTL::ComputeCommandEncoder* enc)
|
||||||
: enc(enc), concurrent(false){};
|
: enc(enc), concurrent(false) {};
|
||||||
CommandEncoder(const CommandEncoder&) = delete;
|
CommandEncoder(const CommandEncoder&) = delete;
|
||||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||||
|
|
||||||
|
@ -12,13 +12,13 @@ template <typename T>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_arange(tname, type) \
|
#define instantiate_arange(tname, type) \
|
||||||
template [[host_name("arange" #tname)]] \
|
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
|
||||||
[[kernel]] void arange<type>( \
|
|
||||||
constant const type& start, \
|
constant const type& start, \
|
||||||
constant const type& step, \
|
constant const type& step, \
|
||||||
device type* out, \
|
device type* out, \
|
||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_arange(uint8, uint8_t)
|
instantiate_arange(uint8, uint8_t)
|
||||||
instantiate_arange(uint16, uint16_t)
|
instantiate_arange(uint16, uint16_t)
|
||||||
instantiate_arange(uint32, uint32_t)
|
instantiate_arange(uint32, uint32_t)
|
||||||
@ -29,4 +29,4 @@ instantiate_arange(int32, int32_t)
|
|||||||
instantiate_arange(int64, int64_t)
|
instantiate_arange(int64, int64_t)
|
||||||
instantiate_arange(float16, half)
|
instantiate_arange(float16, half)
|
||||||
instantiate_arange(float32, float)
|
instantiate_arange(float32, float)
|
||||||
instantiate_arange(bfloat16, bfloat16_t)
|
instantiate_arange(bfloat16, bfloat16_t) // clang-format on
|
@ -18,7 +18,8 @@ struct ArgMin {
|
|||||||
static constexpr constant U init = Limits<U>::max;
|
static constexpr constant U init = Limits<U>::max;
|
||||||
|
|
||||||
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
||||||
if (best.val > current.val || (best.val == current.val && best.index > current.index)) {
|
if (best.val > current.val ||
|
||||||
|
(best.val == current.val && best.index > current.index)) {
|
||||||
return current;
|
return current;
|
||||||
} else {
|
} else {
|
||||||
return best;
|
return best;
|
||||||
@ -26,11 +27,12 @@ struct ArgMin {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int N>
|
template <int N>
|
||||||
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
IndexValPair<U>
|
||||||
for (int i=0; i<N; i++) {
|
reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
if (vals[i] < best.val) {
|
if (vals[i] < best.val) {
|
||||||
best.val = vals[i];
|
best.val = vals[i];
|
||||||
best.index = offset+i;
|
best.index = offset + i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return best;
|
return best;
|
||||||
@ -42,7 +44,8 @@ struct ArgMax {
|
|||||||
static constexpr constant U init = Limits<U>::min;
|
static constexpr constant U init = Limits<U>::min;
|
||||||
|
|
||||||
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
||||||
if (best.val < current.val || (best.val == current.val && best.index > current.index)) {
|
if (best.val < current.val ||
|
||||||
|
(best.val == current.val && best.index > current.index)) {
|
||||||
return current;
|
return current;
|
||||||
} else {
|
} else {
|
||||||
return best;
|
return best;
|
||||||
@ -50,11 +53,12 @@ struct ArgMax {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int N>
|
template <int N>
|
||||||
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
IndexValPair<U>
|
||||||
for (int i=0; i<N; i++) {
|
reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
if (vals[i] > best.val) {
|
if (vals[i] > best.val) {
|
||||||
best.val = vals[i];
|
best.val = vals[i];
|
||||||
best.index = offset+i;
|
best.index = offset + i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return best;
|
return best;
|
||||||
@ -64,19 +68,16 @@ struct ArgMax {
|
|||||||
template <typename U>
|
template <typename U>
|
||||||
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
||||||
return IndexValPair<U>{
|
return IndexValPair<U>{
|
||||||
simd_shuffle_down(data.index, delta),
|
simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)};
|
||||||
simd_shuffle_down(data.val, delta)
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T, typename Op, int N_READS>
|
template <typename T, typename Op, int N_READS>
|
||||||
[[kernel]] void arg_reduce_general(
|
[[kernel]] void arg_reduce_general(
|
||||||
const device T *in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device uint32_t *out [[buffer(1)]],
|
device uint32_t* out [[buffer(1)]],
|
||||||
const device int *shape [[buffer(2)]],
|
const device int* shape [[buffer(2)]],
|
||||||
const device size_t *in_strides [[buffer(3)]],
|
const device size_t* in_strides [[buffer(3)]],
|
||||||
const device size_t *out_strides [[buffer(4)]],
|
const device size_t* out_strides [[buffer(4)]],
|
||||||
const device size_t& ndim [[buffer(5)]],
|
const device size_t& ndim [[buffer(5)]],
|
||||||
const device size_t& axis_stride [[buffer(6)]],
|
const device size_t& axis_stride [[buffer(6)]],
|
||||||
const device size_t& axis_size [[buffer(7)]],
|
const device size_t& axis_size [[buffer(7)]],
|
||||||
@ -86,7 +87,6 @@ template <typename T, typename Op, int N_READS>
|
|||||||
uint simd_size [[threads_per_simdgroup]],
|
uint simd_size [[threads_per_simdgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
// Shapes and strides *do not* contain the reduction axis. The reduction size
|
// Shapes and strides *do not* contain the reduction axis. The reduction size
|
||||||
// and stride are provided in axis_stride and axis_size.
|
// and stride are provided in axis_stride and axis_size.
|
||||||
//
|
//
|
||||||
@ -113,13 +113,13 @@ template <typename T, typename Op, int N_READS>
|
|||||||
threadgroup IndexValPair<T> local_data[32];
|
threadgroup IndexValPair<T> local_data[32];
|
||||||
|
|
||||||
// Loop over the reduction axis in lsize*N_READS buckets
|
// Loop over the reduction axis in lsize*N_READS buckets
|
||||||
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
|
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
|
||||||
// Read the current value
|
// Read the current value
|
||||||
uint32_t current_index = r*lsize*N_READS + lid*N_READS;
|
uint32_t current_index = r * lsize * N_READS + lid * N_READS;
|
||||||
uint32_t offset = current_index;
|
uint32_t offset = current_index;
|
||||||
const device T * current_in = in + in_idx + current_index * axis_stride;
|
const device T* current_in = in + in_idx + current_index * axis_stride;
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
vals[i] = (current_index < axis_size) ? *current_in : T(Op::init);
|
vals[i] = (current_index < axis_size) ? *current_in : T(Op::init);
|
||||||
current_index++;
|
current_index++;
|
||||||
current_in += axis_stride;
|
current_in += axis_stride;
|
||||||
@ -130,7 +130,7 @@ template <typename T, typename Op, int N_READS>
|
|||||||
// need to reduce across the thread group.
|
// need to reduce across the thread group.
|
||||||
|
|
||||||
// First per simd reduction.
|
// First per simd reduction.
|
||||||
for (uint offset=simd_size/2; offset>0; offset/=2) {
|
for (uint offset = simd_size / 2; offset > 0; offset /= 2) {
|
||||||
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
|
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
|
||||||
best = op.reduce(best, neighbor);
|
best = op.reduce(best, neighbor);
|
||||||
}
|
}
|
||||||
@ -149,7 +149,7 @@ template <typename T, typename Op, int N_READS>
|
|||||||
if (simd_lane_id < simd_groups) {
|
if (simd_lane_id < simd_groups) {
|
||||||
best = local_data[simd_lane_id];
|
best = local_data[simd_lane_id];
|
||||||
}
|
}
|
||||||
for (uint offset=simd_size/2; offset>0; offset/=2) {
|
for (uint offset = simd_size / 2; offset > 0; offset /= 2) {
|
||||||
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
|
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
|
||||||
best = op.reduce(best, neighbor);
|
best = op.reduce(best, neighbor);
|
||||||
}
|
}
|
||||||
@ -161,13 +161,13 @@ template <typename T, typename Op, int N_READS>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_arg_reduce_helper(name, itype, op) \
|
#define instantiate_arg_reduce_helper(name, itype, op) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void \
|
||||||
[[kernel]] void arg_reduce_general<itype, op<itype>, 4>( \
|
arg_reduce_general<itype, op<itype>, 4>( \
|
||||||
const device itype *in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device uint32_t * out [[buffer(1)]], \
|
device uint32_t* out [[buffer(1)]], \
|
||||||
const device int *shape [[buffer(2)]], \
|
const device int* shape [[buffer(2)]], \
|
||||||
const device size_t *in_strides [[buffer(3)]], \
|
const device size_t* in_strides [[buffer(3)]], \
|
||||||
const device size_t *out_strides [[buffer(4)]], \
|
const device size_t* out_strides [[buffer(4)]], \
|
||||||
const device size_t& ndim [[buffer(5)]], \
|
const device size_t& ndim [[buffer(5)]], \
|
||||||
const device size_t& axis_stride [[buffer(6)]], \
|
const device size_t& axis_stride [[buffer(6)]], \
|
||||||
const device size_t& axis_size [[buffer(7)]], \
|
const device size_t& axis_size [[buffer(7)]], \
|
||||||
@ -178,6 +178,7 @@ template <typename T, typename Op, int N_READS>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_arg_reduce(name, itype) \
|
#define instantiate_arg_reduce(name, itype) \
|
||||||
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
|
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
|
||||||
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
|
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
|
||||||
@ -193,4 +194,4 @@ instantiate_arg_reduce(int32, int32_t)
|
|||||||
instantiate_arg_reduce(int64, int64_t)
|
instantiate_arg_reduce(int64, int64_t)
|
||||||
instantiate_arg_reduce(float16, half)
|
instantiate_arg_reduce(float16, half)
|
||||||
instantiate_arg_reduce(float32, float)
|
instantiate_arg_reduce(float32, float)
|
||||||
instantiate_arg_reduce(bfloat16, bfloat16_t)
|
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on
|
@ -77,7 +77,8 @@ template <typename T, typename U, typename Op>
|
|||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
size_t out_idx =
|
||||||
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,7 +93,8 @@ template <typename T, typename U, typename Op, int DIM>
|
|||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
size_t out_idx =
|
||||||
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -113,16 +115,16 @@ template <typename T, typename U, typename Op>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_binary(name, itype, otype, op, bopt) \
|
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||||
template [[host_name(name)]] \
|
template \
|
||||||
[[kernel]] void binary_op_##bopt<itype, otype, op>( \
|
[[host_name(name)]] [[kernel]] void binary_op_##bopt<itype, otype, op>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
||||||
template [[host_name(name "_" #dims)]] \
|
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g_nd<itype, otype, op, dims>( \
|
binary_op_g_nd<itype, otype, op, dims>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -133,16 +135,16 @@ template <typename T, typename U, typename Op>
|
|||||||
uint3 grid_dim [[threads_per_grid]]);
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
||||||
template [[host_name(name "_1")]] \
|
template [[host_name(name "_1")]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g_nd1<itype, otype, op>( \
|
binary_op_g_nd1<itype, otype, op>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
constant const size_t& a_stride, \
|
constant const size_t& a_stride, \
|
||||||
constant const size_t& b_stride, \
|
constant const size_t& b_stride, \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name(name "_2")]] \
|
template [[host_name(name "_2")]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g_nd2<itype, otype, op>( \
|
binary_op_g_nd2<itype, otype, op>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -150,8 +152,8 @@ template <typename T, typename U, typename Op>
|
|||||||
constant const size_t b_strides[2], \
|
constant const size_t b_strides[2], \
|
||||||
uint2 index [[thread_position_in_grid]], \
|
uint2 index [[thread_position_in_grid]], \
|
||||||
uint2 grid_dim [[threads_per_grid]]); \
|
uint2 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name(name "_3")]] \
|
template [[host_name(name "_3")]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g_nd3<itype, otype, op>( \
|
binary_op_g_nd3<itype, otype, op>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -162,10 +164,8 @@ template <typename T, typename U, typename Op>
|
|||||||
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||||
|
|
||||||
|
|
||||||
#define instantiate_binary_g(name, itype, otype, op) \
|
#define instantiate_binary_g(name, itype, otype, op) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void binary_op_g<itype, otype, op>( \
|
||||||
[[kernel]] void binary_op_g<itype, otype, op>( \
|
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -176,14 +176,16 @@ template <typename T, typename U, typename Op>
|
|||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]);
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
||||||
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
||||||
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
||||||
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
||||||
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
||||||
instantiate_binary_g("g" #name #tname, itype, otype, op) \
|
instantiate_binary_g("g" #name #tname, itype, otype, op) \
|
||||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op)
|
instantiate_binary_g_nd("g" #name #tname, itype, otype, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_integer(name, op) \
|
#define instantiate_binary_integer(name, op) \
|
||||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||||
@ -192,19 +194,22 @@ template <typename T, typename U, typename Op>
|
|||||||
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||||
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
|
instantiate_binary_all(name, int64, int64_t, int64_t, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_float(name, op) \
|
#define instantiate_binary_float(name, op) \
|
||||||
instantiate_binary_all(name, float16, half, half, op) \
|
instantiate_binary_all(name, float16, half, half, op) \
|
||||||
instantiate_binary_all(name, float32, float, float, op) \
|
instantiate_binary_all(name, float32, float, float, op) \
|
||||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_types(name, op) \
|
#define instantiate_binary_types(name, op) \
|
||||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||||
instantiate_binary_integer(name, op) \
|
instantiate_binary_integer(name, op) \
|
||||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||||
instantiate_binary_float(name, op)
|
instantiate_binary_float(name, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_types_bool(name, op) \
|
#define instantiate_binary_types_bool(name, op) \
|
||||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||||
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
|
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
|
||||||
@ -218,8 +223,9 @@ template <typename T, typename U, typename Op>
|
|||||||
instantiate_binary_all(name, float16, half, bool, op) \
|
instantiate_binary_all(name, float16, half, bool, op) \
|
||||||
instantiate_binary_all(name, float32, float, bool, op) \
|
instantiate_binary_all(name, float32, float, bool, op) \
|
||||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
|
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
|
||||||
instantiate_binary_all(name, complex64, complex64_t, bool, op)
|
instantiate_binary_all(name, complex64, complex64_t, bool, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_binary_types(add, Add)
|
instantiate_binary_types(add, Add)
|
||||||
instantiate_binary_types(div, Divide)
|
instantiate_binary_types(div, Divide)
|
||||||
instantiate_binary_types_bool(eq, Equal)
|
instantiate_binary_types_bool(eq, Equal)
|
||||||
@ -253,4 +259,4 @@ instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr)
|
|||||||
instantiate_binary_integer(bitwise_xor, BitwiseXor)
|
instantiate_binary_integer(bitwise_xor, BitwiseXor)
|
||||||
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
|
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
|
||||||
instantiate_binary_integer(left_shift, LeftShift)
|
instantiate_binary_integer(left_shift, LeftShift)
|
||||||
instantiate_binary_integer(right_shift, RightShift)
|
instantiate_binary_integer(right_shift, RightShift) // clang-format on
|
||||||
|
@ -3,23 +3,37 @@
|
|||||||
#include <metal_integer>
|
#include <metal_integer>
|
||||||
#include <metal_math>
|
#include <metal_math>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
struct FloorDivide {
|
struct FloorDivide {
|
||||||
template <typename T> T operator()(T x, T y) { return x / y; }
|
template <typename T>
|
||||||
template <> float operator()(float x, float y) { return trunc(x / y); }
|
T operator()(T x, T y) {
|
||||||
template <> half operator()(half x, half y) { return trunc(x / y); }
|
return x / y;
|
||||||
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return trunc(x / y); }
|
}
|
||||||
|
template <>
|
||||||
|
float operator()(float x, float y) {
|
||||||
|
return trunc(x / y);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
half operator()(half x, half y) {
|
||||||
|
return trunc(x / y);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
|
||||||
|
return trunc(x / y);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Remainder {
|
struct Remainder {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> operator()(T x, T y) {
|
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
|
||||||
|
operator()(T x, T y) {
|
||||||
return x % y;
|
return x % y;
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T> operator()(T x, T y) {
|
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
|
||||||
|
operator()(T x, T y) {
|
||||||
auto r = x % y;
|
auto r = x % y;
|
||||||
if (r != 0 && (r < 0 != y < 0)) {
|
if (r != 0 && (r < 0 != y < 0)) {
|
||||||
r += y;
|
r += y;
|
||||||
@ -34,7 +48,8 @@ struct Remainder {
|
|||||||
}
|
}
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
template <> complex64_t operator()(complex64_t x, complex64_t y) {
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||||
return x % y;
|
return x % y;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -50,7 +65,6 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
d[index] = Op2()(a[0], b[0]);
|
d[index] = Op2()(a[0], b[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op1, typename Op2>
|
template <typename T, typename U, typename Op1, typename Op2>
|
||||||
[[kernel]] void binary_op_ss(
|
[[kernel]] void binary_op_ss(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
@ -139,7 +153,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
size_t out_idx =
|
||||||
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
|
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
|
||||||
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
|
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
|
||||||
}
|
}
|
||||||
@ -156,7 +171,8 @@ template <typename T, typename U, typename Op1, typename Op2, int DIM>
|
|||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
size_t out_idx =
|
||||||
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
|
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
|
||||||
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
|
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
|
||||||
}
|
}
|
||||||
@ -180,8 +196,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
|
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_##bopt<itype, otype, op1, op2>( \
|
binary_op_##bopt<itype, otype, op1, op2>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -189,8 +205,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
|
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
|
||||||
template [[host_name(name "_" #dims)]] \
|
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g_nd<itype, otype, op1, op2, dims>( \
|
binary_op_g_nd<itype, otype, op1, op2, dims>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -201,9 +217,10 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]);
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
|
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
|
||||||
template [[host_name(name "_1")]] \
|
template [[host_name(name "_1")]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g_nd1<itype, otype, op1, op2>( \
|
binary_op_g_nd1<itype, otype, op1, op2>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -211,8 +228,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
constant const size_t& a_stride, \
|
constant const size_t& a_stride, \
|
||||||
constant const size_t& b_stride, \
|
constant const size_t& b_stride, \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name(name "_2")]] \
|
template [[host_name(name "_2")]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g_nd2<itype, otype, op1, op2>( \
|
binary_op_g_nd2<itype, otype, op1, op2>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -221,8 +238,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
constant const size_t b_strides[2], \
|
constant const size_t b_strides[2], \
|
||||||
uint2 index [[thread_position_in_grid]], \
|
uint2 index [[thread_position_in_grid]], \
|
||||||
uint2 grid_dim [[threads_per_grid]]); \
|
uint2 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name(name "_3")]] \
|
template [[host_name(name "_3")]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g_nd3<itype, otype, op1, op2>( \
|
binary_op_g_nd3<itype, otype, op1, op2>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -232,12 +249,11 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
|
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
|
||||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5)
|
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) // clang-format on
|
||||||
|
|
||||||
|
|
||||||
#define instantiate_binary_g(name, itype, otype, op1, op2) \
|
#define instantiate_binary_g(name, itype, otype, op1, op2) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void \
|
||||||
[[kernel]] void binary_op_g<itype, otype, op2, op2>( \
|
binary_op_g<itype, otype, op2, op2>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -249,19 +265,22 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]);
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
|
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
|
||||||
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
|
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
|
||||||
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
|
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
|
||||||
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
|
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
|
||||||
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
|
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
|
||||||
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
|
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
|
||||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2)
|
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_float(name, op1, op2) \
|
#define instantiate_binary_float(name, op1, op2) \
|
||||||
instantiate_binary_all(name, float16, half, half, op1, op2) \
|
instantiate_binary_all(name, float16, half, half, op1, op2) \
|
||||||
instantiate_binary_all(name, float32, float, float, op1, op2) \
|
instantiate_binary_all(name, float32, float, float, op1, op2) \
|
||||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2)
|
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_binary_types(name, op1, op2) \
|
#define instantiate_binary_types(name, op1, op2) \
|
||||||
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
|
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
|
||||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
|
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
|
||||||
@ -275,4 +294,4 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
|
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
|
||||||
instantiate_binary_float(name, op1, op2)
|
instantiate_binary_float(name, op1, op2)
|
||||||
|
|
||||||
instantiate_binary_types(divmod, FloorDivide, Remainder)
|
instantiate_binary_types(divmod, FloorDivide, Remainder) // clang-format on
|
||||||
|
@ -22,7 +22,7 @@ struct complex64_t {
|
|||||||
float imag;
|
float imag;
|
||||||
|
|
||||||
// Constructors
|
// Constructors
|
||||||
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
|
constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
|
||||||
|
|
||||||
// Conversions to complex64_t
|
// Conversions to complex64_t
|
||||||
template <
|
template <
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <metal_stdlib>
|
|
||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
#include <metal_simdgroup_matrix>
|
#include <metal_simdgroup_matrix>
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||||
|
|
||||||
#define MLX_MTL_CONST static constant constexpr const
|
#define MLX_MTL_CONST static constant constexpr const
|
||||||
|
|
||||||
@ -23,12 +21,13 @@ template <typename T, int N>
|
|||||||
device T* out [[buffer(1)]],
|
device T* out [[buffer(1)]],
|
||||||
const constant MLXConvParams<N>* params [[buffer(2)]],
|
const constant MLXConvParams<N>* params [[buffer(2)]],
|
||||||
uint3 gid [[thread_position_in_grid]]) {
|
uint3 gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
int filter_size = params->C;
|
int filter_size = params->C;
|
||||||
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
|
for (short i = 0; i < N; i++)
|
||||||
|
filter_size *= params->wS[i];
|
||||||
|
|
||||||
int out_pixels = 1;
|
int out_pixels = 1;
|
||||||
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
|
for (short i = 0; i < N; i++)
|
||||||
|
out_pixels *= params->oS[i];
|
||||||
|
|
||||||
// Set out
|
// Set out
|
||||||
out += gid.z * filter_size + gid.y * (params->C);
|
out += gid.z * filter_size + gid.y * (params->C);
|
||||||
@ -64,10 +63,10 @@ template <typename T, int N>
|
|||||||
wS /= params->wS[i];
|
wS /= params->wS[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
if(valid) {
|
if (valid) {
|
||||||
size_t in_offset = n * params->in_strides[0];
|
size_t in_offset = n * params->in_strides[0];
|
||||||
|
|
||||||
for(int i = 0; i < N; ++i) {
|
for (int i = 0; i < N; ++i) {
|
||||||
in_offset += is[i] * params->in_strides[i + 1];
|
in_offset += is[i] * params->in_strides[i + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -85,12 +84,13 @@ template <typename T, int N>
|
|||||||
device T* out [[buffer(1)]],
|
device T* out [[buffer(1)]],
|
||||||
const constant MLXConvParams<N>* params [[buffer(2)]],
|
const constant MLXConvParams<N>* params [[buffer(2)]],
|
||||||
uint3 gid [[thread_position_in_grid]]) {
|
uint3 gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
int filter_size = params->C;
|
int filter_size = params->C;
|
||||||
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
|
for (short i = 0; i < N; i++)
|
||||||
|
filter_size *= params->wS[i];
|
||||||
|
|
||||||
int out_pixels = 1;
|
int out_pixels = 1;
|
||||||
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
|
for (short i = 0; i < N; i++)
|
||||||
|
out_pixels *= params->oS[i];
|
||||||
|
|
||||||
// Set out
|
// Set out
|
||||||
out += gid.z * filter_size + gid.x * (filter_size / params->C);
|
out += gid.z * filter_size + gid.x * (filter_size / params->C);
|
||||||
@ -128,10 +128,10 @@ template <typename T, int N>
|
|||||||
out += ws_ * params->str[i];
|
out += ws_ * params->str[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
if(valid) {
|
if (valid) {
|
||||||
size_t in_offset = n * params->in_strides[0];
|
size_t in_offset = n * params->in_strides[0];
|
||||||
|
|
||||||
for(int i = 0; i < N; ++i) {
|
for (int i = 0; i < N; ++i) {
|
||||||
in_offset += is[i] * params->in_strides[i + 1];
|
in_offset += is[i] * params->in_strides[i + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,23 +142,23 @@ template <typename T, int N>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_naive_unfold_nd(name, itype, n) \
|
#define instantiate_naive_unfold_nd(name, itype, n) \
|
||||||
template [[host_name("naive_unfold_nd_" #name "_" #n)]] \
|
template [[host_name("naive_unfold_nd_" #name "_" #n)]] [[kernel]] void \
|
||||||
[[kernel]] void naive_unfold_Nd( \
|
naive_unfold_Nd( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device itype* out [[buffer(1)]], \
|
device itype* out [[buffer(1)]], \
|
||||||
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
||||||
uint3 gid [[thread_position_in_grid]]); \
|
uint3 gid [[thread_position_in_grid]]); \
|
||||||
template [[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] \
|
template \
|
||||||
[[kernel]] void naive_unfold_transpose_Nd( \
|
[[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] [[kernel]] void \
|
||||||
|
naive_unfold_transpose_Nd( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device itype* out [[buffer(1)]], \
|
device itype* out [[buffer(1)]], \
|
||||||
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
||||||
uint3 gid [[thread_position_in_grid]]);
|
uint3 gid [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_naive_unfold_nd_dims(name, itype) \
|
#define instantiate_naive_unfold_nd_dims(name, itype) \
|
||||||
instantiate_naive_unfold_nd(name, itype, 1) \
|
instantiate_naive_unfold_nd(name, itype, 1) instantiate_naive_unfold_nd( \
|
||||||
instantiate_naive_unfold_nd(name, itype, 2) \
|
name, itype, 2) instantiate_naive_unfold_nd(name, itype, 3)
|
||||||
instantiate_naive_unfold_nd(name, itype, 3)
|
|
||||||
|
|
||||||
instantiate_naive_unfold_nd_dims(float32, float);
|
instantiate_naive_unfold_nd_dims(float32, float);
|
||||||
instantiate_naive_unfold_nd_dims(float16, half);
|
instantiate_naive_unfold_nd_dims(float16, half);
|
||||||
@ -168,7 +168,8 @@ instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
|
|||||||
/// Slow and naive conv2d kernels
|
/// Slow and naive conv2d kernels
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T,
|
template <
|
||||||
|
typename T,
|
||||||
const int BM, /* Threadgroup rows (in threads) */
|
const int BM, /* Threadgroup rows (in threads) */
|
||||||
const int BN, /* Threadgroup cols (in threads) */
|
const int BN, /* Threadgroup cols (in threads) */
|
||||||
const int TM, /* Thread rows (in elements) */
|
const int TM, /* Thread rows (in elements) */
|
||||||
@ -183,7 +184,6 @@ template <typename T,
|
|||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
(void)simd_gid;
|
(void)simd_gid;
|
||||||
(void)simd_lid;
|
(void)simd_lid;
|
||||||
|
|
||||||
@ -196,64 +196,66 @@ template <typename T,
|
|||||||
int out_h[TM];
|
int out_h[TM];
|
||||||
int out_w[TN];
|
int out_w[TN];
|
||||||
|
|
||||||
for(int m = 0; m < TM; ++m) {
|
for (int m = 0; m < TM; ++m) {
|
||||||
int mm = (out_hw + m);
|
int mm = (out_hw + m);
|
||||||
out_h[m] = mm / params.oS[1];
|
out_h[m] = mm / params.oS[1];
|
||||||
out_w[m] = mm % params.oS[1];
|
out_w[m] = mm % params.oS[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
T in_local[TM];
|
T in_local[TM];
|
||||||
T wt_local[TN];
|
T wt_local[TN];
|
||||||
T out_local[TM * TN] = {T(0)};
|
T out_local[TM * TN] = {T(0)};
|
||||||
|
|
||||||
for(int h = 0; h < params.wS[0]; ++h) {
|
for (int h = 0; h < params.wS[0]; ++h) {
|
||||||
for(int w = 0; w < params.wS[1]; ++w) {
|
for (int w = 0; w < params.wS[1]; ++w) {
|
||||||
for(int c = 0; c < params.C; ++c) {
|
for (int c = 0; c < params.C; ++c) {
|
||||||
|
|
||||||
// Local in
|
// Local in
|
||||||
for(int m = 0; m < TM; m++) {
|
for (int m = 0; m < TM; m++) {
|
||||||
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
|
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
|
||||||
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
|
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
|
||||||
|
|
||||||
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
|
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
|
||||||
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0);
|
in_local[m] = valid
|
||||||
|
? in[i * params.in_strides[1] + j * params.in_strides[2] + c]
|
||||||
|
: T(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load weight
|
// Load weight
|
||||||
for (int n = 0; n < TN; ++n) {
|
for (int n = 0; n < TN; ++n) {
|
||||||
int o = out_o + n;
|
int o = out_o + n;
|
||||||
wt_local[n] = o < params.O ? wt[o * params.wt_strides[0] +
|
wt_local[n] = o < params.O
|
||||||
h * params.wt_strides[1] +
|
? wt[o * params.wt_strides[0] + h * params.wt_strides[1] +
|
||||||
w * params.wt_strides[2] + c] : T(0);
|
w * params.wt_strides[2] + c]
|
||||||
|
: T(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accumulate
|
// Accumulate
|
||||||
for(int m = 0; m < TM; ++m) {
|
for (int m = 0; m < TM; ++m) {
|
||||||
for(int n = 0; n < TN; ++n) {
|
for (int n = 0; n < TN; ++n) {
|
||||||
out_local[m * TN + n] += in_local[m] * wt_local[n];
|
out_local[m * TN + n] += in_local[m] * wt_local[n];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for(int m = 0; m < TM; ++m) {
|
for (int m = 0; m < TM; ++m) {
|
||||||
for(int n = 0; n < TN; ++n) {
|
for (int n = 0; n < TN; ++n) {
|
||||||
if(out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && (out_o + n) < params.O)
|
if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] &&
|
||||||
|
(out_o + n) < params.O)
|
||||||
out[out_h[m] * params.out_strides[1] +
|
out[out_h[m] * params.out_strides[1] +
|
||||||
out_w[m] * params.out_strides[2] + out_o + n] = out_local[m * TN + n];
|
out_w[m] * params.out_strides[2] + out_o + n] =
|
||||||
|
out_local[m * TN + n];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instantiations
|
// Instantiations
|
||||||
|
|
||||||
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
|
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
|
||||||
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \
|
||||||
[[kernel]] void naive_conv_2d<itype, bm, bn, tm, tn>( \
|
"_tn" #tn)]] [[kernel]] void \
|
||||||
|
naive_conv_2d<itype, bm, bn, tm, tn>( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
const device itype* wt [[buffer(1)]], \
|
const device itype* wt [[buffer(1)]], \
|
||||||
device itype* out [[buffer(2)]], \
|
device itype* out [[buffer(2)]], \
|
||||||
@ -276,9 +278,7 @@ instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
|
|||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <int M, int R, int S>
|
template <int M, int R, int S>
|
||||||
struct WinogradTransforms {
|
struct WinogradTransforms {};
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct WinogradTransforms<6, 3, 8> {
|
struct WinogradTransforms<6, 3, 8> {
|
||||||
@ -287,36 +287,36 @@ struct WinogradTransforms<6, 3, 8> {
|
|||||||
MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1;
|
MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1;
|
||||||
MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8;
|
MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8;
|
||||||
MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||||
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
{1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
||||||
{ 0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
|
{0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
|
||||||
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
|
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
|
||||||
{ 0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
|
{0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
|
||||||
{ 5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
|
{5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
|
||||||
{ 0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
|
{0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
|
||||||
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
|
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
|
||||||
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
{0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
||||||
};
|
};
|
||||||
|
|
||||||
MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||||
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
{1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
||||||
{ 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
|
{1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
|
||||||
{ 1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
|
{1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
|
||||||
{ 1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
|
{1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
|
||||||
{ 1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
|
{1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
|
||||||
{ 1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
|
{1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
|
||||||
{ 1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
|
{1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
|
||||||
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
{0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
||||||
};
|
};
|
||||||
|
|
||||||
MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||||
{ 1.00, 0.00, 0.00},
|
{1.00, 0.00, 0.00},
|
||||||
{ -2.0/9.00, -2.0/9.00, -2.0/9.00},
|
{-2.0 / 9.00, -2.0 / 9.00, -2.0 / 9.00},
|
||||||
{ -2.0/9.00, 2.0/9.00, -2.0/9.00},
|
{-2.0 / 9.00, 2.0 / 9.00, -2.0 / 9.00},
|
||||||
{ 1.0/90.0, 1.0/45.0, 2.0/45.0},
|
{1.0 / 90.0, 1.0 / 45.0, 2.0 / 45.0},
|
||||||
{ 1.0/90.0, -1.0/45.0, 2.0/45.0},
|
{1.0 / 90.0, -1.0 / 45.0, 2.0 / 45.0},
|
||||||
{ 32.0/45.0, 16.0/45.0, 8.0/45.0},
|
{32.0 / 45.0, 16.0 / 45.0, 8.0 / 45.0},
|
||||||
{ 32.0/45.0, -16.0/45.0, 8.0/45.0},
|
{32.0 / 45.0, -16.0 / 45.0, 8.0 / 45.0},
|
||||||
{ 0.00, 0.00, 1.00},
|
{0.00, 0.00, 1.00},
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -324,12 +324,9 @@ constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
|
|||||||
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
|
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
|
||||||
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
|
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
|
||||||
|
|
||||||
template <typename T,
|
template <typename T, int BC = 32, int BO = 4, int M = 6, int R = 3>
|
||||||
int BC = 32,
|
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void
|
||||||
int BO = 4,
|
winograd_conv_2d_weight_transform(
|
||||||
int M = 6,
|
|
||||||
int R = 3>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform(
|
|
||||||
const device T* wt_in [[buffer(0)]],
|
const device T* wt_in [[buffer(0)]],
|
||||||
device T* wt_out [[buffer(1)]],
|
device T* wt_out [[buffer(1)]],
|
||||||
const constant int& C [[buffer(2)]],
|
const constant int& C [[buffer(2)]],
|
||||||
@ -337,7 +334,6 @@ template <typename T,
|
|||||||
uint tid [[threadgroup_position_in_grid]],
|
uint tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
using WGT = WinogradTransforms<M, R, 8>;
|
using WGT = WinogradTransforms<M, R, 8>;
|
||||||
|
|
||||||
// Get lane position in simdgroup
|
// Get lane position in simdgroup
|
||||||
@ -369,12 +365,12 @@ template <typename T,
|
|||||||
threadgroup T Ws[BO][R][R][BC];
|
threadgroup T Ws[BO][R][R][BC];
|
||||||
|
|
||||||
// Loop over C
|
// Loop over C
|
||||||
for(int bc = 0; bc < C; bc += BC) {
|
for (int bc = 0; bc < C; bc += BC) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// Read into shared memory
|
// Read into shared memory
|
||||||
for(int kh = 0; kh < R; ++kh) {
|
for (int kh = 0; kh < R; ++kh) {
|
||||||
for(int kw = 0; kw < R; ++kw) {
|
for (int kw = 0; kw < R; ++kw) {
|
||||||
for(int kc = simd_lane_id; kc < BC; kc += 32) {
|
for (int kc = simd_lane_id; kc < BC; kc += 32) {
|
||||||
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
|
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -382,10 +378,12 @@ template <typename T,
|
|||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// Do transform and store the result
|
// Do transform and store the result
|
||||||
for(int c = 0; c < BC; ++c) {
|
for (int c = 0; c < BC; ++c) {
|
||||||
simdgroup_matrix<T, 8, 8> g;
|
simdgroup_matrix<T, 8, 8> g;
|
||||||
g.thread_elements()[0] = sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
|
g.thread_elements()[0] =
|
||||||
g.thread_elements()[1] = sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
|
sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
|
||||||
|
g.thread_elements()[1] =
|
||||||
|
sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
|
||||||
|
|
||||||
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
|
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
|
||||||
wt_out_0[c * O] = g_out.thread_elements()[0];
|
wt_out_0[c * O] = g_out.thread_elements()[0];
|
||||||
@ -396,27 +394,23 @@ template <typename T,
|
|||||||
wt_out_0 += BC * O;
|
wt_out_0 += BC * O;
|
||||||
wt_out_1 += BC * O;
|
wt_out_1 += BC * O;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
|
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
|
||||||
template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc)]]\
|
template [[host_name("winograd_conv_2d_weight_transform_" #name \
|
||||||
[[kernel]] void winograd_conv_2d_weight_transform<itype, bc>(\
|
"_bc" #bc)]] [[kernel]] void \
|
||||||
const device itype* wt_in [[buffer(0)]],\
|
winograd_conv_2d_weight_transform<itype, bc>( \
|
||||||
device itype* wt_out [[buffer(1)]],\
|
const device itype* wt_in [[buffer(0)]], \
|
||||||
const constant int& C [[buffer(2)]],\
|
device itype* wt_out [[buffer(1)]], \
|
||||||
const constant int& O [[buffer(3)]],\
|
const constant int& C [[buffer(2)]], \
|
||||||
uint tid [[threadgroup_position_in_grid]],\
|
const constant int& O [[buffer(3)]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
|
uint tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
template <typename T,
|
template <typename T, int BC, int WM, int WN, int M = 6, int R = 3>
|
||||||
int BC,
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||||
int WM,
|
winograd_conv_2d_input_transform(
|
||||||
int WN,
|
|
||||||
int M = 6,
|
|
||||||
int R = 3>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_input_transform(
|
|
||||||
const device T* inp_in [[buffer(0)]],
|
const device T* inp_in [[buffer(0)]],
|
||||||
device T* inp_out [[buffer(1)]],
|
device T* inp_out [[buffer(1)]],
|
||||||
const constant MLXConvParams<2>& params [[buffer(2)]],
|
const constant MLXConvParams<2>& params [[buffer(2)]],
|
||||||
@ -425,7 +419,6 @@ template <typename T,
|
|||||||
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
(void)lid;
|
(void)lid;
|
||||||
|
|
||||||
using WGT = WinogradTransforms<M, R, 8>;
|
using WGT = WinogradTransforms<M, R, 8>;
|
||||||
@ -456,38 +449,40 @@ template <typename T,
|
|||||||
int bw = M * tid.x + kw;
|
int bw = M * tid.x + kw;
|
||||||
|
|
||||||
// Move to the correct input tile
|
// Move to the correct input tile
|
||||||
inp_in += tid.z * params.in_strides[0]
|
inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] +
|
||||||
+ bh * params.in_strides[1]
|
bw * params.in_strides[2];
|
||||||
+ bw * params.in_strides[2];
|
|
||||||
|
|
||||||
// Pre compute strides
|
// Pre compute strides
|
||||||
int jump_in[TH][TW];
|
int jump_in[TH][TW];
|
||||||
|
|
||||||
for(int h = 0; h < TH; h++) {
|
for (int h = 0; h < TH; h++) {
|
||||||
for(int w = 0; w < TW; w++) {
|
for (int w = 0; w < TW; w++) {
|
||||||
jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2];
|
jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// inp_out is stored interleaved (A x A x tiles x C)
|
// inp_out is stored interleaved (A x A x tiles x C)
|
||||||
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
||||||
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
size_t tile_id =
|
||||||
|
tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
||||||
size_t ohw_0 = sm * 8 + sn;
|
size_t ohw_0 = sm * 8 + sn;
|
||||||
size_t ohw_1 = sm * 8 + sn + 1;
|
size_t ohw_1 = sm * 8 + sn + 1;
|
||||||
device T* inp_out_0 = inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
|
device T* inp_out_0 =
|
||||||
device T* inp_out_1 = inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
|
inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
|
||||||
|
device T* inp_out_1 =
|
||||||
|
inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
|
||||||
|
|
||||||
// Prepare shared memory
|
// Prepare shared memory
|
||||||
threadgroup T Is[A][A][BC];
|
threadgroup T Is[A][A][BC];
|
||||||
|
|
||||||
// Loop over C
|
// Loop over C
|
||||||
for(int bc = 0; bc < params.C; bc += BC) {
|
for (int bc = 0; bc < params.C; bc += BC) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// Read into shared memory
|
// Read into shared memory
|
||||||
for(int h = 0; h < TH; h++) {
|
for (int h = 0; h < TH; h++) {
|
||||||
for(int w = 0; w < TW; w++) {
|
for (int w = 0; w < TW; w++) {
|
||||||
const device T* in_ptr = inp_in + jump_in[h][w];
|
const device T* in_ptr = inp_in + jump_in[h][w];
|
||||||
for(int c = simd_lane_id; c < BC; c += 32) {
|
for (int c = simd_lane_id; c < BC; c += 32) {
|
||||||
Is[kh + h][kw + w][c] = in_ptr[c];
|
Is[kh + h][kw + w][c] = in_ptr[c];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -495,7 +490,7 @@ template <typename T,
|
|||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// Do transform and store the result
|
// Do transform and store the result
|
||||||
for(int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
|
for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
|
||||||
simdgroup_matrix<T, 8, 8> I;
|
simdgroup_matrix<T, 8, 8> I;
|
||||||
I.thread_elements()[0] = Is[sm][sn][c];
|
I.thread_elements()[0] = Is[sm][sn][c];
|
||||||
I.thread_elements()[1] = Is[sm][sn + 1][c];
|
I.thread_elements()[1] = Is[sm][sn + 1][c];
|
||||||
@ -509,28 +504,24 @@ template <typename T,
|
|||||||
inp_out_0 += BC;
|
inp_out_0 += BC;
|
||||||
inp_out_1 += BC;
|
inp_out_1 += BC;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \
|
#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \
|
||||||
template [[host_name("winograd_conv_2d_input_transform_" #name "_bc" #bc)]]\
|
template [[host_name("winograd_conv_2d_input_transform_" #name \
|
||||||
[[kernel]] void winograd_conv_2d_input_transform<itype, bc, 2, 2>(\
|
"_bc" #bc)]] [[kernel]] void \
|
||||||
const device itype* inp_in [[buffer(0)]],\
|
winograd_conv_2d_input_transform<itype, bc, 2, 2>( \
|
||||||
device itype* inp_out [[buffer(1)]],\
|
const device itype* inp_in [[buffer(0)]], \
|
||||||
const constant MLXConvParams<2>& params [[buffer(2)]],\
|
device itype* inp_out [[buffer(1)]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]],\
|
const constant MLXConvParams<2>& params [[buffer(2)]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]],\
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 tgp_per_grid [[threadgroups_per_grid]],\
|
uint3 lid [[thread_position_in_threadgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
|
uint3 tgp_per_grid [[threadgroups_per_grid]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
template <typename T,
|
template <typename T, int BO, int WM, int WN, int M = 6, int R = 3>
|
||||||
int BO,
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||||
int WM,
|
winograd_conv_2d_output_transform(
|
||||||
int WN,
|
|
||||||
int M = 6,
|
|
||||||
int R = 3>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_output_transform(
|
|
||||||
const device T* out_in [[buffer(0)]],
|
const device T* out_in [[buffer(0)]],
|
||||||
device T* out_out [[buffer(1)]],
|
device T* out_out [[buffer(1)]],
|
||||||
const constant MLXConvParams<2>& params [[buffer(2)]],
|
const constant MLXConvParams<2>& params [[buffer(2)]],
|
||||||
@ -539,7 +530,6 @@ template <typename T,
|
|||||||
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
(void)lid;
|
(void)lid;
|
||||||
|
|
||||||
using WGT = WinogradTransforms<M, R, 8>;
|
using WGT = WinogradTransforms<M, R, 8>;
|
||||||
@ -572,57 +562,59 @@ template <typename T,
|
|||||||
int bw = M * tid.x + kw;
|
int bw = M * tid.x + kw;
|
||||||
|
|
||||||
// Move to the correct input tile
|
// Move to the correct input tile
|
||||||
out_out += tid.z * params.out_strides[0]
|
out_out += tid.z * params.out_strides[0] + bh * params.out_strides[1] +
|
||||||
+ bh * params.out_strides[1]
|
bw * params.out_strides[2];
|
||||||
+ bw * params.out_strides[2];
|
|
||||||
|
|
||||||
// Pre compute strides
|
// Pre compute strides
|
||||||
int jump_in[TH][TW];
|
int jump_in[TH][TW];
|
||||||
|
|
||||||
for(int h = 0; h < TH; h++) {
|
for (int h = 0; h < TH; h++) {
|
||||||
for(int w = 0; w < TW; w++) {
|
for (int w = 0; w < TW; w++) {
|
||||||
bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]);
|
bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]);
|
||||||
jump_in[h][w] = valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
|
jump_in[h][w] =
|
||||||
|
valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// out_in is stored interleaved (A x A x tiles x O)
|
// out_in is stored interleaved (A x A x tiles x O)
|
||||||
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
||||||
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
size_t tile_id =
|
||||||
|
tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
||||||
size_t ohw_0 = sm * 8 + sn;
|
size_t ohw_0 = sm * 8 + sn;
|
||||||
size_t ohw_1 = sm * 8 + sn + 1;
|
size_t ohw_1 = sm * 8 + sn + 1;
|
||||||
const device T* out_in_0 = out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
|
const device T* out_in_0 =
|
||||||
const device T* out_in_1 = out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
|
out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
|
||||||
|
const device T* out_in_1 =
|
||||||
|
out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
|
||||||
|
|
||||||
// Prepare shared memory
|
// Prepare shared memory
|
||||||
threadgroup T Os[M][M][BO];
|
threadgroup T Os[M][M][BO];
|
||||||
|
|
||||||
// Loop over O
|
// Loop over O
|
||||||
for(int bo = 0; bo < params.O; bo += BO) {
|
for (int bo = 0; bo < params.O; bo += BO) {
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// Do transform and store the result
|
// Do transform and store the result
|
||||||
for(int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
|
for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
|
||||||
simdgroup_matrix<T, 8, 8> O_mat;
|
simdgroup_matrix<T, 8, 8> O_mat;
|
||||||
O_mat.thread_elements()[0] = out_in_0[c];
|
O_mat.thread_elements()[0] = out_in_0[c];
|
||||||
O_mat.thread_elements()[1] = out_in_1[c];
|
O_mat.thread_elements()[1] = out_in_1[c];
|
||||||
|
|
||||||
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B));
|
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B));
|
||||||
if((sm < M) && (sn < M)) {
|
if ((sm < M) && (sn < M)) {
|
||||||
Os[sm][sn][c] = O_out.thread_elements()[0];
|
Os[sm][sn][c] = O_out.thread_elements()[0];
|
||||||
}
|
}
|
||||||
if((sm < M) && ((sn + 1) < M)) {
|
if ((sm < M) && ((sn + 1) < M)) {
|
||||||
Os[sm][sn + 1][c] = O_out.thread_elements()[1];
|
Os[sm][sn + 1][c] = O_out.thread_elements()[1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// Read out from shared memory
|
// Read out from shared memory
|
||||||
for(int h = 0; h < TH; h++) {
|
for (int h = 0; h < TH; h++) {
|
||||||
for(int w = 0; w < TW; w++) {
|
for (int w = 0; w < TW; w++) {
|
||||||
if(jump_in[h][w] >= 0) {
|
if (jump_in[h][w] >= 0) {
|
||||||
device T* out_ptr = out_out + jump_in[h][w];
|
device T* out_ptr = out_out + jump_in[h][w];
|
||||||
for(int c = simd_lane_id; c < BO; c += 32) {
|
for (int c = simd_lane_id; c < BO; c += 32) {
|
||||||
out_ptr[c] = Os[kh + h][kw + w][c];
|
out_ptr[c] = Os[kh + h][kw + w][c];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -633,25 +625,27 @@ template <typename T,
|
|||||||
out_in_0 += BO;
|
out_in_0 += BO;
|
||||||
out_in_1 += BO;
|
out_in_1 += BO;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \
|
#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \
|
||||||
template [[host_name("winograd_conv_2d_output_transform_" #name "_bo" #bo)]]\
|
template [[host_name("winograd_conv_2d_output_transform_" #name \
|
||||||
[[kernel]] void winograd_conv_2d_output_transform<itype, bo, 2, 2>(\
|
"_bo" #bo)]] [[kernel]] void \
|
||||||
const device itype* out_in [[buffer(0)]],\
|
winograd_conv_2d_output_transform<itype, bo, 2, 2>( \
|
||||||
device itype* out_out [[buffer(1)]],\
|
const device itype* out_in [[buffer(0)]], \
|
||||||
const constant MLXConvParams<2>& params [[buffer(2)]],\
|
device itype* out_out [[buffer(1)]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]],\
|
const constant MLXConvParams<2>& params [[buffer(2)]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]],\
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 tgp_per_grid [[threadgroups_per_grid]],\
|
uint3 lid [[thread_position_in_threadgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
|
uint3 tgp_per_grid [[threadgroups_per_grid]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_winograd_conv_2d(name, itype) \
|
#define instantiate_winograd_conv_2d(name, itype) \
|
||||||
instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \
|
instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \
|
||||||
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
|
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
|
||||||
instantiate_winograd_conv_2d_output_transform(name, itype, 32)
|
instantiate_winograd_conv_2d_output_transform(name, itype, 32) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_winograd_conv_2d(float32, float);
|
instantiate_winograd_conv_2d(float32, float);
|
||||||
instantiate_winograd_conv_2d(float16, half);
|
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
@ -49,7 +49,8 @@ template <typename T, typename U>
|
|||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
int64_t dst_idx =
|
||||||
|
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,7 +63,8 @@ template <typename T, typename U, int DIM>
|
|||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
int64_t dst_idx =
|
||||||
|
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,7 +78,8 @@ template <typename T, typename U>
|
|||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
int64_t dst_idx =
|
||||||
|
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -144,23 +147,22 @@ template <typename T, typename U>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_copy(name, itype, otype, ctype) \
|
#define instantiate_copy(name, itype, otype, ctype) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
|
||||||
[[kernel]] void copy_##ctype<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
||||||
template [[host_name(name "_" #dims)]] \
|
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||||
[[kernel]] void copy_g_nd<itype, otype, dims>( \
|
copy_g_nd<itype, otype, dims>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int* src_shape [[buffer(2)]], \
|
constant const int* src_shape [[buffer(2)]], \
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name("g" name "_" #dims)]] \
|
template [[host_name("g" name "_" #dims)]] [[kernel]] void \
|
||||||
[[kernel]] void copy_gg_nd<itype, otype, dims>( \
|
copy_gg_nd<itype, otype, dims>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int* src_shape [[buffer(2)]], \
|
constant const int* src_shape [[buffer(2)]], \
|
||||||
@ -168,44 +170,40 @@ template <typename T, typename U>
|
|||||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||||
uint3 index [[thread_position_in_grid]]);
|
uint3 index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
|
||||||
#define instantiate_copy_g_nd(name, itype, otype) \
|
#define instantiate_copy_g_nd(name, itype, otype) \
|
||||||
template [[host_name(name "_1")]] \
|
template [[host_name(name "_1")]] [[kernel]] void copy_g_nd1<itype, otype>( \
|
||||||
[[kernel]] void copy_g_nd1<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int64_t& src_stride [[buffer(3)]], \
|
constant const int64_t& src_stride [[buffer(3)]], \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name(name "_2")]] \
|
template [[host_name(name "_2")]] [[kernel]] void copy_g_nd2<itype, otype>( \
|
||||||
[[kernel]] void copy_g_nd2<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
uint2 index [[thread_position_in_grid]], \
|
uint2 index [[thread_position_in_grid]], \
|
||||||
uint2 grid_dim [[threads_per_grid]]); \
|
uint2 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name(name "_3")]] \
|
template [[host_name(name "_3")]] [[kernel]] void copy_g_nd3<itype, otype>( \
|
||||||
[[kernel]] void copy_g_nd3<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name("g" name "_1")]] \
|
template [[host_name("g" name "_1")]] [[kernel]] void \
|
||||||
[[kernel]] void copy_gg_nd1<itype, otype>( \
|
copy_gg_nd1<itype, otype>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int64_t& src_stride [[buffer(3)]], \
|
constant const int64_t& src_stride [[buffer(3)]], \
|
||||||
constant const int64_t& dst_stride [[buffer(4)]], \
|
constant const int64_t& dst_stride [[buffer(4)]], \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name("g" name "_2")]] \
|
template [[host_name("g" name "_2")]] [[kernel]] void \
|
||||||
[[kernel]] void copy_gg_nd2<itype, otype>( \
|
copy_gg_nd2<itype, otype>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||||
uint2 index [[thread_position_in_grid]]); \
|
uint2 index [[thread_position_in_grid]]); \
|
||||||
template [[host_name("g" name "_3")]] \
|
template [[host_name("g" name "_3")]] [[kernel]] void \
|
||||||
[[kernel]] void copy_gg_nd3<itype, otype>( \
|
copy_gg_nd3<itype, otype>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
@ -214,10 +212,8 @@ template <typename T, typename U>
|
|||||||
instantiate_copy_g_dim(name, itype, otype, 4) \
|
instantiate_copy_g_dim(name, itype, otype, 4) \
|
||||||
instantiate_copy_g_dim(name, itype, otype, 5)
|
instantiate_copy_g_dim(name, itype, otype, 5)
|
||||||
|
|
||||||
|
|
||||||
#define instantiate_copy_g(name, itype, otype) \
|
#define instantiate_copy_g(name, itype, otype) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void copy_g<itype, otype>( \
|
||||||
[[kernel]] void copy_g<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int* src_shape [[buffer(2)]], \
|
constant const int* src_shape [[buffer(2)]], \
|
||||||
@ -225,8 +221,7 @@ template <typename T, typename U>
|
|||||||
constant const int& ndim [[buffer(5)]], \
|
constant const int& ndim [[buffer(5)]], \
|
||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name("g" name)]] \
|
template [[host_name("g" name)]] [[kernel]] void copy_gg<itype, otype>( \
|
||||||
[[kernel]] void copy_gg<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int* src_shape [[buffer(2)]], \
|
constant const int* src_shape [[buffer(2)]], \
|
||||||
@ -235,12 +230,14 @@ template <typename T, typename U>
|
|||||||
constant const int& ndim [[buffer(5)]], \
|
constant const int& ndim [[buffer(5)]], \
|
||||||
uint3 index [[thread_position_in_grid]]);
|
uint3 index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_copy_all(tname, itype, otype) \
|
#define instantiate_copy_all(tname, itype, otype) \
|
||||||
instantiate_copy("scopy" #tname, itype, otype, s) \
|
instantiate_copy("scopy" #tname, itype, otype, s) \
|
||||||
instantiate_copy("vcopy" #tname, itype, otype, v) \
|
instantiate_copy("vcopy" #tname, itype, otype, v) \
|
||||||
instantiate_copy_g("gcopy" #tname, itype, otype) \
|
instantiate_copy_g("gcopy" #tname, itype, otype) \
|
||||||
instantiate_copy_g_nd("gcopy" #tname, itype, otype)
|
instantiate_copy_g_nd("gcopy" #tname, itype, otype) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_copy_itype(itname, itype) \
|
#define instantiate_copy_itype(itname, itype) \
|
||||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||||
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
||||||
@ -268,4 +265,4 @@ instantiate_copy_itype(int64, int64_t)
|
|||||||
instantiate_copy_itype(float16, half)
|
instantiate_copy_itype(float16, half)
|
||||||
instantiate_copy_itype(float32, float)
|
instantiate_copy_itype(float32, float)
|
||||||
instantiate_copy_itype(bfloat16, bfloat16_t)
|
instantiate_copy_itype(bfloat16, bfloat16_t)
|
||||||
instantiate_copy_itype(complex64, complex64_t)
|
instantiate_copy_itype(complex64, complex64_t) // clang-format on
|
||||||
|
@ -6,9 +6,8 @@
|
|||||||
// - VkFFT (https://github.com/DTolm/VkFFT)
|
// - VkFFT (https://github.com/DTolm/VkFFT)
|
||||||
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
|
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
|
||||||
|
|
||||||
#include <metal_math>
|
|
||||||
#include <metal_common>
|
#include <metal_common>
|
||||||
|
#include <metal_math>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
@ -23,7 +22,7 @@ float2 complex_mul(float2 a, float2 b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
float2 get_twiddle(int k, int p) {
|
float2 get_twiddle(int k, int p) {
|
||||||
float theta = -1.0f * k * M_PI_F / (2*p);
|
float theta = -1.0f * k * M_PI_F / (2 * p);
|
||||||
|
|
||||||
float2 twiddle;
|
float2 twiddle;
|
||||||
twiddle.x = metal::fast::cos(theta);
|
twiddle.x = metal::fast::cos(theta);
|
||||||
@ -32,7 +31,12 @@ float2 get_twiddle(int k, int p) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// single threaded radix2 implemetation
|
// single threaded radix2 implemetation
|
||||||
void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
|
void radix2(
|
||||||
|
int i,
|
||||||
|
int p,
|
||||||
|
int m,
|
||||||
|
threadgroup float2* read_buf,
|
||||||
|
threadgroup float2* write_buf) {
|
||||||
float2 x_0 = read_buf[i];
|
float2 x_0 = read_buf[i];
|
||||||
float2 x_1 = read_buf[i + m];
|
float2 x_1 = read_buf[i + m];
|
||||||
|
|
||||||
@ -53,11 +57,16 @@ void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float
|
|||||||
}
|
}
|
||||||
|
|
||||||
// single threaded radix4 implemetation
|
// single threaded radix4 implemetation
|
||||||
void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
|
void radix4(
|
||||||
|
int i,
|
||||||
|
int p,
|
||||||
|
int m,
|
||||||
|
threadgroup float2* read_buf,
|
||||||
|
threadgroup float2* write_buf) {
|
||||||
float2 x_0 = read_buf[i];
|
float2 x_0 = read_buf[i];
|
||||||
float2 x_1 = read_buf[i + m];
|
float2 x_1 = read_buf[i + m];
|
||||||
float2 x_2 = read_buf[i + 2*m];
|
float2 x_2 = read_buf[i + 2 * m];
|
||||||
float2 x_3 = read_buf[i + 3*m];
|
float2 x_3 = read_buf[i + 3 * m];
|
||||||
|
|
||||||
// The index within this sub-DFT
|
// The index within this sub-DFT
|
||||||
int k = i & (p - 1);
|
int k = i & (p - 1);
|
||||||
@ -90,11 +99,10 @@ void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float
|
|||||||
|
|
||||||
write_buf[j] = y_0;
|
write_buf[j] = y_0;
|
||||||
write_buf[j + p] = y_1;
|
write_buf[j + p] = y_1;
|
||||||
write_buf[j + 2*p] = y_2;
|
write_buf[j + 2 * p] = y_2;
|
||||||
write_buf[j + 3*p] = y_3;
|
write_buf[j + 3 * p] = y_3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Each FFT is computed entirely in shared GPU memory.
|
// Each FFT is computed entirely in shared GPU memory.
|
||||||
//
|
//
|
||||||
// N is decomposed into radix-2 and radix-4 DFTs:
|
// N is decomposed into radix-2 and radix-4 DFTs:
|
||||||
@ -107,11 +115,10 @@ void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float
|
|||||||
// steps at compile time for a ~20% performance boost.
|
// steps at compile time for a ~20% performance boost.
|
||||||
template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
|
template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
|
||||||
[[kernel]] void fft(
|
[[kernel]] void fft(
|
||||||
const device float2 *in [[buffer(0)]],
|
const device float2* in [[buffer(0)]],
|
||||||
device float2 * out [[buffer(1)]],
|
device float2* out [[buffer(1)]],
|
||||||
uint3 thread_position_in_grid [[thread_position_in_grid]],
|
uint3 thread_position_in_grid [[thread_position_in_grid]],
|
||||||
uint3 threads_per_grid [[threads_per_grid]]) {
|
uint3 threads_per_grid [[threads_per_grid]]) {
|
||||||
|
|
||||||
// Index of the DFT in batch
|
// Index of the DFT in batch
|
||||||
int batch_idx = thread_position_in_grid.x * n;
|
int batch_idx = thread_position_in_grid.x * n;
|
||||||
// The index in the DFT we're working on
|
// The index in the DFT we're working on
|
||||||
@ -132,16 +139,16 @@ template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
|
|||||||
// Copy input into shared memory
|
// Copy input into shared memory
|
||||||
shared_in[i] = in[batch_idx + i];
|
shared_in[i] = in[batch_idx + i];
|
||||||
shared_in[i + m] = in[batch_idx + i + m];
|
shared_in[i + m] = in[batch_idx + i + m];
|
||||||
shared_in[i + 2*m] = in[batch_idx + i + 2*m];
|
shared_in[i + 2 * m] = in[batch_idx + i + 2 * m];
|
||||||
shared_in[i + 3*m] = in[batch_idx + i + 3*m];
|
shared_in[i + 3 * m] = in[batch_idx + i + 3 * m];
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
int p = 1;
|
int p = 1;
|
||||||
|
|
||||||
for (size_t r = 0; r < radix_2_steps; r++) {
|
for (size_t r = 0; r < radix_2_steps; r++) {
|
||||||
radix2(i, p, m*2, read_buf, write_buf);
|
radix2(i, p, m * 2, read_buf, write_buf);
|
||||||
radix2(i + m, p, m*2, read_buf, write_buf);
|
radix2(i + m, p, m * 2, read_buf, write_buf);
|
||||||
p *= 2;
|
p *= 2;
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
@ -167,29 +174,26 @@ template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
|
|||||||
// Copy shared memory to output
|
// Copy shared memory to output
|
||||||
out[batch_idx + i] = read_buf[i];
|
out[batch_idx + i] = read_buf[i];
|
||||||
out[batch_idx + i + m] = read_buf[i + m];
|
out[batch_idx + i + m] = read_buf[i + m];
|
||||||
out[batch_idx + i + 2*m] = read_buf[i + 2*m];
|
out[batch_idx + i + 2 * m] = read_buf[i + 2 * m];
|
||||||
out[batch_idx + i + 3*m] = read_buf[i + 3*m];
|
out[batch_idx + i + 3 * m] = read_buf[i + 3 * m];
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
|
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
|
||||||
template [[host_name("fft_" #name)]] \
|
template [[host_name("fft_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void fft<n, radix_2_steps, radix_4_steps>( \
|
fft<n, radix_2_steps, radix_4_steps>( \
|
||||||
const device float2* in [[buffer(0)]], \
|
const device float2* in [[buffer(0)]], \
|
||||||
device float2* out [[buffer(1)]], \
|
device float2* out [[buffer(1)]], \
|
||||||
uint3 thread_position_in_grid [[thread_position_in_grid]], \
|
uint3 thread_position_in_grid [[thread_position_in_grid]], \
|
||||||
uint3 threads_per_grid [[threads_per_grid]]);
|
uint3 threads_per_grid [[threads_per_grid]]);
|
||||||
|
|
||||||
|
|
||||||
// Explicitly define kernels for each power of 2.
|
// Explicitly define kernels for each power of 2.
|
||||||
|
// clang-format off
|
||||||
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
|
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
|
||||||
instantiate_fft(8, 8, 1, 1)
|
instantiate_fft(8, 8, 1, 1) instantiate_fft(16, 16, 0, 2)
|
||||||
instantiate_fft(16, 16, 0, 2)
|
instantiate_fft(32, 32, 1, 2) instantiate_fft(64, 64, 0, 3)
|
||||||
instantiate_fft(32, 32, 1, 2)
|
instantiate_fft(128, 128, 1, 3) instantiate_fft(256, 256, 0, 4)
|
||||||
instantiate_fft(64, 64, 0, 3)
|
|
||||||
instantiate_fft(128, 128, 1, 3)
|
|
||||||
instantiate_fft(256, 256, 0, 4)
|
|
||||||
instantiate_fft(512, 512, 1, 4)
|
instantiate_fft(512, 512, 1, 4)
|
||||||
instantiate_fft(1024, 1024, 0, 5)
|
instantiate_fft(1024, 1024, 0, 5)
|
||||||
// 2048 is the max that will fit into 32KB of threadgroup memory.
|
// 2048 is the max that will fit into 32KB of threadgroup memory.
|
||||||
// TODO: implement 4 step FFT for larger n.
|
// TODO: implement 4 step FFT for larger n.
|
||||||
instantiate_fft(2048, 2048, 1, 5)
|
instantiate_fft(2048, 2048, 1, 5) // clang-format on
|
||||||
|
@ -14,17 +14,16 @@ using namespace metal;
|
|||||||
|
|
||||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
||||||
METAL_FUNC void gather_impl(
|
METAL_FUNC void gather_impl(
|
||||||
const device T *src [[buffer(0)]],
|
const device T* src [[buffer(0)]],
|
||||||
device T *out [[buffer(1)]],
|
device T* out [[buffer(1)]],
|
||||||
const constant int *src_shape [[buffer(2)]],
|
const constant int* src_shape [[buffer(2)]],
|
||||||
const constant size_t *src_strides [[buffer(3)]],
|
const constant size_t* src_strides [[buffer(3)]],
|
||||||
const constant size_t& src_ndim [[buffer(4)]],
|
const constant size_t& src_ndim [[buffer(4)]],
|
||||||
const constant int *slice_sizes [[buffer(5)]],
|
const constant int* slice_sizes [[buffer(5)]],
|
||||||
const constant int *axes [[buffer(6)]],
|
const constant int* axes [[buffer(6)]],
|
||||||
const thread Indices<IdxT, NIDX>& indices,
|
const thread Indices<IdxT, NIDX>& indices,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
|
||||||
auto ind_idx = index.x;
|
auto ind_idx = index.x;
|
||||||
auto ind_offset = index.y;
|
auto ind_offset = index.y;
|
||||||
|
|
||||||
@ -43,41 +42,33 @@ METAL_FUNC void gather_impl(
|
|||||||
indices.ndim);
|
indices.ndim);
|
||||||
}
|
}
|
||||||
auto ax = axes[i];
|
auto ax = axes[i];
|
||||||
auto idx_val = offset_neg_idx(
|
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||||
indices.buffers[i][idx_loc], src_shape[ax]);
|
|
||||||
src_idx += idx_val * src_strides[ax];
|
src_idx += idx_val * src_strides[ax];
|
||||||
}
|
}
|
||||||
|
|
||||||
auto src_offset = elem_to_loc(
|
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
|
||||||
ind_offset, slice_sizes, src_strides, src_ndim);
|
|
||||||
|
|
||||||
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
||||||
out[out_idx] = src[src_offset + src_idx];
|
out[out_idx] = src[src_offset + src_idx];
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define make_gather_impl(IDX_ARG, IDX_ARR) \
|
#define make_gather_impl(IDX_ARG, IDX_ARR) \
|
||||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
|
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
|
||||||
[[kernel]] void gather( \
|
[[kernel]] void gather( \
|
||||||
const device T *src [[buffer(0)]], \
|
const device T* src [[buffer(0)]], \
|
||||||
device T *out [[buffer(1)]], \
|
device T* out [[buffer(1)]], \
|
||||||
const constant int *src_shape [[buffer(2)]], \
|
const constant int* src_shape [[buffer(2)]], \
|
||||||
const constant size_t *src_strides [[buffer(3)]], \
|
const constant size_t* src_strides [[buffer(3)]], \
|
||||||
const constant size_t& src_ndim [[buffer(4)]], \
|
const constant size_t& src_ndim [[buffer(4)]], \
|
||||||
const constant int *slice_sizes [[buffer(5)]], \
|
const constant int* slice_sizes [[buffer(5)]], \
|
||||||
const constant int *axes [[buffer(6)]], \
|
const constant int* axes [[buffer(6)]], \
|
||||||
const constant int *idx_shapes [[buffer(7)]], \
|
const constant int* idx_shapes [[buffer(7)]], \
|
||||||
const constant size_t *idx_strides [[buffer(8)]], \
|
const constant size_t* idx_strides [[buffer(8)]], \
|
||||||
const constant int& idx_ndim [[buffer(9)]], \
|
const constant int& idx_ndim [[buffer(9)]], \
|
||||||
IDX_ARG(IdxT) \
|
IDX_ARG(IdxT) uint2 index [[thread_position_in_grid]], \
|
||||||
uint2 index [[thread_position_in_grid]], \
|
|
||||||
uint2 grid_dim [[threads_per_grid]]) { \
|
uint2 grid_dim [[threads_per_grid]]) { \
|
||||||
\
|
|
||||||
Indices<IdxT, NIDX> idxs{ \
|
Indices<IdxT, NIDX> idxs{ \
|
||||||
{{IDX_ARR()}}, \
|
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
|
||||||
idx_shapes, \
|
|
||||||
idx_strides, \
|
|
||||||
idx_ndim}; \
|
|
||||||
\
|
\
|
||||||
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
|
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
|
||||||
src, \
|
src, \
|
||||||
@ -90,46 +81,39 @@ template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
|
|||||||
idxs, \
|
idxs, \
|
||||||
index, \
|
index, \
|
||||||
grid_dim); \
|
grid_dim); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define make_gather(n) make_gather_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
|
#define make_gather(n) make_gather_impl(IDX_ARG_##n, IDX_ARR_##n)
|
||||||
|
|
||||||
make_gather(0)
|
make_gather(0) make_gather(1) make_gather(2) make_gather(3) make_gather(4)
|
||||||
make_gather(1)
|
make_gather(5) make_gather(6) make_gather(7) make_gather(8) make_gather(9)
|
||||||
make_gather(2)
|
make_gather(10)
|
||||||
make_gather(3)
|
|
||||||
make_gather(4)
|
|
||||||
make_gather(5)
|
|
||||||
make_gather(6)
|
|
||||||
make_gather(7)
|
|
||||||
make_gather(8)
|
|
||||||
make_gather(9)
|
|
||||||
make_gather(10)
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
// Gather instantiations
|
// Gather instantiations
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
|
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
|
||||||
template [[host_name("gather" name "_" #nidx "" #nd_name)]] \
|
template [[host_name("gather" name "_" #nidx "" #nd_name)]] [[kernel]] void \
|
||||||
[[kernel]] void gather<src_t, idx_t, nidx, nd>( \
|
gather<src_t, idx_t, nidx, nd>( \
|
||||||
const device src_t *src [[buffer(0)]], \
|
const device src_t* src [[buffer(0)]], \
|
||||||
device src_t *out [[buffer(1)]], \
|
device src_t* out [[buffer(1)]], \
|
||||||
const constant int *src_shape [[buffer(2)]], \
|
const constant int* src_shape [[buffer(2)]], \
|
||||||
const constant size_t *src_strides [[buffer(3)]], \
|
const constant size_t* src_strides [[buffer(3)]], \
|
||||||
const constant size_t& src_ndim [[buffer(4)]], \
|
const constant size_t& src_ndim [[buffer(4)]], \
|
||||||
const constant int *slice_sizes [[buffer(5)]], \
|
const constant int* slice_sizes [[buffer(5)]], \
|
||||||
const constant int *axes [[buffer(6)]], \
|
const constant int* axes [[buffer(6)]], \
|
||||||
const constant int *idx_shapes [[buffer(7)]], \
|
const constant int* idx_shapes [[buffer(7)]], \
|
||||||
const constant size_t *idx_strides [[buffer(8)]], \
|
const constant size_t* idx_strides [[buffer(8)]], \
|
||||||
const constant int& idx_ndim [[buffer(9)]], \
|
const constant int& idx_ndim [[buffer(9)]], \
|
||||||
IDX_ARG(idx_t) \
|
IDX_ARG(idx_t) uint2 index [[thread_position_in_grid]], \
|
||||||
uint2 index [[thread_position_in_grid]], \
|
|
||||||
uint2 grid_dim [[threads_per_grid]]);
|
uint2 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
|
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
|
||||||
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name)
|
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gather4(name, src_t, idx_t, nidx) \
|
#define instantiate_gather4(name, src_t, idx_t, nidx) \
|
||||||
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
|
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
|
||||||
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
|
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
|
||||||
@ -148,8 +132,9 @@ instantiate_gather4("int32", int32_t, bool, 0)
|
|||||||
instantiate_gather4("int64", int64_t, bool, 0)
|
instantiate_gather4("int64", int64_t, bool, 0)
|
||||||
instantiate_gather4("float16", half, bool, 0)
|
instantiate_gather4("float16", half, bool, 0)
|
||||||
instantiate_gather4("float32", float, bool, 0)
|
instantiate_gather4("float32", float, bool, 0)
|
||||||
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
|
instantiate_gather4("bfloat16", bfloat16_t, bool, 0) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gather3(name, src_type, ind_type) \
|
#define instantiate_gather3(name, src_type, ind_type) \
|
||||||
instantiate_gather4(name, src_type, ind_type, 1) \
|
instantiate_gather4(name, src_type, ind_type, 1) \
|
||||||
instantiate_gather4(name, src_type, ind_type, 2) \
|
instantiate_gather4(name, src_type, ind_type, 2) \
|
||||||
@ -160,8 +145,9 @@ instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
|
|||||||
instantiate_gather4(name, src_type, ind_type, 7) \
|
instantiate_gather4(name, src_type, ind_type, 7) \
|
||||||
instantiate_gather4(name, src_type, ind_type, 8) \
|
instantiate_gather4(name, src_type, ind_type, 8) \
|
||||||
instantiate_gather4(name, src_type, ind_type, 9) \
|
instantiate_gather4(name, src_type, ind_type, 9) \
|
||||||
instantiate_gather4(name, src_type, ind_type, 10)
|
instantiate_gather4(name, src_type, ind_type, 10) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gather(name, src_type) \
|
#define instantiate_gather(name, src_type) \
|
||||||
instantiate_gather3(#name "bool_", src_type, bool) \
|
instantiate_gather3(#name "bool_", src_type, bool) \
|
||||||
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
||||||
@ -184,4 +170,4 @@ instantiate_gather(int32, int32_t)
|
|||||||
instantiate_gather(int64, int64_t)
|
instantiate_gather(int64, int64_t)
|
||||||
instantiate_gather(float16, half)
|
instantiate_gather(float16, half)
|
||||||
instantiate_gather(float32, float)
|
instantiate_gather(float32, float)
|
||||||
instantiate_gather(bfloat16, bfloat16_t)
|
instantiate_gather(bfloat16, bfloat16_t) // clang-format on
|
@ -1,7 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <metal_stdlib>
|
|
||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
@ -22,10 +22,9 @@ template <
|
|||||||
const int BM, /* Threadgroup rows (in threads) */
|
const int BM, /* Threadgroup rows (in threads) */
|
||||||
const int BN, /* Threadgroup cols (in threads) */
|
const int BN, /* Threadgroup cols (in threads) */
|
||||||
const int TM, /* Thread rows (in elements) */
|
const int TM, /* Thread rows (in elements) */
|
||||||
const int TN , /* Thread cols (in elements) */
|
const int TN, /* Thread cols (in elements) */
|
||||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||||
struct GEMVKernel {
|
struct GEMVKernel {
|
||||||
|
|
||||||
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
||||||
|
|
||||||
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
||||||
@ -35,15 +34,20 @@ struct GEMVKernel {
|
|||||||
//
|
//
|
||||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||||
// and the corresponding scalar from the vector
|
// and the corresponding scalar from the vector
|
||||||
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
// 2. The thread then multiplies and adds to accumulate its local result for
|
||||||
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
// the block
|
||||||
|
// 3. At the end, each thread has accumulated results over all blocks across
|
||||||
|
// the rows
|
||||||
// These are then summed up across the threadgroup
|
// These are then summed up across the threadgroup
|
||||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||||
//
|
//
|
||||||
// Edge case handling:
|
// Edge case handling:
|
||||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
// - The threadgroup with the largest tid will have blocks that exceed the
|
||||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
// matrix
|
||||||
// * The last thread that partially overlaps with the matrix is shifted inwards
|
// * The blocks that start outside the matrix are never read (thread results
|
||||||
|
// remain zero)
|
||||||
|
// * The last thread that partially overlaps with the matrix is shifted
|
||||||
|
// inwards
|
||||||
// such that the thread block fits exactly in the matrix
|
// such that the thread block fits exactly in the matrix
|
||||||
|
|
||||||
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
|
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
|
||||||
@ -64,7 +68,6 @@ struct GEMVKernel {
|
|||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
// Appease compiler
|
// Appease compiler
|
||||||
(void)lid;
|
(void)lid;
|
||||||
|
|
||||||
@ -80,7 +83,7 @@ struct GEMVKernel {
|
|||||||
int out_row = (tid.x * BM + simd_gid) * TM;
|
int out_row = (tid.x * BM + simd_gid) * TM;
|
||||||
|
|
||||||
// Exit simdgroup if rows out of bound
|
// Exit simdgroup if rows out of bound
|
||||||
if(out_row >= out_vec_size)
|
if (out_row >= out_vec_size)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
// Adjust tail simdgroup to ensure in bound reads
|
// Adjust tail simdgroup to ensure in bound reads
|
||||||
@ -90,89 +93,80 @@ struct GEMVKernel {
|
|||||||
mat += out_row * marix_ld;
|
mat += out_row * marix_ld;
|
||||||
|
|
||||||
// Loop over in_vec in blocks of BN * TN
|
// Loop over in_vec in blocks of BN * TN
|
||||||
for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
|
for (int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Prefetch in_vector for threadgroup use
|
// Prefetch in_vector for threadgroup use
|
||||||
if(simd_gid == 0) {
|
if (simd_gid == 0) {
|
||||||
// Main load loop
|
// Main load loop
|
||||||
if(bn + TN <= in_vec_size) {
|
if (bn + TN <= in_vec_size) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
#pragma clang loop unroll(full)
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
for(int tn = 0; tn < TN; tn++) {
|
|
||||||
in_vec_block[tn] = in_vec[bn + tn];
|
in_vec_block[tn] = in_vec[bn + tn];
|
||||||
}
|
}
|
||||||
|
|
||||||
} else { // Edgecase
|
} else { // Edgecase
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(int tn = 0; tn < TN; tn++) {
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
|
in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Load for all rows
|
// Load for all rows
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(int tn = 0; tn < TN; tn++) {
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
v_coeff[tn] = in_vec_block[tn];
|
v_coeff[tn] = in_vec_block[tn];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Per thread work loop
|
// Per thread work loop
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(int tm = 0; tm < TM; tm++) {
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
|
|
||||||
// Load for the row
|
// Load for the row
|
||||||
if(bn + TN <= in_vec_size) {
|
if (bn + TN <= in_vec_size) {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(int tn = 0; tn < TN; tn++) {
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
inter[tn] = mat[tm * marix_ld + bn + tn];
|
inter[tn] = mat[tm * marix_ld + bn + tn];
|
||||||
}
|
}
|
||||||
|
|
||||||
} else { // Edgecase
|
} else { // Edgecase
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(int tn = 0; tn < TN; tn++) {
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
|
int col_idx =
|
||||||
|
(bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
|
||||||
inter[tn] = mat[tm * marix_ld + col_idx];
|
inter[tn] = mat[tm * marix_ld + col_idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accumulate results
|
// Accumulate results
|
||||||
for(int tn = 0; tn < TN; tn++) {
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
result[tm] += inter[tn] * v_coeff[tn];
|
result[tm] += inter[tn] * v_coeff[tn];
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simdgroup accumulations
|
// Simdgroup accumulations
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(int tm = 0; tm < TM; tm++) {
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
result[tm] = simd_sum(result[tm]);
|
result[tm] = simd_sum(result[tm]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write outputs
|
// Write outputs
|
||||||
if(simd_lid == 0) {
|
if (simd_lid == 0) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
#pragma clang loop unroll(full)
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
for(int tm = 0; tm < TM; tm++) {
|
if (kDoAxpby) {
|
||||||
if(kDoAxpby) {
|
out_vec[out_row + tm] = static_cast<T>(alpha) * result[tm] +
|
||||||
out_vec[out_row + tm] =
|
|
||||||
static_cast<T>(alpha) * result[tm] +
|
|
||||||
static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
|
static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
|
||||||
} else {
|
} else {
|
||||||
out_vec[out_row + tm] = result[tm];
|
out_vec[out_row + tm] = result[tm];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@ -187,7 +181,6 @@ template <
|
|||||||
const int TN, /* Thread cols (in elements) */
|
const int TN, /* Thread cols (in elements) */
|
||||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||||
struct GEMVTKernel {
|
struct GEMVTKernel {
|
||||||
|
|
||||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||||
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
||||||
// - Every thread works on a block of (TM, TN)
|
// - Every thread works on a block of (TM, TN)
|
||||||
@ -195,18 +188,22 @@ struct GEMVTKernel {
|
|||||||
//
|
//
|
||||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||||
// and the corresponding scalar from the vector
|
// and the corresponding scalar from the vector
|
||||||
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
// 2. The thread then multiplies and adds to accumulate its local result for
|
||||||
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
// the block
|
||||||
|
// 3. At the end, each thread has accumulated results over all blocks across
|
||||||
|
// the rows
|
||||||
// These are then summed up across the threadgroup
|
// These are then summed up across the threadgroup
|
||||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||||
//
|
//
|
||||||
// Edge case handling:
|
// Edge case handling:
|
||||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
// - The threadgroup with the largest tid will have blocks that exceed the
|
||||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
// matrix
|
||||||
// * The last thread that partially overlaps with the matrix is shifted inwards
|
// * The blocks that start outside the matrix are never read (thread results
|
||||||
|
// remain zero)
|
||||||
|
// * The last thread that partially overlaps with the matrix is shifted
|
||||||
|
// inwards
|
||||||
// such that the thread block fits exactly in the matrix
|
// such that the thread block fits exactly in the matrix
|
||||||
|
|
||||||
|
|
||||||
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
|
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
|
||||||
|
|
||||||
static METAL_FUNC void run(
|
static METAL_FUNC void run(
|
||||||
@ -225,7 +222,6 @@ struct GEMVTKernel {
|
|||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
// Appease compiler
|
// Appease compiler
|
||||||
(void)simd_gid;
|
(void)simd_gid;
|
||||||
(void)simd_lid;
|
(void)simd_lid;
|
||||||
@ -243,77 +239,69 @@ struct GEMVTKernel {
|
|||||||
|
|
||||||
// Edgecase handling
|
// Edgecase handling
|
||||||
if (out_col < out_vec_size) {
|
if (out_col < out_vec_size) {
|
||||||
|
|
||||||
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
|
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
|
||||||
|
|
||||||
// Per thread accumulation main loop
|
// Per thread accumulation main loop
|
||||||
int bm = in_row;
|
int bm = in_row;
|
||||||
for(; bm < in_vec_size; bm += BM * TM) {
|
for (; bm < in_vec_size; bm += BM * TM) {
|
||||||
// Adding a threadgroup_barrier improves performance slightly
|
// Adding a threadgroup_barrier improves performance slightly
|
||||||
// This is possibly it may help exploit cache better
|
// This is possibly it may help exploit cache better
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
if(bm + TM <= in_vec_size) {
|
if (bm + TM <= in_vec_size) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
#pragma clang loop unroll(full)
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
for(int tm = 0; tm < TM; tm++) {
|
|
||||||
v_coeff[tm] = in_vec[bm + tm];
|
v_coeff[tm] = in_vec[bm + tm];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(int tm = 0; tm < TM; tm++) {
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
for(int tn = 0; tn < TN; tn++) {
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||||
}
|
}
|
||||||
for(int tn = 0; tn < TN; tn++) {
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
result[tn] += v_coeff[tm] * inter[tn];
|
result[tn] += v_coeff[tm] * inter[tn];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} else { // Edgecase handling
|
} else { // Edgecase handling
|
||||||
for(int tm = 0; bm + tm < in_vec_size; tm++) {
|
for (int tm = 0; bm + tm < in_vec_size; tm++) {
|
||||||
v_coeff[tm] = in_vec[bm + tm];
|
v_coeff[tm] = in_vec[bm + tm];
|
||||||
|
|
||||||
for(int tn = 0; tn < TN; tn++) {
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||||
}
|
}
|
||||||
for(int tn = 0; tn < TN; tn++) {
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
result[tn] += v_coeff[tm] * inter[tn];
|
result[tn] += v_coeff[tm] * inter[tn];
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Threadgroup collection
|
// Threadgroup collection
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(int i = 0; i < TN; i++) {
|
for (int i = 0; i < TN; i++) {
|
||||||
tgp_results[lid.y * TN + i] = result[i];
|
tgp_results[lid.y * TN + i] = result[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Threadgroup accumulation and writing out results
|
// Threadgroup accumulation and writing out results
|
||||||
if(lid.y == 0 && out_col < out_vec_size) {
|
if (lid.y == 0 && out_col < out_vec_size) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
#pragma clang loop unroll(full)
|
for (int i = 1; i < BM; i++) {
|
||||||
for(int i = 1; i < BM; i++) {
|
#pragma clang loop unroll(full)
|
||||||
|
for (int j = 0; j < TN; j++) {
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for(int j = 0; j < TN; j++) {
|
|
||||||
result[j] += tgp_results[i * TN + j];
|
result[j] += tgp_results[i * TN + j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(int j = 0; j < TN; j++) {
|
for (int j = 0; j < TN; j++) {
|
||||||
|
if (kDoAxpby) {
|
||||||
if(kDoAxpby) {
|
out_vec[out_col + j] = static_cast<T>(alpha) * result[j] +
|
||||||
out_vec[out_col + j] =
|
|
||||||
static_cast<T>(alpha) * result[j] +
|
|
||||||
static_cast<T>(beta) * bias[(out_col + j) * bias_stride];
|
static_cast<T>(beta) * bias[(out_col + j) * bias_stride];
|
||||||
} else {
|
} else {
|
||||||
out_vec[out_col + j] = result[j];
|
out_vec[out_col + j] = result[j];
|
||||||
@ -335,7 +323,7 @@ template <
|
|||||||
const int TN, /* Thread cols (in elements) */
|
const int TN, /* Thread cols (in elements) */
|
||||||
const bool kDoNCBatch, /* Batch ndim > 1 */
|
const bool kDoNCBatch, /* Batch ndim > 1 */
|
||||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv(
|
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv(
|
||||||
const device T* mat [[buffer(0)]],
|
const device T* mat [[buffer(0)]],
|
||||||
const device T* in_vec [[buffer(1)]],
|
const device T* in_vec [[buffer(1)]],
|
||||||
const device T* bias [[buffer(2)]],
|
const device T* bias [[buffer(2)]],
|
||||||
@ -355,16 +343,15 @@ template <
|
|||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, kDoAxpby>;
|
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, kDoAxpby>;
|
||||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||||
|
|
||||||
// Update batch offsets
|
// Update batch offsets
|
||||||
if(kDoNCBatch) {
|
if (kDoNCBatch) {
|
||||||
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
||||||
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
||||||
|
|
||||||
if(kDoAxpby) {
|
if (kDoAxpby) {
|
||||||
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
|
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -372,7 +359,7 @@ template <
|
|||||||
in_vec += tid.z * vector_batch_stride[0];
|
in_vec += tid.z * vector_batch_stride[0];
|
||||||
mat += tid.z * matrix_batch_stride[0];
|
mat += tid.z * matrix_batch_stride[0];
|
||||||
|
|
||||||
if(kDoAxpby) {
|
if (kDoAxpby) {
|
||||||
bias += tid.z * bias_batch_stride[0];
|
bias += tid.z * bias_batch_stride[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -394,15 +381,13 @@ template <
|
|||||||
tid,
|
tid,
|
||||||
lid,
|
lid,
|
||||||
simd_gid,
|
simd_gid,
|
||||||
simd_lid
|
simd_lid);
|
||||||
);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
||||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
|
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \
|
||||||
[[kernel]] void gemv<itype, bm, bn, tm, tn, nc, axpby>( \
|
"_nc" #nc "_axpby" #axpby)]] [[kernel]] void \
|
||||||
|
gemv<itype, bm, bn, tm, tn, nc, axpby>( \
|
||||||
const device itype* mat [[buffer(0)]], \
|
const device itype* mat [[buffer(0)]], \
|
||||||
const device itype* in_vec [[buffer(1)]], \
|
const device itype* in_vec [[buffer(1)]], \
|
||||||
const device itype* bias [[buffer(2)]], \
|
const device itype* bias [[buffer(2)]], \
|
||||||
@ -430,9 +415,8 @@ template <
|
|||||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1)
|
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1)
|
||||||
|
|
||||||
#define instantiate_gemv_blocks(name, itype) \
|
#define instantiate_gemv_blocks(name, itype) \
|
||||||
instantiate_gemv(name, itype, 4, 32, 1, 4) \
|
instantiate_gemv(name, itype, 4, 32, 1, 4) instantiate_gemv( \
|
||||||
instantiate_gemv(name, itype, 4, 32, 4, 4) \
|
name, itype, 4, 32, 4, 4) instantiate_gemv(name, itype, 8, 32, 4, 4)
|
||||||
instantiate_gemv(name, itype, 8, 32, 4, 4)
|
|
||||||
|
|
||||||
instantiate_gemv_blocks(float32, float);
|
instantiate_gemv_blocks(float32, float);
|
||||||
instantiate_gemv_blocks(float16, half);
|
instantiate_gemv_blocks(float16, half);
|
||||||
@ -450,7 +434,7 @@ template <
|
|||||||
const int TN, /* Thread cols (in elements) */
|
const int TN, /* Thread cols (in elements) */
|
||||||
const bool kDoNCBatch, /* Batch ndim > 1 */
|
const bool kDoNCBatch, /* Batch ndim > 1 */
|
||||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t(
|
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_t(
|
||||||
const device T* mat [[buffer(0)]],
|
const device T* mat [[buffer(0)]],
|
||||||
const device T* in_vec [[buffer(1)]],
|
const device T* in_vec [[buffer(1)]],
|
||||||
const device T* bias [[buffer(2)]],
|
const device T* bias [[buffer(2)]],
|
||||||
@ -470,16 +454,15 @@ template <
|
|||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, kDoAxpby>;
|
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, kDoAxpby>;
|
||||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||||
|
|
||||||
// Update batch offsets
|
// Update batch offsets
|
||||||
if(kDoNCBatch) {
|
if (kDoNCBatch) {
|
||||||
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
||||||
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
||||||
|
|
||||||
if(kDoAxpby) {
|
if (kDoAxpby) {
|
||||||
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
|
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -487,7 +470,7 @@ template <
|
|||||||
in_vec += tid.z * vector_batch_stride[0];
|
in_vec += tid.z * vector_batch_stride[0];
|
||||||
mat += tid.z * matrix_batch_stride[0];
|
mat += tid.z * matrix_batch_stride[0];
|
||||||
|
|
||||||
if(kDoAxpby) {
|
if (kDoAxpby) {
|
||||||
bias += tid.z * bias_batch_stride[0];
|
bias += tid.z * bias_batch_stride[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -509,14 +492,13 @@ template <
|
|||||||
tid,
|
tid,
|
||||||
lid,
|
lid,
|
||||||
simd_gid,
|
simd_gid,
|
||||||
simd_lid
|
simd_lid);
|
||||||
);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
||||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
|
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \
|
||||||
[[kernel]] void gemv_t<itype, bm, bn, tm, tn, nc, axpby>( \
|
"_nc" #nc "_axpby" #axpby)]] [[kernel]] void \
|
||||||
|
gemv_t<itype, bm, bn, tm, tn, nc, axpby>( \
|
||||||
const device itype* mat [[buffer(0)]], \
|
const device itype* mat [[buffer(0)]], \
|
||||||
const device itype* in_vec [[buffer(1)]], \
|
const device itype* in_vec [[buffer(1)]], \
|
||||||
const device itype* bias [[buffer(2)]], \
|
const device itype* bias [[buffer(2)]], \
|
||||||
@ -537,20 +519,23 @@ template <
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
|
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
|
||||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 0) \
|
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 0) \
|
||||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 1) \
|
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 1) \
|
||||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 0) \
|
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 0) \
|
||||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1)
|
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemv_t_blocks(name, itype) \
|
#define instantiate_gemv_t_blocks(name, itype) \
|
||||||
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
|
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
|
||||||
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
|
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
|
||||||
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
|
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
|
||||||
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
|
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
|
||||||
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
|
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
|
||||||
instantiate_gemv_t(name, itype, 8, 128, 4, 4)
|
instantiate_gemv_t(name, itype, 8, 128, 4, 4) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_gemv_t_blocks(float32, float);
|
instantiate_gemv_t_blocks(float32, float);
|
||||||
instantiate_gemv_t_blocks(float16, half);
|
instantiate_gemv_t_blocks(float16, half);
|
||||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t);
|
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
|
@ -99,7 +99,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
if ((lid * N_READS + i) < axis_size) {
|
if ((lid * N_READS + i) < axis_size) {
|
||||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||||
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
out[i] =
|
||||||
|
w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -192,13 +193,15 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float xi = (x[r + i] - mean) * normalizer;
|
float xi = (x[r + i] - mean) * normalizer;
|
||||||
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
|
out[r + i] =
|
||||||
|
w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
if ((r + lid * N_READS + i) < axis_size) {
|
if ((r + lid * N_READS + i) < axis_size) {
|
||||||
float xi = (x[r + i] - mean) * normalizer;
|
float xi = (x[r + i] - mean) * normalizer;
|
||||||
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
|
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) +
|
||||||
|
b[b_stride * (i + r)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -323,7 +326,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
if (lid * N_READS + N_READS <= axis_size) {
|
if (lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||||
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
gx[i] = static_cast<T>(
|
||||||
|
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||||
thread_x[i] * meanwgxc * normalizer2);
|
thread_x[i] * meanwgxc * normalizer2);
|
||||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||||
}
|
}
|
||||||
@ -331,7 +335,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
if ((lid * N_READS + i) < axis_size) {
|
if ((lid * N_READS + i) < axis_size) {
|
||||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||||
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
gx[i] = static_cast<T>(
|
||||||
|
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||||
thread_x[i] * meanwgxc * normalizer2);
|
thread_x[i] * meanwgxc * normalizer2);
|
||||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||||
}
|
}
|
||||||
@ -460,8 +465,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float xi = (x[i + r] - mean) * normalizer;
|
float xi = (x[i + r] - mean) * normalizer;
|
||||||
float wi = w[(i + r) * w_stride];
|
float wi = w[(i + r) * w_stride];
|
||||||
float gi = g[i + r];
|
float gi = g[i + r];
|
||||||
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
|
gx[i + r] = static_cast<T>(
|
||||||
xi * meanwgxc * normalizer2);
|
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
||||||
gw[i + r] = static_cast<T>(gi * xi);
|
gw[i + r] = static_cast<T>(gi * xi);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -470,8 +475,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float xi = (x[i + r] - mean) * normalizer;
|
float xi = (x[i + r] - mean) * normalizer;
|
||||||
float wi = w[(i + r) * w_stride];
|
float wi = w[(i + r) * w_stride];
|
||||||
float gi = g[i + r];
|
float gi = g[i + r];
|
||||||
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
|
gx[i + r] = static_cast<T>(
|
||||||
xi * meanwgxc * normalizer2);
|
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
||||||
gw[i + r] = static_cast<T>(gi * xi);
|
gw[i + r] = static_cast<T>(gi * xi);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -548,6 +553,4 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
|
|
||||||
instantiate_layer_norm(float32, float)
|
instantiate_layer_norm(float32, float)
|
||||||
instantiate_layer_norm(float16, half)
|
instantiate_layer_norm(float16, half)
|
||||||
instantiate_layer_norm(bfloat16, bfloat16_t)
|
instantiate_layer_norm(bfloat16, bfloat16_t) // clang-format on
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <metal_stdlib>
|
|
||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
@ -15,30 +15,31 @@ using namespace metal;
|
|||||||
|
|
||||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||||
|
|
||||||
|
|
||||||
template <typename T, typename U, int values_per_thread, int bits>
|
template <typename T, typename U, int values_per_thread, int bits>
|
||||||
inline U load_vector(const device T *x, thread U *x_thread) {
|
inline U load_vector(const device T* x, thread U* x_thread) {
|
||||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
static_assert(
|
||||||
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
|
||||||
U sum = 0;
|
U sum = 0;
|
||||||
|
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
for (int i = 0; i < values_per_thread; i += 4) {
|
for (int i = 0; i < values_per_thread; i += 4) {
|
||||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||||
x_thread[i] = x[i];
|
x_thread[i] = x[i];
|
||||||
x_thread[i+1] = x[i+1] / 4.0f;
|
x_thread[i + 1] = x[i + 1] / 4.0f;
|
||||||
x_thread[i+2] = x[i+2] / 16.0f;
|
x_thread[i + 2] = x[i + 2] / 16.0f;
|
||||||
x_thread[i+3] = x[i+3] / 64.0f;
|
x_thread[i + 3] = x[i + 3] / 64.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
else if (bits == 4) {
|
else if (bits == 4) {
|
||||||
for (int i = 0; i < values_per_thread; i += 4) {
|
for (int i = 0; i < values_per_thread; i += 4) {
|
||||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||||
x_thread[i] = x[i];
|
x_thread[i] = x[i];
|
||||||
x_thread[i+1] = x[i+1] / 16.0f;
|
x_thread[i + 1] = x[i + 1] / 16.0f;
|
||||||
x_thread[i+2] = x[i+2] / 256.0f;
|
x_thread[i + 2] = x[i + 2] / 256.0f;
|
||||||
x_thread[i+3] = x[i+3] / 4096.0f;
|
x_thread[i + 3] = x[i + 3] / 4096.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,33 +54,35 @@ inline U load_vector(const device T *x, thread U *x_thread) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, int values_per_thread, int bits>
|
template <typename T, typename U, int values_per_thread, int bits>
|
||||||
inline U load_vector_safe(const device T *x, thread U *x_thread, int N) {
|
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
||||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
static_assert(
|
||||||
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
|
||||||
U sum = 0;
|
U sum = 0;
|
||||||
|
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
for (int i = 0; i < N; i += 4) {
|
for (int i = 0; i < N; i += 4) {
|
||||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||||
x_thread[i] = x[i];
|
x_thread[i] = x[i];
|
||||||
x_thread[i+1] = x[i+1] / 4.0f;
|
x_thread[i + 1] = x[i + 1] / 4.0f;
|
||||||
x_thread[i+2] = x[i+2] / 16.0f;
|
x_thread[i + 2] = x[i + 2] / 16.0f;
|
||||||
x_thread[i+3] = x[i+3] / 64.0f;
|
x_thread[i + 3] = x[i + 3] / 64.0f;
|
||||||
}
|
}
|
||||||
for (int i=N; i<values_per_thread; i++) {
|
for (int i = N; i < values_per_thread; i++) {
|
||||||
x_thread[i] = 0;
|
x_thread[i] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
else if (bits == 4) {
|
else if (bits == 4) {
|
||||||
for (int i = 0; i < N; i += 4) {
|
for (int i = 0; i < N; i += 4) {
|
||||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||||
x_thread[i] = x[i];
|
x_thread[i] = x[i];
|
||||||
x_thread[i+1] = x[i+1] / 16.0f;
|
x_thread[i + 1] = x[i + 1] / 16.0f;
|
||||||
x_thread[i+2] = x[i+2] / 256.0f;
|
x_thread[i + 2] = x[i + 2] / 256.0f;
|
||||||
x_thread[i+3] = x[i+3] / 4096.0f;
|
x_thread[i + 3] = x[i + 3] / 4096.0f;
|
||||||
}
|
}
|
||||||
for (int i=N; i<values_per_thread; i++) {
|
for (int i = N; i < values_per_thread; i++) {
|
||||||
x_thread[i] = 0;
|
x_thread[i] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -89,7 +92,7 @@ inline U load_vector_safe(const device T *x, thread U *x_thread, int N) {
|
|||||||
sum += x[i];
|
sum += x[i];
|
||||||
x_thread[i] = x[i];
|
x_thread[i] = x[i];
|
||||||
}
|
}
|
||||||
for (int i=N; i<values_per_thread; i++) {
|
for (int i = N; i < values_per_thread; i++) {
|
||||||
x_thread[i] = 0;
|
x_thread[i] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -98,29 +101,36 @@ inline U load_vector_safe(const device T *x, thread U *x_thread, int N) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename U, int values_per_thread, int bits>
|
template <typename U, int values_per_thread, int bits>
|
||||||
inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum) {
|
inline U qdot(
|
||||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
const device uint8_t* w,
|
||||||
|
const thread U* x_thread,
|
||||||
|
U scale,
|
||||||
|
U bias,
|
||||||
|
U sum) {
|
||||||
|
static_assert(
|
||||||
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
|
||||||
U accum = 0;
|
U accum = 0;
|
||||||
|
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||||
accum += (
|
accum +=
|
||||||
x_thread[4*i] * (w[i] & 0x03)
|
(x_thread[4 * i] * (w[i] & 0x03) +
|
||||||
+ x_thread[4*i+1] * (w[i] & 0x0c)
|
x_thread[4 * i + 1] * (w[i] & 0x0c) +
|
||||||
+ x_thread[4*i+2] * (w[i] & 0x30)
|
x_thread[4 * i + 2] * (w[i] & 0x30) +
|
||||||
+ x_thread[4*i+3] * (w[i] & 0xc0));
|
x_thread[4 * i + 3] * (w[i] & 0xc0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
else if (bits == 4) {
|
else if (bits == 4) {
|
||||||
const device uint16_t* ws = (const device uint16_t*)w;
|
const device uint16_t* ws = (const device uint16_t*)w;
|
||||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||||
accum += (
|
accum +=
|
||||||
x_thread[4*i] * (ws[i] & 0x000f)
|
(x_thread[4 * i] * (ws[i] & 0x000f) +
|
||||||
+ x_thread[4*i+1] * (ws[i] & 0x00f0)
|
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
|
||||||
+ x_thread[4*i+2] * (ws[i] & 0x0f00)
|
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
|
||||||
+ x_thread[4*i+3] * (ws[i] & 0xf000));
|
x_thread[4 * i + 3] * (ws[i] & 0xf000));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,29 +144,37 @@ inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename U, int values_per_thread, int bits>
|
template <typename U, int values_per_thread, int bits>
|
||||||
inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum, int N) {
|
inline U qdot_safe(
|
||||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
const device uint8_t* w,
|
||||||
|
const thread U* x_thread,
|
||||||
|
U scale,
|
||||||
|
U bias,
|
||||||
|
U sum,
|
||||||
|
int N) {
|
||||||
|
static_assert(
|
||||||
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
|
||||||
U accum = 0;
|
U accum = 0;
|
||||||
|
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
for (int i = 0; i < (N / 4); i++) {
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
accum += (
|
accum +=
|
||||||
x_thread[4*i] * (w[i] & 0x03)
|
(x_thread[4 * i] * (w[i] & 0x03) +
|
||||||
+ x_thread[4*i+1] * (w[i] & 0x0c)
|
x_thread[4 * i + 1] * (w[i] & 0x0c) +
|
||||||
+ x_thread[4*i+2] * (w[i] & 0x30)
|
x_thread[4 * i + 2] * (w[i] & 0x30) +
|
||||||
+ x_thread[4*i+3] * (w[i] & 0xc0));
|
x_thread[4 * i + 3] * (w[i] & 0xc0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
else if (bits == 4) {
|
else if (bits == 4) {
|
||||||
const device uint16_t* ws = (const device uint16_t*)w;
|
const device uint16_t* ws = (const device uint16_t*)w;
|
||||||
for (int i = 0; i < (N / 4); i++) {
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
accum += (
|
accum +=
|
||||||
x_thread[4*i] * (ws[i] & 0x000f)
|
(x_thread[4 * i] * (ws[i] & 0x000f) +
|
||||||
+ x_thread[4*i+1] * (ws[i] & 0x00f0)
|
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
|
||||||
+ x_thread[4*i+2] * (ws[i] & 0x0f00)
|
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
|
||||||
+ x_thread[4*i+3] * (ws[i] & 0xf000));
|
x_thread[4 * i + 3] * (ws[i] & 0xf000));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -170,16 +188,19 @@ inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename U, int values_per_thread, int bits>
|
template <typename U, int values_per_thread, int bits>
|
||||||
inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
inline void
|
||||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||||
|
static_assert(
|
||||||
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
||||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||||
result[4*i] += x * (s[0] * (w[i] & 0x03) + bias);
|
result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
|
||||||
result[4*i+1] += x * (s[1] * (w[i] & 0x0c) + bias);
|
result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
|
||||||
result[4*i+2] += x * (s[2] * (w[i] & 0x30) + bias);
|
result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
|
||||||
result[4*i+3] += x * (s[3] * (w[i] & 0xc0) + bias);
|
result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -187,10 +208,10 @@ inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* resu
|
|||||||
const thread uint16_t* ws = (const thread uint16_t*)w;
|
const thread uint16_t* ws = (const thread uint16_t*)w;
|
||||||
U s[4] = {scale, scale / 16.0f, scale / 256.0f, scale / 4096.0f};
|
U s[4] = {scale, scale / 16.0f, scale / 256.0f, scale / 4096.0f};
|
||||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||||
result[4*i] += x * (s[0] * (ws[i] & 0x000f) + bias);
|
result[4 * i] += x * (s[0] * (ws[i] & 0x000f) + bias);
|
||||||
result[4*i+1] += x * (s[1] * (ws[i] & 0x00f0) + bias);
|
result[4 * i + 1] += x * (s[1] * (ws[i] & 0x00f0) + bias);
|
||||||
result[4*i+2] += x * (s[2] * (ws[i] & 0x0f00) + bias);
|
result[4 * i + 2] += x * (s[2] * (ws[i] & 0x0f00) + bias);
|
||||||
result[4*i+3] += x * (s[3] * (ws[i] & 0xf000) + bias);
|
result[4 * i + 3] += x * (s[3] * (ws[i] & 0xf000) + bias);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -202,27 +223,38 @@ inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* resu
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename U, int N, int bits>
|
template <typename U, int N, int bits>
|
||||||
inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
inline void
|
||||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
||||||
|
static_assert(
|
||||||
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
U s[4] = {scale, scale / static_cast<U>(4.0f), scale / static_cast<U>(16.0f), scale / static_cast<U>(64.0f)};
|
U s[4] = {
|
||||||
|
scale,
|
||||||
|
scale / static_cast<U>(4.0f),
|
||||||
|
scale / static_cast<U>(16.0f),
|
||||||
|
scale / static_cast<U>(64.0f)};
|
||||||
for (int i = 0; i < (N / 4); i++) {
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
w_local[4*i] = s[0] * (w[i] & 0x03) + bias;
|
w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
|
||||||
w_local[4*i+1] = s[1] * (w[i] & 0x0c) + bias;
|
w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
|
||||||
w_local[4*i+2] = s[2] * (w[i] & 0x30) + bias;
|
w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
|
||||||
w_local[4*i+3] = s[3] * (w[i] & 0xc0) + bias;
|
w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
else if (bits == 4) {
|
else if (bits == 4) {
|
||||||
const device uint16_t* ws = (const device uint16_t*)w;
|
const device uint16_t* ws = (const device uint16_t*)w;
|
||||||
U s[4] = {scale, scale / static_cast<U>(16.0f), scale / static_cast<U>(256.0f), scale / static_cast<U>(4096.0f)};
|
U s[4] = {
|
||||||
|
scale,
|
||||||
|
scale / static_cast<U>(16.0f),
|
||||||
|
scale / static_cast<U>(256.0f),
|
||||||
|
scale / static_cast<U>(4096.0f)};
|
||||||
for (int i = 0; i < (N / 4); i++) {
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
w_local[4*i] = s[0] * (ws[i] & 0x000f) + bias;
|
w_local[4 * i] = s[0] * (ws[i] & 0x000f) + bias;
|
||||||
w_local[4*i+1] = s[1] * (ws[i] & 0x00f0) + bias;
|
w_local[4 * i + 1] = s[1] * (ws[i] & 0x00f0) + bias;
|
||||||
w_local[4*i+2] = s[2] * (ws[i] & 0x0f00) + bias;
|
w_local[4 * i + 2] = s[2] * (ws[i] & 0x0f00) + bias;
|
||||||
w_local[4*i+3] = s[3] * (ws[i] & 0xf000) + bias;
|
w_local[4 * i + 3] = s[3] * (ws[i] & 0xf000) + bias;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -243,13 +275,20 @@ template <
|
|||||||
short group_size,
|
short group_size,
|
||||||
short bits>
|
short bits>
|
||||||
struct QuantizedBlockLoader {
|
struct QuantizedBlockLoader {
|
||||||
static_assert(BCOLS <= group_size, "The group size should be larger than the columns");
|
static_assert(
|
||||||
static_assert(group_size % BCOLS == 0, "The group size should be divisible by the columns");
|
BCOLS <= group_size,
|
||||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
"The group size should be larger than the columns");
|
||||||
|
static_assert(
|
||||||
|
group_size % BCOLS == 0,
|
||||||
|
"The group size should be divisible by the columns");
|
||||||
|
static_assert(
|
||||||
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
|
||||||
MLX_MTL_CONST short pack_factor = 32 / bits;
|
MLX_MTL_CONST short pack_factor = 32 / bits;
|
||||||
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
||||||
MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
|
MLX_MTL_CONST short n_reads =
|
||||||
|
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
|
||||||
MLX_MTL_CONST short group_steps = group_size / BCOLS;
|
MLX_MTL_CONST short group_steps = group_size / BCOLS;
|
||||||
|
|
||||||
const int src_ld;
|
const int src_ld;
|
||||||
@ -275,7 +314,8 @@ struct QuantizedBlockLoader {
|
|||||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||||
: src_ld(src_ld_),
|
: src_ld(src_ld_),
|
||||||
tile_stride(reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
|
tile_stride(
|
||||||
|
reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
|
||||||
group_step_cnt(0),
|
group_step_cnt(0),
|
||||||
group_stride(BROWS * src_ld / group_size),
|
group_stride(BROWS * src_ld / group_size),
|
||||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||||
@ -293,8 +333,9 @@ struct QuantizedBlockLoader {
|
|||||||
|
|
||||||
T scale = *scales;
|
T scale = *scales;
|
||||||
T bias = *biases;
|
T bias = *biases;
|
||||||
for (int i=0; i<n_reads; i++) {
|
for (int i = 0; i < n_reads; i++) {
|
||||||
dequantize<T, pack_factor, bits>((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
dequantize<T, pack_factor, bits>(
|
||||||
|
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -304,14 +345,14 @@ struct QuantizedBlockLoader {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (reduction_dim == 1 && bi >= src_tile_dim.y) {
|
if (reduction_dim == 1 && bi >= src_tile_dim.y) {
|
||||||
for (int i=0; i<n_reads*pack_factor; i++) {
|
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||||
dst[i] = T(0);
|
dst[i] = T(0);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (reduction_dim == 0 && bi >= src_tile_dim.x) {
|
if (reduction_dim == 0 && bi >= src_tile_dim.x) {
|
||||||
for (int i=0; i<n_reads*pack_factor; i++) {
|
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||||
dst[i] = T(0);
|
dst[i] = T(0);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
@ -319,8 +360,9 @@ struct QuantizedBlockLoader {
|
|||||||
|
|
||||||
T scale = *scales;
|
T scale = *scales;
|
||||||
T bias = *biases;
|
T bias = *biases;
|
||||||
for (int i=0; i<n_reads; i++) {
|
for (int i = 0; i < n_reads; i++) {
|
||||||
dequantize<T, pack_factor, bits>((device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
dequantize<T, pack_factor, bits>(
|
||||||
|
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -357,7 +399,6 @@ template <typename T, int group_size, int bits, int packs_per_thread>
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
constexpr int num_simdgroups = 2;
|
constexpr int num_simdgroups = 2;
|
||||||
constexpr int results_per_simdgroup = 4;
|
constexpr int results_per_simdgroup = 4;
|
||||||
constexpr int pack_factor = 32 / bits;
|
constexpr int pack_factor = 32 / bits;
|
||||||
@ -373,7 +414,8 @@ template <typename T, int group_size, int bits, int packs_per_thread>
|
|||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||||
const int in_vec_size_g = in_vec_size / group_size;
|
const int in_vec_size_g = in_vec_size / group_size;
|
||||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup;
|
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
||||||
|
simd_gid * results_per_simdgroup;
|
||||||
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
||||||
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
@ -384,7 +426,8 @@ template <typename T, int group_size, int bits, int packs_per_thread>
|
|||||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||||
|
|
||||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
const device uint8_t* wl =
|
||||||
|
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||||
const device T* sl = scales + row * in_vec_size_g;
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
|
||||||
@ -407,7 +450,6 @@ template <typename T, int group_size, int bits, int packs_per_thread>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T, const int group_size, const int bits>
|
template <typename T, const int group_size, const int bits>
|
||||||
[[kernel]] void qmv(
|
[[kernel]] void qmv(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
@ -420,7 +462,6 @@ template <typename T, const int group_size, const int bits>
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
constexpr int num_simdgroups = 2;
|
constexpr int num_simdgroups = 2;
|
||||||
constexpr int results_per_simdgroup = 4;
|
constexpr int results_per_simdgroup = 4;
|
||||||
constexpr int packs_per_thread = 1;
|
constexpr int packs_per_thread = 1;
|
||||||
@ -437,7 +478,8 @@ template <typename T, const int group_size, const int bits>
|
|||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||||
const int in_vec_size_g = in_vec_size / group_size;
|
const int in_vec_size_g = in_vec_size / group_size;
|
||||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup;
|
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
||||||
|
simd_gid * results_per_simdgroup;
|
||||||
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
|
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
|
||||||
|
|
||||||
if (out_row >= out_vec_size) {
|
if (out_row >= out_vec_size) {
|
||||||
@ -454,17 +496,19 @@ template <typename T, const int group_size, const int bits>
|
|||||||
y += tid.z * out_vec_size + out_row;
|
y += tid.z * out_vec_size + out_row;
|
||||||
|
|
||||||
int k = 0;
|
int k = 0;
|
||||||
for (; k < in_vec_size-block_size; k += block_size) {
|
for (; k < in_vec_size - block_size; k += block_size) {
|
||||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||||
|
|
||||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
const device uint8_t* wl =
|
||||||
|
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||||
const device T* sl = scales + row * in_vec_size_g;
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
|
||||||
U s = sl[0];
|
U s = sl[0];
|
||||||
U b = bl[0];
|
U b = bl[0];
|
||||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
result[row] +=
|
||||||
|
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
w += block_size / pack_factor;
|
w += block_size / pack_factor;
|
||||||
@ -472,11 +516,16 @@ template <typename T, const int group_size, const int bits>
|
|||||||
biases += block_size / group_size;
|
biases += block_size / group_size;
|
||||||
x += block_size;
|
x += block_size;
|
||||||
}
|
}
|
||||||
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
|
const int remaining = clamp(
|
||||||
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
|
||||||
|
0,
|
||||||
|
values_per_thread);
|
||||||
|
U sum =
|
||||||
|
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||||
|
|
||||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
const device uint8_t* wl =
|
||||||
|
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||||
const device T* sl = scales + row * in_vec_size_g;
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
|
||||||
@ -502,17 +551,19 @@ template <typename T, const int group_size, const int bits>
|
|||||||
y += tid.z * out_vec_size + used_out_row;
|
y += tid.z * out_vec_size + used_out_row;
|
||||||
|
|
||||||
int k = 0;
|
int k = 0;
|
||||||
for (; k < in_vec_size-block_size; k += block_size) {
|
for (; k < in_vec_size - block_size; k += block_size) {
|
||||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||||
|
|
||||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
const device uint8_t* wl =
|
||||||
|
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||||
const device T* sl = scales + row * in_vec_size_g;
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
|
||||||
U s = sl[0];
|
U s = sl[0];
|
||||||
U b = bl[0];
|
U b = bl[0];
|
||||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
result[row] +=
|
||||||
|
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
w += block_size / pack_factor;
|
w += block_size / pack_factor;
|
||||||
@ -520,17 +571,23 @@ template <typename T, const int group_size, const int bits>
|
|||||||
biases += block_size / group_size;
|
biases += block_size / group_size;
|
||||||
x += block_size;
|
x += block_size;
|
||||||
}
|
}
|
||||||
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
|
const int remaining = clamp(
|
||||||
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
|
||||||
|
0,
|
||||||
|
values_per_thread);
|
||||||
|
U sum =
|
||||||
|
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||||
|
|
||||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
const device uint8_t* wl =
|
||||||
|
(const device uint8_t*)(w + row * in_vec_size_w);
|
||||||
const device T* sl = scales + row * in_vec_size_g;
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
|
||||||
U s = sl[0];
|
U s = sl[0];
|
||||||
U b = bl[0];
|
U b = bl[0];
|
||||||
result[row] += qdot_safe<U, values_per_thread, bits>(wl, x_thread, s, b, sum, remaining);
|
result[row] += qdot_safe<U, values_per_thread, bits>(
|
||||||
|
wl, x_thread, s, b, sum, remaining);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
@ -542,7 +599,6 @@ template <typename T, const int group_size, const int bits>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T, const int group_size, const int bits>
|
template <typename T, const int group_size, const int bits>
|
||||||
[[kernel]] void qvm(
|
[[kernel]] void qvm(
|
||||||
const device T* x [[buffer(0)]],
|
const device T* x [[buffer(0)]],
|
||||||
@ -555,7 +611,6 @@ template <typename T, const int group_size, const int bits>
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
constexpr int num_simdgroups = 8;
|
constexpr int num_simdgroups = 8;
|
||||||
constexpr int pack_factor = 32 / bits;
|
constexpr int pack_factor = 32 / bits;
|
||||||
constexpr int blocksize = SIMD_SIZE;
|
constexpr int blocksize = SIMD_SIZE;
|
||||||
@ -590,7 +645,8 @@ template <typename T, const int group_size, const int bits>
|
|||||||
bias = biases[(i + simd_lid) * out_vec_size_g];
|
bias = biases[(i + simd_lid) * out_vec_size_g];
|
||||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||||
|
|
||||||
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
|
qouter<U, pack_factor, bits>(
|
||||||
|
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||||
}
|
}
|
||||||
if (static_cast<int>(i + simd_lid) < in_vec_size) {
|
if (static_cast<int>(i + simd_lid) < in_vec_size) {
|
||||||
x_local = x[i + simd_lid];
|
x_local = x[i + simd_lid];
|
||||||
@ -603,25 +659,32 @@ template <typename T, const int group_size, const int bits>
|
|||||||
bias = 0;
|
bias = 0;
|
||||||
w_local = 0;
|
w_local = 0;
|
||||||
}
|
}
|
||||||
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
|
qouter<U, pack_factor, bits>(
|
||||||
|
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||||
|
|
||||||
// Accumulate in the simdgroup
|
// Accumulate in the simdgroup
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int k=0; k<pack_factor; k++) {
|
for (int k = 0; k < pack_factor; k++) {
|
||||||
result[k] = simd_sum(result[k]);
|
result[k] = simd_sum(result[k]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the result
|
// Store the result
|
||||||
if (simd_lid == 0) {
|
if (simd_lid == 0) {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int k=0; k<pack_factor; k++) {
|
for (int k = 0; k < pack_factor; k++) {
|
||||||
y[k] = static_cast<T>(result[k]);
|
y[k] = static_cast<T>(result[k]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits, const bool aligned_N>
|
typename T,
|
||||||
|
const int BM,
|
||||||
|
const int BK,
|
||||||
|
const int BN,
|
||||||
|
const int group_size,
|
||||||
|
const int bits,
|
||||||
|
const bool aligned_N>
|
||||||
[[kernel]] void qmm_t(
|
[[kernel]] void qmm_t(
|
||||||
const device T* x [[buffer(0)]],
|
const device T* x [[buffer(0)]],
|
||||||
const device uint32_t* w [[buffer(1)]],
|
const device uint32_t* w [[buffer(1)]],
|
||||||
@ -635,7 +698,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
uint lid [[thread_index_in_threadgroup]],
|
uint lid [[thread_index_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||||
|
|
||||||
@ -647,9 +709,19 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
|
|
||||||
// Instantiate the appropriate BlockMMA and Loader
|
// Instantiate the appropriate BlockMMA and Loader
|
||||||
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
|
using mma_t = mlx::steel::
|
||||||
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
|
||||||
using loader_w_t = QuantizedBlockLoader<T, BN, BK, BK_padded, 1, WM * WN * SIMD_SIZE, group_size, bits>;
|
using loader_x_t =
|
||||||
|
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
||||||
|
using loader_w_t = QuantizedBlockLoader<
|
||||||
|
T,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
BK_padded,
|
||||||
|
1,
|
||||||
|
WM * WN * SIMD_SIZE,
|
||||||
|
group_size,
|
||||||
|
bits>;
|
||||||
|
|
||||||
threadgroup T Xs[BM * BK_padded];
|
threadgroup T Xs[BM * BK_padded];
|
||||||
threadgroup T Ws[BN * BK_padded];
|
threadgroup T Ws[BN * BK_padded];
|
||||||
@ -675,7 +747,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
|
|
||||||
if (num_els < BM) {
|
if (num_els < BM) {
|
||||||
if (!aligned_N && num_outs < BN) {
|
if (!aligned_N && num_outs < BN) {
|
||||||
for (int k=0; k<K; k += BK) {
|
for (int k = 0; k < K; k += BK) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
loader_x.load_safe(short2(BK, num_els));
|
loader_x.load_safe(short2(BK, num_els));
|
||||||
loader_w.load_safe(short2(BK, num_outs));
|
loader_w.load_safe(short2(BK, num_outs));
|
||||||
@ -685,7 +757,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
loader_w.next();
|
loader_w.next();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int k=0; k<K; k += BK) {
|
for (int k = 0; k < K; k += BK) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
loader_x.load_safe(short2(BK, num_els));
|
loader_x.load_safe(short2(BK, num_els));
|
||||||
loader_w.load_unsafe();
|
loader_w.load_unsafe();
|
||||||
@ -697,7 +769,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!aligned_N && num_outs < BN) {
|
if (!aligned_N && num_outs < BN) {
|
||||||
for (int k=0; k<K; k += BK) {
|
for (int k = 0; k < K; k += BK) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
loader_x.load_unsafe();
|
loader_x.load_unsafe();
|
||||||
loader_w.load_safe(short2(BK, num_outs));
|
loader_w.load_safe(short2(BK, num_outs));
|
||||||
@ -707,7 +779,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
loader_w.next();
|
loader_w.next();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int k=0; k<K; k += BK) {
|
for (int k = 0; k < K; k += BK) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
loader_x.load_unsafe();
|
loader_x.load_unsafe();
|
||||||
loader_w.load_unsafe();
|
loader_w.load_unsafe();
|
||||||
@ -728,8 +800,13 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
|
typename T,
|
||||||
|
const int BM,
|
||||||
|
const int BK,
|
||||||
|
const int BN,
|
||||||
|
const int group_size,
|
||||||
|
const int bits>
|
||||||
[[kernel]] void qmm_n(
|
[[kernel]] void qmm_n(
|
||||||
const device T* x [[buffer(0)]],
|
const device T* x [[buffer(0)]],
|
||||||
const device uint32_t* w [[buffer(1)]],
|
const device uint32_t* w [[buffer(1)]],
|
||||||
@ -743,7 +820,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
uint lid [[thread_index_in_threadgroup]],
|
uint lid [[thread_index_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||||
|
|
||||||
@ -756,9 +832,19 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||||
|
|
||||||
// Instantiate the appropriate BlockMMA and Loader
|
// Instantiate the appropriate BlockMMA and Loader
|
||||||
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
|
using mma_t = mlx::steel::
|
||||||
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
|
BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
|
||||||
using loader_w_t = QuantizedBlockLoader<T, BK, BN, BN_padded, 0, WM * WN * SIMD_SIZE, group_size, bits>;
|
using loader_x_t = mlx::steel::
|
||||||
|
BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
|
||||||
|
using loader_w_t = QuantizedBlockLoader<
|
||||||
|
T,
|
||||||
|
BK,
|
||||||
|
BN,
|
||||||
|
BN_padded,
|
||||||
|
0,
|
||||||
|
WM * WN * SIMD_SIZE,
|
||||||
|
group_size,
|
||||||
|
bits>;
|
||||||
|
|
||||||
threadgroup T Xs[BM * BK_padded];
|
threadgroup T Xs[BM * BK_padded];
|
||||||
threadgroup T Ws[BK * BN_padded];
|
threadgroup T Ws[BK * BN_padded];
|
||||||
@ -780,8 +866,8 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
|
|
||||||
if (num_els < BM) {
|
if (num_els < BM) {
|
||||||
if ((K % BK) != 0) {
|
if ((K % BK) != 0) {
|
||||||
const int k_blocks = K/BK;
|
const int k_blocks = K / BK;
|
||||||
for (int k=0; k<k_blocks; k++) {
|
for (int k = 0; k < k_blocks; k++) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
loader_x.load_safe(short2(BK, num_els));
|
loader_x.load_safe(short2(BK, num_els));
|
||||||
loader_w.load_unsafe();
|
loader_w.load_unsafe();
|
||||||
@ -797,7 +883,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
mma_op.mma(Xs, Ws);
|
mma_op.mma(Xs, Ws);
|
||||||
} else {
|
} else {
|
||||||
for (int k=0; k<K; k += BK) {
|
for (int k = 0; k < K; k += BK) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
loader_x.load_safe(short2(BK, num_els));
|
loader_x.load_safe(short2(BK, num_els));
|
||||||
loader_w.load_unsafe();
|
loader_w.load_unsafe();
|
||||||
@ -809,8 +895,8 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if ((K % BK) != 0) {
|
if ((K % BK) != 0) {
|
||||||
const int k_blocks = K/BK;
|
const int k_blocks = K / BK;
|
||||||
for (int k=0; k<k_blocks; k++) {
|
for (int k = 0; k < k_blocks; k++) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
loader_x.load_unsafe();
|
loader_x.load_unsafe();
|
||||||
loader_w.load_unsafe();
|
loader_w.load_unsafe();
|
||||||
@ -826,7 +912,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
mma_op.mma(Xs, Ws);
|
mma_op.mma(Xs, Ws);
|
||||||
} else {
|
} else {
|
||||||
for (int k=0; k<K; k += BK) {
|
for (int k = 0; k < K; k += BK) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
loader_x.load_unsafe();
|
loader_x.load_unsafe();
|
||||||
loader_w.load_unsafe();
|
loader_w.load_unsafe();
|
||||||
@ -847,10 +933,10 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \
|
#define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \
|
||||||
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits "_fast")]] \
|
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits \
|
||||||
[[kernel]] void qmv_fast<itype, group_size, bits, packs_per_thread>( \
|
"_fast")]] [[kernel]] void \
|
||||||
|
qmv_fast<itype, group_size, bits, packs_per_thread>( \
|
||||||
const device uint32_t* w [[buffer(0)]], \
|
const device uint32_t* w [[buffer(0)]], \
|
||||||
const device itype* scales [[buffer(1)]], \
|
const device itype* scales [[buffer(1)]], \
|
||||||
const device itype* biases [[buffer(2)]], \
|
const device itype* biases [[buffer(2)]], \
|
||||||
@ -862,11 +948,13 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_qmv_fast_types(group_size, bits, packs_per_thread) \
|
#define instantiate_qmv_fast_types(group_size, bits, packs_per_thread) \
|
||||||
instantiate_qmv_fast(float32, float, group_size, bits, packs_per_thread) \
|
instantiate_qmv_fast(float32, float, group_size, bits, packs_per_thread) \
|
||||||
instantiate_qmv_fast(float16, half, group_size, bits, packs_per_thread) \
|
instantiate_qmv_fast(float16, half, group_size, bits, packs_per_thread) \
|
||||||
instantiate_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread)
|
instantiate_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_qmv_fast_types(128, 2, 1)
|
instantiate_qmv_fast_types(128, 2, 1)
|
||||||
instantiate_qmv_fast_types(128, 4, 2)
|
instantiate_qmv_fast_types(128, 4, 2)
|
||||||
instantiate_qmv_fast_types(128, 8, 2)
|
instantiate_qmv_fast_types(128, 8, 2)
|
||||||
@ -875,11 +963,12 @@ instantiate_qmv_fast_types( 64, 4, 2)
|
|||||||
instantiate_qmv_fast_types( 64, 8, 2)
|
instantiate_qmv_fast_types( 64, 8, 2)
|
||||||
instantiate_qmv_fast_types( 32, 2, 1)
|
instantiate_qmv_fast_types( 32, 2, 1)
|
||||||
instantiate_qmv_fast_types( 32, 4, 2)
|
instantiate_qmv_fast_types( 32, 4, 2)
|
||||||
instantiate_qmv_fast_types( 32, 8, 2)
|
instantiate_qmv_fast_types( 32, 8, 2) // clang-format on
|
||||||
|
|
||||||
#define instantiate_qmv(name, itype, group_size, bits) \
|
#define instantiate_qmv(name, itype, group_size, bits) \
|
||||||
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits)]] \
|
template [[host_name("qmv_" #name "_gs_" #group_size \
|
||||||
[[kernel]] void qmv<itype, group_size, bits>( \
|
"_b_" #bits)]] [[kernel]] void \
|
||||||
|
qmv<itype, group_size, bits>( \
|
||||||
const device uint32_t* w [[buffer(0)]], \
|
const device uint32_t* w [[buffer(0)]], \
|
||||||
const device itype* scales [[buffer(1)]], \
|
const device itype* scales [[buffer(1)]], \
|
||||||
const device itype* biases [[buffer(2)]], \
|
const device itype* biases [[buffer(2)]], \
|
||||||
@ -891,11 +980,13 @@ instantiate_qmv_fast_types( 32, 8, 2)
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_qmv_types(group_size, bits) \
|
#define instantiate_qmv_types(group_size, bits) \
|
||||||
instantiate_qmv(float32, float, group_size, bits) \
|
instantiate_qmv(float32, float, group_size, bits) \
|
||||||
instantiate_qmv(float16, half, group_size, bits) \
|
instantiate_qmv(float16, half, group_size, bits) \
|
||||||
instantiate_qmv(bfloat16, bfloat16_t, group_size, bits)
|
instantiate_qmv(bfloat16, bfloat16_t, group_size, bits) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_qmv_types(128, 2)
|
instantiate_qmv_types(128, 2)
|
||||||
instantiate_qmv_types(128, 4)
|
instantiate_qmv_types(128, 4)
|
||||||
instantiate_qmv_types(128, 8)
|
instantiate_qmv_types(128, 8)
|
||||||
@ -904,11 +995,12 @@ instantiate_qmv_types( 64, 4)
|
|||||||
instantiate_qmv_types( 64, 8)
|
instantiate_qmv_types( 64, 8)
|
||||||
instantiate_qmv_types( 32, 2)
|
instantiate_qmv_types( 32, 2)
|
||||||
instantiate_qmv_types( 32, 4)
|
instantiate_qmv_types( 32, 4)
|
||||||
instantiate_qmv_types( 32, 8)
|
instantiate_qmv_types( 32, 8) // clang-format on
|
||||||
|
|
||||||
#define instantiate_qvm(name, itype, group_size, bits) \
|
#define instantiate_qvm(name, itype, group_size, bits) \
|
||||||
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
|
template [[host_name("qvm_" #name "_gs_" #group_size \
|
||||||
[[kernel]] void qvm<itype, group_size, bits>( \
|
"_b_" #bits)]] [[kernel]] void \
|
||||||
|
qvm<itype, group_size, bits>( \
|
||||||
const device itype* x [[buffer(0)]], \
|
const device itype* x [[buffer(0)]], \
|
||||||
const device uint32_t* w [[buffer(1)]], \
|
const device uint32_t* w [[buffer(1)]], \
|
||||||
const device itype* scales [[buffer(2)]], \
|
const device itype* scales [[buffer(2)]], \
|
||||||
@ -920,11 +1012,13 @@ instantiate_qmv_types( 32, 8)
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_qvm_types(group_size, bits) \
|
#define instantiate_qvm_types(group_size, bits) \
|
||||||
instantiate_qvm(float32, float, group_size, bits) \
|
instantiate_qvm(float32, float, group_size, bits) \
|
||||||
instantiate_qvm(float16, half, group_size, bits) \
|
instantiate_qvm(float16, half, group_size, bits) \
|
||||||
instantiate_qvm(bfloat16, bfloat16_t, group_size, bits)
|
instantiate_qvm(bfloat16, bfloat16_t, group_size, bits) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_qvm_types(128, 2)
|
instantiate_qvm_types(128, 2)
|
||||||
instantiate_qvm_types(128, 4)
|
instantiate_qvm_types(128, 4)
|
||||||
instantiate_qvm_types(128, 8)
|
instantiate_qvm_types(128, 8)
|
||||||
@ -933,11 +1027,12 @@ instantiate_qvm_types( 64, 4)
|
|||||||
instantiate_qvm_types( 64, 8)
|
instantiate_qvm_types( 64, 8)
|
||||||
instantiate_qvm_types( 32, 2)
|
instantiate_qvm_types( 32, 2)
|
||||||
instantiate_qvm_types( 32, 4)
|
instantiate_qvm_types( 32, 4)
|
||||||
instantiate_qvm_types( 32, 8)
|
instantiate_qvm_types( 32, 8) // clang-format on
|
||||||
|
|
||||||
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
|
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
|
||||||
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
|
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits \
|
||||||
[[kernel]] void qmm_t<itype, 32, 32, 32, group_size, bits, aligned_N>( \
|
"_alN_" #aligned_N)]] [[kernel]] void \
|
||||||
|
qmm_t<itype, 32, 32, 32, group_size, bits, aligned_N>( \
|
||||||
const device itype* x [[buffer(0)]], \
|
const device itype* x [[buffer(0)]], \
|
||||||
const device uint32_t* w [[buffer(1)]], \
|
const device uint32_t* w [[buffer(1)]], \
|
||||||
const device itype* scales [[buffer(2)]], \
|
const device itype* scales [[buffer(2)]], \
|
||||||
@ -951,14 +1046,16 @@ instantiate_qvm_types( 32, 8)
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_qmm_t_types(group_size, bits) \
|
#define instantiate_qmm_t_types(group_size, bits) \
|
||||||
instantiate_qmm_t(float32, float, group_size, bits, false) \
|
instantiate_qmm_t(float32, float, group_size, bits, false) \
|
||||||
instantiate_qmm_t(float16, half, group_size, bits, false) \
|
instantiate_qmm_t(float16, half, group_size, bits, false) \
|
||||||
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \
|
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \
|
||||||
instantiate_qmm_t(float32, float, group_size, bits, true) \
|
instantiate_qmm_t(float32, float, group_size, bits, true) \
|
||||||
instantiate_qmm_t(float16, half, group_size, bits, true) \
|
instantiate_qmm_t(float16, half, group_size, bits, true) \
|
||||||
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true)
|
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_qmm_t_types(128, 2)
|
instantiate_qmm_t_types(128, 2)
|
||||||
instantiate_qmm_t_types(128, 4)
|
instantiate_qmm_t_types(128, 4)
|
||||||
instantiate_qmm_t_types(128, 8)
|
instantiate_qmm_t_types(128, 8)
|
||||||
@ -967,11 +1064,12 @@ instantiate_qmm_t_types( 64, 4)
|
|||||||
instantiate_qmm_t_types( 64, 8)
|
instantiate_qmm_t_types( 64, 8)
|
||||||
instantiate_qmm_t_types( 32, 2)
|
instantiate_qmm_t_types( 32, 2)
|
||||||
instantiate_qmm_t_types( 32, 4)
|
instantiate_qmm_t_types( 32, 4)
|
||||||
instantiate_qmm_t_types( 32, 8)
|
instantiate_qmm_t_types( 32, 8) // clang-format on
|
||||||
|
|
||||||
#define instantiate_qmm_n(name, itype, group_size, bits) \
|
#define instantiate_qmm_n(name, itype, group_size, bits) \
|
||||||
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
|
template [[host_name("qmm_n_" #name "_gs_" #group_size \
|
||||||
[[kernel]] void qmm_n<itype, 32, 32, 32, group_size, bits>( \
|
"_b_" #bits)]] [[kernel]] void \
|
||||||
|
qmm_n<itype, 32, 32, 32, group_size, bits>( \
|
||||||
const device itype* x [[buffer(0)]], \
|
const device itype* x [[buffer(0)]], \
|
||||||
const device uint32_t* w [[buffer(1)]], \
|
const device uint32_t* w [[buffer(1)]], \
|
||||||
const device itype* scales [[buffer(2)]], \
|
const device itype* scales [[buffer(2)]], \
|
||||||
@ -985,11 +1083,13 @@ instantiate_qmm_t_types( 32, 8)
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_qmm_n_types(group_size, bits) \
|
#define instantiate_qmm_n_types(group_size, bits) \
|
||||||
instantiate_qmm_n(float32, float, group_size, bits) \
|
instantiate_qmm_n(float32, float, group_size, bits) \
|
||||||
instantiate_qmm_n(float16, half, group_size, bits) \
|
instantiate_qmm_n(float16, half, group_size, bits) \
|
||||||
instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits)
|
instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_qmm_n_types(128, 2)
|
instantiate_qmm_n_types(128, 2)
|
||||||
instantiate_qmm_n_types(128, 4)
|
instantiate_qmm_n_types(128, 4)
|
||||||
instantiate_qmm_n_types(128, 8)
|
instantiate_qmm_n_types(128, 8)
|
||||||
@ -998,4 +1098,4 @@ instantiate_qmm_n_types( 64, 4)
|
|||||||
instantiate_qmm_n_types( 64, 8)
|
instantiate_qmm_n_types( 64, 8)
|
||||||
instantiate_qmm_n_types( 32, 2)
|
instantiate_qmm_n_types( 32, 2)
|
||||||
instantiate_qmm_n_types( 32, 4)
|
instantiate_qmm_n_types( 32, 4)
|
||||||
instantiate_qmm_n_types( 32, 8)
|
instantiate_qmm_n_types( 32, 8) // clang-format on
|
||||||
|
@ -4,8 +4,7 @@
|
|||||||
|
|
||||||
static constexpr constant uint32_t rotations[2][4] = {
|
static constexpr constant uint32_t rotations[2][4] = {
|
||||||
{13, 15, 26, 6},
|
{13, 15, 26, 6},
|
||||||
{17, 29, 16, 24}
|
{17, 29, 16, 24}};
|
||||||
};
|
|
||||||
|
|
||||||
union rbits {
|
union rbits {
|
||||||
uint2 val;
|
uint2 val;
|
||||||
@ -13,7 +12,6 @@ union rbits {
|
|||||||
};
|
};
|
||||||
|
|
||||||
rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||||
|
|
||||||
uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
|
uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
|
||||||
|
|
||||||
rbits v;
|
rbits v;
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
|
||||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||||
|
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
@ -60,14 +60,13 @@ METAL_FUNC U per_thread_all_reduce(
|
|||||||
// All reduce kernel
|
// All reduce kernel
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
// NB: This kernel assumes threads_per_threadgroup is at most
|
// NB: This kernel assumes threads_per_threadgroup is at most
|
||||||
// 1024. This way with a simd_size of 32, we are guaranteed to
|
// 1024. This way with a simd_size of 32, we are guaranteed to
|
||||||
// complete the reduction in two steps of simd-level reductions.
|
// complete the reduction in two steps of simd-level reductions.
|
||||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||||
[[kernel]] void all_reduce(
|
[[kernel]] void all_reduce(
|
||||||
const device T *in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device mlx_atomic<U> *out [[buffer(1)]],
|
device mlx_atomic<U>* out [[buffer(1)]],
|
||||||
const device size_t& in_size [[buffer(2)]],
|
const device size_t& in_size [[buffer(2)]],
|
||||||
uint gid [[thread_position_in_grid]],
|
uint gid [[thread_position_in_grid]],
|
||||||
uint lid [[thread_position_in_threadgroup]],
|
uint lid [[thread_position_in_threadgroup]],
|
||||||
@ -75,11 +74,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
threadgroup U local_vals[simd_size];
|
threadgroup U local_vals[simd_size];
|
||||||
|
|
||||||
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
U total_val =
|
||||||
|
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||||
|
|
||||||
// Reduction within simd group
|
// Reduction within simd group
|
||||||
total_val = op.simd_reduce(total_val);
|
total_val = op.simd_reduce(total_val);
|
||||||
@ -98,10 +97,10 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||||
[[kernel]] void all_reduce_no_atomics(
|
[[kernel]] void all_reduce_no_atomics(
|
||||||
const device T *in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device U *out [[buffer(1)]],
|
device U* out [[buffer(1)]],
|
||||||
const device size_t& in_size [[buffer(2)]],
|
const device size_t& in_size [[buffer(2)]],
|
||||||
uint gid [[thread_position_in_grid]],
|
uint gid [[thread_position_in_grid]],
|
||||||
uint lid [[thread_position_in_threadgroup]],
|
uint lid [[thread_position_in_threadgroup]],
|
||||||
@ -110,14 +109,16 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint thread_group_id [[threadgroup_position_in_grid]]) {
|
uint thread_group_id [[threadgroup_position_in_grid]]) {
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
threadgroup U local_vals[simd_size];
|
threadgroup U local_vals[simd_size];
|
||||||
|
|
||||||
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
U total_val =
|
||||||
|
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||||
|
|
||||||
// Reduction within simd group (simd_add isn't supported for uint64/int64 types)
|
// Reduction within simd group (simd_add isn't supported for uint64/int64
|
||||||
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
|
// types)
|
||||||
|
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
|
||||||
|
lane_offset /= 2) {
|
||||||
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||||
}
|
}
|
||||||
// Write simd group reduction results to local memory
|
// Write simd group reduction results to local memory
|
||||||
@ -128,7 +129,8 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
|
|
||||||
// Reduction of simdgroup reduction results within threadgroup.
|
// Reduction of simdgroup reduction results within threadgroup.
|
||||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||||
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
|
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
|
||||||
|
lane_offset /= 2) {
|
||||||
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,10 +141,10 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||||
template [[host_name("all_reduce_" #name)]] \
|
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void all_reduce<itype, otype, op>( \
|
all_reduce<itype, otype, op>( \
|
||||||
const device itype *in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||||
const device size_t& in_size [[buffer(2)]], \
|
const device size_t& in_size [[buffer(2)]], \
|
||||||
uint gid [[thread_position_in_grid]], \
|
uint gid [[thread_position_in_grid]], \
|
||||||
uint lid [[thread_position_in_threadgroup]], \
|
uint lid [[thread_position_in_threadgroup]], \
|
||||||
@ -152,10 +154,10 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||||
template [[host_name("all_reduce_no_atomics_" #name)]] \
|
template [[host_name("all_reduce_no_atomics_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \
|
all_reduce_no_atomics<itype, otype, op>( \
|
||||||
const device itype *in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device otype *out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const device size_t& in_size [[buffer(2)]], \
|
const device size_t& in_size [[buffer(2)]], \
|
||||||
uint gid [[thread_position_in_grid]], \
|
uint gid [[thread_position_in_grid]], \
|
||||||
uint lid [[thread_position_in_threadgroup]], \
|
uint lid [[thread_position_in_threadgroup]], \
|
||||||
@ -170,11 +172,12 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
|
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
|
||||||
instantiate_all_reduce(name ##tname, type, type, op<type>)
|
instantiate_all_reduce(name##tname, type, type, op<type>)
|
||||||
|
|
||||||
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
|
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
|
||||||
instantiate_all_reduce_no_atomics(name ##tname, type, type, op<type>)
|
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
||||||
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
|
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||||
|
|
||||||
@ -182,4 +185,4 @@ instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And)
|
|||||||
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or)
|
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or)
|
||||||
|
|
||||||
// special case bool with larger output type
|
// special case bool with larger output type
|
||||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
|
||||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||||
|
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
@ -12,8 +12,8 @@ using namespace metal;
|
|||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
[[kernel]] void col_reduce_small(
|
[[kernel]] void col_reduce_small(
|
||||||
const device T *in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device U *out [[buffer(1)]],
|
device U* out [[buffer(1)]],
|
||||||
const constant size_t& reduction_size [[buffer(2)]],
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
const constant size_t& reduction_stride [[buffer(3)]],
|
const constant size_t& reduction_stride [[buffer(3)]],
|
||||||
const constant size_t& out_size [[buffer(4)]],
|
const constant size_t& out_size [[buffer(4)]],
|
||||||
@ -25,7 +25,6 @@ template <typename T, typename U, typename Op>
|
|||||||
const constant size_t* non_col_strides [[buffer(10)]],
|
const constant size_t* non_col_strides [[buffer(10)]],
|
||||||
const constant int& non_col_ndim [[buffer(11)]],
|
const constant int& non_col_ndim [[buffer(11)]],
|
||||||
uint tid [[thread_position_in_grid]]) {
|
uint tid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
// Appease the compiler
|
// Appease the compiler
|
||||||
(void)out_size;
|
(void)out_size;
|
||||||
|
|
||||||
@ -40,10 +39,11 @@ template <typename T, typename U, typename Op>
|
|||||||
strides + non_col_ndim,
|
strides + non_col_ndim,
|
||||||
ndim - non_col_ndim);
|
ndim - non_col_ndim);
|
||||||
|
|
||||||
for(uint i = 0; i < non_col_reductions; i++) {
|
for (uint i = 0; i < non_col_reductions; i++) {
|
||||||
size_t in_idx = elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
|
size_t in_idx =
|
||||||
|
elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
|
||||||
|
|
||||||
for(uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
|
for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
|
||||||
U val = static_cast<U>(in[in_idx]);
|
U val = static_cast<U>(in[in_idx]);
|
||||||
total_val = op(total_val, val);
|
total_val = op(total_val, val);
|
||||||
}
|
}
|
||||||
@ -53,10 +53,10 @@ template <typename T, typename U, typename Op>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_col_reduce_small(name, itype, otype, op) \
|
#define instantiate_col_reduce_small(name, itype, otype, op) \
|
||||||
template [[host_name("col_reduce_small_" #name)]] \
|
template [[host_name("col_reduce_small_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void col_reduce_small<itype, otype, op>( \
|
col_reduce_small<itype, otype, op>( \
|
||||||
const device itype *in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device otype *out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||||
const constant size_t& out_size [[buffer(4)]], \
|
const constant size_t& out_size [[buffer(4)]], \
|
||||||
@ -112,28 +112,23 @@ METAL_FUNC U _contiguous_strided_reduce(
|
|||||||
|
|
||||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||||
[[kernel]] void col_reduce_general(
|
[[kernel]] void col_reduce_general(
|
||||||
const device T *in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device mlx_atomic<U> *out [[buffer(1)]],
|
device mlx_atomic<U>* out [[buffer(1)]],
|
||||||
const constant size_t& reduction_size [[buffer(2)]],
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
const constant size_t& reduction_stride [[buffer(3)]],
|
const constant size_t& reduction_stride [[buffer(3)]],
|
||||||
const constant size_t& out_size [[buffer(4)]],
|
const constant size_t& out_size [[buffer(4)]],
|
||||||
const constant int* shape [[buffer(5)]],
|
const constant int* shape [[buffer(5)]],
|
||||||
const constant size_t* strides [[buffer(6)]],
|
const constant size_t* strides [[buffer(6)]],
|
||||||
const constant int& ndim [[buffer(7)]],
|
const constant int& ndim [[buffer(7)]],
|
||||||
threadgroup U *local_data [[threadgroup(0)]],
|
threadgroup U* local_data [[threadgroup(0)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint3 lsize [[threads_per_threadgroup]]) {
|
uint3 lsize [[threads_per_threadgroup]]) {
|
||||||
auto out_idx = tid.x * lsize.x + lid.x;
|
auto out_idx = tid.x * lsize.x + lid.x;
|
||||||
auto in_idx = elem_to_loc(
|
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
|
||||||
out_idx + tid.z * out_size,
|
|
||||||
shape,
|
|
||||||
strides,
|
|
||||||
ndim
|
|
||||||
);
|
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
if(out_idx < out_size) {
|
if (out_idx < out_size) {
|
||||||
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||||
in,
|
in,
|
||||||
local_data,
|
local_data,
|
||||||
@ -144,7 +139,8 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
lid.xy,
|
lid.xy,
|
||||||
lsize.xy);
|
lsize.xy);
|
||||||
|
|
||||||
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
|
// Write out reduction results generated by threadgroups working on specific
|
||||||
|
// output element, contiguously.
|
||||||
if (lid.y == 0) {
|
if (lid.y == 0) {
|
||||||
op.atomic_update(out, val, out_idx);
|
op.atomic_update(out, val, out_idx);
|
||||||
}
|
}
|
||||||
@ -153,29 +149,24 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
|
|
||||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||||
[[kernel]] void col_reduce_general_no_atomics(
|
[[kernel]] void col_reduce_general_no_atomics(
|
||||||
const device T *in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device U *out [[buffer(1)]],
|
device U* out [[buffer(1)]],
|
||||||
const constant size_t& reduction_size [[buffer(2)]],
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
const constant size_t& reduction_stride [[buffer(3)]],
|
const constant size_t& reduction_stride [[buffer(3)]],
|
||||||
const constant size_t& out_size [[buffer(4)]],
|
const constant size_t& out_size [[buffer(4)]],
|
||||||
const constant int* shape [[buffer(5)]],
|
const constant int* shape [[buffer(5)]],
|
||||||
const constant size_t* strides [[buffer(6)]],
|
const constant size_t* strides [[buffer(6)]],
|
||||||
const constant int& ndim [[buffer(7)]],
|
const constant int& ndim [[buffer(7)]],
|
||||||
threadgroup U *local_data [[threadgroup(0)]],
|
threadgroup U* local_data [[threadgroup(0)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint3 gid [[thread_position_in_grid]],
|
uint3 gid [[thread_position_in_grid]],
|
||||||
uint3 lsize [[threads_per_threadgroup]],
|
uint3 lsize [[threads_per_threadgroup]],
|
||||||
uint3 gsize [[threads_per_grid]]) {
|
uint3 gsize [[threads_per_grid]]) {
|
||||||
auto out_idx = tid.x * lsize.x + lid.x;
|
auto out_idx = tid.x * lsize.x + lid.x;
|
||||||
auto in_idx = elem_to_loc(
|
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
|
||||||
out_idx + tid.z * out_size,
|
|
||||||
shape,
|
|
||||||
strides,
|
|
||||||
ndim
|
|
||||||
);
|
|
||||||
|
|
||||||
if(out_idx < out_size) {
|
if (out_idx < out_size) {
|
||||||
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||||
in,
|
in,
|
||||||
local_data,
|
local_data,
|
||||||
@ -186,7 +177,8 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
lid.xy,
|
lid.xy,
|
||||||
lsize.xy);
|
lsize.xy);
|
||||||
|
|
||||||
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
|
// Write out reduction results generated by threadgroups working on specific
|
||||||
|
// output element, contiguously.
|
||||||
if (lid.y == 0) {
|
if (lid.y == 0) {
|
||||||
uint tgsize_y = ceildiv(gsize.y, lsize.y);
|
uint tgsize_y = ceildiv(gsize.y, lsize.y);
|
||||||
uint tgsize_z = ceildiv(gsize.z, lsize.z);
|
uint tgsize_z = ceildiv(gsize.z, lsize.z);
|
||||||
@ -196,33 +188,34 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||||
template [[host_name("col_reduce_general_" #name)]] \
|
template [[host_name("col_reduce_general_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void col_reduce_general<itype, otype, op>( \
|
col_reduce_general<itype, otype, op>( \
|
||||||
const device itype *in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||||
const constant size_t& out_size [[buffer(4)]], \
|
const constant size_t& out_size [[buffer(4)]], \
|
||||||
const constant int* shape [[buffer(5)]], \
|
const constant int* shape [[buffer(5)]], \
|
||||||
const constant size_t* strides [[buffer(6)]], \
|
const constant size_t* strides [[buffer(6)]], \
|
||||||
const constant int& ndim [[buffer(7)]], \
|
const constant int& ndim [[buffer(7)]], \
|
||||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]], \
|
uint3 lid [[thread_position_in_threadgroup]], \
|
||||||
uint3 lsize [[threads_per_threadgroup]]);
|
uint3 lsize [[threads_per_threadgroup]]);
|
||||||
|
|
||||||
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
||||||
template [[host_name("col_reduce_general_no_atomics_" #name)]] \
|
template \
|
||||||
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \
|
[[host_name("col_reduce_general_no_atomics_" #name)]] [[kernel]] void \
|
||||||
const device itype *in [[buffer(0)]], \
|
col_reduce_general_no_atomics<itype, otype, op>( \
|
||||||
device otype *out [[buffer(1)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
|
device otype* out [[buffer(1)]], \
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||||
const constant size_t& out_size [[buffer(4)]], \
|
const constant size_t& out_size [[buffer(4)]], \
|
||||||
const constant int* shape [[buffer(5)]], \
|
const constant int* shape [[buffer(5)]], \
|
||||||
const constant size_t* strides [[buffer(6)]], \
|
const constant size_t* strides [[buffer(6)]], \
|
||||||
const constant int& ndim [[buffer(7)]], \
|
const constant int& ndim [[buffer(7)]], \
|
||||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]], \
|
uint3 lid [[thread_position_in_threadgroup]], \
|
||||||
uint3 gid [[thread_position_in_grid]], \
|
uint3 gid [[thread_position_in_grid]], \
|
||||||
@ -233,14 +226,17 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
// Instantiations
|
// Instantiations
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||||
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
|
instantiate_col_reduce_general(name ##tname, type, type, op<type>) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
|
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
|
||||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||||
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
|
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||||
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
|
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||||
|
|
||||||
@ -250,4 +246,4 @@ instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
|
|||||||
|
|
||||||
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||||
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
|
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
|
||||||
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or)
|
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) // clang-format on
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
|
||||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||||
|
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
@ -12,22 +12,21 @@ using namespace metal;
|
|||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename Op>
|
||||||
[[kernel]] void init_reduce(
|
[[kernel]] void init_reduce(
|
||||||
device T *out [[buffer(0)]],
|
device T* out [[buffer(0)]],
|
||||||
uint tid [[thread_position_in_grid]]) {
|
uint tid [[thread_position_in_grid]]) {
|
||||||
out[tid] = Op::init;
|
out[tid] = Op::init;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_init_reduce(name, otype, op) \
|
#define instantiate_init_reduce(name, otype, op) \
|
||||||
template [[host_name("i" #name)]] \
|
template [[host_name("i" #name)]] [[kernel]] void init_reduce<otype, op>( \
|
||||||
[[kernel]] void init_reduce<otype, op>( \
|
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
|
||||||
device otype *out [[buffer(1)]], \
|
|
||||||
uint tid [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
||||||
instantiate_init_reduce(name ##tname, type, op<type>)
|
instantiate_init_reduce(name##tname, type, op<type>)
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
|
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
|
||||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
|
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
|
||||||
|
|
||||||
instantiate_init_reduce(andbool_, bool, And)
|
instantiate_init_reduce(andbool_, bool, And)
|
||||||
instantiate_init_reduce(orbool_, bool, Or)
|
instantiate_init_reduce(orbool_, bool, Or) // clang-format on
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
|
||||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||||
|
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
@ -13,8 +13,8 @@ using namespace metal;
|
|||||||
// Each thread reduces for one output
|
// Each thread reduces for one output
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
[[kernel]] void row_reduce_general_small(
|
[[kernel]] void row_reduce_general_small(
|
||||||
const device T *in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device U *out [[buffer(1)]],
|
device U* out [[buffer(1)]],
|
||||||
const constant size_t& reduction_size [[buffer(2)]],
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
const constant size_t& out_size [[buffer(3)]],
|
const constant size_t& out_size [[buffer(3)]],
|
||||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||||
@ -22,22 +22,21 @@ template <typename T, typename U, typename Op>
|
|||||||
const constant size_t* strides [[buffer(6)]],
|
const constant size_t* strides [[buffer(6)]],
|
||||||
const constant int& ndim [[buffer(7)]],
|
const constant int& ndim [[buffer(7)]],
|
||||||
uint lid [[thread_position_in_grid]]) {
|
uint lid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
uint out_idx = lid;
|
uint out_idx = lid;
|
||||||
|
|
||||||
if(out_idx >= out_size) {
|
if (out_idx >= out_size) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
U total_val = Op::init;
|
U total_val = Op::init;
|
||||||
|
|
||||||
for(short r = 0; r < short(non_row_reductions); r++) {
|
for (short r = 0; r < short(non_row_reductions); r++) {
|
||||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||||
const device T * in_row = in + in_idx;
|
const device T* in_row = in + in_idx;
|
||||||
|
|
||||||
for(short i = 0; i < short(reduction_size); i++) {
|
for (short i = 0; i < short(reduction_size); i++) {
|
||||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -48,8 +47,8 @@ template <typename T, typename U, typename Op>
|
|||||||
// Each simdgroup reduces for one output
|
// Each simdgroup reduces for one output
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
[[kernel]] void row_reduce_general_med(
|
[[kernel]] void row_reduce_general_med(
|
||||||
const device T *in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device U *out [[buffer(1)]],
|
device U* out [[buffer(1)]],
|
||||||
const constant size_t& reduction_size [[buffer(2)]],
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
const constant size_t& out_size [[buffer(3)]],
|
const constant size_t& out_size [[buffer(3)]],
|
||||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||||
@ -60,45 +59,42 @@ template <typename T, typename U, typename Op>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
|
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
uint out_idx = simd_per_group * tid + simd_group_id;
|
uint out_idx = simd_per_group * tid + simd_group_id;
|
||||||
|
|
||||||
if(out_idx >= out_size) {
|
if (out_idx >= out_size) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
U total_val = Op::init;
|
U total_val = Op::init;
|
||||||
|
|
||||||
if(short(non_row_reductions) == 1) {
|
if (short(non_row_reductions) == 1) {
|
||||||
uint in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
uint in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||||
const device T * in_row = in + in_idx;
|
const device T* in_row = in + in_idx;
|
||||||
|
|
||||||
for(short i = simd_lane_id; i < short(reduction_size); i += 32) {
|
for (short i = simd_lane_id; i < short(reduction_size); i += 32) {
|
||||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
else if (short(non_row_reductions) >= 32) {
|
else if (short(non_row_reductions) >= 32) {
|
||||||
|
for (short r = simd_lane_id; r < short(non_row_reductions); r += 32) {
|
||||||
for(short r = simd_lane_id; r < short(non_row_reductions); r+=32) {
|
|
||||||
|
|
||||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||||
const device T * in_row = in + in_idx;
|
const device T* in_row = in + in_idx;
|
||||||
|
|
||||||
for(short i = 0; i < short(reduction_size); i++) {
|
for (short i = 0; i < short(reduction_size); i++) {
|
||||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
else {
|
else {
|
||||||
|
const short n_reductions =
|
||||||
const short n_reductions = short(reduction_size) * short(non_row_reductions);
|
short(reduction_size) * short(non_row_reductions);
|
||||||
const short reductions_per_thread = (n_reductions + simd_size - 1) / simd_size;
|
const short reductions_per_thread =
|
||||||
|
(n_reductions + simd_size - 1) / simd_size;
|
||||||
|
|
||||||
const short r_st = simd_lane_id / reductions_per_thread;
|
const short r_st = simd_lane_id / reductions_per_thread;
|
||||||
const short r_ed = short(non_row_reductions);
|
const short r_ed = short(non_row_reductions);
|
||||||
@ -108,34 +104,30 @@ template <typename T, typename U, typename Op>
|
|||||||
const short i_ed = short(reduction_size);
|
const short i_ed = short(reduction_size);
|
||||||
const short i_jump = reductions_per_thread;
|
const short i_jump = reductions_per_thread;
|
||||||
|
|
||||||
if(r_st < r_jump) {
|
if (r_st < r_jump) {
|
||||||
for(short r = r_st; r < r_ed; r += r_jump) {
|
for (short r = r_st; r < r_ed; r += r_jump) {
|
||||||
|
|
||||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||||
const device T * in_row = in + in_idx;
|
const device T* in_row = in + in_idx;
|
||||||
|
|
||||||
for(short i = i_st; i < i_ed; i += i_jump) {
|
for (short i = i_st; i < i_ed; i += i_jump) {
|
||||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
total_val = op.simd_reduce(total_val);
|
total_val = op.simd_reduce(total_val);
|
||||||
|
|
||||||
if(simd_lane_id == 0) {
|
if (simd_lane_id == 0) {
|
||||||
out[out_idx] = total_val;
|
out[out_idx] = total_val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_row_reduce_small(name, itype, otype, op) \
|
#define instantiate_row_reduce_small(name, itype, otype, op) \
|
||||||
template[[host_name("row_reduce_general_small_" #name)]] \
|
template [[host_name("row_reduce_general_small_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void row_reduce_general_small<itype, otype, op>( \
|
row_reduce_general_small<itype, otype, op>( \
|
||||||
const device itype *in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device otype *out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
const constant size_t& out_size [[buffer(3)]], \
|
const constant size_t& out_size [[buffer(3)]], \
|
||||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||||
@ -143,10 +135,10 @@ template <typename T, typename U, typename Op>
|
|||||||
const constant size_t* strides [[buffer(6)]], \
|
const constant size_t* strides [[buffer(6)]], \
|
||||||
const constant int& ndim [[buffer(7)]], \
|
const constant int& ndim [[buffer(7)]], \
|
||||||
uint lid [[thread_position_in_grid]]); \
|
uint lid [[thread_position_in_grid]]); \
|
||||||
template[[host_name("row_reduce_general_med_" #name)]] \
|
template [[host_name("row_reduce_general_med_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void row_reduce_general_med<itype, otype, op>( \
|
row_reduce_general_med<itype, otype, op>( \
|
||||||
const device itype *in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device otype *out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
const constant size_t& out_size [[buffer(3)]], \
|
const constant size_t& out_size [[buffer(3)]], \
|
||||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||||
@ -217,10 +209,10 @@ METAL_FUNC U per_thread_row_reduce(
|
|||||||
return total_val;
|
return total_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||||
[[kernel]] void row_reduce_general(
|
[[kernel]] void row_reduce_general(
|
||||||
const device T *in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device mlx_atomic<U> *out [[buffer(1)]],
|
device mlx_atomic<U>* out [[buffer(1)]],
|
||||||
const constant size_t& reduction_size [[buffer(2)]],
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
const constant size_t& out_size [[buffer(3)]],
|
const constant size_t& out_size [[buffer(3)]],
|
||||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||||
@ -233,13 +225,21 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
(void)non_row_reductions;
|
(void)non_row_reductions;
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
threadgroup U local_vals[simd_size];
|
threadgroup U local_vals[simd_size];
|
||||||
|
|
||||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
|
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
|
||||||
|
in,
|
||||||
|
reduction_size,
|
||||||
|
out_size,
|
||||||
|
shape,
|
||||||
|
strides,
|
||||||
|
ndim,
|
||||||
|
lsize.x,
|
||||||
|
lid.x,
|
||||||
|
tid.xy);
|
||||||
|
|
||||||
total_val = op.simd_reduce(total_val);
|
total_val = op.simd_reduce(total_val);
|
||||||
|
|
||||||
@ -251,7 +251,7 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
|
|
||||||
// Reduction within thread group
|
// Reduction within thread group
|
||||||
// Only needed if multiple simd groups
|
// Only needed if multiple simd groups
|
||||||
if(reduction_size > simd_size) {
|
if (reduction_size > simd_size) {
|
||||||
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||||
total_val = op.simd_reduce(total_val);
|
total_val = op.simd_reduce(total_val);
|
||||||
}
|
}
|
||||||
@ -261,10 +261,10 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||||
[[kernel]] void row_reduce_general_no_atomics(
|
[[kernel]] void row_reduce_general_no_atomics(
|
||||||
const device T *in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device U *out [[buffer(1)]],
|
device U* out [[buffer(1)]],
|
||||||
const constant size_t& reduction_size [[buffer(2)]],
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
const constant size_t& out_size [[buffer(3)]],
|
const constant size_t& out_size [[buffer(3)]],
|
||||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||||
@ -278,16 +278,24 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
(void)non_row_reductions;
|
(void)non_row_reductions;
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
threadgroup U local_vals[simd_size];
|
threadgroup U local_vals[simd_size];
|
||||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
|
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
|
||||||
|
in,
|
||||||
|
reduction_size,
|
||||||
|
out_size,
|
||||||
|
shape,
|
||||||
|
strides,
|
||||||
|
ndim,
|
||||||
|
lsize.x,
|
||||||
|
lid.x,
|
||||||
|
tid.xy);
|
||||||
|
|
||||||
// Reduction within simd group - simd_add isn't supported for int64 types
|
// Reduction within simd group - simd_add isn't supported for int64 types
|
||||||
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
|
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
|
||||||
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -299,9 +307,9 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
|
|
||||||
// Reduction within thread group
|
// Reduction within thread group
|
||||||
// Only needed if thread group has multiple simd groups
|
// Only needed if thread group has multiple simd groups
|
||||||
if(ceildiv(reduction_size, N_READS) > simd_size) {
|
if (ceildiv(reduction_size, N_READS) > simd_size) {
|
||||||
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||||
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
|
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
|
||||||
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -312,11 +320,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
instantiate_row_reduce_small(name, itype, otype, op) template \
|
||||||
template [[host_name("row_reduce_general_" #name)]] \
|
[[host_name("row_reduce_general_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void row_reduce_general<itype, otype, op>( \
|
row_reduce_general<itype, otype, op>( \
|
||||||
const device itype *in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
const constant size_t& out_size [[buffer(3)]], \
|
const constant size_t& out_size [[buffer(3)]], \
|
||||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||||
@ -331,11 +339,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
||||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
instantiate_row_reduce_small(name, itype, otype, op) template \
|
||||||
template [[host_name("row_reduce_general_no_atomics_" #name)]] \
|
[[host_name("row_reduce_general_no_atomics_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \
|
row_reduce_general_no_atomics<itype, otype, op>( \
|
||||||
const device itype *in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device otype *out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
const constant size_t& out_size [[buffer(3)]], \
|
const constant size_t& out_size [[buffer(3)]], \
|
||||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||||
@ -350,22 +358,21 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Instantiations
|
// Instantiations
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
|
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
|
||||||
instantiate_row_reduce_general(name ##tname, type, type, op<type>)
|
instantiate_row_reduce_general(name##tname, type, type, op<type>)
|
||||||
|
|
||||||
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
|
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
|
||||||
instantiate_row_reduce_general_no_atomics(name ##tname, type, type, op<type>)
|
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
|
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
|
||||||
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
|
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||||
|
|
||||||
|
|
||||||
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And)
|
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And)
|
||||||
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or)
|
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or)
|
||||||
|
|
||||||
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on
|
@ -237,13 +237,17 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
gw += gid * axis_size + lid * N_READS;
|
gw += gid * axis_size + lid * N_READS;
|
||||||
if (lid * N_READS + N_READS <= axis_size) {
|
if (lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
|
gx[i] = static_cast<T>(
|
||||||
|
thread_g[i] * thread_w[i] * normalizer -
|
||||||
|
thread_x[i] * meangwx * normalizer3);
|
||||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
if ((lid * N_READS + i) < axis_size) {
|
if ((lid * N_READS + i) < axis_size) {
|
||||||
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
|
gx[i] = static_cast<T>(
|
||||||
|
thread_g[i] * thread_w[i] * normalizer -
|
||||||
|
thread_x[i] * meangwx * normalizer3);
|
||||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -342,7 +346,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float wi = w[w_stride * (i + r)];
|
float wi = w[w_stride * (i + r)];
|
||||||
float gi = g[i + r];
|
float gi = g[i + r];
|
||||||
|
|
||||||
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
gx[i + r] =
|
||||||
|
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -352,7 +357,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float wi = w[w_stride * (i + r)];
|
float wi = w[w_stride * (i + r)];
|
||||||
float gi = g[i + r];
|
float gi = g[i + r];
|
||||||
|
|
||||||
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
gx[i + r] =
|
||||||
|
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -431,5 +437,4 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
|
|
||||||
instantiate_rms(float32, float)
|
instantiate_rms(float32, float)
|
||||||
instantiate_rms(float16, half)
|
instantiate_rms(float16, half)
|
||||||
instantiate_rms(bfloat16, bfloat16_t)
|
instantiate_rms(bfloat16, bfloat16_t) // clang-format on
|
||||||
// clang-format on
|
|
||||||
|
@ -7,8 +7,8 @@
|
|||||||
|
|
||||||
template <typename T, bool traditional, bool forward>
|
template <typename T, bool traditional, bool forward>
|
||||||
[[kernel]] void rope(
|
[[kernel]] void rope(
|
||||||
const device T *in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device T * out [[buffer(1)]],
|
device T* out [[buffer(1)]],
|
||||||
constant const size_t strides[3],
|
constant const size_t strides[3],
|
||||||
constant const size_t out_strides[3],
|
constant const size_t out_strides[3],
|
||||||
constant const int& offset,
|
constant const int& offset,
|
||||||
@ -20,12 +20,15 @@ template <typename T, bool traditional, bool forward>
|
|||||||
uint in_index_1, in_index_2;
|
uint in_index_1, in_index_2;
|
||||||
uint out_index_1, out_index_2;
|
uint out_index_1, out_index_2;
|
||||||
if (traditional) {
|
if (traditional) {
|
||||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
|
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||||
|
pos.z * out_strides[0];
|
||||||
out_index_2 = out_index_1 + 1;
|
out_index_2 = out_index_1 + 1;
|
||||||
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
in_index_1 =
|
||||||
|
2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||||
in_index_2 = in_index_1 + strides[2];
|
in_index_2 = in_index_1 + strides[2];
|
||||||
} else {
|
} else {
|
||||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
|
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||||
|
pos.z * out_strides[0];
|
||||||
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
||||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||||
@ -57,8 +60,8 @@ template <typename T, bool traditional, bool forward>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_rope(name, type, traditional, forward) \
|
#define instantiate_rope(name, type, traditional, forward) \
|
||||||
template [[host_name("rope_" #name)]] \
|
template [[host_name("rope_" #name)]] [[kernel]] void \
|
||||||
[[kernel]] void rope<type, traditional, forward>( \
|
rope<type, traditional, forward>( \
|
||||||
const device type* in [[buffer(0)]], \
|
const device type* in [[buffer(0)]], \
|
||||||
device type* out [[buffer(1)]], \
|
device type* out [[buffer(1)]], \
|
||||||
constant const size_t strides[3], \
|
constant const size_t strides[3], \
|
||||||
@ -69,6 +72,7 @@ template <typename T, bool traditional, bool forward>
|
|||||||
uint3 pos [[thread_position_in_grid]], \
|
uint3 pos [[thread_position_in_grid]], \
|
||||||
uint3 grid [[threads_per_grid]]);
|
uint3 grid [[threads_per_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_rope(traditional_float16, half, true, true)
|
instantiate_rope(traditional_float16, half, true, true)
|
||||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
||||||
instantiate_rope(traditional_float32, float, true, true)
|
instantiate_rope(traditional_float32, float, true, true)
|
||||||
@ -80,4 +84,4 @@ instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
|
|||||||
instantiate_rope(vjp_traditional_float32, float, true, false)
|
instantiate_rope(vjp_traditional_float32, float, true, false)
|
||||||
instantiate_rope(vjp_float16, half, false, false)
|
instantiate_rope(vjp_float16, half, false, false)
|
||||||
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
|
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
|
||||||
instantiate_rope(vjp_float32, float, false, false)
|
instantiate_rope(vjp_float32, float, false, false) // clang-format on
|
@ -1,13 +1,19 @@
|
|||||||
#include <metal_stdlib>
|
|
||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_t NSIMDGROUPS>
|
template <
|
||||||
[[kernel]] void fast_inference_sdpa_compute_partials_template(const device T *Q [[buffer(0)]],
|
typename T,
|
||||||
const device T *K [[buffer(1)]],
|
typename T2,
|
||||||
const device T *V [[buffer(2)]],
|
typename T4,
|
||||||
|
uint16_t TILE_SIZE_CONST,
|
||||||
|
uint16_t NSIMDGROUPS>
|
||||||
|
[[kernel]] void fast_inference_sdpa_compute_partials_template(
|
||||||
|
const device T* Q [[buffer(0)]],
|
||||||
|
const device T* K [[buffer(1)]],
|
||||||
|
const device T* V [[buffer(2)]],
|
||||||
const device uint64_t& L [[buffer(3)]],
|
const device uint64_t& L [[buffer(3)]],
|
||||||
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]],
|
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]],
|
||||||
device float* O_partials [[buffer(5)]],
|
device float* O_partials [[buffer(5)]],
|
||||||
@ -23,29 +29,36 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
constexpr const uint iter_offset = NSIMDGROUPS * 4;
|
constexpr const uint iter_offset = NSIMDGROUPS * 4;
|
||||||
const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS;
|
const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS;
|
||||||
uint kv_head_offset_factor = tid.x;
|
uint kv_head_offset_factor = tid.x;
|
||||||
if(is_gqa) {
|
if (is_gqa) {
|
||||||
int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS;
|
int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS;
|
||||||
kv_head_offset_factor = tid.x / q_kv_head_ratio;
|
kv_head_offset_factor = tid.x / q_kv_head_ratio;
|
||||||
}
|
}
|
||||||
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
|
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
|
||||||
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP = TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
|
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP =
|
||||||
|
TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
|
||||||
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
|
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||||
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR * SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) * NSIMDGROUPS;
|
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR *
|
||||||
|
SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) *
|
||||||
|
NSIMDGROUPS;
|
||||||
|
|
||||||
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
|
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(uint i = 0; i < 8; i++) {
|
for (uint i = 0; i < 8; i++) {
|
||||||
smemFlush[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP + i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
|
smemFlush
|
||||||
|
[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP +
|
||||||
|
i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// TODO: multiple query sequence length for speculative decoding
|
// TODO: multiple query sequence length for speculative decoding
|
||||||
const uint tgroup_query_head_offset = tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
|
const uint tgroup_query_head_offset =
|
||||||
|
tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
|
||||||
|
|
||||||
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
|
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
|
||||||
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
|
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
|
||||||
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
|
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
|
||||||
|
|
||||||
const device T* baseK = K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
|
const device T* baseK =
|
||||||
|
K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
|
||||||
const device T* baseQ = Q + tgroup_query_head_offset;
|
const device T* baseQ = Q + tgroup_query_head_offset;
|
||||||
|
|
||||||
device T4* simdgroupQueryData = (device T4*)baseQ;
|
device T4* simdgroupQueryData = (device T4*)baseQ;
|
||||||
@ -53,8 +66,9 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS;
|
constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS;
|
||||||
float threadAccum[ACCUM_PER_GROUP];
|
float threadAccum[ACCUM_PER_GROUP];
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP; threadAccumIndex++) {
|
for (size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP;
|
||||||
|
threadAccumIndex++) {
|
||||||
threadAccum[threadAccumIndex] = -INFINITY;
|
threadAccum[threadAccumIndex] = -INFINITY;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,14 +76,16 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
|
|
||||||
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
|
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
|
||||||
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
|
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
|
||||||
const bool LAST_TILE_ALIGNED = (SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
|
const bool LAST_TILE_ALIGNED =
|
||||||
|
(SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
|
||||||
|
|
||||||
T4 thread_data_x4;
|
T4 thread_data_x4;
|
||||||
T4 thread_data_y4;
|
T4 thread_data_y4;
|
||||||
if(!LAST_TILE || LAST_TILE_ALIGNED) {
|
if (!LAST_TILE || LAST_TILE_ALIGNED) {
|
||||||
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
|
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST; KROW += NSIMDGROUPS) {
|
for (size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST;
|
||||||
|
KROW += NSIMDGROUPS) {
|
||||||
const uint KROW_OFFSET = KROW * DK;
|
const uint KROW_OFFSET = KROW * DK;
|
||||||
const device T* baseKRow = baseK + KROW_OFFSET;
|
const device T* baseKRow = baseK + KROW_OFFSET;
|
||||||
device T4* keysData = (device T4*)baseKRow;
|
device T4* keysData = (device T4*)baseKRow;
|
||||||
@ -81,9 +97,11 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
} else {
|
} else {
|
||||||
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
|
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
|
||||||
const uint START_ROW = tid.y * TILE_SIZE_CONST;
|
const uint START_ROW = tid.y * TILE_SIZE_CONST;
|
||||||
const device T* baseKThisHead = K + tgroup_k_batch_offset + tgroup_k_head_offset;
|
const device T* baseKThisHead =
|
||||||
|
K + tgroup_k_batch_offset + tgroup_k_head_offset;
|
||||||
|
|
||||||
for(size_t KROW = START_ROW + simd_group_id; KROW < L; KROW += NSIMDGROUPS) {
|
for (size_t KROW = START_ROW + simd_group_id; KROW < L;
|
||||||
|
KROW += NSIMDGROUPS) {
|
||||||
const uint KROW_OFFSET = KROW * DK;
|
const uint KROW_OFFSET = KROW * DK;
|
||||||
const device T* baseKRow = baseKThisHead + KROW_OFFSET;
|
const device T* baseKRow = baseKThisHead + KROW_OFFSET;
|
||||||
device T4* keysData = (device T4*)baseKRow;
|
device T4* keysData = (device T4*)baseKRow;
|
||||||
@ -95,12 +113,16 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
}
|
}
|
||||||
threadgroup float* smemP = (threadgroup float*)threadgroup_block;
|
threadgroup float* smemP = (threadgroup float*)threadgroup_block;
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(size_t i = 0; i < P_VEC4; i++) {
|
for (size_t i = 0; i < P_VEC4; i++) {
|
||||||
thread_data_x4 = T4(threadAccum[4 * i], threadAccum[4 * i + 1], threadAccum[4 * i + 2], threadAccum[4 * i + 3]);
|
thread_data_x4 =
|
||||||
|
T4(threadAccum[4 * i],
|
||||||
|
threadAccum[4 * i + 1],
|
||||||
|
threadAccum[4 * i + 2],
|
||||||
|
threadAccum[4 * i + 3]);
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
thread_data_y4 = simd_sum(thread_data_x4);
|
thread_data_y4 = simd_sum(thread_data_x4);
|
||||||
if(simd_lane_id == 0) {
|
if (simd_lane_id == 0) {
|
||||||
const uint base_smem_p_offset = i * iter_offset + simd_group_id;
|
const uint base_smem_p_offset = i * iter_offset + simd_group_id;
|
||||||
smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x);
|
smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x);
|
||||||
smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y);
|
smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y);
|
||||||
@ -115,11 +137,13 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
float lse = 0.f;
|
float lse = 0.f;
|
||||||
|
|
||||||
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
|
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
|
||||||
constexpr const size_t ACCUM_ARRAY_LENGTH = TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
|
constexpr const size_t ACCUM_ARRAY_LENGTH =
|
||||||
|
TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
|
||||||
float4 pvals[ACCUM_ARRAY_LENGTH];
|
float4 pvals[ACCUM_ARRAY_LENGTH];
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH; accum_array_iter++) {
|
for (uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH;
|
||||||
|
accum_array_iter++) {
|
||||||
pvals[accum_array_iter] = float4(-INFINITY);
|
pvals[accum_array_iter] = float4(-INFINITY);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,8 +172,8 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
if (TILE_SIZE_LARGER_THAN_64) {
|
if (TILE_SIZE_LARGER_THAN_64) {
|
||||||
float maxval = -INFINITY;
|
float maxval = -INFINITY;
|
||||||
threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block;
|
threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block;
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
||||||
float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP];
|
float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP];
|
||||||
vals *= params.INV_ALPHA;
|
vals *= params.INV_ALPHA;
|
||||||
pvals[i] = vals;
|
pvals[i] = vals;
|
||||||
@ -160,16 +184,16 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
groupMax = simd_max(maxval);
|
groupMax = simd_max(maxval);
|
||||||
|
|
||||||
float sumExpLocal = 0.f;
|
float sumExpLocal = 0.f;
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
||||||
pvals[i] = exp(pvals[i] - groupMax);
|
pvals[i] = exp(pvals[i] - groupMax);
|
||||||
sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w;
|
sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w;
|
||||||
}
|
}
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
float tgroupExpSum = simd_sum(sumExpLocal);
|
float tgroupExpSum = simd_sum(sumExpLocal);
|
||||||
lse = log(tgroupExpSum);
|
lse = log(tgroupExpSum);
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
||||||
pvals[i] = pvals[i] / tgroupExpSum;
|
pvals[i] = pvals[i] / tgroupExpSum;
|
||||||
smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f);
|
smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f);
|
||||||
}
|
}
|
||||||
@ -187,15 +211,20 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV);
|
threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV);
|
||||||
|
|
||||||
if (!LAST_TILE || LAST_TILE_ALIGNED) {
|
if (!LAST_TILE || LAST_TILE_ALIGNED) {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(size_t col = 0; col < MATRIX_COLS; col++) {
|
for (size_t col = 0; col < MATRIX_COLS; col++) {
|
||||||
uint matrix_load_loop_iter = 0;
|
uint matrix_load_loop_iter = 0;
|
||||||
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
|
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
|
||||||
|
|
||||||
for(size_t tile_start = simd_group_id; tile_start < TILE_SIZE_CONST_DIV_8; tile_start += NSIMDGROUPS) {
|
for (size_t tile_start = simd_group_id;
|
||||||
|
tile_start < TILE_SIZE_CONST_DIV_8;
|
||||||
|
tile_start += NSIMDGROUPS) {
|
||||||
simdgroup_matrix<T, 8, 8> tmp;
|
simdgroup_matrix<T, 8, 8> tmp;
|
||||||
ulong simdgroup_matrix_offset = matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
ulong simdgroup_matrix_offset =
|
||||||
ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
|
matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR +
|
||||||
|
simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||||
|
ulong2 matrixOrigin =
|
||||||
|
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
|
||||||
simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
|
simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
|
||||||
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
|
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
|
||||||
const ulong elemsPerRowSmem = TILE_SIZE_CONST;
|
const ulong elemsPerRowSmem = TILE_SIZE_CONST;
|
||||||
@ -208,10 +237,12 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
if (TILE_SIZE_CONST == 64) {
|
if (TILE_SIZE_CONST == 64) {
|
||||||
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
|
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
|
||||||
uint loop_iter = 0;
|
uint loop_iter = 0;
|
||||||
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
threadgroup float* oPartialSmem =
|
||||||
|
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) {
|
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||||
|
row += NSIMDGROUPS) {
|
||||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
|
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
|
||||||
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
|
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
|
||||||
T2 v_local = *(smemV2 + simd_lane_id);
|
T2 v_local = *(smemV2 + simd_lane_id);
|
||||||
@ -220,20 +251,24 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
T row_sum = simd_sum(val);
|
T row_sum = simd_sum(val);
|
||||||
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum);
|
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
|
||||||
|
float(row_sum);
|
||||||
loop_iter++;
|
loop_iter++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (TILE_SIZE_CONST > 64) {
|
if (TILE_SIZE_CONST > 64) {
|
||||||
constexpr const size_t TILE_SIZE_CONST_DIV_128 = (TILE_SIZE_CONST + 1) / 128;
|
constexpr const size_t TILE_SIZE_CONST_DIV_128 =
|
||||||
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
(TILE_SIZE_CONST + 1) / 128;
|
||||||
|
threadgroup float* oPartialSmem =
|
||||||
|
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||||
uint loop_iter = 0;
|
uint loop_iter = 0;
|
||||||
for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) {
|
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||||
|
row += NSIMDGROUPS) {
|
||||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
|
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
|
||||||
|
|
||||||
T row_sum = 0.f;
|
T row_sum = 0.f;
|
||||||
for(size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) {
|
for (size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) {
|
||||||
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
|
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
|
||||||
T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP);
|
T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP);
|
||||||
T4 p_local = T4(pvals[i]);
|
T4 p_local = T4(pvals[i]);
|
||||||
@ -242,7 +277,8 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
}
|
}
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
row_sum = simd_sum(row_sum);
|
row_sum = simd_sum(row_sum);
|
||||||
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum);
|
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
|
||||||
|
float(row_sum);
|
||||||
loop_iter++;
|
loop_iter++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -252,35 +288,50 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1;
|
const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1;
|
||||||
const device T* baseVThisHead = V + v_batch_offset + v_head_offset;
|
const device T* baseVThisHead = V + v_batch_offset + v_head_offset;
|
||||||
constexpr const int ROWS_PER_ITER = 8;
|
constexpr const int ROWS_PER_ITER = 8;
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for(size_t col = 0; col < MATRIX_COLS; col++) {
|
for (size_t col = 0; col < MATRIX_COLS; col++) {
|
||||||
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||||
int32_t tile_start;
|
int32_t tile_start;
|
||||||
for(tile_start = START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; tile_start < MAX_START_ROW; tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
|
for (tile_start =
|
||||||
|
START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||||
|
tile_start < MAX_START_ROW;
|
||||||
|
tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
|
||||||
simdgroup_matrix<T, 8, 8> tmp;
|
simdgroup_matrix<T, 8, 8> tmp;
|
||||||
ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
|
ulong2 matrixOrigin =
|
||||||
simdgroup_load(tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
|
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
|
||||||
|
simdgroup_load(
|
||||||
|
tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
|
||||||
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
|
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
|
||||||
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
|
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
|
||||||
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, /* transpose */ false);
|
simdgroup_store(
|
||||||
|
tmp,
|
||||||
|
smemV,
|
||||||
|
elemsPerRowSmem,
|
||||||
|
matrixOriginSmem,
|
||||||
|
/* transpose */ false);
|
||||||
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||||
};
|
};
|
||||||
|
|
||||||
tile_start = ((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
|
tile_start =
|
||||||
|
((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
|
||||||
|
|
||||||
const int32_t INT_L = int32_t(L);
|
const int32_t INT_L = int32_t(L);
|
||||||
for(int row_index = tile_start + simd_group_id ; row_index < INT_L; row_index += NSIMDGROUPS) {
|
for (int row_index = tile_start + simd_group_id; row_index < INT_L;
|
||||||
if(simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
|
row_index += NSIMDGROUPS) {
|
||||||
|
if (simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
|
||||||
const uint elems_per_row_gmem = DK;
|
const uint elems_per_row_gmem = DK;
|
||||||
const uint col_index_v_gmem = col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
|
const uint col_index_v_gmem =
|
||||||
|
col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
|
||||||
const uint row_index_v_gmem = row_index;
|
const uint row_index_v_gmem = row_index;
|
||||||
|
|
||||||
const uint elems_per_row_smem = TILE_SIZE_CONST;
|
const uint elems_per_row_smem = TILE_SIZE_CONST;
|
||||||
const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
|
const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
|
||||||
const uint row_index_v_smem = simd_lane_id;
|
const uint row_index_v_smem = simd_lane_id;
|
||||||
|
|
||||||
const uint scalar_offset_gmem = row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
|
const uint scalar_offset_gmem =
|
||||||
const uint scalar_offset_smem = row_index_v_smem * elems_per_row_smem + col_index_v_smem;
|
row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
|
||||||
|
const uint scalar_offset_smem =
|
||||||
|
row_index_v_smem * elems_per_row_smem + col_index_v_smem;
|
||||||
T vdata = T(*(baseVThisHead + scalar_offset_gmem));
|
T vdata = T(*(baseVThisHead + scalar_offset_gmem));
|
||||||
smemV[scalar_offset_smem] = vdata;
|
smemV[scalar_offset_smem] = vdata;
|
||||||
smem_col_index += NSIMDGROUPS;
|
smem_col_index += NSIMDGROUPS;
|
||||||
@ -291,9 +342,11 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
|
|
||||||
if (TILE_SIZE_CONST == 64) {
|
if (TILE_SIZE_CONST == 64) {
|
||||||
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
|
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
|
||||||
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
threadgroup float* oPartialSmem =
|
||||||
for(size_t smem_row_index = simd_group_id;
|
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||||
smem_row_index < ROWS_PER_ITER; smem_row_index += NSIMDGROUPS) {
|
for (size_t smem_row_index = simd_group_id;
|
||||||
|
smem_row_index < ROWS_PER_ITER;
|
||||||
|
smem_row_index += NSIMDGROUPS) {
|
||||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
|
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
|
||||||
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
|
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
|
||||||
T2 v_local = *(smemV2 + simd_lane_id);
|
T2 v_local = *(smemV2 + simd_lane_id);
|
||||||
@ -305,22 +358,25 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (TILE_SIZE_CONST > 64) {
|
if (TILE_SIZE_CONST > 64) {
|
||||||
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
threadgroup float* oPartialSmem =
|
||||||
|
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||||
uint loop_count = 0;
|
uint loop_count = 0;
|
||||||
for(size_t row_index = simd_group_id;
|
for (size_t row_index = simd_group_id; row_index < ROWS_PER_ITER;
|
||||||
row_index < ROWS_PER_ITER; row_index += NSIMDGROUPS) {
|
row_index += NSIMDGROUPS) {
|
||||||
T row_sum = 0.f;
|
T row_sum = 0.f;
|
||||||
for(size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128; tile_iters++) {
|
for (size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128;
|
||||||
|
tile_iters++) {
|
||||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
|
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
|
||||||
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
|
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
|
||||||
T4 v_local = *(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
|
T4 v_local =
|
||||||
|
*(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
|
||||||
T4 p_local = T4(pvals[tile_iters]);
|
T4 p_local = T4(pvals[tile_iters]);
|
||||||
row_sum += dot(p_local, v_local);
|
row_sum += dot(p_local, v_local);
|
||||||
|
|
||||||
}
|
}
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
row_sum = simd_sum(row_sum);
|
row_sum = simd_sum(row_sum);
|
||||||
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] = float(row_sum);
|
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] =
|
||||||
|
float(row_sum);
|
||||||
loop_count++;
|
loop_count++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -329,76 +385,121 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
if(simd_group_id == 0) {
|
if (simd_group_id == 0) {
|
||||||
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
|
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
|
||||||
float4 vals = *(oPartialVec4 + simd_lane_id);
|
float4 vals = *(oPartialVec4 + simd_lane_id);
|
||||||
device float* oPartialGmem = O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
|
device float* oPartialGmem =
|
||||||
|
O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
|
||||||
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
|
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
|
||||||
oPartialGmemVec4[simd_lane_id] = vals;
|
oPartialGmemVec4[simd_lane_id] = vals;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(simd_group_id == 0 && simd_lane_id == 0) {
|
if (simd_group_id == 0 && simd_lane_id == 0) {
|
||||||
const uint tileIndex = tid.y;
|
const uint tileIndex = tid.y;
|
||||||
const uint gmem_partial_scalar_offset = tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES + tileIndex;
|
const uint gmem_partial_scalar_offset =
|
||||||
|
tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES +
|
||||||
|
tileIndex;
|
||||||
p_lse[gmem_partial_scalar_offset] = lse;
|
p_lse[gmem_partial_scalar_offset] = lse;
|
||||||
p_maxes[gmem_partial_scalar_offset] = groupMax;
|
p_maxes[gmem_partial_scalar_offset] = groupMax;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, nsimdgroups) \
|
#define instantiate_fast_inference_sdpa_to_partials_kernel( \
|
||||||
template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_size "_" #nsimdgroups )]] \
|
itype, itype2, itype4, tile_size, nsimdgroups) \
|
||||||
[[kernel]] void fast_inference_sdpa_compute_partials_template<itype, itype2, itype4, tile_size, nsimdgroups>( \
|
template [[host_name("fast_inference_sdpa_compute_partials_" #itype \
|
||||||
const device itype *Q [[buffer(0)]], \
|
"_" #tile_size "_" #nsimdgroups)]] [[kernel]] void \
|
||||||
const device itype *K [[buffer(1)]], \
|
fast_inference_sdpa_compute_partials_template< \
|
||||||
const device itype *V [[buffer(2)]], \
|
itype, \
|
||||||
|
itype2, \
|
||||||
|
itype4, \
|
||||||
|
tile_size, \
|
||||||
|
nsimdgroups>( \
|
||||||
|
const device itype* Q [[buffer(0)]], \
|
||||||
|
const device itype* K [[buffer(1)]], \
|
||||||
|
const device itype* V [[buffer(2)]], \
|
||||||
const device uint64_t& L [[buffer(3)]], \
|
const device uint64_t& L [[buffer(3)]], \
|
||||||
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \
|
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \
|
||||||
device float* O_partials [[buffer(5)]], \
|
device float* O_partials [[buffer(5)]], \
|
||||||
device float* p_lse [[buffer(6)]], \
|
device float* p_lse [[buffer(6)]], \
|
||||||
device float* p_maxes [[buffer(7)]], \
|
device float* p_maxes [[buffer(7)]], \
|
||||||
threadgroup itype *threadgroup_block [[threadgroup(0)]], \
|
threadgroup itype* threadgroup_block [[threadgroup(0)]], \
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]]);
|
uint3 tid [[threadgroup_position_in_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper( \
|
||||||
|
itype, itype2, itype4, tile_size) \
|
||||||
|
instantiate_fast_inference_sdpa_to_partials_kernel( \
|
||||||
|
itype, itype2, itype4, tile_size, 4) \
|
||||||
|
instantiate_fast_inference_sdpa_to_partials_kernel( \
|
||||||
|
itype, itype2, itype4, tile_size, 8) // clang-format on
|
||||||
|
|
||||||
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper(itype, itype2, itype4, tile_size) \
|
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||||
instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 4) \
|
float,
|
||||||
instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 8) \
|
float2,
|
||||||
|
float4,
|
||||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 64);
|
64);
|
||||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 128);
|
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 256);
|
float,
|
||||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 512);
|
float2,
|
||||||
|
float4,
|
||||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 64);
|
128);
|
||||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 128);
|
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 256);
|
float,
|
||||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 512);
|
float2,
|
||||||
|
float4,
|
||||||
|
256);
|
||||||
|
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||||
|
float,
|
||||||
|
float2,
|
||||||
|
float4,
|
||||||
|
512);
|
||||||
|
|
||||||
|
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||||
|
half,
|
||||||
|
half2,
|
||||||
|
half4,
|
||||||
|
64);
|
||||||
|
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||||
|
half,
|
||||||
|
half2,
|
||||||
|
half4,
|
||||||
|
128);
|
||||||
|
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||||
|
half,
|
||||||
|
half2,
|
||||||
|
half4,
|
||||||
|
256);
|
||||||
|
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||||
|
half,
|
||||||
|
half2,
|
||||||
|
half4,
|
||||||
|
512);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void fast_inference_sdpa_reduce_tiles_template(
|
void fast_inference_sdpa_reduce_tiles_template(
|
||||||
const device float *O_partials [[buffer(0)]],
|
const device float* O_partials [[buffer(0)]],
|
||||||
const device float *p_lse[[buffer(1)]],
|
const device float* p_lse [[buffer(1)]],
|
||||||
const device float *p_maxes [[buffer(2)]],
|
const device float* p_maxes [[buffer(2)]],
|
||||||
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
||||||
device T* O [[buffer(4)]],
|
device T* O [[buffer(4)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
constexpr const int DK = 128;
|
constexpr const int DK = 128;
|
||||||
const ulong offset_rows = tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
|
const ulong offset_rows =
|
||||||
|
tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
|
||||||
const device float* p_lse_row = p_lse + offset_rows;
|
const device float* p_lse_row = p_lse + offset_rows;
|
||||||
const device float* p_rowmax_row = p_maxes + offset_rows;
|
const device float* p_rowmax_row = p_maxes + offset_rows;
|
||||||
// reserve some number of registers. this constitutes an assumption on max value of KV TILES.
|
// reserve some number of registers. this constitutes an assumption on max
|
||||||
|
// value of KV TILES.
|
||||||
constexpr const uint8_t reserve = 128;
|
constexpr const uint8_t reserve = 128;
|
||||||
float p_lse_regs[reserve];
|
float p_lse_regs[reserve];
|
||||||
float p_rowmax_regs[reserve];
|
float p_rowmax_regs[reserve];
|
||||||
float weights[reserve];
|
float weights[reserve];
|
||||||
|
|
||||||
float true_max = -INFINITY;
|
float true_max = -INFINITY;
|
||||||
for(size_t i = 0; i < params.KV_TILES; i++) {
|
for (size_t i = 0; i < params.KV_TILES; i++) {
|
||||||
p_lse_regs[i] = float(*(p_lse_row + i));
|
p_lse_regs[i] = float(*(p_lse_row + i));
|
||||||
p_rowmax_regs[i] = float(*(p_rowmax_row + i));
|
p_rowmax_regs[i] = float(*(p_rowmax_row + i));
|
||||||
true_max = fmax(p_rowmax_regs[i], true_max);
|
true_max = fmax(p_rowmax_regs[i], true_max);
|
||||||
@ -406,15 +507,17 @@ void fast_inference_sdpa_reduce_tiles_template(
|
|||||||
}
|
}
|
||||||
|
|
||||||
float denom = 0.f;
|
float denom = 0.f;
|
||||||
for(size_t i = 0; i < params.KV_TILES; i++) {
|
for (size_t i = 0; i < params.KV_TILES; i++) {
|
||||||
weights[i] *= exp(p_rowmax_regs[i]-true_max);
|
weights[i] *= exp(p_rowmax_regs[i] - true_max);
|
||||||
denom += weights[i];
|
denom += weights[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
const device float* O_partials_with_offset = O_partials + tid.z * params.N_Q_HEADS * DK * params.KV_TILES + tid.x * DK * params.KV_TILES;
|
const device float* O_partials_with_offset = O_partials +
|
||||||
|
tid.z * params.N_Q_HEADS * DK * params.KV_TILES +
|
||||||
|
tid.x * DK * params.KV_TILES;
|
||||||
|
|
||||||
float o_value = 0.f;
|
float o_value = 0.f;
|
||||||
for(size_t i = 0; i < params.KV_TILES; i++) {
|
for (size_t i = 0; i < params.KV_TILES; i++) {
|
||||||
float val = *(O_partials_with_offset + i * DK + lid.x);
|
float val = *(O_partials_with_offset + i * DK + lid.x);
|
||||||
o_value += val * weights[i] / denom;
|
o_value += val * weights[i] / denom;
|
||||||
}
|
}
|
||||||
@ -423,29 +526,26 @@ void fast_inference_sdpa_reduce_tiles_template(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
kernel void fast_inference_sdpa_reduce_tiles_float(
|
kernel void fast_inference_sdpa_reduce_tiles_float(
|
||||||
const device float *O_partials [[buffer(0)]],
|
const device float* O_partials [[buffer(0)]],
|
||||||
const device float *p_lse[[buffer(1)]],
|
const device float* p_lse [[buffer(1)]],
|
||||||
const device float *p_maxes [[buffer(2)]],
|
const device float* p_maxes [[buffer(2)]],
|
||||||
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
||||||
device float* O [[buffer(4)]],
|
device float* O [[buffer(4)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]])
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
{
|
fast_inference_sdpa_reduce_tiles_template<float>(
|
||||||
fast_inference_sdpa_reduce_tiles_template<float>(O_partials, p_lse, p_maxes, params,
|
O_partials, p_lse, p_maxes, params, O, tid, lid);
|
||||||
O, tid, lid);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void fast_inference_sdpa_reduce_tiles_half(
|
kernel void fast_inference_sdpa_reduce_tiles_half(
|
||||||
const device float *O_partials [[buffer(0)]],
|
const device float* O_partials [[buffer(0)]],
|
||||||
const device float *p_lse[[buffer(1)]],
|
const device float* p_lse [[buffer(1)]],
|
||||||
const device float *p_maxes [[buffer(2)]],
|
const device float* p_maxes [[buffer(2)]],
|
||||||
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
||||||
device half* O [[buffer(4)]],
|
device half* O [[buffer(4)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]])
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
{
|
fast_inference_sdpa_reduce_tiles_template<half>(
|
||||||
fast_inference_sdpa_reduce_tiles_template<half>(O_partials, p_lse, p_maxes, params,
|
O_partials, p_lse, p_maxes, params, O, tid, lid);
|
||||||
O, tid, lid);
|
|
||||||
}
|
}
|
||||||
|
@ -54,7 +54,7 @@ struct CumProd<bool> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool simd_scan(bool x) {
|
bool simd_scan(bool x) {
|
||||||
for (int i=1; i<=16; i*=2) {
|
for (int i = 1; i <= 16; i *= 2) {
|
||||||
bool other = simd_shuffle_up(x, i);
|
bool other = simd_shuffle_up(x, i);
|
||||||
x &= other;
|
x &= other;
|
||||||
}
|
}
|
||||||
@ -77,7 +77,7 @@ struct CumMax {
|
|||||||
}
|
}
|
||||||
|
|
||||||
U simd_scan(U x) {
|
U simd_scan(U x) {
|
||||||
for (int i=1; i<=16; i*=2) {
|
for (int i = 1; i <= 16; i *= 2) {
|
||||||
U other = simd_shuffle_up(x, i);
|
U other = simd_shuffle_up(x, i);
|
||||||
x = (x >= other) ? x : other;
|
x = (x >= other) ? x : other;
|
||||||
}
|
}
|
||||||
@ -100,7 +100,7 @@ struct CumMin {
|
|||||||
}
|
}
|
||||||
|
|
||||||
U simd_scan(U x) {
|
U simd_scan(U x) {
|
||||||
for (int i=1; i<=16; i*=2) {
|
for (int i = 1; i <= 16; i *= 2) {
|
||||||
U other = simd_shuffle_up(x, i);
|
U other = simd_shuffle_up(x, i);
|
||||||
x = (x <= other) ? x : other;
|
x = (x <= other) ? x : other;
|
||||||
}
|
}
|
||||||
@ -114,54 +114,60 @@ struct CumMin {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, int N_READS, bool reverse>
|
template <typename T, typename U, int N_READS, bool reverse>
|
||||||
inline void load_unsafe(U values[N_READS], const device T * input) {
|
inline void load_unsafe(U values[N_READS], const device T* input) {
|
||||||
if (reverse) {
|
if (reverse) {
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
values[N_READS-i-1] = input[i];
|
values[N_READS - i - 1] = input[i];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
values[i] = input[i];
|
values[i] = input[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, int N_READS, bool reverse>
|
template <typename T, typename U, int N_READS, bool reverse>
|
||||||
inline void load_safe(U values[N_READS], const device T * input, int start, int total, U init) {
|
inline void load_safe(
|
||||||
|
U values[N_READS],
|
||||||
|
const device T* input,
|
||||||
|
int start,
|
||||||
|
int total,
|
||||||
|
U init) {
|
||||||
if (reverse) {
|
if (reverse) {
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
values[N_READS-i-1] = (start + N_READS - i - 1 < total) ? input[i] : init;
|
values[N_READS - i - 1] =
|
||||||
|
(start + N_READS - i - 1 < total) ? input[i] : init;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
values[i] = (start + i < total) ? input[i] : init;
|
values[i] = (start + i < total) ? input[i] : init;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename U, int N_READS, bool reverse>
|
template <typename U, int N_READS, bool reverse>
|
||||||
inline void write_unsafe(U values[N_READS], device U * out) {
|
inline void write_unsafe(U values[N_READS], device U* out) {
|
||||||
if (reverse) {
|
if (reverse) {
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
out[i] = values[N_READS-i-1];
|
out[i] = values[N_READS - i - 1];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
out[i] = values[i];
|
out[i] = values[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename U, int N_READS, bool reverse>
|
template <typename U, int N_READS, bool reverse>
|
||||||
inline void write_safe(U values[N_READS], device U * out, int start, int total) {
|
inline void write_safe(U values[N_READS], device U* out, int start, int total) {
|
||||||
if (reverse) {
|
if (reverse) {
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
if (start + N_READS - i - 1 < total) {
|
if (start + N_READS - i - 1 < total) {
|
||||||
out[i] = values[N_READS-i-1];
|
out[i] = values[N_READS - i - 1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
if (start + i < total) {
|
if (start + i < total) {
|
||||||
out[i] = values[i];
|
out[i] = values[i];
|
||||||
}
|
}
|
||||||
@ -169,12 +175,17 @@ inline void write_safe(U values[N_READS], device U * out, int start, int total)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
|
typename T,
|
||||||
|
typename U,
|
||||||
|
typename Op,
|
||||||
|
int N_READS,
|
||||||
|
bool inclusive,
|
||||||
|
bool reverse>
|
||||||
[[kernel]] void contiguous_scan(
|
[[kernel]] void contiguous_scan(
|
||||||
const device T* in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device U* out [[buffer(1)]],
|
device U* out [[buffer(1)]],
|
||||||
const constant size_t & axis_size [[buffer(2)]],
|
const constant size_t& axis_size [[buffer(2)]],
|
||||||
uint gid [[thread_position_in_grid]],
|
uint gid [[thread_position_in_grid]],
|
||||||
uint lid [[thread_position_in_threadgroup]],
|
uint lid [[thread_position_in_threadgroup]],
|
||||||
uint lsize [[threads_per_threadgroup]],
|
uint lsize [[threads_per_threadgroup]],
|
||||||
@ -195,42 +206,51 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
U values[N_READS];
|
U values[N_READS];
|
||||||
threadgroup U simdgroup_sums[32];
|
threadgroup U simdgroup_sums[32];
|
||||||
|
|
||||||
// Loop over the reduced axis in blocks of size ceildiv(axis_size, N_READS*lsize)
|
// Loop over the reduced axis in blocks of size ceildiv(axis_size,
|
||||||
|
// N_READS*lsize)
|
||||||
// Read block
|
// Read block
|
||||||
// Compute inclusive scan of the block
|
// Compute inclusive scan of the block
|
||||||
// Compute inclusive scan per thread
|
// Compute inclusive scan per thread
|
||||||
// Compute exclusive scan of thread sums in simdgroup
|
// Compute exclusive scan of thread sums in simdgroup
|
||||||
// Write simdgroup sums in SM
|
// Write simdgroup sums in SM
|
||||||
// Compute exclusive scan of simdgroup sums
|
// Compute exclusive scan of simdgroup sums
|
||||||
// Compute the output by scanning prefix, prev_simdgroup, prev_thread, value
|
// Compute the output by scanning prefix, prev_simdgroup, prev_thread,
|
||||||
|
// value
|
||||||
// Write block
|
// Write block
|
||||||
|
|
||||||
for (uint r = 0; r < ceildiv(axis_size, N_READS*lsize); r++) {
|
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
|
||||||
// Compute the block offset
|
// Compute the block offset
|
||||||
uint offset = r*lsize*N_READS + lid*N_READS;
|
uint offset = r * lsize * N_READS + lid * N_READS;
|
||||||
|
|
||||||
// Read the values
|
// Read the values
|
||||||
if (reverse) {
|
if (reverse) {
|
||||||
if ((offset + N_READS) < axis_size) {
|
if ((offset + N_READS) < axis_size) {
|
||||||
load_unsafe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS);
|
load_unsafe<T, U, N_READS, reverse>(
|
||||||
|
values, in + axis_size - offset - N_READS);
|
||||||
} else {
|
} else {
|
||||||
load_safe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS, offset, axis_size, Op::init);
|
load_safe<T, U, N_READS, reverse>(
|
||||||
|
values,
|
||||||
|
in + axis_size - offset - N_READS,
|
||||||
|
offset,
|
||||||
|
axis_size,
|
||||||
|
Op::init);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if ((offset + N_READS) < axis_size) {
|
if ((offset + N_READS) < axis_size) {
|
||||||
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
|
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
|
||||||
} else {
|
} else {
|
||||||
load_safe<T, U, N_READS, reverse>(values, in + offset, offset, axis_size, Op::init);
|
load_safe<T, U, N_READS, reverse>(
|
||||||
|
values, in + offset, offset, axis_size, Op::init);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute an inclusive scan per thread
|
// Compute an inclusive scan per thread
|
||||||
for (int i=1; i<N_READS; i++) {
|
for (int i = 1; i < N_READS; i++) {
|
||||||
values[i] = op(values[i], values[i-1]);
|
values[i] = op(values[i], values[i - 1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute exclusive scan of thread sums
|
// Compute exclusive scan of thread sums
|
||||||
U prev_thread = op.simd_exclusive_scan(values[N_READS-1]);
|
U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);
|
||||||
|
|
||||||
// Write simdgroup_sums to SM
|
// Write simdgroup_sums to SM
|
||||||
if (simd_lane_id == simd_size - 1) {
|
if (simd_lane_id == simd_size - 1) {
|
||||||
@ -246,7 +266,7 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Compute the output
|
// Compute the output
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
values[i] = op(values[i], prefix);
|
values[i] = op(values[i], prefix);
|
||||||
values[i] = op(values[i], simdgroup_sums[simd_group_id]);
|
values[i] = op(values[i], simdgroup_sums[simd_group_id]);
|
||||||
values[i] = op(values[i], prev_thread);
|
values[i] = op(values[i], prev_thread);
|
||||||
@ -256,18 +276,25 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
if (reverse) {
|
if (reverse) {
|
||||||
if (inclusive) {
|
if (inclusive) {
|
||||||
if ((offset + N_READS) < axis_size) {
|
if ((offset + N_READS) < axis_size) {
|
||||||
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS);
|
write_unsafe<U, N_READS, reverse>(
|
||||||
|
values, out + axis_size - offset - N_READS);
|
||||||
} else {
|
} else {
|
||||||
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS, offset, axis_size);
|
write_safe<U, N_READS, reverse>(
|
||||||
|
values, out + axis_size - offset - N_READS, offset, axis_size);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (lid == 0 && offset == 0) {
|
if (lid == 0 && offset == 0) {
|
||||||
out[axis_size-1] = Op::init;
|
out[axis_size - 1] = Op::init;
|
||||||
}
|
}
|
||||||
if ((offset + N_READS + 1) < axis_size) {
|
if ((offset + N_READS + 1) < axis_size) {
|
||||||
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS);
|
write_unsafe<U, N_READS, reverse>(
|
||||||
|
values, out + axis_size - offset - 1 - N_READS);
|
||||||
} else {
|
} else {
|
||||||
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS, offset + 1, axis_size);
|
write_safe<U, N_READS, reverse>(
|
||||||
|
values,
|
||||||
|
out + axis_size - offset - 1 - N_READS,
|
||||||
|
offset + 1,
|
||||||
|
axis_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -275,7 +302,8 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
if ((offset + N_READS) < axis_size) {
|
if ((offset + N_READS) < axis_size) {
|
||||||
write_unsafe<U, N_READS, reverse>(values, out + offset);
|
write_unsafe<U, N_READS, reverse>(values, out + offset);
|
||||||
} else {
|
} else {
|
||||||
write_safe<U, N_READS, reverse>(values, out + offset, offset, axis_size);
|
write_safe<U, N_READS, reverse>(
|
||||||
|
values, out + offset, offset, axis_size);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (lid == 0 && offset == 0) {
|
if (lid == 0 && offset == 0) {
|
||||||
@ -284,26 +312,33 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
if ((offset + N_READS + 1) < axis_size) {
|
if ((offset + N_READS + 1) < axis_size) {
|
||||||
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
|
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
|
||||||
} else {
|
} else {
|
||||||
write_safe<U, N_READS, reverse>(values, out + offset + 1, offset + 1, axis_size);
|
write_safe<U, N_READS, reverse>(
|
||||||
|
values, out + offset + 1, offset + 1, axis_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Share the prefix
|
// Share the prefix
|
||||||
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
|
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
|
||||||
simdgroup_sums[0] = values[N_READS-1];
|
simdgroup_sums[0] = values[N_READS - 1];
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
prefix = simdgroup_sums[0];
|
prefix = simdgroup_sums[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
typename Op,
|
||||||
|
int N_READS,
|
||||||
|
bool inclusive,
|
||||||
|
bool reverse>
|
||||||
[[kernel]] void strided_scan(
|
[[kernel]] void strided_scan(
|
||||||
const device T* in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device U* out [[buffer(1)]],
|
device U* out [[buffer(1)]],
|
||||||
const constant size_t & axis_size [[buffer(2)]],
|
const constant size_t& axis_size [[buffer(2)]],
|
||||||
const constant size_t & stride [[buffer(3)]],
|
const constant size_t& stride [[buffer(3)]],
|
||||||
uint2 gid [[threadgroup_position_in_grid]],
|
uint2 gid [[threadgroup_position_in_grid]],
|
||||||
uint2 lid [[thread_position_in_threadgroup]],
|
uint2 lid [[thread_position_in_threadgroup]],
|
||||||
uint2 lsize [[threads_per_threadgroup]],
|
uint2 lsize [[threads_per_threadgroup]],
|
||||||
@ -311,10 +346,10 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
// Allocate memory
|
// Allocate memory
|
||||||
threadgroup U read_buffer[N_READS*32*32 + N_READS*32];
|
threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32];
|
||||||
U values[N_READS];
|
U values[N_READS];
|
||||||
U prefix[N_READS];
|
U prefix[N_READS];
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
prefix[i] = Op::init;
|
prefix[i] = Op::init;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -322,7 +357,7 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
int offset = gid.y * axis_size * stride;
|
int offset = gid.y * axis_size * stride;
|
||||||
int global_index_x = gid.x * lsize.y * N_READS;
|
int global_index_x = gid.x * lsize.y * N_READS;
|
||||||
|
|
||||||
for (uint j=0; j<axis_size; j+=simd_size) {
|
for (uint j = 0; j < axis_size; j += simd_size) {
|
||||||
// Calculate the indices for the current thread
|
// Calculate the indices for the current thread
|
||||||
uint index_y = j + lid.y;
|
uint index_y = j + lid.y;
|
||||||
uint check_index_y = index_y;
|
uint check_index_y = index_y;
|
||||||
@ -333,37 +368,43 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
|
|
||||||
// Read in SM
|
// Read in SM
|
||||||
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
|
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||||
|
in[offset + index_y * stride + index_x + i];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
if (check_index_y < axis_size && (index_x + i) < stride) {
|
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
|
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||||
|
in[offset + index_y * stride + index_x + i];
|
||||||
} else {
|
} else {
|
||||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = Op::init;
|
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
|
||||||
|
Op::init;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Read strided into registers
|
// Read strided into registers
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
values[i] = read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
|
values[i] =
|
||||||
|
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
|
||||||
}
|
}
|
||||||
// Do we need the following barrier? Shouldn't all simd threads execute simultaneously?
|
// Do we need the following barrier? Shouldn't all simd threads execute
|
||||||
|
// simultaneously?
|
||||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Perform the scan
|
// Perform the scan
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
values[i] = op.simd_scan(values[i]);
|
values[i] = op.simd_scan(values[i]);
|
||||||
values[i] = op(values[i], prefix[i]);
|
values[i] = op(values[i], prefix[i]);
|
||||||
prefix[i] = simd_shuffle(values[i], simd_size-1);
|
prefix[i] = simd_shuffle(values[i], simd_size - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write to SM
|
// Write to SM
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] = values[i];
|
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] =
|
||||||
|
values[i];
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
@ -371,11 +412,11 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
if (!inclusive) {
|
if (!inclusive) {
|
||||||
if (check_index_y == 0) {
|
if (check_index_y == 0) {
|
||||||
if ((index_x + N_READS) < stride) {
|
if ((index_x + N_READS) < stride) {
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
out[offset + index_y * stride + index_x + i] = Op::init;
|
out[offset + index_y * stride + index_x + i] = Op::init;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
if ((index_x + i) < stride) {
|
if ((index_x + i) < stride) {
|
||||||
out[offset + index_y * stride + index_x + i] = Op::init;
|
out[offset + index_y * stride + index_x + i] = Op::init;
|
||||||
}
|
}
|
||||||
@ -391,25 +432,28 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
out[offset + index_y * stride + index_x + i] =
|
||||||
|
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i=0; i<N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
if (check_index_y < axis_size && (index_x + i) < stride) {
|
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||||
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
out[offset + index_y * stride + index_x + i] =
|
||||||
|
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_contiguous_scan(name, itype, otype, op, inclusive, reverse, nreads) \
|
#define instantiate_contiguous_scan( \
|
||||||
template [[host_name("contiguous_scan_" #name)]] \
|
name, itype, otype, op, inclusive, reverse, nreads) \
|
||||||
[[kernel]] void contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
template [[host_name("contiguous_scan_" #name)]] [[kernel]] void \
|
||||||
|
contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device otype* out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const constant size_t & axis_size [[buffer(2)]], \
|
const constant size_t& axis_size [[buffer(2)]], \
|
||||||
uint gid [[thread_position_in_grid]], \
|
uint gid [[thread_position_in_grid]], \
|
||||||
uint lid [[thread_position_in_threadgroup]], \
|
uint lid [[thread_position_in_threadgroup]], \
|
||||||
uint lsize [[threads_per_threadgroup]], \
|
uint lsize [[threads_per_threadgroup]], \
|
||||||
@ -417,19 +461,20 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
#define instantiate_strided_scan(name, itype, otype, op, inclusive, reverse, nreads) \
|
#define instantiate_strided_scan( \
|
||||||
template [[host_name("strided_scan_" #name)]] \
|
name, itype, otype, op, inclusive, reverse, nreads) \
|
||||||
[[kernel]] void strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
template [[host_name("strided_scan_" #name)]] [[kernel]] void \
|
||||||
|
strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
||||||
const device itype* in [[buffer(0)]], \
|
const device itype* in [[buffer(0)]], \
|
||||||
device otype* out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const constant size_t & axis_size [[buffer(2)]], \
|
const constant size_t& axis_size [[buffer(2)]], \
|
||||||
const constant size_t & stride [[buffer(3)]], \
|
const constant size_t& stride [[buffer(3)]], \
|
||||||
uint2 gid [[thread_position_in_grid]], \
|
uint2 gid [[thread_position_in_grid]], \
|
||||||
uint2 lid [[thread_position_in_threadgroup]], \
|
uint2 lid [[thread_position_in_threadgroup]], \
|
||||||
uint2 lsize [[threads_per_threadgroup]], \
|
uint2 lsize [[threads_per_threadgroup]], \
|
||||||
uint simd_size [[threads_per_simdgroup]]);
|
uint simd_size [[threads_per_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
|
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
|
||||||
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
||||||
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
||||||
@ -438,8 +483,9 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
|
|||||||
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
||||||
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
||||||
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
|
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
|
||||||
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads)
|
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
|
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
|
||||||
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
|
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
|
||||||
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
|
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
|
||||||
@ -491,4 +537,4 @@ instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMi
|
|||||||
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
|
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
|
||||||
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
|
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
|
||||||
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
||||||
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin)
|
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) // clang-format on
|
@ -13,22 +13,20 @@ using namespace metal;
|
|||||||
// Scatter kernel
|
// Scatter kernel
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||||
METAL_FUNC void scatter_1d_index_impl(
|
METAL_FUNC void scatter_1d_index_impl(
|
||||||
const device T *updates [[buffer(1)]],
|
const device T* updates [[buffer(1)]],
|
||||||
device mlx_atomic<T> *out [[buffer(2)]],
|
device mlx_atomic<T>* out [[buffer(2)]],
|
||||||
const constant int* out_shape [[buffer(3)]],
|
const constant int* out_shape [[buffer(3)]],
|
||||||
const constant size_t* out_strides [[buffer(4)]],
|
const constant size_t* out_strides [[buffer(4)]],
|
||||||
const constant size_t& upd_size [[buffer(5)]],
|
const constant size_t& upd_size [[buffer(5)]],
|
||||||
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
uint out_idx = 0;
|
uint out_idx = 0;
|
||||||
for (int i = 0; i < NIDX; i++) {
|
for (int i = 0; i < NIDX; i++) {
|
||||||
auto idx_val = offset_neg_idx(
|
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
|
||||||
idx_buffers[i][gid.y], out_shape[i]);
|
|
||||||
out_idx += idx_val * out_strides[i];
|
out_idx += idx_val * out_strides[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -36,44 +34,34 @@ METAL_FUNC void scatter_1d_index_impl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
|
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
|
||||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||||
[[kernel]] void scatter_1d_index( \
|
[[kernel]] void scatter_1d_index( \
|
||||||
const device T *updates [[buffer(1)]], \
|
const device T* updates [[buffer(1)]], \
|
||||||
device mlx_atomic<T> *out [[buffer(2)]], \
|
device mlx_atomic<T>* out [[buffer(2)]], \
|
||||||
const constant int* out_shape [[buffer(3)]], \
|
const constant int* out_shape [[buffer(3)]], \
|
||||||
const constant size_t* out_strides [[buffer(4)]], \
|
const constant size_t* out_strides [[buffer(4)]], \
|
||||||
const constant size_t& upd_size [[buffer(5)]], \
|
const constant size_t& upd_size [[buffer(5)]], \
|
||||||
IDX_ARG(IdxT) \
|
IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
|
||||||
uint2 gid [[thread_position_in_grid]]) { \
|
|
||||||
\
|
|
||||||
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
|
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
|
||||||
\
|
\
|
||||||
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
|
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
|
||||||
updates, \
|
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid); \
|
||||||
out, \
|
}
|
||||||
out_shape, \
|
|
||||||
out_strides, \
|
|
||||||
upd_size, \
|
|
||||||
idx_buffers, \
|
|
||||||
gid); \
|
|
||||||
\
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||||
METAL_FUNC void scatter_impl(
|
METAL_FUNC void scatter_impl(
|
||||||
const device T *updates [[buffer(1)]],
|
const device T* updates [[buffer(1)]],
|
||||||
device mlx_atomic<T> *out [[buffer(2)]],
|
device mlx_atomic<T>* out [[buffer(2)]],
|
||||||
const constant int *upd_shape [[buffer(3)]],
|
const constant int* upd_shape [[buffer(3)]],
|
||||||
const constant size_t *upd_strides [[buffer(4)]],
|
const constant size_t* upd_strides [[buffer(4)]],
|
||||||
const constant size_t& upd_ndim [[buffer(5)]],
|
const constant size_t& upd_ndim [[buffer(5)]],
|
||||||
const constant size_t& upd_size [[buffer(6)]],
|
const constant size_t& upd_size [[buffer(6)]],
|
||||||
const constant int *out_shape [[buffer(7)]],
|
const constant int* out_shape [[buffer(7)]],
|
||||||
const constant size_t *out_strides [[buffer(8)]],
|
const constant size_t* out_strides [[buffer(8)]],
|
||||||
const constant size_t& out_ndim [[buffer(9)]],
|
const constant size_t& out_ndim [[buffer(9)]],
|
||||||
const constant int* axes [[buffer(10)]],
|
const constant int* axes [[buffer(10)]],
|
||||||
const thread Indices<IdxT, NIDX>& indices,
|
const thread Indices<IdxT, NIDX>& indices,
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
auto ind_idx = gid.y;
|
auto ind_idx = gid.y;
|
||||||
auto ind_offset = gid.x;
|
auto ind_offset = gid.x;
|
||||||
@ -86,8 +74,7 @@ METAL_FUNC void scatter_impl(
|
|||||||
&indices.strides[indices.ndim * i],
|
&indices.strides[indices.ndim * i],
|
||||||
indices.ndim);
|
indices.ndim);
|
||||||
auto ax = axes[i];
|
auto ax = axes[i];
|
||||||
auto idx_val = offset_neg_idx(
|
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
||||||
indices.buffers[i][idx_loc], out_shape[ax]);
|
|
||||||
out_idx += idx_val * out_strides[ax];
|
out_idx += idx_val * out_strides[ax];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,34 +84,30 @@ METAL_FUNC void scatter_impl(
|
|||||||
out_idx += out_offset;
|
out_idx += out_offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
auto upd_idx =
|
||||||
|
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
|
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
|
||||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||||
[[kernel]] void scatter( \
|
[[kernel]] void scatter( \
|
||||||
const device T *updates [[buffer(1)]], \
|
const device T* updates [[buffer(1)]], \
|
||||||
device mlx_atomic<T> *out [[buffer(2)]], \
|
device mlx_atomic<T>* out [[buffer(2)]], \
|
||||||
const constant int *upd_shape [[buffer(3)]], \
|
const constant int* upd_shape [[buffer(3)]], \
|
||||||
const constant size_t *upd_strides [[buffer(4)]], \
|
const constant size_t* upd_strides [[buffer(4)]], \
|
||||||
const constant size_t& upd_ndim [[buffer(5)]], \
|
const constant size_t& upd_ndim [[buffer(5)]], \
|
||||||
const constant size_t& upd_size [[buffer(6)]], \
|
const constant size_t& upd_size [[buffer(6)]], \
|
||||||
const constant int *out_shape [[buffer(7)]], \
|
const constant int* out_shape [[buffer(7)]], \
|
||||||
const constant size_t *out_strides [[buffer(8)]], \
|
const constant size_t* out_strides [[buffer(8)]], \
|
||||||
const constant size_t& out_ndim [[buffer(9)]], \
|
const constant size_t& out_ndim [[buffer(9)]], \
|
||||||
const constant int* axes [[buffer(10)]], \
|
const constant int* axes [[buffer(10)]], \
|
||||||
const constant int *idx_shapes [[buffer(11)]], \
|
const constant int* idx_shapes [[buffer(11)]], \
|
||||||
const constant size_t *idx_strides [[buffer(12)]], \
|
const constant size_t* idx_strides [[buffer(12)]], \
|
||||||
const constant int& idx_ndim [[buffer(13)]], \
|
const constant int& idx_ndim [[buffer(13)]], \
|
||||||
IDX_ARG(IdxT) \
|
IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
|
||||||
uint2 gid [[thread_position_in_grid]]) { \
|
|
||||||
\
|
|
||||||
Indices<IdxT, NIDX> idxs{ \
|
Indices<IdxT, NIDX> idxs{ \
|
||||||
{{IDX_ARR()}}, \
|
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
|
||||||
idx_shapes, \
|
|
||||||
idx_strides, \
|
|
||||||
idx_ndim}; \
|
|
||||||
\
|
\
|
||||||
return scatter_impl<T, IdxT, Op, NIDX>( \
|
return scatter_impl<T, IdxT, Op, NIDX>( \
|
||||||
updates, \
|
updates, \
|
||||||
@ -139,70 +122,63 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
|
|||||||
axes, \
|
axes, \
|
||||||
idxs, \
|
idxs, \
|
||||||
gid); \
|
gid); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define make_scatter(n) \
|
#define make_scatter(n) \
|
||||||
make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) \
|
make_scatter_impl(IDX_ARG_##n, IDX_ARR_##n) \
|
||||||
make_scatter_1d_index(IDX_ARG_ ##n, IDX_ARR_ ##n)
|
make_scatter_1d_index(IDX_ARG_##n, IDX_ARR_##n)
|
||||||
|
|
||||||
make_scatter(0)
|
make_scatter(0) make_scatter(1) make_scatter(2) make_scatter(3) make_scatter(4)
|
||||||
make_scatter(1)
|
make_scatter(5) make_scatter(6) make_scatter(7) make_scatter(8)
|
||||||
make_scatter(2)
|
make_scatter(9) make_scatter(10)
|
||||||
make_scatter(3)
|
|
||||||
make_scatter(4)
|
|
||||||
make_scatter(5)
|
|
||||||
make_scatter(6)
|
|
||||||
make_scatter(7)
|
|
||||||
make_scatter(8)
|
|
||||||
make_scatter(9)
|
|
||||||
make_scatter(10)
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
// Scatter instantiations
|
// Scatter instantiations
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
||||||
template [[host_name("scatter" name "_" #nidx)]] \
|
template [[host_name("scatter" name "_" #nidx)]] [[kernel]] void \
|
||||||
[[kernel]] void scatter<src_t, idx_t, op_t, nidx>( \
|
scatter<src_t, idx_t, op_t, nidx>( \
|
||||||
const device src_t *updates [[buffer(1)]], \
|
const device src_t* updates [[buffer(1)]], \
|
||||||
device mlx_atomic<src_t> *out [[buffer(2)]], \
|
device mlx_atomic<src_t>* out [[buffer(2)]], \
|
||||||
const constant int *upd_shape [[buffer(3)]], \
|
const constant int* upd_shape [[buffer(3)]], \
|
||||||
const constant size_t *upd_strides [[buffer(4)]], \
|
const constant size_t* upd_strides [[buffer(4)]], \
|
||||||
const constant size_t& upd_ndim [[buffer(5)]], \
|
const constant size_t& upd_ndim [[buffer(5)]], \
|
||||||
const constant size_t& upd_size [[buffer(6)]], \
|
const constant size_t& upd_size [[buffer(6)]], \
|
||||||
const constant int *out_shape [[buffer(7)]], \
|
const constant int* out_shape [[buffer(7)]], \
|
||||||
const constant size_t *out_strides [[buffer(8)]], \
|
const constant size_t* out_strides [[buffer(8)]], \
|
||||||
const constant size_t& out_ndim [[buffer(9)]], \
|
const constant size_t& out_ndim [[buffer(9)]], \
|
||||||
const constant int* axes [[buffer(10)]], \
|
const constant int* axes [[buffer(10)]], \
|
||||||
const constant int *idx_shapes [[buffer(11)]], \
|
const constant int* idx_shapes [[buffer(11)]], \
|
||||||
const constant size_t *idx_strides [[buffer(12)]], \
|
const constant size_t* idx_strides [[buffer(12)]], \
|
||||||
const constant int& idx_ndim [[buffer(13)]], \
|
const constant int& idx_ndim [[buffer(13)]], \
|
||||||
IDX_ARG(idx_t) \
|
IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
|
||||||
uint2 gid [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
||||||
template [[host_name("scatter_1d_index" name "_" #nidx)]] \
|
template [[host_name("scatter_1d_index" name "_" #nidx)]] [[kernel]] void \
|
||||||
[[kernel]] void scatter_1d_index<src_t, idx_t, op_t, nidx>( \
|
scatter_1d_index<src_t, idx_t, op_t, nidx>( \
|
||||||
const device src_t *updates [[buffer(1)]], \
|
const device src_t* updates [[buffer(1)]], \
|
||||||
device mlx_atomic<src_t> *out [[buffer(2)]], \
|
device mlx_atomic<src_t>* out [[buffer(2)]], \
|
||||||
const constant int* out_shape [[buffer(3)]], \
|
const constant int* out_shape [[buffer(3)]], \
|
||||||
const constant size_t* out_strides [[buffer(4)]], \
|
const constant size_t* out_strides [[buffer(4)]], \
|
||||||
const constant size_t& upd_size [[buffer(5)]], \
|
const constant size_t& upd_size [[buffer(5)]], \
|
||||||
IDX_ARG(idx_t) \
|
IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
|
||||||
uint2 gid [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
|
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
|
||||||
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \
|
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \
|
||||||
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
|
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
// Special case NINDEX=0
|
// Special case NINDEX=0
|
||||||
#define instantiate_scatter_nd0(name, type) \
|
#define instantiate_scatter_nd0(name, type) \
|
||||||
instantiate_scatter4(#name "none", type, bool, None, 0) \
|
instantiate_scatter4(#name "none", type, bool, None, 0) \
|
||||||
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
|
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
|
||||||
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
|
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
|
||||||
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
|
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
|
||||||
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
|
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_scatter3(name, type, ind_type, op_type) \
|
#define instantiate_scatter3(name, type, ind_type, op_type) \
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 1) \
|
instantiate_scatter4(name, type, ind_type, op_type, 1) \
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 2) \
|
instantiate_scatter4(name, type, ind_type, op_type, 2) \
|
||||||
@ -213,15 +189,17 @@ template [[host_name("scatter_1d_index" name "_" #nidx)]] \
|
|||||||
instantiate_scatter4(name, type, ind_type, op_type, 7) \
|
instantiate_scatter4(name, type, ind_type, op_type, 7) \
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 8) \
|
instantiate_scatter4(name, type, ind_type, op_type, 8) \
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 9) \
|
instantiate_scatter4(name, type, ind_type, op_type, 9) \
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 10)
|
instantiate_scatter4(name, type, ind_type, op_type, 10) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_scatter2(name, type, ind_type) \
|
#define instantiate_scatter2(name, type, ind_type) \
|
||||||
instantiate_scatter3(name "_none", type, ind_type, None) \
|
instantiate_scatter3(name "_none", type, ind_type, None) \
|
||||||
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
|
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
|
||||||
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
|
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
|
||||||
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
|
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
|
||||||
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
|
instantiate_scatter3(name "_min", type, ind_type, Min<type>) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_scatter(name, type) \
|
#define instantiate_scatter(name, type) \
|
||||||
instantiate_scatter2(#name "bool_", type, bool) \
|
instantiate_scatter2(#name "bool_", type, bool) \
|
||||||
instantiate_scatter2(#name "uint8", type, uint8_t) \
|
instantiate_scatter2(#name "uint8", type, uint8_t) \
|
||||||
@ -231,8 +209,9 @@ template [[host_name("scatter_1d_index" name "_" #nidx)]] \
|
|||||||
instantiate_scatter2(#name "int8", type, int8_t) \
|
instantiate_scatter2(#name "int8", type, int8_t) \
|
||||||
instantiate_scatter2(#name "int16", type, int16_t) \
|
instantiate_scatter2(#name "int16", type, int16_t) \
|
||||||
instantiate_scatter2(#name "int32", type, int32_t) \
|
instantiate_scatter2(#name "int32", type, int32_t) \
|
||||||
instantiate_scatter2(#name "int64", type, int64_t)
|
instantiate_scatter2(#name "int64", type, int64_t) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
// TODO uint64 and int64 unsupported
|
// TODO uint64 and int64 unsupported
|
||||||
instantiate_scatter_nd0(bool_, bool)
|
instantiate_scatter_nd0(bool_, bool)
|
||||||
instantiate_scatter_nd0(uint8, uint8_t)
|
instantiate_scatter_nd0(uint8, uint8_t)
|
||||||
@ -254,4 +233,4 @@ instantiate_scatter(int16, int16_t)
|
|||||||
instantiate_scatter(int32, int32_t)
|
instantiate_scatter(int32, int32_t)
|
||||||
instantiate_scatter(float16, half)
|
instantiate_scatter(float16, half)
|
||||||
instantiate_scatter(float32, float)
|
instantiate_scatter(float32, float)
|
||||||
instantiate_scatter(bfloat16, bfloat16_t)
|
instantiate_scatter(bfloat16, bfloat16_t) // clang-format on
|
||||||
|
@ -198,7 +198,6 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_softmax(name, itype) \
|
#define instantiate_softmax(name, itype) \
|
||||||
template [[host_name("softmax_" #name)]] [[kernel]] void \
|
template [[host_name("softmax_" #name)]] [[kernel]] void \
|
||||||
softmax_single_row<itype>( \
|
softmax_single_row<itype>( \
|
||||||
@ -241,9 +240,9 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_softmax(float32, float)
|
instantiate_softmax(float32, float)
|
||||||
instantiate_softmax(float16, half)
|
instantiate_softmax(float16, half)
|
||||||
instantiate_softmax(bfloat16, bfloat16_t)
|
instantiate_softmax(bfloat16, bfloat16_t)
|
||||||
instantiate_softmax_precise(float16, half)
|
instantiate_softmax_precise(float16, half)
|
||||||
instantiate_softmax_precise(bfloat16, bfloat16_t)
|
instantiate_softmax_precise(bfloat16, bfloat16_t) // clang-format on
|
||||||
// clang-format on
|
|
||||||
|
@ -11,7 +11,8 @@
|
|||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
// Based on GPU merge sort algorithm at https://github.com/NVIDIA/cccl/tree/main/cub/cub
|
// Based on GPU merge sort algorithm at
|
||||||
|
// https://github.com/NVIDIA/cccl/tree/main/cub/cub
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Thread-level sort
|
// Thread-level sort
|
||||||
@ -43,20 +44,18 @@ struct ThreadSort {
|
|||||||
static METAL_FUNC void sort(
|
static METAL_FUNC void sort(
|
||||||
thread val_t (&vals)[N_PER_THREAD],
|
thread val_t (&vals)[N_PER_THREAD],
|
||||||
thread idx_t (&idxs)[N_PER_THREAD]) {
|
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||||
|
|
||||||
CompareOp op;
|
CompareOp op;
|
||||||
|
|
||||||
MLX_MTL_LOOP_UNROLL
|
MLX_MTL_LOOP_UNROLL
|
||||||
for(short i = 0; i < N_PER_THREAD; ++i) {
|
for (short i = 0; i < N_PER_THREAD; ++i) {
|
||||||
MLX_MTL_LOOP_UNROLL
|
MLX_MTL_LOOP_UNROLL
|
||||||
for(short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
|
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
|
||||||
if(op(vals[j + 1], vals[j])) {
|
if (op(vals[j + 1], vals[j])) {
|
||||||
thread_swap(vals[j + 1], vals[j]);
|
thread_swap(vals[j + 1], vals[j]);
|
||||||
thread_swap(idxs[j + 1], idxs[j]);
|
thread_swap(idxs[j + 1], idxs[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -72,25 +71,25 @@ template <
|
|||||||
short N_PER_THREAD,
|
short N_PER_THREAD,
|
||||||
typename CompareOp>
|
typename CompareOp>
|
||||||
struct BlockMergeSort {
|
struct BlockMergeSort {
|
||||||
using thread_sort_t = ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
|
using thread_sort_t =
|
||||||
|
ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
|
||||||
static METAL_FUNC int merge_partition(
|
static METAL_FUNC int merge_partition(
|
||||||
const threadgroup val_t* As,
|
const threadgroup val_t* As,
|
||||||
const threadgroup val_t* Bs,
|
const threadgroup val_t* Bs,
|
||||||
short A_sz,
|
short A_sz,
|
||||||
short B_sz,
|
short B_sz,
|
||||||
short sort_md) {
|
short sort_md) {
|
||||||
|
|
||||||
CompareOp op;
|
CompareOp op;
|
||||||
|
|
||||||
short A_st = max(0, sort_md - B_sz);
|
short A_st = max(0, sort_md - B_sz);
|
||||||
short A_ed = min(sort_md, A_sz);
|
short A_ed = min(sort_md, A_sz);
|
||||||
|
|
||||||
while(A_st < A_ed) {
|
while (A_st < A_ed) {
|
||||||
short md = A_st + (A_ed - A_st) / 2;
|
short md = A_st + (A_ed - A_st) / 2;
|
||||||
auto a = As[md];
|
auto a = As[md];
|
||||||
auto b = Bs[sort_md - 1 - md];
|
auto b = Bs[sort_md - 1 - md];
|
||||||
|
|
||||||
if(op(b, a)) {
|
if (op(b, a)) {
|
||||||
A_ed = md;
|
A_ed = md;
|
||||||
} else {
|
} else {
|
||||||
A_st = md + 1;
|
A_st = md + 1;
|
||||||
@ -98,7 +97,6 @@ struct BlockMergeSort {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return A_ed;
|
return A_ed;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static METAL_FUNC void merge_step(
|
static METAL_FUNC void merge_step(
|
||||||
@ -110,12 +108,11 @@ struct BlockMergeSort {
|
|||||||
short B_sz,
|
short B_sz,
|
||||||
thread val_t (&vals)[N_PER_THREAD],
|
thread val_t (&vals)[N_PER_THREAD],
|
||||||
thread idx_t (&idxs)[N_PER_THREAD]) {
|
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||||
|
|
||||||
CompareOp op;
|
CompareOp op;
|
||||||
short a_idx = 0;
|
short a_idx = 0;
|
||||||
short b_idx = 0;
|
short b_idx = 0;
|
||||||
|
|
||||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||||
auto a = As[a_idx];
|
auto a = As[a_idx];
|
||||||
auto b = Bs[b_idx];
|
auto b = Bs[b_idx];
|
||||||
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
|
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
|
||||||
@ -126,7 +123,6 @@ struct BlockMergeSort {
|
|||||||
b_idx += short(pred);
|
b_idx += short(pred);
|
||||||
a_idx += short(!pred);
|
a_idx += short(!pred);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static METAL_FUNC void sort(
|
static METAL_FUNC void sort(
|
||||||
@ -134,32 +130,32 @@ struct BlockMergeSort {
|
|||||||
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
|
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
|
||||||
int size_sorted_axis,
|
int size_sorted_axis,
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
// Get thread location
|
// Get thread location
|
||||||
int idx = lid.x * N_PER_THREAD;
|
int idx = lid.x * N_PER_THREAD;
|
||||||
|
|
||||||
// Load from shared memory
|
// Load from shared memory
|
||||||
thread val_t thread_vals[N_PER_THREAD];
|
thread val_t thread_vals[N_PER_THREAD];
|
||||||
thread idx_t thread_idxs[N_PER_THREAD];
|
thread idx_t thread_idxs[N_PER_THREAD];
|
||||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||||
thread_vals[i] = tgp_vals[idx + i];
|
thread_vals[i] = tgp_vals[idx + i];
|
||||||
if(ARG_SORT) {
|
if (ARG_SORT) {
|
||||||
thread_idxs[i] = tgp_idxs[idx + i];
|
thread_idxs[i] = tgp_idxs[idx + i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Per thread sort
|
// Per thread sort
|
||||||
if(idx < size_sorted_axis) {
|
if (idx < size_sorted_axis) {
|
||||||
thread_sort_t::sort(thread_vals, thread_idxs);
|
thread_sort_t::sort(thread_vals, thread_idxs);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do merges using threadgroup memory
|
// Do merges using threadgroup memory
|
||||||
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; merge_threads *= 2) {
|
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;
|
||||||
|
merge_threads *= 2) {
|
||||||
// Update threadgroup memory
|
// Update threadgroup memory
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||||
tgp_vals[idx + i] = thread_vals[i];
|
tgp_vals[idx + i] = thread_vals[i];
|
||||||
if(ARG_SORT) {
|
if (ARG_SORT) {
|
||||||
tgp_idxs[idx + i] = thread_idxs[i];
|
tgp_idxs[idx + i] = thread_idxs[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -189,12 +185,7 @@ struct BlockMergeSort {
|
|||||||
// of size N_PER_THREAD for each merge lane i
|
// of size N_PER_THREAD for each merge lane i
|
||||||
// C = [Ci] is sorted
|
// C = [Ci] is sorted
|
||||||
int sort_md = N_PER_THREAD * merge_lane;
|
int sort_md = N_PER_THREAD * merge_lane;
|
||||||
int partition = merge_partition(
|
int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
A_sz,
|
|
||||||
B_sz,
|
|
||||||
sort_md);
|
|
||||||
|
|
||||||
As += partition;
|
As += partition;
|
||||||
Bs += sort_md - partition;
|
Bs += sort_md - partition;
|
||||||
@ -202,27 +193,20 @@ struct BlockMergeSort {
|
|||||||
A_sz -= partition;
|
A_sz -= partition;
|
||||||
B_sz -= sort_md - partition;
|
B_sz -= sort_md - partition;
|
||||||
|
|
||||||
const threadgroup idx_t* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
|
const threadgroup idx_t* As_idx =
|
||||||
const threadgroup idx_t* Bs_idx = ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
|
ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
|
||||||
|
const threadgroup idx_t* Bs_idx =
|
||||||
|
ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
|
||||||
|
|
||||||
// Merge starting at the partition and store results in thread registers
|
// Merge starting at the partition and store results in thread registers
|
||||||
merge_step(
|
merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
As_idx,
|
|
||||||
Bs_idx,
|
|
||||||
A_sz,
|
|
||||||
B_sz,
|
|
||||||
thread_vals,
|
|
||||||
thread_idxs);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write out to shared memory
|
// Write out to shared memory
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||||
tgp_vals[idx + i] = thread_vals[i];
|
tgp_vals[idx + i] = thread_vals[i];
|
||||||
if(ARG_SORT) {
|
if (ARG_SORT) {
|
||||||
tgp_idxs[idx + i] = thread_idxs[i];
|
tgp_idxs[idx + i] = thread_idxs[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -263,15 +247,15 @@ struct KernelMergeSort {
|
|||||||
threadgroup idx_t* tgp_idxs,
|
threadgroup idx_t* tgp_idxs,
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
// tid.y tells us the segment index
|
// tid.y tells us the segment index
|
||||||
inp += tid.y * stride_segment_axis;
|
inp += tid.y * stride_segment_axis;
|
||||||
out += tid.y * stride_segment_axis;
|
out += tid.y * stride_segment_axis;
|
||||||
|
|
||||||
// Copy into threadgroup memory
|
// Copy into threadgroup memory
|
||||||
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
|
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||||
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] : val_t(CompareOp::init);
|
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis]
|
||||||
if(ARG_SORT) {
|
: val_t(CompareOp::init);
|
||||||
|
if (ARG_SORT) {
|
||||||
tgp_idxs[i] = i;
|
tgp_idxs[i] = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -284,8 +268,8 @@ struct KernelMergeSort {
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Write output
|
// Write output
|
||||||
for(int i = lid.x; i < size_sorted_axis; i+= BLOCK_THREADS) {
|
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
|
||||||
if(ARG_SORT) {
|
if (ARG_SORT) {
|
||||||
out[i * stride_sorted_axis] = tgp_idxs[i];
|
out[i * stride_sorted_axis] = tgp_idxs[i];
|
||||||
} else {
|
} else {
|
||||||
out[i * stride_sorted_axis] = tgp_vals[i];
|
out[i * stride_sorted_axis] = tgp_vals[i];
|
||||||
@ -308,12 +292,12 @@ template <
|
|||||||
const constant int& stride_segment_axis [[buffer(4)]],
|
const constant int& stride_segment_axis [[buffer(4)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
using sort_kernel =
|
||||||
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||||
using val_t = typename sort_kernel::val_t;
|
using val_t = typename sort_kernel::val_t;
|
||||||
using idx_t = typename sort_kernel::idx_t;
|
using idx_t = typename sort_kernel::idx_t;
|
||||||
|
|
||||||
if(ARG_SORT) {
|
if (ARG_SORT) {
|
||||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||||
sort_kernel::block_sort(
|
sort_kernel::block_sort(
|
||||||
@ -339,7 +323,6 @@ template <
|
|||||||
tid,
|
tid,
|
||||||
lid);
|
lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
constant constexpr const int zero_helper = 0;
|
constant constexpr const int zero_helper = 0;
|
||||||
@ -360,8 +343,8 @@ template <
|
|||||||
const device size_t* nc_strides [[buffer(6)]],
|
const device size_t* nc_strides [[buffer(6)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
using sort_kernel =
|
||||||
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||||
using val_t = typename sort_kernel::val_t;
|
using val_t = typename sort_kernel::val_t;
|
||||||
using idx_t = typename sort_kernel::idx_t;
|
using idx_t = typename sort_kernel::idx_t;
|
||||||
|
|
||||||
@ -369,7 +352,7 @@ template <
|
|||||||
inp += block_idx;
|
inp += block_idx;
|
||||||
out += block_idx;
|
out += block_idx;
|
||||||
|
|
||||||
if(ARG_SORT) {
|
if (ARG_SORT) {
|
||||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||||
sort_kernel::block_sort(
|
sort_kernel::block_sort(
|
||||||
@ -395,17 +378,17 @@ template <
|
|||||||
tid,
|
tid,
|
||||||
lid);
|
lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Instantiations
|
// Instantiations
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#define instantiate_block_sort( \
|
||||||
#define instantiate_block_sort(name, itname, itype, otname, otype, arg_sort, bn, tn) \
|
name, itname, itype, otname, otype, arg_sort, bn, tn) \
|
||||||
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn)]] \
|
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn \
|
||||||
[[kernel]] void block_sort<itype, otype, arg_sort, bn, tn>( \
|
"_tn" #tn)]] [[kernel]] void \
|
||||||
|
block_sort<itype, otype, arg_sort, bn, tn>( \
|
||||||
const device itype* inp [[buffer(0)]], \
|
const device itype* inp [[buffer(0)]], \
|
||||||
device otype* out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const constant int& size_sorted_axis [[buffer(2)]], \
|
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||||
@ -413,8 +396,9 @@ template <
|
|||||||
const constant int& stride_segment_axis [[buffer(4)]], \
|
const constant int& stride_segment_axis [[buffer(4)]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||||
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn "_nc")]] \
|
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
|
||||||
[[kernel]] void block_sort_nc<itype, otype, arg_sort, bn, tn>( \
|
"_nc")]] [[kernel]] void \
|
||||||
|
block_sort_nc<itype, otype, arg_sort, bn, tn>( \
|
||||||
const device itype* inp [[buffer(0)]], \
|
const device itype* inp [[buffer(0)]], \
|
||||||
device otype* out [[buffer(1)]], \
|
device otype* out [[buffer(1)]], \
|
||||||
const constant int& size_sorted_axis [[buffer(2)]], \
|
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||||
@ -426,15 +410,19 @@ template <
|
|||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
|
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
|
||||||
instantiate_block_sort(arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
|
instantiate_block_sort( \
|
||||||
|
arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
|
||||||
|
|
||||||
#define instantiate_block_sort_base(itname, itype, bn, tn) \
|
#define instantiate_block_sort_base(itname, itype, bn, tn) \
|
||||||
instantiate_block_sort(block_merge_sort, itname, itype, itname, itype, false, bn, tn)
|
instantiate_block_sort( \
|
||||||
|
block_merge_sort, itname, itype, itname, itype, false, bn, tn)
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_block_sort_tn(itname, itype, bn) \
|
#define instantiate_block_sort_tn(itname, itype, bn) \
|
||||||
instantiate_block_sort_base(itname, itype, bn, 8) \
|
instantiate_block_sort_base(itname, itype, bn, 8) \
|
||||||
instantiate_arg_block_sort_base(itname, itype, bn, 8)
|
instantiate_arg_block_sort_base(itname, itype, bn, 8) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_block_sort_bn(itname, itype) \
|
#define instantiate_block_sort_bn(itname, itype) \
|
||||||
instantiate_block_sort_tn(itname, itype, 128) \
|
instantiate_block_sort_tn(itname, itype, 128) \
|
||||||
instantiate_block_sort_tn(itname, itype, 256) \
|
instantiate_block_sort_tn(itname, itype, 256) \
|
||||||
@ -448,27 +436,27 @@ instantiate_block_sort_bn(int16, int16_t)
|
|||||||
instantiate_block_sort_bn(int32, int32_t)
|
instantiate_block_sort_bn(int32, int32_t)
|
||||||
instantiate_block_sort_bn(float16, half)
|
instantiate_block_sort_bn(float16, half)
|
||||||
instantiate_block_sort_bn(float32, float)
|
instantiate_block_sort_bn(float32, float)
|
||||||
instantiate_block_sort_bn(bfloat16, bfloat16_t)
|
instantiate_block_sort_bn(bfloat16, bfloat16_t) // clang-format on
|
||||||
|
// clang-format off
|
||||||
#define instantiate_block_sort_long(itname, itype) \
|
#define instantiate_block_sort_long(itname, itype) \
|
||||||
instantiate_block_sort_tn(itname, itype, 128) \
|
instantiate_block_sort_tn(itname, itype, 128) \
|
||||||
instantiate_block_sort_tn(itname, itype, 256)
|
instantiate_block_sort_tn(itname, itype, 256)
|
||||||
|
|
||||||
instantiate_block_sort_long(uint64, uint64_t)
|
instantiate_block_sort_long(uint64, uint64_t)
|
||||||
instantiate_block_sort_long(int64, int64_t)
|
instantiate_block_sort_long(int64, int64_t) // clang-format on
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Multi block merge sort
|
// Multi block merge sort
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename val_t,
|
typename val_t,
|
||||||
typename idx_t,
|
typename idx_t,
|
||||||
bool ARG_SORT,
|
bool ARG_SORT,
|
||||||
short BLOCK_THREADS,
|
short BLOCK_THREADS,
|
||||||
short N_PER_THREAD,
|
short N_PER_THREAD,
|
||||||
typename CompareOp = LessThan<val_t>>
|
typename CompareOp = LessThan<val_t>>
|
||||||
struct KernelMultiBlockMergeSort {
|
struct KernelMultiBlockMergeSort {
|
||||||
using block_merge_sort_t = BlockMergeSort<
|
using block_merge_sort_t = BlockMergeSort<
|
||||||
val_t,
|
val_t,
|
||||||
idx_t,
|
idx_t,
|
||||||
@ -489,14 +477,14 @@ struct KernelMultiBlockMergeSort {
|
|||||||
threadgroup idx_t* tgp_idxs,
|
threadgroup idx_t* tgp_idxs,
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
// tid.y tells us the segment index
|
// tid.y tells us the segment index
|
||||||
int base_idx = tid.x * N_PER_BLOCK;
|
int base_idx = tid.x * N_PER_BLOCK;
|
||||||
|
|
||||||
// Copy into threadgroup memory
|
// Copy into threadgroup memory
|
||||||
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
|
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||||
int idx = base_idx + i;
|
int idx = base_idx + i;
|
||||||
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] : val_t(CompareOp::init);
|
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
|
||||||
|
: val_t(CompareOp::init);
|
||||||
tgp_idxs[i] = idx;
|
tgp_idxs[i] = idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -508,9 +496,9 @@ struct KernelMultiBlockMergeSort {
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Write output
|
// Write output
|
||||||
for(int i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
|
for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||||
int idx = base_idx + i;
|
int idx = base_idx + i;
|
||||||
if(idx < size_sorted_axis) {
|
if (idx < size_sorted_axis) {
|
||||||
out_vals[idx] = tgp_vals[i];
|
out_vals[idx] = tgp_vals[i];
|
||||||
out_idxs[idx] = tgp_idxs[i];
|
out_idxs[idx] = tgp_idxs[i];
|
||||||
}
|
}
|
||||||
@ -523,18 +511,17 @@ struct KernelMultiBlockMergeSort {
|
|||||||
int A_sz,
|
int A_sz,
|
||||||
int B_sz,
|
int B_sz,
|
||||||
int sort_md) {
|
int sort_md) {
|
||||||
|
|
||||||
CompareOp op;
|
CompareOp op;
|
||||||
|
|
||||||
int A_st = max(0, sort_md - B_sz);
|
int A_st = max(0, sort_md - B_sz);
|
||||||
int A_ed = min(sort_md, A_sz);
|
int A_ed = min(sort_md, A_sz);
|
||||||
|
|
||||||
while(A_st < A_ed) {
|
while (A_st < A_ed) {
|
||||||
int md = A_st + (A_ed - A_st) / 2;
|
int md = A_st + (A_ed - A_st) / 2;
|
||||||
auto a = As[md];
|
auto a = As[md];
|
||||||
auto b = Bs[sort_md - 1 - md];
|
auto b = Bs[sort_md - 1 - md];
|
||||||
|
|
||||||
if(op(b, a)) {
|
if (op(b, a)) {
|
||||||
A_ed = md;
|
A_ed = md;
|
||||||
} else {
|
} else {
|
||||||
A_st = md + 1;
|
A_st = md + 1;
|
||||||
@ -542,7 +529,6 @@ struct KernelMultiBlockMergeSort {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return A_ed;
|
return A_ed;
|
||||||
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -563,8 +549,12 @@ template <
|
|||||||
const device size_t* nc_strides [[buffer(7)]],
|
const device size_t* nc_strides [[buffer(7)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
using sort_kernel = KernelMultiBlockMergeSort<
|
||||||
using sort_kernel = KernelMultiBlockMergeSort<val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
val_t,
|
||||||
|
idx_t,
|
||||||
|
ARG_SORT,
|
||||||
|
BLOCK_THREADS,
|
||||||
|
N_PER_THREAD>;
|
||||||
|
|
||||||
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
||||||
inp += block_idx;
|
inp += block_idx;
|
||||||
@ -592,7 +582,8 @@ template <
|
|||||||
bool ARG_SORT,
|
bool ARG_SORT,
|
||||||
short BLOCK_THREADS,
|
short BLOCK_THREADS,
|
||||||
short N_PER_THREAD>
|
short N_PER_THREAD>
|
||||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partition(
|
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
|
||||||
|
mb_block_partition(
|
||||||
device idx_t* block_partitions [[buffer(0)]],
|
device idx_t* block_partitions [[buffer(0)]],
|
||||||
const device val_t* dev_vals [[buffer(1)]],
|
const device val_t* dev_vals [[buffer(1)]],
|
||||||
const device idx_t* dev_idxs [[buffer(2)]],
|
const device idx_t* dev_idxs [[buffer(2)]],
|
||||||
@ -601,7 +592,6 @@ template <
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint3 tgp_dims [[threads_per_threadgroup]]) {
|
uint3 tgp_dims [[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
using sort_kernel = KernelMultiBlockMergeSort<
|
using sort_kernel = KernelMultiBlockMergeSort<
|
||||||
val_t,
|
val_t,
|
||||||
idx_t,
|
idx_t,
|
||||||
@ -627,14 +617,9 @@ template <
|
|||||||
|
|
||||||
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
|
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
|
||||||
int partition = sort_kernel::merge_partition(
|
int partition = sort_kernel::merge_partition(
|
||||||
dev_vals + A_st,
|
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
|
||||||
dev_vals + B_st,
|
|
||||||
A_ed - A_st,
|
|
||||||
B_ed - B_st,
|
|
||||||
partition_at);
|
|
||||||
|
|
||||||
block_partitions[lid.x] = A_st + partition;
|
block_partitions[lid.x] = A_st + partition;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
@ -644,7 +629,8 @@ template <
|
|||||||
short BLOCK_THREADS,
|
short BLOCK_THREADS,
|
||||||
short N_PER_THREAD,
|
short N_PER_THREAD,
|
||||||
typename CompareOp = LessThan<val_t>>
|
typename CompareOp = LessThan<val_t>>
|
||||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_merge(
|
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
|
||||||
|
mb_block_merge(
|
||||||
const device idx_t* block_partitions [[buffer(0)]],
|
const device idx_t* block_partitions [[buffer(0)]],
|
||||||
const device val_t* dev_vals_in [[buffer(1)]],
|
const device val_t* dev_vals_in [[buffer(1)]],
|
||||||
const device idx_t* dev_idxs_in [[buffer(2)]],
|
const device idx_t* dev_idxs_in [[buffer(2)]],
|
||||||
@ -655,7 +641,6 @@ template <
|
|||||||
const constant int& num_tiles [[buffer(7)]],
|
const constant int& num_tiles [[buffer(7)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
using sort_kernel = KernelMultiBlockMergeSort<
|
using sort_kernel = KernelMultiBlockMergeSort<
|
||||||
val_t,
|
val_t,
|
||||||
idx_t,
|
idx_t,
|
||||||
@ -680,11 +665,13 @@ template <
|
|||||||
|
|
||||||
int A_st = block_partitions[block_idx + 0];
|
int A_st = block_partitions[block_idx + 0];
|
||||||
int A_ed = block_partitions[block_idx + 1];
|
int A_ed = block_partitions[block_idx + 1];
|
||||||
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md - A_st);
|
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
|
||||||
int B_ed = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
|
int B_ed = min(
|
||||||
|
size_sorted_axis,
|
||||||
|
2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
|
||||||
|
|
||||||
if((block_idx % merge_tiles) == merge_tiles - 1) {
|
if ((block_idx % merge_tiles) == merge_tiles - 1) {
|
||||||
A_ed = min(size_sorted_axis, sort_st + sort_sz/2);
|
A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
||||||
B_ed = min(size_sorted_axis, sort_st + sort_sz);
|
B_ed = min(size_sorted_axis, sort_st + sort_sz);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -694,11 +681,13 @@ template <
|
|||||||
// Load from global memory
|
// Load from global memory
|
||||||
thread val_t thread_vals[N_PER_THREAD];
|
thread val_t thread_vals[N_PER_THREAD];
|
||||||
thread idx_t thread_idxs[N_PER_THREAD];
|
thread idx_t thread_idxs[N_PER_THREAD];
|
||||||
for(int i = 0; i < N_PER_THREAD; i++) {
|
for (int i = 0; i < N_PER_THREAD; i++) {
|
||||||
int idx = BLOCK_THREADS * i + lid.x;
|
int idx = BLOCK_THREADS * i + lid.x;
|
||||||
if(idx < (A_sz + B_sz)) {
|
if (idx < (A_sz + B_sz)) {
|
||||||
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] : dev_vals_in[B_st + idx - A_sz];
|
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
|
||||||
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] : dev_idxs_in[B_st + idx - A_sz];
|
: dev_vals_in[B_st + idx - A_sz];
|
||||||
|
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
|
||||||
|
: dev_idxs_in[B_st + idx - A_sz];
|
||||||
} else {
|
} else {
|
||||||
thread_vals[i] = CompareOp::init;
|
thread_vals[i] = CompareOp::init;
|
||||||
thread_idxs[i] = 0;
|
thread_idxs[i] = 0;
|
||||||
@ -709,7 +698,7 @@ template <
|
|||||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||||
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
for(int i = 0; i < N_PER_THREAD; i++) {
|
for (int i = 0; i < N_PER_THREAD; i++) {
|
||||||
int idx = BLOCK_THREADS * i + lid.x;
|
int idx = BLOCK_THREADS * i + lid.x;
|
||||||
tgp_vals[idx] = thread_vals[i];
|
tgp_vals[idx] = thread_vals[i];
|
||||||
tgp_idxs[idx] = thread_idxs[i];
|
tgp_idxs[idx] = thread_idxs[i];
|
||||||
@ -720,11 +709,7 @@ template <
|
|||||||
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
|
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
|
||||||
|
|
||||||
int A_st_local = block_sort_t::merge_partition(
|
int A_st_local = block_sort_t::merge_partition(
|
||||||
tgp_vals,
|
tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
|
||||||
tgp_vals + A_sz,
|
|
||||||
A_sz,
|
|
||||||
B_sz,
|
|
||||||
sort_md_local);
|
|
||||||
int A_ed_local = A_sz;
|
int A_ed_local = A_sz;
|
||||||
|
|
||||||
int B_st_local = sort_md_local - A_st_local;
|
int B_st_local = sort_md_local - A_st_local;
|
||||||
@ -745,7 +730,7 @@ template <
|
|||||||
thread_idxs);
|
thread_idxs);
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
for(int i = 0; i < N_PER_THREAD; ++i) {
|
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||||
int idx = lid.x * N_PER_THREAD;
|
int idx = lid.x * N_PER_THREAD;
|
||||||
tgp_vals[idx + i] = thread_vals[i];
|
tgp_vals[idx + i] = thread_vals[i];
|
||||||
tgp_idxs[idx + i] = thread_idxs[i];
|
tgp_idxs[idx + i] = thread_idxs[i];
|
||||||
@ -754,19 +739,20 @@ template <
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// Write output
|
// Write output
|
||||||
int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
|
int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
|
||||||
for(int i = lid.x; i < sort_kernel::N_PER_BLOCK; i+= BLOCK_THREADS) {
|
for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||||
int idx = base_idx + i;
|
int idx = base_idx + i;
|
||||||
if(idx < size_sorted_axis) {
|
if (idx < size_sorted_axis) {
|
||||||
dev_vals_out[idx] = tgp_vals[i];
|
dev_vals_out[idx] = tgp_vals[i];
|
||||||
dev_idxs_out[idx] = tgp_idxs[i];
|
dev_idxs_out[idx] = tgp_idxs[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_multi_block_sort(vtname, vtype, itname, itype, arg_sort, bn, tn) \
|
#define instantiate_multi_block_sort( \
|
||||||
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
vtname, vtype, itname, itype, arg_sort, bn, tn) \
|
||||||
[[kernel]] void mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
|
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn \
|
||||||
|
"_tn" #tn)]] [[kernel]] void \
|
||||||
|
mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
|
||||||
const device vtype* inp [[buffer(0)]], \
|
const device vtype* inp [[buffer(0)]], \
|
||||||
device vtype* out_vals [[buffer(1)]], \
|
device vtype* out_vals [[buffer(1)]], \
|
||||||
device itype* out_idxs [[buffer(2)]], \
|
device itype* out_idxs [[buffer(2)]], \
|
||||||
@ -777,9 +763,10 @@ template <
|
|||||||
const device size_t* nc_strides [[buffer(7)]], \
|
const device size_t* nc_strides [[buffer(7)]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||||
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn \
|
||||||
[[kernel]] void mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
|
"_tn" #tn)]] [[kernel]] void \
|
||||||
device itype* block_partitions [[buffer(0)]], \
|
mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
|
||||||
|
device itype * block_partitions [[buffer(0)]], \
|
||||||
const device vtype* dev_vals [[buffer(1)]], \
|
const device vtype* dev_vals [[buffer(1)]], \
|
||||||
const device itype* dev_idxs [[buffer(2)]], \
|
const device itype* dev_idxs [[buffer(2)]], \
|
||||||
const constant int& size_sorted_axis [[buffer(3)]], \
|
const constant int& size_sorted_axis [[buffer(3)]], \
|
||||||
@ -787,8 +774,9 @@ template <
|
|||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]], \
|
uint3 lid [[thread_position_in_threadgroup]], \
|
||||||
uint3 tgp_dims [[threads_per_threadgroup]]); \
|
uint3 tgp_dims [[threads_per_threadgroup]]); \
|
||||||
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn \
|
||||||
[[kernel]] void mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
|
"_tn" #tn)]] [[kernel]] void \
|
||||||
|
mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
|
||||||
const device itype* block_partitions [[buffer(0)]], \
|
const device itype* block_partitions [[buffer(0)]], \
|
||||||
const device vtype* dev_vals_in [[buffer(1)]], \
|
const device vtype* dev_vals_in [[buffer(1)]], \
|
||||||
const device itype* dev_idxs_in [[buffer(2)]], \
|
const device itype* dev_idxs_in [[buffer(2)]], \
|
||||||
@ -800,6 +788,7 @@ template <
|
|||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_multi_block_sort_base(vtname, vtype) \
|
#define instantiate_multi_block_sort_base(vtname, vtype) \
|
||||||
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
|
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
|
||||||
|
|
||||||
@ -811,10 +800,11 @@ instantiate_multi_block_sort_base(int16, int16_t)
|
|||||||
instantiate_multi_block_sort_base(int32, int32_t)
|
instantiate_multi_block_sort_base(int32, int32_t)
|
||||||
instantiate_multi_block_sort_base(float16, half)
|
instantiate_multi_block_sort_base(float16, half)
|
||||||
instantiate_multi_block_sort_base(float32, float)
|
instantiate_multi_block_sort_base(float32, float)
|
||||||
instantiate_multi_block_sort_base(bfloat16, bfloat16_t)
|
instantiate_multi_block_sort_base(bfloat16, bfloat16_t) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_multi_block_sort_long(vtname, vtype) \
|
#define instantiate_multi_block_sort_long(vtname, vtype) \
|
||||||
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8)
|
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8)
|
||||||
|
|
||||||
instantiate_multi_block_sort_long(uint64, uint64_t)
|
instantiate_multi_block_sort_long(uint64, uint64_t)
|
||||||
instantiate_multi_block_sort_long(int64, int64_t)
|
instantiate_multi_block_sort_long(int64, int64_t) // clang-format on
|
@ -4,13 +4,14 @@
|
|||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
template <typename T,
|
template <
|
||||||
|
typename T,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int BK,
|
int BK,
|
||||||
@ -18,7 +19,8 @@ template <typename T,
|
|||||||
int WN,
|
int WN,
|
||||||
int N_CHANNELS = 0,
|
int N_CHANNELS = 0,
|
||||||
bool SMALL_FILTER = false>
|
bool SMALL_FILTER = false>
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||||
|
implicit_gemm_conv_2d(
|
||||||
const device T* A [[buffer(0)]],
|
const device T* A [[buffer(0)]],
|
||||||
const device T* B [[buffer(1)]],
|
const device T* B [[buffer(1)]],
|
||||||
device T* C [[buffer(2)]],
|
device T* C [[buffer(2)]],
|
||||||
@ -28,8 +30,6 @@ template <typename T,
|
|||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
|
|
||||||
(void)lid;
|
(void)lid;
|
||||||
@ -56,7 +56,13 @@ template <typename T,
|
|||||||
|
|
||||||
// Go to small channel specialization
|
// Go to small channel specialization
|
||||||
Conv2DInputBlockLoaderSmallChannels<
|
Conv2DInputBlockLoaderSmallChannels<
|
||||||
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_a>,
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
tgp_size,
|
||||||
|
N_CHANNELS,
|
||||||
|
tgp_padding_a>,
|
||||||
|
|
||||||
// Else go to general loader
|
// Else go to general loader
|
||||||
typename metal::conditional_t<
|
typename metal::conditional_t<
|
||||||
@ -65,14 +71,21 @@ template <typename T,
|
|||||||
|
|
||||||
// Go to small filter specialization
|
// Go to small filter specialization
|
||||||
Conv2DInputBlockLoaderSmallFilter<
|
Conv2DInputBlockLoaderSmallFilter<
|
||||||
T, BM, BN, BK, tgp_size, tgp_padding_a>,
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
tgp_size,
|
||||||
|
tgp_padding_a>,
|
||||||
|
|
||||||
// Else go to large filter generalization
|
// Else go to large filter generalization
|
||||||
Conv2DInputBlockLoaderLargeFilter<
|
Conv2DInputBlockLoaderLargeFilter<
|
||||||
T, BM, BN, BK, tgp_size, tgp_padding_a>
|
T,
|
||||||
>
|
BM,
|
||||||
>;
|
BN,
|
||||||
|
BK,
|
||||||
|
tgp_size,
|
||||||
|
tgp_padding_a>>>;
|
||||||
|
|
||||||
// Weight loader
|
// Weight loader
|
||||||
using loader_b_t = typename metal::conditional_t<
|
using loader_b_t = typename metal::conditional_t<
|
||||||
@ -81,11 +94,16 @@ template <typename T,
|
|||||||
|
|
||||||
// Go to small channel specialization
|
// Go to small channel specialization
|
||||||
Conv2DWeightBlockLoaderSmallChannels<
|
Conv2DWeightBlockLoaderSmallChannels<
|
||||||
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_b>,
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
tgp_size,
|
||||||
|
N_CHANNELS,
|
||||||
|
tgp_padding_b>,
|
||||||
|
|
||||||
// Else go to general loader
|
// Else go to general loader
|
||||||
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>
|
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>>;
|
||||||
>;
|
|
||||||
|
|
||||||
using mma_t = BlockMMA<
|
using mma_t = BlockMMA<
|
||||||
T,
|
T,
|
||||||
@ -123,8 +141,10 @@ template <typename T,
|
|||||||
const int2 offsets_b(0, c_col);
|
const int2 offsets_b(0, c_col);
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
// Prepare threadgroup loading operations
|
||||||
loader_a_t loader_a(A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
|
loader_a_t loader_a(
|
||||||
loader_b_t loader_b(B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
|
A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
|
||||||
|
loader_b_t loader_b(
|
||||||
|
B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
|
||||||
|
|
||||||
// Prepare threadgroup mma operation
|
// Prepare threadgroup mma operation
|
||||||
mma_t mma_op(simd_gid, simd_lid);
|
mma_t mma_op(simd_gid, simd_lid);
|
||||||
@ -152,12 +172,24 @@ template <typename T,
|
|||||||
short tgp_bm = min(BM, gemm_params->M - c_row);
|
short tgp_bm = min(BM, gemm_params->M - c_row);
|
||||||
short tgp_bn = min(BN, gemm_params->N - c_col);
|
short tgp_bn = min(BN, gemm_params->N - c_col);
|
||||||
mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm));
|
mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, channel_name, n_channels, filter_name, small_filter) \
|
#define instantiate_implicit_conv_2d( \
|
||||||
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name "_filter_" #filter_name)]] \
|
name, \
|
||||||
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>( \
|
itype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
channel_name, \
|
||||||
|
n_channels, \
|
||||||
|
filter_name, \
|
||||||
|
small_filter) \
|
||||||
|
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn \
|
||||||
|
"_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name \
|
||||||
|
"_filter_" #filter_name)]] [[kernel]] void \
|
||||||
|
implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>( \
|
||||||
const device itype* A [[buffer(0)]], \
|
const device itype* A [[buffer(0)]], \
|
||||||
const device itype* B [[buffer(1)]], \
|
const device itype* B [[buffer(1)]], \
|
||||||
device itype* C [[buffer(2)]], \
|
device itype* C [[buffer(2)]], \
|
||||||
@ -168,22 +200,25 @@ template <typename T,
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false)
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
|
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_implicit_2d_blocks(float32, float);
|
instantiate_implicit_2d_blocks(float32, float);
|
||||||
instantiate_implicit_2d_blocks(float16, half);
|
instantiate_implicit_2d_blocks(float16, half);
|
||||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
|
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on
|
@ -4,15 +4,16 @@
|
|||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
|
|
||||||
template <typename T,
|
template <
|
||||||
|
typename T,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int BK,
|
int BK,
|
||||||
@ -20,7 +21,8 @@ template <typename T,
|
|||||||
int WN,
|
int WN,
|
||||||
typename AccumType = float,
|
typename AccumType = float,
|
||||||
typename Epilogue = TransformNone<T, AccumType>>
|
typename Epilogue = TransformNone<T, AccumType>>
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d_general(
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||||
|
implicit_gemm_conv_2d_general(
|
||||||
const device T* A [[buffer(0)]],
|
const device T* A [[buffer(0)]],
|
||||||
const device T* B [[buffer(1)]],
|
const device T* B [[buffer(1)]],
|
||||||
device T* C [[buffer(2)]],
|
device T* C [[buffer(2)]],
|
||||||
@ -33,7 +35,6 @@ template <typename T,
|
|||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
(void)lid;
|
(void)lid;
|
||||||
|
|
||||||
constexpr bool transpose_a = false;
|
constexpr bool transpose_a = false;
|
||||||
@ -51,12 +52,12 @@ template <typename T,
|
|||||||
constexpr short tgp_size = WM * WN * 32;
|
constexpr short tgp_size = WM * WN * 32;
|
||||||
|
|
||||||
// Input loader
|
// Input loader
|
||||||
using loader_a_t = Conv2DInputBlockLoaderGeneral<
|
using loader_a_t =
|
||||||
T, BM, BN, BK, tgp_size, tgp_padding_a>;
|
Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
|
||||||
|
|
||||||
// Weight loader
|
// Weight loader
|
||||||
using loader_b_t = Conv2DWeightBlockLoaderGeneral<
|
using loader_b_t =
|
||||||
T, BM, BN, BK, tgp_size, tgp_padding_b>;
|
Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
|
||||||
|
|
||||||
using mma_t = BlockMMA<
|
using mma_t = BlockMMA<
|
||||||
T,
|
T,
|
||||||
@ -103,13 +104,32 @@ template <typename T,
|
|||||||
const int2 offsets_b(0, c_col);
|
const int2 offsets_b(0, c_col);
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
// Prepare threadgroup loading operations
|
||||||
loader_a_t loader_a(A, As, offsets_a, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
|
loader_a_t loader_a(
|
||||||
loader_b_t loader_b(B, Bs, offsets_b, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
|
A,
|
||||||
|
As,
|
||||||
|
offsets_a,
|
||||||
|
params,
|
||||||
|
jump_params,
|
||||||
|
base_wh,
|
||||||
|
base_ww,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid);
|
||||||
|
loader_b_t loader_b(
|
||||||
|
B,
|
||||||
|
Bs,
|
||||||
|
offsets_b,
|
||||||
|
params,
|
||||||
|
jump_params,
|
||||||
|
base_wh,
|
||||||
|
base_ww,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid);
|
||||||
|
|
||||||
// Prepare threadgroup mma operation
|
// Prepare threadgroup mma operation
|
||||||
mma_t mma_op(simd_gid, simd_lid);
|
mma_t mma_op(simd_gid, simd_lid);
|
||||||
|
|
||||||
int gemm_k_iterations = base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
int gemm_k_iterations =
|
||||||
|
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
||||||
|
|
||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
@ -143,22 +163,24 @@ template <typename T,
|
|||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (int i = 0; i < mma_t::TM; i++) {
|
for (int i = 0; i < mma_t::TM; i++) {
|
||||||
|
|
||||||
int cm = offset_m + i * mma_t::TM_stride;
|
int cm = offset_m + i * mma_t::TM_stride;
|
||||||
|
|
||||||
int n = cm / jump_params->adj_out_hw;
|
int n = cm / jump_params->adj_out_hw;
|
||||||
int hw = cm % jump_params->adj_out_hw;
|
int hw = cm % jump_params->adj_out_hw;
|
||||||
int oh = (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
|
int oh =
|
||||||
int ow = (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
|
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
|
||||||
|
int ow =
|
||||||
|
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
|
||||||
|
|
||||||
if(n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
|
if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
|
||||||
|
int offset_cm = n * params->out_strides[0] +
|
||||||
int offset_cm = n * params->out_strides[0] + oh * params->out_strides[1] + ow * params->out_strides[2];
|
oh * params->out_strides[1] + ow * params->out_strides[2];
|
||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (int j = 0; j < mma_t::TN; j++) {
|
for (int j = 0; j < mma_t::TN; j++) {
|
||||||
// Get accumulated result and associated offset in C
|
// Get accumulated result and associated offset in C
|
||||||
thread const auto& accum = mma_op.results[i * mma_t::TN + j].thread_elements();
|
thread const auto& accum =
|
||||||
|
mma_op.results[i * mma_t::TN + j].thread_elements();
|
||||||
int offset = offset_cm + (j * mma_t::TN_stride);
|
int offset = offset_cm + (j * mma_t::TN_stride);
|
||||||
|
|
||||||
// Apply epilogue and output C
|
// Apply epilogue and output C
|
||||||
@ -170,16 +192,16 @@ template <typename T,
|
|||||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
|
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
|
||||||
template [[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
|
template \
|
||||||
[[kernel]] void implicit_gemm_conv_2d_general<itype, bm, bn, bk, wm, wn>( \
|
[[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn \
|
||||||
|
"_bk" #bk "_wm" #wm "_wn" #wn)]] [[kernel]] void \
|
||||||
|
implicit_gemm_conv_2d_general<itype, bm, bn, bk, wm, wn>( \
|
||||||
const device itype* A [[buffer(0)]], \
|
const device itype* A [[buffer(0)]], \
|
||||||
const device itype* B [[buffer(1)]], \
|
const device itype* B [[buffer(1)]], \
|
||||||
device itype* C [[buffer(2)]], \
|
device itype* C [[buffer(2)]], \
|
||||||
@ -196,14 +218,16 @@ template <typename T,
|
|||||||
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
|
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_implicit_2d_blocks(float32, float);
|
instantiate_implicit_2d_blocks(float32, float);
|
||||||
instantiate_implicit_2d_blocks(float16, half);
|
instantiate_implicit_2d_blocks(float16, half);
|
||||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
|
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
@ -11,7 +11,8 @@ using namespace mlx::steel;
|
|||||||
// GEMM kernels
|
// GEMM kernels
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T,
|
template <
|
||||||
|
typename T,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int BK,
|
int BK,
|
||||||
@ -21,10 +22,10 @@ template <typename T,
|
|||||||
bool transpose_b,
|
bool transpose_b,
|
||||||
bool MN_aligned,
|
bool MN_aligned,
|
||||||
bool K_aligned>
|
bool K_aligned>
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
|
||||||
const device T *A [[buffer(0)]],
|
const device T* A [[buffer(0)]],
|
||||||
const device T *B [[buffer(1)]],
|
const device T* B [[buffer(1)]],
|
||||||
device T *D [[buffer(3)]],
|
device T* D [[buffer(3)]],
|
||||||
const constant GEMMParams* params [[buffer(4)]],
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
const constant int* batch_shape [[buffer(6)]],
|
const constant int* batch_shape [[buffer(6)]],
|
||||||
const constant size_t* batch_strides [[buffer(7)]],
|
const constant size_t* batch_strides [[buffer(7)]],
|
||||||
@ -32,14 +33,24 @@ template <typename T,
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
using gemm_kernel = GEMMKernel<
|
||||||
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
T,
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
MN_aligned,
|
||||||
|
K_aligned>;
|
||||||
|
|
||||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
// Adjust for batch
|
// Adjust for batch
|
||||||
if(params->batch_ndim > 1) {
|
if (params->batch_ndim > 1) {
|
||||||
const constant size_t* A_bstrides = batch_strides;
|
const constant size_t* A_bstrides = batch_strides;
|
||||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||||
|
|
||||||
@ -57,23 +68,37 @@ template <typename T,
|
|||||||
D += params->batch_stride_d * tid.z;
|
D += params->batch_stride_d * tid.z;
|
||||||
|
|
||||||
gemm_kernel::run(
|
gemm_kernel::run(
|
||||||
A, B, D,
|
A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid);
|
||||||
params,
|
|
||||||
As, Bs,
|
|
||||||
simd_lane_id, simd_group_id, tid, lid
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// GEMM kernel initializations
|
// GEMM kernel initializations
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
#define instantiate_gemm( \
|
||||||
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
tname, \
|
||||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
trans_a, \
|
||||||
const device itype *A [[buffer(0)]], \
|
trans_b, \
|
||||||
const device itype *B [[buffer(1)]], \
|
iname, \
|
||||||
device itype *D [[buffer(3)]], \
|
itype, \
|
||||||
|
oname, \
|
||||||
|
otype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
aname, \
|
||||||
|
mn_aligned, \
|
||||||
|
kname, \
|
||||||
|
k_aligned) \
|
||||||
|
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm \
|
||||||
|
"_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \
|
||||||
|
"_K_" #kname)]] [[kernel]] void \
|
||||||
|
gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||||
|
const device itype* A [[buffer(0)]], \
|
||||||
|
const device itype* B [[buffer(1)]], \
|
||||||
|
device itype* D [[buffer(3)]], \
|
||||||
const constant GEMMParams* params [[buffer(4)]], \
|
const constant GEMMParams* params [[buffer(4)]], \
|
||||||
const constant int* batch_shape [[buffer(6)]], \
|
const constant int* batch_shape [[buffer(6)]], \
|
||||||
const constant size_t* batch_strides [[buffer(7)]], \
|
const constant size_t* batch_strides [[buffer(7)]], \
|
||||||
@ -82,26 +107,30 @@ template <typename T,
|
|||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
@ -10,7 +10,8 @@ using namespace mlx::steel;
|
|||||||
// GEMM kernels
|
// GEMM kernels
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T,
|
template <
|
||||||
|
typename T,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int BK,
|
int BK,
|
||||||
@ -22,11 +23,11 @@ template <typename T,
|
|||||||
bool K_aligned,
|
bool K_aligned,
|
||||||
typename AccumType = float,
|
typename AccumType = float,
|
||||||
typename Epilogue = TransformAdd<T, AccumType>>
|
typename Epilogue = TransformAdd<T, AccumType>>
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm(
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void addmm(
|
||||||
const device T *A [[buffer(0)]],
|
const device T* A [[buffer(0)]],
|
||||||
const device T *B [[buffer(1)]],
|
const device T* B [[buffer(1)]],
|
||||||
const device T *C [[buffer(2)]],
|
const device T* C [[buffer(2)]],
|
||||||
device T *D [[buffer(3)]],
|
device T* D [[buffer(3)]],
|
||||||
const constant GEMMParams* params [[buffer(4)]],
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
const constant GEMMAddMMParams* addmm_params [[buffer(5)]],
|
const constant GEMMAddMMParams* addmm_params [[buffer(5)]],
|
||||||
const constant int* batch_shape [[buffer(6)]],
|
const constant int* batch_shape [[buffer(6)]],
|
||||||
@ -35,15 +36,23 @@ template <typename T,
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
// Pacifying compiler
|
// Pacifying compiler
|
||||||
(void)lid;
|
(void)lid;
|
||||||
|
|
||||||
using gemm_kernel =
|
using gemm_kernel = GEMMKernel<
|
||||||
GEMMKernel<T, T, BM, BN, BK, WM, WN,
|
T,
|
||||||
transpose_a, transpose_b,
|
T,
|
||||||
MN_aligned, K_aligned,
|
BM,
|
||||||
AccumType, Epilogue>;
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
MN_aligned,
|
||||||
|
K_aligned,
|
||||||
|
AccumType,
|
||||||
|
Epilogue>;
|
||||||
|
|
||||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||||
@ -53,13 +62,18 @@ template <typename T,
|
|||||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
// Adjust for batch
|
// Adjust for batch
|
||||||
if(params->batch_ndim > 1) {
|
if (params->batch_ndim > 1) {
|
||||||
const constant size_t* A_bstrides = batch_strides;
|
const constant size_t* A_bstrides = batch_strides;
|
||||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||||
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
|
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
|
||||||
|
|
||||||
ulong3 batch_offsets = elem_to_loc_broadcast(
|
ulong3 batch_offsets = elem_to_loc_broadcast(
|
||||||
tid.z, batch_shape, A_bstrides, B_bstrides, C_bstrides, params->batch_ndim);
|
tid.z,
|
||||||
|
batch_shape,
|
||||||
|
A_bstrides,
|
||||||
|
B_bstrides,
|
||||||
|
C_bstrides,
|
||||||
|
params->batch_ndim);
|
||||||
|
|
||||||
A += batch_offsets.x;
|
A += batch_offsets.x;
|
||||||
B += batch_offsets.y;
|
B += batch_offsets.y;
|
||||||
@ -140,7 +154,8 @@ template <typename T,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Store results to device memory
|
// Store results to device memory
|
||||||
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
mma_op.store_result(
|
||||||
|
D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
||||||
return;
|
return;
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -164,7 +179,8 @@ template <typename T,
|
|||||||
leftover_bk,
|
leftover_bk,
|
||||||
LoopAlignment<true, true, K_aligned>{});
|
LoopAlignment<true, true, K_aligned>{});
|
||||||
|
|
||||||
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
mma_op.store_result(
|
||||||
|
D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
||||||
return;
|
return;
|
||||||
|
|
||||||
} else if (tgp_bn == BN) {
|
} else if (tgp_bn == BN) {
|
||||||
@ -181,8 +197,11 @@ template <typename T,
|
|||||||
LoopAlignment<false, true, K_aligned>{});
|
LoopAlignment<false, true, K_aligned>{});
|
||||||
|
|
||||||
return mma_op.store_result_safe(
|
return mma_op.store_result_safe(
|
||||||
D, params->ldd,
|
D,
|
||||||
C, addmm_params->ldc, addmm_params->fdc,
|
params->ldd,
|
||||||
|
C,
|
||||||
|
addmm_params->ldc,
|
||||||
|
addmm_params->fdc,
|
||||||
short2(tgp_bn, tgp_bm),
|
short2(tgp_bn, tgp_bm),
|
||||||
epilogue_op);
|
epilogue_op);
|
||||||
|
|
||||||
@ -200,8 +219,11 @@ template <typename T,
|
|||||||
LoopAlignment<true, false, K_aligned>{});
|
LoopAlignment<true, false, K_aligned>{});
|
||||||
|
|
||||||
return mma_op.store_result_safe(
|
return mma_op.store_result_safe(
|
||||||
D, params->ldd,
|
D,
|
||||||
C, addmm_params->ldc, addmm_params->fdc,
|
params->ldd,
|
||||||
|
C,
|
||||||
|
addmm_params->ldc,
|
||||||
|
addmm_params->fdc,
|
||||||
short2(tgp_bn, tgp_bm),
|
short2(tgp_bn, tgp_bm),
|
||||||
epilogue_op);
|
epilogue_op);
|
||||||
|
|
||||||
@ -219,8 +241,11 @@ template <typename T,
|
|||||||
LoopAlignment<false, false, K_aligned>{});
|
LoopAlignment<false, false, K_aligned>{});
|
||||||
|
|
||||||
return mma_op.store_result_safe(
|
return mma_op.store_result_safe(
|
||||||
D, params->ldd,
|
D,
|
||||||
C, addmm_params->ldc, addmm_params->fdc,
|
params->ldd,
|
||||||
|
C,
|
||||||
|
addmm_params->ldc,
|
||||||
|
addmm_params->fdc,
|
||||||
short2(tgp_bn, tgp_bm),
|
short2(tgp_bn, tgp_bm),
|
||||||
epilogue_op);
|
epilogue_op);
|
||||||
}
|
}
|
||||||
@ -231,13 +256,45 @@ template <typename T,
|
|||||||
// GEMM kernel initializations
|
// GEMM kernel initializations
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \
|
#define instantiate_gemm( \
|
||||||
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \
|
tname, \
|
||||||
[[kernel]] void addmm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, float, epilogue<itype, float>>( \
|
trans_a, \
|
||||||
const device itype *A [[buffer(0)]], \
|
trans_b, \
|
||||||
const device itype *B [[buffer(1)]], \
|
iname, \
|
||||||
const device itype *C [[buffer(2)]], \
|
itype, \
|
||||||
device itype *D [[buffer(3)]], \
|
oname, \
|
||||||
|
otype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
aname, \
|
||||||
|
mn_aligned, \
|
||||||
|
kname, \
|
||||||
|
k_aligned, \
|
||||||
|
ep_name, \
|
||||||
|
epilogue) \
|
||||||
|
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm \
|
||||||
|
"_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \
|
||||||
|
"_K_" #kname "_" #ep_name)]] [[kernel]] void \
|
||||||
|
addmm< \
|
||||||
|
itype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
trans_a, \
|
||||||
|
trans_b, \
|
||||||
|
mn_aligned, \
|
||||||
|
k_aligned, \
|
||||||
|
float, \
|
||||||
|
epilogue<itype, float>>( \
|
||||||
|
const device itype* A [[buffer(0)]], \
|
||||||
|
const device itype* B [[buffer(1)]], \
|
||||||
|
const device itype* C [[buffer(2)]], \
|
||||||
|
device itype* D [[buffer(3)]], \
|
||||||
const constant GEMMParams* gemm_params [[buffer(4)]], \
|
const constant GEMMParams* gemm_params [[buffer(4)]], \
|
||||||
const constant GEMMAddMMParams* params [[buffer(5)]], \
|
const constant GEMMAddMMParams* params [[buffer(5)]], \
|
||||||
const constant int* batch_shape [[buffer(6)]], \
|
const constant int* batch_shape [[buffer(6)]], \
|
||||||
@ -247,30 +304,35 @@ template <typename T,
|
|||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby)
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
@ -11,7 +11,8 @@ using namespace mlx::steel;
|
|||||||
// GEMM kernels
|
// GEMM kernels
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T,
|
template <
|
||||||
|
typename T,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int BK,
|
int BK,
|
||||||
@ -21,27 +22,38 @@ template <typename T,
|
|||||||
bool transpose_b,
|
bool transpose_b,
|
||||||
bool MN_aligned,
|
bool MN_aligned,
|
||||||
bool K_aligned,
|
bool K_aligned,
|
||||||
bool has_operand_mask=false>
|
bool has_operand_mask = false>
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void block_masked_gemm(
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||||
const device T *A [[buffer(0)]],
|
block_masked_gemm(
|
||||||
const device T *B [[buffer(1)]],
|
const device T* A [[buffer(0)]],
|
||||||
device T *D [[buffer(3)]],
|
const device T* B [[buffer(1)]],
|
||||||
|
device T* D [[buffer(3)]],
|
||||||
const constant GEMMParams* params [[buffer(4)]],
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
const constant int* batch_shape [[buffer(6)]],
|
const constant int* batch_shape [[buffer(6)]],
|
||||||
const constant size_t* batch_strides [[buffer(7)]],
|
const constant size_t* batch_strides [[buffer(7)]],
|
||||||
const device bool *out_mask [[buffer(10)]],
|
const device bool* out_mask [[buffer(10)]],
|
||||||
const device bool *lhs_mask [[buffer(11)]],
|
const device bool* lhs_mask [[buffer(11)]],
|
||||||
const device bool *rhs_mask [[buffer(12)]],
|
const device bool* rhs_mask [[buffer(12)]],
|
||||||
const constant int* mask_strides [[buffer(13)]],
|
const constant int* mask_strides [[buffer(13)]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
// Appease the compiler
|
// Appease the compiler
|
||||||
(void)lid;
|
(void)lid;
|
||||||
|
|
||||||
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
using gemm_kernel = GEMMKernel<
|
||||||
|
T,
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
MN_aligned,
|
||||||
|
K_aligned>;
|
||||||
|
|
||||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||||
@ -51,30 +63,38 @@ template <typename T,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(params->batch_ndim > 1) {
|
if (params->batch_ndim > 1) {
|
||||||
const constant size_t* mask_batch_strides = batch_strides + 2 * params->batch_ndim;
|
const constant size_t* mask_batch_strides =
|
||||||
out_mask += elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
batch_strides + 2 * params->batch_ndim;
|
||||||
|
out_mask +=
|
||||||
|
elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
||||||
|
|
||||||
if(has_operand_mask) {
|
if (has_operand_mask) {
|
||||||
const constant size_t* mask_strides_lhs = mask_batch_strides + params->batch_ndim;
|
const constant size_t* mask_strides_lhs =
|
||||||
const constant size_t* mask_strides_rhs = mask_strides_lhs + params->batch_ndim;
|
mask_batch_strides + params->batch_ndim;
|
||||||
|
const constant size_t* mask_strides_rhs =
|
||||||
|
mask_strides_lhs + params->batch_ndim;
|
||||||
|
|
||||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||||
tid.z, batch_shape, mask_strides_lhs, mask_strides_rhs, params->batch_ndim);
|
tid.z,
|
||||||
|
batch_shape,
|
||||||
|
mask_strides_lhs,
|
||||||
|
mask_strides_rhs,
|
||||||
|
params->batch_ndim);
|
||||||
|
|
||||||
lhs_mask += batch_offsets.x;
|
lhs_mask += batch_offsets.x;
|
||||||
rhs_mask += batch_offsets.y;
|
rhs_mask += batch_offsets.y;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
|
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
|
||||||
if(has_operand_mask) {
|
if (has_operand_mask) {
|
||||||
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
|
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
|
||||||
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
|
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adjust for batch
|
// Adjust for batch
|
||||||
if(params->batch_ndim > 1) {
|
if (params->batch_ndim > 1) {
|
||||||
const constant size_t* A_bstrides = batch_strides;
|
const constant size_t* A_bstrides = batch_strides;
|
||||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||||
|
|
||||||
@ -99,11 +119,10 @@ template <typename T,
|
|||||||
B += transpose_b ? c_col * params->ldb : c_col;
|
B += transpose_b ? c_col * params->ldb : c_col;
|
||||||
D += c_row * params->ldd + c_col;
|
D += c_row * params->ldd + c_col;
|
||||||
|
|
||||||
|
|
||||||
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
|
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
|
||||||
|
|
||||||
// Write zeros and return
|
// Write zeros and return
|
||||||
if(!mask_out) {
|
if (!mask_out) {
|
||||||
constexpr short tgp_size = WM * WN * 32;
|
constexpr short tgp_size = WM * WN * 32;
|
||||||
constexpr short vec_size = 4;
|
constexpr short vec_size = 4;
|
||||||
|
|
||||||
@ -123,7 +142,7 @@ template <typename T,
|
|||||||
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||||
for (short ti = 0; ti < BM; ti += TM) {
|
for (short ti = 0; ti < BM; ti += TM) {
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for(short j = 0; j < vec_size; j++) {
|
for (short j = 0; j < vec_size; j++) {
|
||||||
D[ti * params->ldd + j] = T(0.);
|
D[ti * params->ldd + j] = T(0.);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -131,7 +150,7 @@ template <typename T,
|
|||||||
short jmax = tgp_bn - bj;
|
short jmax = tgp_bn - bj;
|
||||||
jmax = jmax < vec_size ? jmax : vec_size;
|
jmax = jmax < vec_size ? jmax : vec_size;
|
||||||
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
|
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
|
||||||
for(short j = 0; j < jmax; j++) {
|
for (short j = 0; j < jmax; j++) {
|
||||||
D[ti * params->ldd + j] = T(0.);
|
D[ti * params->ldd + j] = T(0.);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -151,8 +170,10 @@ template <typename T,
|
|||||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
// Prepare threadgroup loading operations
|
||||||
thread typename gemm_kernel::loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
thread typename gemm_kernel::loader_a_t loader_a(
|
||||||
thread typename gemm_kernel::loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
A, params->lda, As, simd_group_id, simd_lane_id);
|
||||||
|
thread typename gemm_kernel::loader_b_t loader_b(
|
||||||
|
B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// MNK aligned loop
|
// MNK aligned loop
|
||||||
@ -160,10 +181,11 @@ template <typename T,
|
|||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
if(!has_operand_mask ||
|
if (!has_operand_mask ||
|
||||||
(lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
(lhs_mask
|
||||||
rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
||||||
|
rhs_mask
|
||||||
|
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||||
// Load elements into threadgroup
|
// Load elements into threadgroup
|
||||||
loader_a.load_unsafe();
|
loader_a.load_unsafe();
|
||||||
loader_b.load_unsafe();
|
loader_b.load_unsafe();
|
||||||
@ -172,7 +194,6 @@ template <typename T,
|
|||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
// Multiply and accumulate threadgroup elements
|
||||||
mma_op.mma(As, Bs);
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare for next iteration
|
// Prepare for next iteration
|
||||||
@ -184,11 +205,12 @@ template <typename T,
|
|||||||
|
|
||||||
// Loop tail
|
// Loop tail
|
||||||
if (!K_aligned) {
|
if (!K_aligned) {
|
||||||
|
if (!has_operand_mask ||
|
||||||
if(!has_operand_mask ||
|
(lhs_mask
|
||||||
(lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
||||||
rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
rhs_mask
|
||||||
|
[(params->K / BM) * mask_strides[5] +
|
||||||
|
tid_x * mask_strides[4]])) {
|
||||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||||
@ -199,7 +221,6 @@ template <typename T,
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
mma_op.mma(As, Bs);
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -223,10 +244,11 @@ template <typename T,
|
|||||||
|
|
||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
if(!has_operand_mask ||
|
if (!has_operand_mask ||
|
||||||
(lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
(lhs_mask
|
||||||
rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
||||||
|
rhs_mask
|
||||||
|
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||||
// Load elements into threadgroup
|
// Load elements into threadgroup
|
||||||
if (M_aligned) {
|
if (M_aligned) {
|
||||||
loader_a.load_unsafe();
|
loader_a.load_unsafe();
|
||||||
@ -244,7 +266,6 @@ template <typename T,
|
|||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
// Multiply and accumulate threadgroup elements
|
||||||
mma_op.mma(As, Bs);
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare for next iteration
|
// Prepare for next iteration
|
||||||
@ -255,10 +276,12 @@ template <typename T,
|
|||||||
if (!K_aligned) {
|
if (!K_aligned) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
if(!has_operand_mask ||
|
if (!has_operand_mask ||
|
||||||
(lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
(lhs_mask
|
||||||
rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
||||||
|
rhs_mask
|
||||||
|
[(params->K / BM) * mask_strides[5] +
|
||||||
|
tid_x * mask_strides[4]])) {
|
||||||
short2 tile_dims_A_last =
|
short2 tile_dims_A_last =
|
||||||
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
||||||
short2 tile_dims_B_last =
|
short2 tile_dims_B_last =
|
||||||
@ -270,11 +293,10 @@ template <typename T,
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
mma_op.mma(As, Bs);
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if(M_aligned && N_aligned) {
|
if (M_aligned && N_aligned) {
|
||||||
mma_op.store_result(D, params->ldd);
|
mma_op.store_result(D, params->ldd);
|
||||||
} else {
|
} else {
|
||||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||||
@ -286,44 +308,81 @@ template <typename T,
|
|||||||
// GEMM kernel initializations
|
// GEMM kernel initializations
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, omname, op_mask) \
|
#define instantiate_gemm( \
|
||||||
template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_op_mask_" #omname)]] \
|
tname, \
|
||||||
[[kernel]] void block_masked_gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, op_mask>( \
|
trans_a, \
|
||||||
const device itype *A [[buffer(0)]], \
|
trans_b, \
|
||||||
const device itype *B [[buffer(1)]], \
|
iname, \
|
||||||
device itype *D [[buffer(3)]], \
|
itype, \
|
||||||
|
oname, \
|
||||||
|
otype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
aname, \
|
||||||
|
mn_aligned, \
|
||||||
|
kname, \
|
||||||
|
k_aligned, \
|
||||||
|
omname, \
|
||||||
|
op_mask) \
|
||||||
|
template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname \
|
||||||
|
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||||
|
"_MN_" #aname "_K_" #kname \
|
||||||
|
"_op_mask_" #omname)]] [[kernel]] void \
|
||||||
|
block_masked_gemm< \
|
||||||
|
itype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
trans_a, \
|
||||||
|
trans_b, \
|
||||||
|
mn_aligned, \
|
||||||
|
k_aligned, \
|
||||||
|
op_mask>( \
|
||||||
|
const device itype* A [[buffer(0)]], \
|
||||||
|
const device itype* B [[buffer(1)]], \
|
||||||
|
device itype* D [[buffer(3)]], \
|
||||||
const constant GEMMParams* params [[buffer(4)]], \
|
const constant GEMMParams* params [[buffer(4)]], \
|
||||||
const constant int* batch_shape [[buffer(6)]], \
|
const constant int* batch_shape [[buffer(6)]], \
|
||||||
const constant size_t* batch_strides [[buffer(7)]], \
|
const constant size_t* batch_strides [[buffer(7)]], \
|
||||||
const device bool *out_mask [[buffer(10)]], \
|
const device bool* out_mask [[buffer(10)]], \
|
||||||
const device bool *lhs_mask [[buffer(11)]], \
|
const device bool* lhs_mask [[buffer(11)]], \
|
||||||
const device bool *rhs_mask [[buffer(12)]], \
|
const device bool* rhs_mask [[buffer(12)]], \
|
||||||
const constant int* mask_strides [[buffer(13)]], \
|
const constant int* mask_strides [[buffer(13)]], \
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true)
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2)
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
||||||
|
@ -10,7 +10,8 @@ using namespace mlx::steel;
|
|||||||
// GEMM kernels
|
// GEMM kernels
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T,
|
template <
|
||||||
|
typename T,
|
||||||
typename U,
|
typename U,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
@ -21,19 +22,29 @@ template <typename T,
|
|||||||
bool transpose_b,
|
bool transpose_b,
|
||||||
bool MN_aligned,
|
bool MN_aligned,
|
||||||
bool K_aligned>
|
bool K_aligned>
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk(
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk(
|
||||||
const device T *A [[buffer(0)]],
|
const device T* A [[buffer(0)]],
|
||||||
const device T *B [[buffer(1)]],
|
const device T* B [[buffer(1)]],
|
||||||
device U *C [[buffer(2)]],
|
device U* C [[buffer(2)]],
|
||||||
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
(void)lid;
|
(void)lid;
|
||||||
|
|
||||||
using gemm_kernel = GEMMKernel<T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
using gemm_kernel = GEMMKernel<
|
||||||
|
T,
|
||||||
|
U,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
MN_aligned,
|
||||||
|
K_aligned>;
|
||||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||||
using mma_t = typename gemm_kernel::mma_t;
|
using mma_t = typename gemm_kernel::mma_t;
|
||||||
@ -54,9 +65,12 @@ template <typename T,
|
|||||||
const int c_col = tid_x * BN;
|
const int c_col = tid_x * BN;
|
||||||
const int k_start = params->split_k_partition_size * tid_z;
|
const int k_start = params->split_k_partition_size * tid_z;
|
||||||
|
|
||||||
A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda);
|
A += transpose_a ? (c_row + k_start * params->lda)
|
||||||
B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb);
|
: (k_start + c_row * params->lda);
|
||||||
C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col);
|
B += transpose_b ? (k_start + c_col * params->ldb)
|
||||||
|
: (c_col + k_start * params->ldb);
|
||||||
|
C += (params->split_k_partition_stride * tid_z) +
|
||||||
|
(c_row * params->ldc + c_col);
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
// Prepare threadgroup loading operations
|
||||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||||
@ -71,7 +85,7 @@ template <typename T,
|
|||||||
short tgp_bn = min(BN, params->N - c_col);
|
short tgp_bn = min(BN, params->N - c_col);
|
||||||
short leftover_bk = params->K % BK;
|
short leftover_bk = params->K % BK;
|
||||||
|
|
||||||
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||||
gemm_kernel::gemm_loop(
|
gemm_kernel::gemm_loop(
|
||||||
As,
|
As,
|
||||||
Bs,
|
Bs,
|
||||||
@ -124,8 +138,9 @@ template <typename T,
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
if ((tid_z + 1) == (params->split_k_partitions)) {
|
if ((tid_z + 1) == (params->split_k_partitions)) {
|
||||||
int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK;
|
int gemm_k_iter_remaining =
|
||||||
if(!K_aligned || gemm_k_iter_remaining > 0)
|
(params->K - (k_start + params->split_k_partition_size)) / BK;
|
||||||
|
if (!K_aligned || gemm_k_iter_remaining > 0)
|
||||||
gemm_kernel::gemm_loop(
|
gemm_kernel::gemm_loop(
|
||||||
As,
|
As,
|
||||||
Bs,
|
Bs,
|
||||||
@ -139,7 +154,7 @@ template <typename T,
|
|||||||
LoopAlignment<false, false, K_aligned>{});
|
LoopAlignment<false, false, K_aligned>{});
|
||||||
}
|
}
|
||||||
|
|
||||||
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||||
mma_op.store_result(C, params->ldc);
|
mma_op.store_result(C, params->ldc);
|
||||||
} else {
|
} else {
|
||||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||||
@ -150,56 +165,89 @@ template <typename T,
|
|||||||
// GEMM kernel initializations
|
// GEMM kernel initializations
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
#define instantiate_gemm( \
|
||||||
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
tname, \
|
||||||
[[kernel]] void gemm_splitk<itype, otype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
trans_a, \
|
||||||
const device itype *A [[buffer(0)]], \
|
trans_b, \
|
||||||
const device itype *B [[buffer(1)]], \
|
iname, \
|
||||||
device otype *C [[buffer(2)]], \
|
itype, \
|
||||||
|
oname, \
|
||||||
|
otype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
aname, \
|
||||||
|
mn_aligned, \
|
||||||
|
kname, \
|
||||||
|
k_aligned) \
|
||||||
|
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname \
|
||||||
|
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||||
|
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
|
||||||
|
gemm_splitk< \
|
||||||
|
itype, \
|
||||||
|
otype, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
trans_a, \
|
||||||
|
trans_b, \
|
||||||
|
mn_aligned, \
|
||||||
|
k_aligned>( \
|
||||||
|
const device itype* A [[buffer(0)]], \
|
||||||
|
const device itype* B [[buffer(1)]], \
|
||||||
|
device otype* C [[buffer(2)]], \
|
||||||
const constant GEMMSpiltKParams* params [[buffer(3)]], \
|
const constant GEMMSpiltKParams* params [[buffer(3)]], \
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_gemm_shapes_helper(float16, half, float32, float);
|
instantiate_gemm_shapes_helper(float16, half, float32, float);
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Split k accumulation kernel
|
// Split k accumulation kernel
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename AccT,
|
template <
|
||||||
|
typename AccT,
|
||||||
typename OutT,
|
typename OutT,
|
||||||
typename Epilogue = TransformNone<OutT, AccT>>
|
typename Epilogue = TransformNone<OutT, AccT>>
|
||||||
[[kernel]] void gemm_splitk_accum(
|
[[kernel]] void gemm_splitk_accum(
|
||||||
const device AccT *C_split [[buffer(0)]],
|
const device AccT* C_split [[buffer(0)]],
|
||||||
device OutT *D [[buffer(1)]],
|
device OutT* D [[buffer(1)]],
|
||||||
const constant int& k_partitions [[buffer(2)]],
|
const constant int& k_partitions [[buffer(2)]],
|
||||||
const constant int& partition_stride [[buffer(3)]],
|
const constant int& partition_stride [[buffer(3)]],
|
||||||
const constant int& ldd [[buffer(4)]],
|
const constant int& ldd [[buffer(4)]],
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
// Ajust D and C
|
// Ajust D and C
|
||||||
D += gid.x + gid.y * ldd;
|
D += gid.x + gid.y * ldd;
|
||||||
C_split += gid.x + gid.y * ldd;
|
C_split += gid.x + gid.y * ldd;
|
||||||
@ -207,32 +255,31 @@ template <typename AccT,
|
|||||||
int offset = 0;
|
int offset = 0;
|
||||||
AccT out = 0;
|
AccT out = 0;
|
||||||
|
|
||||||
for(int i = 0; i < k_partitions; i++) {
|
for (int i = 0; i < k_partitions; i++) {
|
||||||
out += C_split[offset];
|
out += C_split[offset];
|
||||||
offset += partition_stride;
|
offset += partition_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write output
|
// Write output
|
||||||
D[0] = Epilogue::apply(out);
|
D[0] = Epilogue::apply(out);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename AccT,
|
template <
|
||||||
|
typename AccT,
|
||||||
typename OutT,
|
typename OutT,
|
||||||
typename Epilogue = TransformAxpby<OutT, AccT>>
|
typename Epilogue = TransformAxpby<OutT, AccT>>
|
||||||
[[kernel]] void gemm_splitk_accum_axpby(
|
[[kernel]] void gemm_splitk_accum_axpby(
|
||||||
const device AccT *C_split [[buffer(0)]],
|
const device AccT* C_split [[buffer(0)]],
|
||||||
device OutT *D [[buffer(1)]],
|
device OutT* D [[buffer(1)]],
|
||||||
const constant int& k_partitions [[buffer(2)]],
|
const constant int& k_partitions [[buffer(2)]],
|
||||||
const constant int& partition_stride [[buffer(3)]],
|
const constant int& partition_stride [[buffer(3)]],
|
||||||
const constant int& ldd [[buffer(4)]],
|
const constant int& ldd [[buffer(4)]],
|
||||||
const device OutT *C [[buffer(5)]],
|
const device OutT* C [[buffer(5)]],
|
||||||
const constant int& ldc [[buffer(6)]],
|
const constant int& ldc [[buffer(6)]],
|
||||||
const constant int& fdc [[buffer(7)]],
|
const constant int& fdc [[buffer(7)]],
|
||||||
const constant float& alpha [[buffer(8)]],
|
const constant float& alpha [[buffer(8)]],
|
||||||
const constant float& beta [[buffer(9)]],
|
const constant float& beta [[buffer(9)]],
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
// Ajust D and C
|
// Ajust D and C
|
||||||
C += gid.x * fdc + gid.y * ldc;
|
C += gid.x * fdc + gid.y * ldc;
|
||||||
D += gid.x + gid.y * ldd;
|
D += gid.x + gid.y * ldd;
|
||||||
@ -241,7 +288,7 @@ template <typename AccT,
|
|||||||
int offset = 0;
|
int offset = 0;
|
||||||
AccT out = 0;
|
AccT out = 0;
|
||||||
|
|
||||||
for(int i = 0; i < k_partitions; i++) {
|
for (int i = 0; i < k_partitions; i++) {
|
||||||
out += C_split[offset];
|
out += C_split[offset];
|
||||||
offset += partition_stride;
|
offset += partition_stride;
|
||||||
}
|
}
|
||||||
@ -249,32 +296,34 @@ template <typename AccT,
|
|||||||
// Write output
|
// Write output
|
||||||
Epilogue op(alpha, beta);
|
Epilogue op(alpha, beta);
|
||||||
D[0] = op.apply(out, *C);
|
D[0] = op.apply(out, *C);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_accum(oname, otype, aname, atype) \
|
#define instantiate_accum(oname, otype, aname, atype) \
|
||||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \
|
template [[host_name("steel_gemm_splitk_accum_" #oname \
|
||||||
[[kernel]] void gemm_splitk_accum<atype, otype>( \
|
"_" #aname)]] [[kernel]] void \
|
||||||
const device atype *C_split [[buffer(0)]], \
|
gemm_splitk_accum<atype, otype>( \
|
||||||
device otype *D [[buffer(1)]], \
|
const device atype* C_split [[buffer(0)]], \
|
||||||
|
device otype* D [[buffer(1)]], \
|
||||||
const constant int& k_partitions [[buffer(2)]], \
|
const constant int& k_partitions [[buffer(2)]], \
|
||||||
const constant int& partition_stride [[buffer(3)]], \
|
const constant int& partition_stride [[buffer(3)]], \
|
||||||
const constant int& ldd [[buffer(4)]], \
|
const constant int& ldd [[buffer(4)]], \
|
||||||
uint2 gid [[thread_position_in_grid]]); \
|
uint2 gid [[thread_position_in_grid]]); \
|
||||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \
|
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname \
|
||||||
[[kernel]] void gemm_splitk_accum_axpby<atype, otype>( \
|
"_axpby")]] [[kernel]] void \
|
||||||
const device atype *C_split [[buffer(0)]], \
|
gemm_splitk_accum_axpby<atype, otype>( \
|
||||||
device otype *D [[buffer(1)]], \
|
const device atype* C_split [[buffer(0)]], \
|
||||||
|
device otype* D [[buffer(1)]], \
|
||||||
const constant int& k_partitions [[buffer(2)]], \
|
const constant int& k_partitions [[buffer(2)]], \
|
||||||
const constant int& partition_stride [[buffer(3)]], \
|
const constant int& partition_stride [[buffer(3)]], \
|
||||||
const constant int& ldd [[buffer(4)]], \
|
const constant int& ldd [[buffer(4)]], \
|
||||||
const device otype *C [[buffer(5)]], \
|
const device otype* C [[buffer(5)]], \
|
||||||
const constant int& ldc [[buffer(6)]], \
|
const constant int& ldc [[buffer(6)]], \
|
||||||
const constant int& fdc [[buffer(7)]], \
|
const constant int& fdc [[buffer(7)]], \
|
||||||
const constant float& alpha [[buffer(8)]], \
|
const constant float& alpha [[buffer(8)]], \
|
||||||
const constant float& beta [[buffer(9)]], \
|
const constant float& beta [[buffer(9)]], \
|
||||||
uint2 gid [[thread_position_in_grid]]);
|
uint2 gid [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
||||||
instantiate_accum(float16, half, float32, float);
|
instantiate_accum(float16, half, float32, float);
|
||||||
instantiate_accum(float32, float, float32, float);
|
instantiate_accum(float32, float, float32, float); // clang-format on
|
@ -3,9 +3,9 @@
|
|||||||
#include <metal_integer>
|
#include <metal_integer>
|
||||||
#include <metal_math>
|
#include <metal_math>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/ternary.h"
|
#include "mlx/backend/metal/kernels/ternary.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename Op>
|
||||||
[[kernel]] void ternary_op_v(
|
[[kernel]] void ternary_op_v(
|
||||||
@ -65,7 +65,8 @@ template <typename T, typename Op>
|
|||||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||||
auto c_idx = elem_to_loc_3(index, c_strides);
|
auto c_idx = elem_to_loc_3(index, c_strides);
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
size_t out_idx =
|
||||||
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,8 +82,10 @@ template <typename T, typename Op, int DIM>
|
|||||||
constant const size_t c_strides[DIM],
|
constant const size_t c_strides[DIM],
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx = elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
|
auto idx =
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
|
||||||
|
size_t out_idx =
|
||||||
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,23 +102,22 @@ template <typename T, typename Op>
|
|||||||
constant const int& ndim,
|
constant const int& ndim,
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
|
auto idx =
|
||||||
|
elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
|
||||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_ternary_v(name, type, op) \
|
#define instantiate_ternary_v(name, type, op) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void ternary_op_v<type, op>( \
|
||||||
[[kernel]] void ternary_op_v<type, op>( \
|
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
device type* d, \
|
device type* d, \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_ternary_g(name, type, op) \
|
#define instantiate_ternary_g(name, type, op) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void ternary_op_g<type, op>( \
|
||||||
[[kernel]] void ternary_op_g<type, op>( \
|
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
@ -126,11 +128,11 @@ template <typename T, typename Op>
|
|||||||
constant const size_t* c_strides, \
|
constant const size_t* c_strides, \
|
||||||
constant const int& ndim, \
|
constant const int& ndim, \
|
||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
#define instantiate_ternary_g_dim(name, type, op, dims) \
|
#define instantiate_ternary_g_dim(name, type, op, dims) \
|
||||||
template [[host_name(name "_" #dims)]] \
|
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||||
[[kernel]] void ternary_op_g_nd<type, op, dims>( \
|
ternary_op_g_nd<type, op, dims>( \
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
@ -140,11 +142,11 @@ template <typename T, typename Op>
|
|||||||
constant const size_t b_strides[dims], \
|
constant const size_t b_strides[dims], \
|
||||||
constant const size_t c_strides[dims], \
|
constant const size_t c_strides[dims], \
|
||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
#define instantiate_ternary_g_nd(name, type, op) \
|
#define instantiate_ternary_g_nd(name, type, op) \
|
||||||
template [[host_name(name "_1")]] \
|
template [[host_name(name "_1")]] [[kernel]] void \
|
||||||
[[kernel]] void ternary_op_g_nd1<type, op>( \
|
ternary_op_g_nd1<type, op>( \
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
@ -153,8 +155,8 @@ template <typename T, typename Op>
|
|||||||
constant const size_t& b_strides, \
|
constant const size_t& b_strides, \
|
||||||
constant const size_t& c_strides, \
|
constant const size_t& c_strides, \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name(name "_2")]] \
|
template [[host_name(name "_2")]] [[kernel]] void \
|
||||||
[[kernel]] void ternary_op_g_nd2<type, op>( \
|
ternary_op_g_nd2<type, op>( \
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
@ -164,8 +166,8 @@ template <typename T, typename Op>
|
|||||||
constant const size_t c_strides[2], \
|
constant const size_t c_strides[2], \
|
||||||
uint2 index [[thread_position_in_grid]], \
|
uint2 index [[thread_position_in_grid]], \
|
||||||
uint2 grid_dim [[threads_per_grid]]); \
|
uint2 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name(name "_3")]] \
|
template [[host_name(name "_3")]] [[kernel]] void \
|
||||||
[[kernel]] void ternary_op_g_nd3<type, op>( \
|
ternary_op_g_nd3<type, op>( \
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
@ -176,13 +178,15 @@ template <typename T, typename Op>
|
|||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
instantiate_ternary_g_dim(name, type, op, 4) \
|
instantiate_ternary_g_dim(name, type, op, 4) \
|
||||||
instantiate_ternary_g_dim(name, type, op, 5) \
|
instantiate_ternary_g_dim(name, type, op, 5)
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_ternary_all(name, tname, type, op) \
|
#define instantiate_ternary_all(name, tname, type, op) \
|
||||||
instantiate_ternary_v("v" #name #tname, type, op) \
|
instantiate_ternary_v("v" #name #tname, type, op) \
|
||||||
instantiate_ternary_g("g" #name #tname, type, op) \
|
instantiate_ternary_g("g" #name #tname, type, op) \
|
||||||
instantiate_ternary_g_nd("g" #name #tname, type, op) \
|
instantiate_ternary_g_nd("g" #name #tname, type, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_ternary_types(name, op) \
|
#define instantiate_ternary_types(name, op) \
|
||||||
instantiate_ternary_all(name, bool_, bool, op) \
|
instantiate_ternary_all(name, bool_, bool, op) \
|
||||||
instantiate_ternary_all(name, uint8, uint8_t, op) \
|
instantiate_ternary_all(name, uint8, uint8_t, op) \
|
||||||
@ -196,6 +200,6 @@ template <typename T, typename Op>
|
|||||||
instantiate_ternary_all(name, float16, half, op) \
|
instantiate_ternary_all(name, float16, half, op) \
|
||||||
instantiate_ternary_all(name, float32, float, op) \
|
instantiate_ternary_all(name, float32, float, op) \
|
||||||
instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \
|
instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \
|
||||||
instantiate_ternary_all(name, complex64, complex64_t, op) \
|
instantiate_ternary_all(name, complex64, complex64_t, op) // clang-format on
|
||||||
|
|
||||||
instantiate_ternary_types(select, Select)
|
instantiate_ternary_types(select, Select)
|
@ -23,15 +23,13 @@ template <typename T, typename Op>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_unary_v(name, type, op) \
|
#define instantiate_unary_v(name, type, op) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void unary_op_v<type, op>( \
|
||||||
[[kernel]] void unary_op_v<type, op>( \
|
|
||||||
device const type* in, \
|
device const type* in, \
|
||||||
device type* out, \
|
device type* out, \
|
||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_unary_g(name, type, op) \
|
#define instantiate_unary_g(name, type, op) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] [[kernel]] void unary_op_g<type, op>( \
|
||||||
[[kernel]] void unary_op_g<type, op>( \
|
|
||||||
device const type* in, \
|
device const type* in, \
|
||||||
device type* out, \
|
device type* out, \
|
||||||
device const int* in_shape, \
|
device const int* in_shape, \
|
||||||
@ -39,15 +37,18 @@ template <typename T, typename Op>
|
|||||||
device const int& ndim, \
|
device const int& ndim, \
|
||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_unary_all(name, tname, type, op) \
|
#define instantiate_unary_all(name, tname, type, op) \
|
||||||
instantiate_unary_v("v" #name #tname, type, op) \
|
instantiate_unary_v("v" #name #tname, type, op) \
|
||||||
instantiate_unary_g("g" #name #tname, type, op)
|
instantiate_unary_g("g" #name #tname, type, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_unary_float(name, op) \
|
#define instantiate_unary_float(name, op) \
|
||||||
instantiate_unary_all(name, float16, half, op) \
|
instantiate_unary_all(name, float16, half, op) \
|
||||||
instantiate_unary_all(name, float32, float, op) \
|
instantiate_unary_all(name, float32, float, op) \
|
||||||
instantiate_unary_all(name, bfloat16, bfloat16_t, op) \
|
instantiate_unary_all(name, bfloat16, bfloat16_t, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_unary_types(name, op) \
|
#define instantiate_unary_types(name, op) \
|
||||||
instantiate_unary_all(name, bool_, bool, op) \
|
instantiate_unary_all(name, bool_, bool, op) \
|
||||||
instantiate_unary_all(name, uint8, uint8_t, op) \
|
instantiate_unary_all(name, uint8, uint8_t, op) \
|
||||||
@ -58,8 +59,9 @@ template <typename T, typename Op>
|
|||||||
instantiate_unary_all(name, int16, int16_t, op) \
|
instantiate_unary_all(name, int16, int16_t, op) \
|
||||||
instantiate_unary_all(name, int32, int32_t, op) \
|
instantiate_unary_all(name, int32, int32_t, op) \
|
||||||
instantiate_unary_all(name, int64, int64_t, op) \
|
instantiate_unary_all(name, int64, int64_t, op) \
|
||||||
instantiate_unary_float(name, op)
|
instantiate_unary_float(name, op) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
instantiate_unary_types(abs, Abs)
|
instantiate_unary_types(abs, Abs)
|
||||||
instantiate_unary_float(arccos, ArcCos)
|
instantiate_unary_float(arccos, ArcCos)
|
||||||
instantiate_unary_float(arccosh, ArcCosh)
|
instantiate_unary_float(arccosh, ArcCosh)
|
||||||
@ -102,4 +104,4 @@ instantiate_unary_all(tan, complex64, complex64_t, Tan)
|
|||||||
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
|
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
|
||||||
instantiate_unary_all(round, complex64, complex64_t, Round)
|
instantiate_unary_all(round, complex64, complex64_t, Round)
|
||||||
|
|
||||||
instantiate_unary_all(lnot, bool_, bool, LogicalNot)
|
instantiate_unary_all(lnot, bool_, bool, LogicalNot) // clang-format on
|
||||||
|
@ -13,7 +13,7 @@ struct Device {
|
|||||||
static constexpr DeviceType cpu = DeviceType::cpu;
|
static constexpr DeviceType cpu = DeviceType::cpu;
|
||||||
static constexpr DeviceType gpu = DeviceType::gpu;
|
static constexpr DeviceType gpu = DeviceType::gpu;
|
||||||
|
|
||||||
Device(DeviceType type, int index = 0) : type(type), index(index){};
|
Device(DeviceType type, int index = 0) : type(type), index(index) {};
|
||||||
|
|
||||||
DeviceType type;
|
DeviceType type;
|
||||||
int index;
|
int index;
|
||||||
|
@ -51,7 +51,7 @@ struct Dtype {
|
|||||||
|
|
||||||
Val val;
|
Val val;
|
||||||
const uint8_t size;
|
const uint8_t size;
|
||||||
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){};
|
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size) {};
|
||||||
constexpr operator Val() const {
|
constexpr operator Val() const {
|
||||||
return val;
|
return val;
|
||||||
};
|
};
|
||||||
|
@ -10,7 +10,7 @@ namespace mlx::core {
|
|||||||
|
|
||||||
class Event {
|
class Event {
|
||||||
public:
|
public:
|
||||||
Event(){};
|
Event() {};
|
||||||
|
|
||||||
Event(const Stream& steam);
|
Event(const Stream& steam);
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ class Custom : public Primitive {
|
|||||||
explicit Custom(
|
explicit Custom(
|
||||||
Stream stream,
|
Stream stream,
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback)
|
std::function<std::vector<array>(std::vector<array>)> fallback)
|
||||||
: Primitive(stream), fallback_(fallback){};
|
: Primitive(stream), fallback_(fallback) {};
|
||||||
|
|
||||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
@ -39,7 +39,7 @@ class RMSNorm : public Custom {
|
|||||||
Stream stream,
|
Stream stream,
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||||
float eps)
|
float eps)
|
||||||
: Custom(stream, fallback), eps_(eps){};
|
: Custom(stream, fallback), eps_(eps) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override {
|
override {
|
||||||
@ -68,7 +68,7 @@ class RMSNormVJP : public Custom {
|
|||||||
Stream stream,
|
Stream stream,
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||||
float eps)
|
float eps)
|
||||||
: Custom(stream, fallback), eps_(eps){};
|
: Custom(stream, fallback), eps_(eps) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override {
|
override {
|
||||||
@ -91,7 +91,7 @@ class LayerNorm : public Custom {
|
|||||||
Stream stream,
|
Stream stream,
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||||
float eps)
|
float eps)
|
||||||
: Custom(stream, fallback), eps_(eps){};
|
: Custom(stream, fallback), eps_(eps) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override {
|
override {
|
||||||
@ -120,7 +120,7 @@ class LayerNormVJP : public Custom {
|
|||||||
Stream stream,
|
Stream stream,
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||||
float eps)
|
float eps)
|
||||||
: Custom(stream, fallback), eps_(eps){};
|
: Custom(stream, fallback), eps_(eps) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override {
|
override {
|
||||||
@ -154,7 +154,7 @@ class RoPE : public Custom {
|
|||||||
base_(base),
|
base_(base),
|
||||||
scale_(scale),
|
scale_(scale),
|
||||||
offset_(offset),
|
offset_(offset),
|
||||||
forward_(forward){};
|
forward_(forward) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override {
|
override {
|
||||||
@ -189,7 +189,7 @@ class ScaledDotProductAttention : public Custom {
|
|||||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||||
const float scale,
|
const float scale,
|
||||||
const bool needs_mask)
|
const bool needs_mask)
|
||||||
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask){};
|
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override {
|
override {
|
||||||
|
168
mlx/primitives.h
168
mlx/primitives.h
@ -154,7 +154,7 @@ class UnaryPrimitive : public Primitive {
|
|||||||
|
|
||||||
class Abs : public UnaryPrimitive {
|
class Abs : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Abs(Stream stream) : UnaryPrimitive(stream){};
|
explicit Abs(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -171,7 +171,7 @@ class Abs : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Add : public UnaryPrimitive {
|
class Add : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Add(Stream stream) : UnaryPrimitive(stream){};
|
explicit Add(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -189,7 +189,7 @@ class Add : public UnaryPrimitive {
|
|||||||
class AddMM : public UnaryPrimitive {
|
class AddMM : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit AddMM(Stream stream, float alpha, float beta)
|
explicit AddMM(Stream stream, float alpha, float beta)
|
||||||
: UnaryPrimitive(stream), alpha_(alpha), beta_(beta){};
|
: UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -213,7 +213,7 @@ class AddMM : public UnaryPrimitive {
|
|||||||
class Arange : public UnaryPrimitive {
|
class Arange : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Arange(Stream stream, double start, double stop, double step)
|
explicit Arange(Stream stream, double start, double stop, double step)
|
||||||
: UnaryPrimitive(stream), start_(start), stop_(stop), step_(step){};
|
: UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -231,7 +231,7 @@ class Arange : public UnaryPrimitive {
|
|||||||
|
|
||||||
class ArcCos : public UnaryPrimitive {
|
class ArcCos : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit ArcCos(Stream stream) : UnaryPrimitive(stream){};
|
explicit ArcCos(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -248,7 +248,7 @@ class ArcCos : public UnaryPrimitive {
|
|||||||
|
|
||||||
class ArcCosh : public UnaryPrimitive {
|
class ArcCosh : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit ArcCosh(Stream stream) : UnaryPrimitive(stream){};
|
explicit ArcCosh(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -265,7 +265,7 @@ class ArcCosh : public UnaryPrimitive {
|
|||||||
|
|
||||||
class ArcSin : public UnaryPrimitive {
|
class ArcSin : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit ArcSin(Stream stream) : UnaryPrimitive(stream){};
|
explicit ArcSin(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -282,7 +282,7 @@ class ArcSin : public UnaryPrimitive {
|
|||||||
|
|
||||||
class ArcSinh : public UnaryPrimitive {
|
class ArcSinh : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit ArcSinh(Stream stream) : UnaryPrimitive(stream){};
|
explicit ArcSinh(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -299,7 +299,7 @@ class ArcSinh : public UnaryPrimitive {
|
|||||||
|
|
||||||
class ArcTan : public UnaryPrimitive {
|
class ArcTan : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit ArcTan(Stream stream) : UnaryPrimitive(stream){};
|
explicit ArcTan(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -316,7 +316,7 @@ class ArcTan : public UnaryPrimitive {
|
|||||||
|
|
||||||
class ArcTanh : public UnaryPrimitive {
|
class ArcTanh : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit ArcTanh(Stream stream) : UnaryPrimitive(stream){};
|
explicit ArcTanh(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -334,7 +334,7 @@ class ArcTanh : public UnaryPrimitive {
|
|||||||
class ArgPartition : public UnaryPrimitive {
|
class ArgPartition : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit ArgPartition(Stream stream, int kth, int axis)
|
explicit ArgPartition(Stream stream, int kth, int axis)
|
||||||
: UnaryPrimitive(stream), kth_(kth), axis_(axis){};
|
: UnaryPrimitive(stream), kth_(kth), axis_(axis) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -359,7 +359,7 @@ class ArgReduce : public UnaryPrimitive {
|
|||||||
};
|
};
|
||||||
|
|
||||||
explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
|
explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
|
||||||
: UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis){};
|
: UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -380,7 +380,7 @@ class ArgReduce : public UnaryPrimitive {
|
|||||||
class ArgSort : public UnaryPrimitive {
|
class ArgSort : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit ArgSort(Stream stream, int axis)
|
explicit ArgSort(Stream stream, int axis)
|
||||||
: UnaryPrimitive(stream), axis_(axis){};
|
: UnaryPrimitive(stream), axis_(axis) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -399,7 +399,7 @@ class ArgSort : public UnaryPrimitive {
|
|||||||
class AsType : public UnaryPrimitive {
|
class AsType : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit AsType(Stream stream, Dtype dtype)
|
explicit AsType(Stream stream, Dtype dtype)
|
||||||
: UnaryPrimitive(stream), dtype_(dtype){};
|
: UnaryPrimitive(stream), dtype_(dtype) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -426,7 +426,7 @@ class AsStrided : public UnaryPrimitive {
|
|||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
shape_(std::move(shape)),
|
shape_(std::move(shape)),
|
||||||
strides_(std::move(strides)),
|
strides_(std::move(strides)),
|
||||||
offset_(offset){};
|
offset_(offset) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -448,7 +448,7 @@ class BitwiseBinary : public UnaryPrimitive {
|
|||||||
enum Op { And, Or, Xor, LeftShift, RightShift };
|
enum Op { And, Or, Xor, LeftShift, RightShift };
|
||||||
|
|
||||||
explicit BitwiseBinary(Stream stream, Op op)
|
explicit BitwiseBinary(Stream stream, Op op)
|
||||||
: UnaryPrimitive(stream), op_(op){};
|
: UnaryPrimitive(stream), op_(op) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -465,7 +465,7 @@ class BitwiseBinary : public UnaryPrimitive {
|
|||||||
class BlockMaskedMM : public UnaryPrimitive {
|
class BlockMaskedMM : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit BlockMaskedMM(Stream stream, int block_size)
|
explicit BlockMaskedMM(Stream stream, int block_size)
|
||||||
: UnaryPrimitive(stream), block_size_(block_size){};
|
: UnaryPrimitive(stream), block_size_(block_size) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -488,7 +488,7 @@ class BlockMaskedMM : public UnaryPrimitive {
|
|||||||
class Broadcast : public UnaryPrimitive {
|
class Broadcast : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Broadcast(Stream stream, const std::vector<int>& shape)
|
explicit Broadcast(Stream stream, const std::vector<int>& shape)
|
||||||
: UnaryPrimitive(stream), shape_(shape){};
|
: UnaryPrimitive(stream), shape_(shape) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -506,7 +506,7 @@ class Broadcast : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Ceil : public UnaryPrimitive {
|
class Ceil : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Ceil(Stream stream) : UnaryPrimitive(stream){};
|
explicit Ceil(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -567,7 +567,7 @@ class Compiled : public Primitive {
|
|||||||
class Concatenate : public UnaryPrimitive {
|
class Concatenate : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Concatenate(Stream stream, int axis)
|
explicit Concatenate(Stream stream, int axis)
|
||||||
: UnaryPrimitive(stream), axis_(axis){};
|
: UnaryPrimitive(stream), axis_(axis) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -599,7 +599,7 @@ class Convolution : public UnaryPrimitive {
|
|||||||
kernel_dilation_(kernel_dilation),
|
kernel_dilation_(kernel_dilation),
|
||||||
input_dilation_(input_dilation),
|
input_dilation_(input_dilation),
|
||||||
groups_(groups),
|
groups_(groups),
|
||||||
flip_(flip){};
|
flip_(flip) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -626,7 +626,7 @@ class Convolution : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Copy : public UnaryPrimitive {
|
class Copy : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Copy(Stream stream) : UnaryPrimitive(stream){};
|
explicit Copy(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -643,7 +643,7 @@ class Copy : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Cos : public UnaryPrimitive {
|
class Cos : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Cos(Stream stream) : UnaryPrimitive(stream){};
|
explicit Cos(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -660,7 +660,7 @@ class Cos : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Cosh : public UnaryPrimitive {
|
class Cosh : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Cosh(Stream stream) : UnaryPrimitive(stream){};
|
explicit Cosh(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -731,7 +731,7 @@ class Depends : public Primitive {
|
|||||||
|
|
||||||
class Divide : public UnaryPrimitive {
|
class Divide : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Divide(Stream stream) : UnaryPrimitive(stream){};
|
explicit Divide(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -748,7 +748,7 @@ class Divide : public UnaryPrimitive {
|
|||||||
|
|
||||||
class DivMod : public Primitive {
|
class DivMod : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit DivMod(Stream stream) : Primitive(stream){};
|
explicit DivMod(Stream stream) : Primitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
@ -770,7 +770,7 @@ class DivMod : public Primitive {
|
|||||||
|
|
||||||
class Select : public UnaryPrimitive {
|
class Select : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Select(Stream stream) : UnaryPrimitive(stream){};
|
explicit Select(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -787,7 +787,7 @@ class Select : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Remainder : public UnaryPrimitive {
|
class Remainder : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Remainder(Stream stream) : UnaryPrimitive(stream){};
|
explicit Remainder(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -805,7 +805,7 @@ class Remainder : public UnaryPrimitive {
|
|||||||
class Equal : public UnaryPrimitive {
|
class Equal : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Equal(Stream stream, bool equal_nan = false)
|
explicit Equal(Stream stream, bool equal_nan = false)
|
||||||
: UnaryPrimitive(stream), equal_nan_(equal_nan){};
|
: UnaryPrimitive(stream), equal_nan_(equal_nan) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -830,7 +830,7 @@ class Equal : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Erf : public UnaryPrimitive {
|
class Erf : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Erf(Stream stream) : UnaryPrimitive(stream){};
|
explicit Erf(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -847,7 +847,7 @@ class Erf : public UnaryPrimitive {
|
|||||||
|
|
||||||
class ErfInv : public UnaryPrimitive {
|
class ErfInv : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit ErfInv(Stream stream) : UnaryPrimitive(stream){};
|
explicit ErfInv(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -864,7 +864,7 @@ class ErfInv : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Exp : public UnaryPrimitive {
|
class Exp : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Exp(Stream stream) : UnaryPrimitive(stream){};
|
explicit Exp(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -881,7 +881,7 @@ class Exp : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Expm1 : public UnaryPrimitive {
|
class Expm1 : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Expm1(Stream stream) : UnaryPrimitive(stream){};
|
explicit Expm1(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -902,7 +902,7 @@ class FFT : public UnaryPrimitive {
|
|||||||
const std::vector<size_t>& axes,
|
const std::vector<size_t>& axes,
|
||||||
bool inverse,
|
bool inverse,
|
||||||
bool real)
|
bool real)
|
||||||
: UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real){};
|
: UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -923,7 +923,7 @@ class FFT : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Floor : public UnaryPrimitive {
|
class Floor : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Floor(Stream stream) : UnaryPrimitive(stream){};
|
explicit Floor(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -940,7 +940,7 @@ class Floor : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Full : public UnaryPrimitive {
|
class Full : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Full(Stream stream) : UnaryPrimitive(stream){};
|
explicit Full(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -960,7 +960,7 @@ class Gather : public UnaryPrimitive {
|
|||||||
Stream stream,
|
Stream stream,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const std::vector<int>& slice_sizes)
|
const std::vector<int>& slice_sizes)
|
||||||
: UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes){};
|
: UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -978,7 +978,7 @@ class Gather : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Greater : public UnaryPrimitive {
|
class Greater : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Greater(Stream stream) : UnaryPrimitive(stream){};
|
explicit Greater(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -995,7 +995,7 @@ class Greater : public UnaryPrimitive {
|
|||||||
|
|
||||||
class GreaterEqual : public UnaryPrimitive {
|
class GreaterEqual : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream){};
|
explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1012,7 +1012,7 @@ class GreaterEqual : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Less : public UnaryPrimitive {
|
class Less : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Less(Stream stream) : UnaryPrimitive(stream){};
|
explicit Less(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1029,7 +1029,7 @@ class Less : public UnaryPrimitive {
|
|||||||
|
|
||||||
class LessEqual : public UnaryPrimitive {
|
class LessEqual : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit LessEqual(Stream stream) : UnaryPrimitive(stream){};
|
explicit LessEqual(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1054,7 +1054,7 @@ class Load : public UnaryPrimitive {
|
|||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
reader_(reader),
|
reader_(reader),
|
||||||
offset_(offset),
|
offset_(offset),
|
||||||
swap_endianness_(swap_endianness){};
|
swap_endianness_(swap_endianness) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1073,7 +1073,7 @@ class Log : public UnaryPrimitive {
|
|||||||
enum Base { two, ten, e };
|
enum Base { two, ten, e };
|
||||||
|
|
||||||
explicit Log(Stream stream, Base base)
|
explicit Log(Stream stream, Base base)
|
||||||
: UnaryPrimitive(stream), base_(base){};
|
: UnaryPrimitive(stream), base_(base) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1104,7 +1104,7 @@ class Log : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Log1p : public UnaryPrimitive {
|
class Log1p : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Log1p(Stream stream) : UnaryPrimitive(stream){};
|
explicit Log1p(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1120,7 +1120,7 @@ class Log1p : public UnaryPrimitive {
|
|||||||
|
|
||||||
class LogicalNot : public UnaryPrimitive {
|
class LogicalNot : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit LogicalNot(Stream stream) : UnaryPrimitive(stream){};
|
explicit LogicalNot(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1137,7 +1137,7 @@ class LogicalNot : public UnaryPrimitive {
|
|||||||
|
|
||||||
class LogicalAnd : public UnaryPrimitive {
|
class LogicalAnd : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream){};
|
explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1154,7 +1154,7 @@ class LogicalAnd : public UnaryPrimitive {
|
|||||||
|
|
||||||
class LogicalOr : public UnaryPrimitive {
|
class LogicalOr : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit LogicalOr(Stream stream) : UnaryPrimitive(stream){};
|
explicit LogicalOr(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1171,7 +1171,7 @@ class LogicalOr : public UnaryPrimitive {
|
|||||||
|
|
||||||
class LogAddExp : public UnaryPrimitive {
|
class LogAddExp : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit LogAddExp(Stream stream) : UnaryPrimitive(stream){};
|
explicit LogAddExp(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1188,7 +1188,7 @@ class LogAddExp : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Matmul : public UnaryPrimitive {
|
class Matmul : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Matmul(Stream stream) : UnaryPrimitive(stream){};
|
explicit Matmul(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1206,7 +1206,7 @@ class Matmul : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Maximum : public UnaryPrimitive {
|
class Maximum : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Maximum(Stream stream) : UnaryPrimitive(stream){};
|
explicit Maximum(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1223,7 +1223,7 @@ class Maximum : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Minimum : public UnaryPrimitive {
|
class Minimum : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Minimum(Stream stream) : UnaryPrimitive(stream){};
|
explicit Minimum(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1240,7 +1240,7 @@ class Minimum : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Multiply : public UnaryPrimitive {
|
class Multiply : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Multiply(Stream stream) : UnaryPrimitive(stream){};
|
explicit Multiply(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1257,7 +1257,7 @@ class Multiply : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Negative : public UnaryPrimitive {
|
class Negative : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Negative(Stream stream) : UnaryPrimitive(stream){};
|
explicit Negative(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1274,7 +1274,7 @@ class Negative : public UnaryPrimitive {
|
|||||||
|
|
||||||
class NotEqual : public UnaryPrimitive {
|
class NotEqual : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit NotEqual(Stream stream) : UnaryPrimitive(stream){};
|
explicit NotEqual(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1330,7 +1330,7 @@ class Pad : public UnaryPrimitive {
|
|||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
axes_(axes),
|
axes_(axes),
|
||||||
low_pad_size_(low_pad_size),
|
low_pad_size_(low_pad_size),
|
||||||
high_pad_size_(high_pad_size){};
|
high_pad_size_(high_pad_size) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1351,7 +1351,7 @@ class Pad : public UnaryPrimitive {
|
|||||||
class Partition : public UnaryPrimitive {
|
class Partition : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Partition(Stream stream, int kth, int axis)
|
explicit Partition(Stream stream, int kth, int axis)
|
||||||
: UnaryPrimitive(stream), kth_(kth), axis_(axis){};
|
: UnaryPrimitive(stream), kth_(kth), axis_(axis) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1371,7 +1371,7 @@ class Partition : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Power : public UnaryPrimitive {
|
class Power : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Power(Stream stream) : UnaryPrimitive(stream){};
|
explicit Power(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1396,7 +1396,7 @@ class QuantizedMatmul : public UnaryPrimitive {
|
|||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
group_size_(group_size),
|
group_size_(group_size),
|
||||||
bits_(bits),
|
bits_(bits),
|
||||||
transpose_(transpose){};
|
transpose_(transpose) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1417,7 +1417,7 @@ class QuantizedMatmul : public UnaryPrimitive {
|
|||||||
class RandomBits : public UnaryPrimitive {
|
class RandomBits : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
|
explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
|
||||||
: UnaryPrimitive(stream), shape_(shape), width_(width){};
|
: UnaryPrimitive(stream), shape_(shape), width_(width) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1436,7 +1436,7 @@ class RandomBits : public UnaryPrimitive {
|
|||||||
class Reshape : public UnaryPrimitive {
|
class Reshape : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Reshape(Stream stream, const std::vector<int>& shape)
|
explicit Reshape(Stream stream, const std::vector<int>& shape)
|
||||||
: UnaryPrimitive(stream), shape_(shape){};
|
: UnaryPrimitive(stream), shape_(shape) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1468,7 +1468,7 @@ class Reduce : public UnaryPrimitive {
|
|||||||
Stream stream,
|
Stream stream,
|
||||||
ReduceType reduce_type,
|
ReduceType reduce_type,
|
||||||
const std::vector<int>& axes)
|
const std::vector<int>& axes)
|
||||||
: UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes){};
|
: UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1517,7 +1517,7 @@ class Reduce : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Round : public UnaryPrimitive {
|
class Round : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Round(Stream stream) : UnaryPrimitive(stream){};
|
explicit Round(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1546,7 +1546,7 @@ class Scan : public UnaryPrimitive {
|
|||||||
reduce_type_(reduce_type),
|
reduce_type_(reduce_type),
|
||||||
axis_(axis),
|
axis_(axis),
|
||||||
reverse_(reverse),
|
reverse_(reverse),
|
||||||
inclusive_(inclusive){};
|
inclusive_(inclusive) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1591,7 +1591,7 @@ class Scatter : public UnaryPrimitive {
|
|||||||
Stream stream,
|
Stream stream,
|
||||||
ReduceType reduce_type,
|
ReduceType reduce_type,
|
||||||
const std::vector<int>& axes)
|
const std::vector<int>& axes)
|
||||||
: UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes){};
|
: UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1626,7 +1626,7 @@ class Scatter : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Sigmoid : public UnaryPrimitive {
|
class Sigmoid : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Sigmoid(Stream stream) : UnaryPrimitive(stream){};
|
explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1643,7 +1643,7 @@ class Sigmoid : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Sign : public UnaryPrimitive {
|
class Sign : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Sign(Stream stream) : UnaryPrimitive(stream){};
|
explicit Sign(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1660,7 +1660,7 @@ class Sign : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Sin : public UnaryPrimitive {
|
class Sin : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Sin(Stream stream) : UnaryPrimitive(stream){};
|
explicit Sin(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1677,7 +1677,7 @@ class Sin : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Sinh : public UnaryPrimitive {
|
class Sinh : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Sinh(Stream stream) : UnaryPrimitive(stream){};
|
explicit Sinh(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1702,7 +1702,7 @@ class Slice : public UnaryPrimitive {
|
|||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
start_indices_(start_indices),
|
start_indices_(start_indices),
|
||||||
end_indices_(end_indices),
|
end_indices_(end_indices),
|
||||||
strides_(strides){};
|
strides_(strides) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1738,7 +1738,7 @@ class SliceUpdate : public UnaryPrimitive {
|
|||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
start_indices_(start_indices),
|
start_indices_(start_indices),
|
||||||
end_indices_(end_indices),
|
end_indices_(end_indices),
|
||||||
strides_(strides){};
|
strides_(strides) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1761,7 +1761,7 @@ class SliceUpdate : public UnaryPrimitive {
|
|||||||
class Softmax : public UnaryPrimitive {
|
class Softmax : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Softmax(Stream stream, bool precise)
|
explicit Softmax(Stream stream, bool precise)
|
||||||
: UnaryPrimitive(stream), precise_(precise){};
|
: UnaryPrimitive(stream), precise_(precise) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1781,7 +1781,7 @@ class Softmax : public UnaryPrimitive {
|
|||||||
class Sort : public UnaryPrimitive {
|
class Sort : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Sort(Stream stream, int axis)
|
explicit Sort(Stream stream, int axis)
|
||||||
: UnaryPrimitive(stream), axis_(axis){};
|
: UnaryPrimitive(stream), axis_(axis) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1801,7 +1801,7 @@ class Sort : public UnaryPrimitive {
|
|||||||
class Split : public Primitive {
|
class Split : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit Split(Stream stream, const std::vector<int>& indices, int axis)
|
explicit Split(Stream stream, const std::vector<int>& indices, int axis)
|
||||||
: Primitive(stream), indices_(indices), axis_(axis){};
|
: Primitive(stream), indices_(indices), axis_(axis) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
@ -1822,7 +1822,7 @@ class Split : public Primitive {
|
|||||||
|
|
||||||
class Square : public UnaryPrimitive {
|
class Square : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Square(Stream stream) : UnaryPrimitive(stream){};
|
explicit Square(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1840,7 +1840,7 @@ class Square : public UnaryPrimitive {
|
|||||||
class Sqrt : public UnaryPrimitive {
|
class Sqrt : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Sqrt(Stream stream, bool recip = false)
|
explicit Sqrt(Stream stream, bool recip = false)
|
||||||
: UnaryPrimitive(stream), recip_(recip){};
|
: UnaryPrimitive(stream), recip_(recip) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1865,7 +1865,7 @@ class Sqrt : public UnaryPrimitive {
|
|||||||
|
|
||||||
class StopGradient : public UnaryPrimitive {
|
class StopGradient : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit StopGradient(Stream stream) : UnaryPrimitive(stream){};
|
explicit StopGradient(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1881,7 +1881,7 @@ class StopGradient : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Subtract : public UnaryPrimitive {
|
class Subtract : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Subtract(Stream stream) : UnaryPrimitive(stream){};
|
explicit Subtract(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1898,7 +1898,7 @@ class Subtract : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Tan : public UnaryPrimitive {
|
class Tan : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Tan(Stream stream) : UnaryPrimitive(stream){};
|
explicit Tan(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1915,7 +1915,7 @@ class Tan : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Tanh : public UnaryPrimitive {
|
class Tanh : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Tanh(Stream stream) : UnaryPrimitive(stream){};
|
explicit Tanh(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1932,7 +1932,7 @@ class Tanh : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Uniform : public UnaryPrimitive {
|
class Uniform : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Uniform(Stream stream) : UnaryPrimitive(stream){};
|
explicit Uniform(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1948,7 +1948,7 @@ class Uniform : public UnaryPrimitive {
|
|||||||
class Transpose : public UnaryPrimitive {
|
class Transpose : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Transpose(Stream stream, const std::vector<int>& axes)
|
explicit Transpose(Stream stream, const std::vector<int>& axes)
|
||||||
: UnaryPrimitive(stream), axes_(axes){};
|
: UnaryPrimitive(stream), axes_(axes) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1967,7 +1967,7 @@ class Transpose : public UnaryPrimitive {
|
|||||||
/* QR Factorization primitive. */
|
/* QR Factorization primitive. */
|
||||||
class QRF : public Primitive {
|
class QRF : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit QRF(Stream stream) : Primitive(stream){};
|
explicit QRF(Stream stream) : Primitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
@ -1983,7 +1983,7 @@ class QRF : public Primitive {
|
|||||||
/* SVD primitive. */
|
/* SVD primitive. */
|
||||||
class SVD : public Primitive {
|
class SVD : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit SVD(Stream stream) : Primitive(stream){};
|
explicit SVD(Stream stream) : Primitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
@ -2000,7 +2000,7 @@ class SVD : public Primitive {
|
|||||||
/* Matrix inversion primitive. */
|
/* Matrix inversion primitive. */
|
||||||
class Inverse : public UnaryPrimitive {
|
class Inverse : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Inverse(Stream stream) : UnaryPrimitive(stream){};
|
explicit Inverse(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& output) override;
|
void eval_cpu(const std::vector<array>& inputs, array& output) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& output) override;
|
void eval_gpu(const std::vector<array>& inputs, array& output) override;
|
||||||
|
@ -22,7 +22,7 @@ namespace mlx::core {
|
|||||||
* for synchronizing with the main thread. */
|
* for synchronizing with the main thread. */
|
||||||
class Synchronizer : public Primitive {
|
class Synchronizer : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit Synchronizer(Stream stream) : Primitive(stream){};
|
explicit Synchronizer(Stream stream) : Primitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {};
|
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {};
|
||||||
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {};
|
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {};
|
||||||
|
@ -14,8 +14,8 @@ inline constexpr bool can_convert_to_complex128 =
|
|||||||
!std::is_same_v<T, complex128_t> && std::is_convertible_v<T, double>;
|
!std::is_same_v<T, complex128_t> && std::is_convertible_v<T, double>;
|
||||||
|
|
||||||
struct complex128_t : public std::complex<double> {
|
struct complex128_t : public std::complex<double> {
|
||||||
complex128_t(double v, double u) : std::complex<double>(v, u){};
|
complex128_t(double v, double u) : std::complex<double>(v, u) {};
|
||||||
complex128_t(std::complex<double> v) : std::complex<double>(v){};
|
complex128_t(std::complex<double> v) : std::complex<double>(v) {};
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
@ -32,8 +32,8 @@ inline constexpr bool can_convert_to_complex64 =
|
|||||||
!std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>;
|
!std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>;
|
||||||
|
|
||||||
struct complex64_t : public std::complex<float> {
|
struct complex64_t : public std::complex<float> {
|
||||||
complex64_t(float v, float u) : std::complex<float>(v, u){};
|
complex64_t(float v, float u) : std::complex<float>(v, u) {};
|
||||||
complex64_t(std::complex<float> v) : std::complex<float>(v){};
|
complex64_t(std::complex<float> v) : std::complex<float>(v) {};
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
|
Loading…
Reference in New Issue
Block a user