mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00

* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/` * Update gemm elements for better performance * Add split-K specialization for gemm * Add `addmm` primitive, op and bindings for fused matmul and bias addition * Update tests and benchmarks as needed
63 lines
1.4 KiB
C++
63 lines
1.4 KiB
C++
// Copyright © 2024 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include "mlx/backend/metal/kernels/steel/utils.h"
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// Transforms and Epilogues
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace mlx {
|
|
namespace steel {
|
|
|
|
template <typename OutT, typename InT>
|
|
struct TransformNone {
|
|
static METAL_FUNC OutT apply(InT x) {
|
|
return static_cast<OutT>(x);
|
|
}
|
|
|
|
static METAL_FUNC OutT apply(InT x, OutT) {
|
|
return static_cast<OutT>(x);
|
|
}
|
|
};
|
|
|
|
template <typename OutT, typename InT>
|
|
struct TransformAdd {
|
|
TransformAdd(const float, const float) {}
|
|
|
|
static METAL_FUNC OutT apply(InT x, OutT c) {
|
|
return static_cast<OutT>(x) + c;
|
|
}
|
|
};
|
|
|
|
template <typename OutT, typename InT>
|
|
struct TransformAxpby {
|
|
const float alpha;
|
|
const float beta;
|
|
|
|
TransformAxpby(const float alpha_, const float beta_)
|
|
: alpha(alpha_), beta(beta_) {}
|
|
|
|
METAL_FUNC OutT apply(InT x, OutT c) const {
|
|
return static_cast<OutT>(x * alpha + (beta * c));
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct AccumHelper {
|
|
typedef float accum_type;
|
|
};
|
|
|
|
struct BlockSwizzle {
|
|
static METAL_FUNC int2
|
|
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
|
|
const int tid_x = (tid.x) >> swizzle_log;
|
|
const int tid_y =
|
|
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
|
|
return int2(tid_x, tid_y);
|
|
}
|
|
};
|
|
|
|
} // namespace steel
|
|
} // namespace mlx
|