mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h
Jagrit Digani 78102a47ad
Update GEMM (#424)
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance 
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition 
* Update tests and benchmarks as needed
2024-01-17 12:42:39 -08:00

63 lines
1.4 KiB
C++

// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/utils.h"
///////////////////////////////////////////////////////////////////////////////
// Transforms and Epilogues
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
static METAL_FUNC OutT apply(InT x, OutT) {
return static_cast<OutT>(x);
}
};
template <typename OutT, typename InT>
struct TransformAdd {
TransformAdd(const float, const float) {}
static METAL_FUNC OutT apply(InT x, OutT c) {
return static_cast<OutT>(x) + c;
}
};
template <typename OutT, typename InT>
struct TransformAxpby {
const float alpha;
const float beta;
TransformAxpby(const float alpha_, const float beta_)
: alpha(alpha_), beta(beta_) {}
METAL_FUNC OutT apply(InT x, OutT c) const {
return static_cast<OutT>(x * alpha + (beta * c));
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
struct BlockSwizzle {
static METAL_FUNC int2
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
const int tid_x = (tid.x) >> swizzle_log;
const int tid_y =
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
return int2(tid_x, tid_y);
}
};
} // namespace steel
} // namespace mlx