diff --git a/docs/build/html/.buildinfo b/docs/build/html/.buildinfo index 50696688e..3cf7fbd01 100644 --- a/docs/build/html/.buildinfo +++ b/docs/build/html/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: c431be841be9331fc029403834cef1bf +config: ad0493b39127084c2ab6331071fb3c9b tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/build/html/_static/documentation_options.js b/docs/build/html/_static/documentation_options.js index 51bee6816..213607d15 100644 --- a/docs/build/html/_static/documentation_options.js +++ b/docs/build/html/_static/documentation_options.js @@ -1,5 +1,5 @@ const DOCUMENTATION_OPTIONS = { - VERSION: '0.19.1', + VERSION: '0.19.2', LANGUAGE: 'en', COLLAPSE_INDEX: false, BUILDER: 'html', diff --git a/docs/build/html/annotated.html b/docs/build/html/annotated.html index 780672292..87141e0a0 100644 --- a/docs/build/html/annotated.html +++ b/docs/build/html/annotated.html @@ -91,6 +91,14 @@ $(function(){ initResizable(false); });
+
|
+ +inline | +
+ MLX
+
+ |
+
+Files | |
integral_constant.h | |
type_traits.h | |
+ MLX
+
+ |
+
Go to the source code of this file.
++Classes | |
struct | mlx::steel::integral_constant< T, v > |
struct | mlx::steel::is_integral< T > |
struct | mlx::steel::is_integral< integral_constant< T, v > > |
+Namespaces | |
namespace | mlx |
namespace | mlx::steel |
+Macros | |
#define | integral_const_binop(__op__, __operator__) |
+Typedefs | |
template<bool B> | |
using | mlx::steel::bool_constant = integral_constant<bool, B> |
using | mlx::steel::true_type = bool_constant<true> |
using | mlx::steel::false_type = bool_constant<false> |
template<int val> | |
using | mlx::steel::Int = integral_constant<int, val> |
+Functions | |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | mlx::steel::operator+ (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | mlx::steel::operator- (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | mlx::steel::operator* (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | mlx::steel::operator/ (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | mlx::steel::operator== (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | mlx::steel::operator!= (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | mlx::steel::operator< (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | mlx::steel::operator> (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | mlx::steel::operator<= (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | mlx::steel::operator>= (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | mlx::steel::operator&& (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | mlx::steel::operator|| (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T > | |
METAL_FUNC constexpr T | mlx::steel::sum (T x) |
template<typename T , typename... Us> | |
METAL_FUNC constexpr auto | mlx::steel::sum (T x, Us... us) |
+Variables | |
template<typename T > | |
constexpr constant bool | mlx::steel::is_integral_v = is_integral<T>::value |
#define integral_const_binop | +( | +__op__, | +|
+ | + | __operator__ ) | +
+ MLX
+
+ |
+
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
Go to the source code of this file.
namespace | mlx::steel |
+Functions | |
template<typename T , typename U , int M, int N, int K> | |
METAL_FUNC void | mlx::steel::tile_matmad (thread MMATile< T, M, N > &D, thread MMATile< U, M, K > &A, thread MMATile< U, K, N > &B, thread MMATile< T, M, N > &C) |
+Typedefs | |
template<typename... Ts> | |
using | void_t = typename make_void<Ts...>::type |
template<typename T > | |
using | pointer_element_t = typename pointer_element<remove_cv_t<T>>::type |
Functions | |
METAL_FUNC bfloat16_t | simd_xor (bfloat16_t data) |
using metal::pointer_element_t = typename pointer_element<remove_cv_t<T>>::type | +
using metal::void_t = typename make_void<Ts...>::type | +
+Typedefs | |
template<bool B> | |
using | bool_constant = integral_constant<bool, B> |
using | true_type = bool_constant<true> |
using | false_type = bool_constant<false> |
template<int val> | |
using | Int = integral_constant<int, val> |
+Functions | |
template<typename T , typename U , int M, int N, int K> | |
METAL_FUNC void | tile_matmad (thread MMATile< T, M, N > &D, thread MMATile< U, M, K > &A, thread MMATile< U, K, N > &B, thread MMATile< T, M, N > &C) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | operator+ (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | operator- (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | operator* (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | operator/ (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | operator== (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | operator!= (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | operator< (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | operator> (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | operator<= (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | operator>= (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | operator&& (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T , T tv, typename U , U uv> | |
METAL_FUNC constexpr auto | operator|| (integral_constant< T, tv >, integral_constant< U, uv >) |
template<typename T > | |
METAL_FUNC constexpr T | sum (T x) |
template<typename T , typename... Us> | |
METAL_FUNC constexpr auto | sum (T x, Us... us) |
+Variables | |
template<typename T > | |
constexpr constant bool | is_integral_v = is_integral<T>::value |
using mlx::steel::bool_constant = integral_constant<bool, B> | +
using mlx::steel::false_type = bool_constant<false> | +
using mlx::steel::Int = integral_constant<int, val> | +
using mlx::steel::true_type = bool_constant<true> | +
+
|
+ +constexpr | +
+
|
+ +constexpr | +
+
|
+ +constexpr | +
+
|
+ +constexpr | +
+
|
+ +constexpr | +
+
|
+ +constexpr | +
+
|
+ +constexpr | +
+
|
+ +constexpr | +
+
|
+ +constexpr | +
+
|
+ +constexpr | +
+
|
+ +constexpr | +
+
|
+ +constexpr | +
+
|
+ +constexpr | +
+
|
+ +constexpr | +
METAL_FUNC void mlx::steel::tile_matmad | +( | +thread MMATile< T, M, N > & | D, | +
+ | + | thread MMATile< U, M, K > & | A, | +
+ | + | thread MMATile< U, K, N > & | B, | +
+ | + | thread MMATile< T, M, N > & | C ) | +
+
|
+ +constexpr | +
Functions | |
template<typename T , typename IdxT , typename Op , int NIDX> | |
METAL_FUNC void | scatter_1d_index_impl (const device T *updates, device mlx_atomic< T > *out, const constant int *out_shape, const constant size_t *out_strides, const constant size_t &out_ndim, const constant int *upd_shape, const constant size_t &upd_ndim, const constant size_t &upd_size, const thread array< const device IdxT *, NIDX > &idx_buffers, uint2 gid) |
template<typename T , typename IdxT , typename Op , int NIDX> | |
METAL_FUNC void | scatter_impl (const device T *updates, device mlx_atomic< T > *out, const constant int *upd_shape, const constant size_t *upd_strides, const constant size_t &upd_ndim, const constant size_t &upd_size, const constant int *out_shape, const constant size_t *out_strides, const constant size_t &out_ndim, const constant int *axes, const thread Indices< IdxT, NIDX > &indices, uint2 gid) |
template<typename T , typename IdxT , typename Op , int NIDX, bool UPD_ROW_CONTIG, int NWORK> | |
METAL_FUNC void | scatter_impl (const device T *updates, device mlx_atomic< T > *out, const constant int *upd_shape, const constant size_t *upd_strides, const constant size_t &upd_ndim, const constant size_t &upd_size, const constant int *out_shape, const constant size_t *out_strides, const constant size_t &out_ndim, const constant int *axes, const constant size_t &idx_size, const thread Indices< IdxT, NIDX > &indices, uint2 gid) |
METAL_FUNC void scatter_1d_index_impl | -( | -const device T * | updates, | -
- | - | device mlx_atomic< T > * | out, | -
- | - | const constant int * | out_shape, | -
- | - | const constant size_t * | out_strides, | -
- | - | const constant size_t & | out_ndim, | -
- | - | const constant int * | upd_shape, | -
- | - | const constant size_t & | upd_ndim, | -
- | - | const constant size_t & | upd_size, | -
- | - | const thread array< const device IdxT *, NIDX > & | idx_buffers, | -
- | - | uint2 | gid ) | -
METAL_FUNC void scatter_impl | @@ -228,6 +162,11 @@ template<typename T , typename IdxT , typename Op , int NIDX>const constant int * | axes, | |
+ | + | const constant size_t & | idx_size, | +
diff --git a/docs/build/html/scatter_8h_source.html b/docs/build/html/scatter_8h_source.html index 694aa7df9..2c89521d7 100644 --- a/docs/build/html/scatter_8h_source.html +++ b/docs/build/html/scatter_8h_source.html @@ -97,86 +97,64 @@ $(function(){ initResizable(false); }); - - | |||
const constant size_t * | strides | ||
const constant bool * | row_contiguous | ||
const int | ndim | ||
const constant bool* Indices< IdxT, NIDX >::row_contiguous | +
+ MLX
+
+ |
+
#include <type_traits.h>
+ MLX
+
+ |
+
#include <type_traits.h>
+ MLX
+
+ |
+
This is the complete list of members for metal::make_void< Ts >, including all inherited members.
+type typedef | metal::make_void< Ts > |
+ MLX
+
+ |
+
#include <type_traits.h>
+Public Types | |
typedef void | type |
void metal::make_void< Ts >::type | +
+ MLX
+
+ |
+
#include <type_traits.h>
+ MLX
+
+ |
+
This is the complete list of members for metal::pointer_element< constant T * >, including all inherited members.
+type typedef | metal::pointer_element< constant T * > |
+ MLX
+
+ |
+
#include <type_traits.h>
+Public Types | |
using | type = remove_cv_t<T> |
using metal::pointer_element< constant T * >::type = remove_cv_t<T> | +
+ MLX
+
+ |
+
This is the complete list of members for metal::pointer_element< device T * >, including all inherited members.
+type typedef | metal::pointer_element< device T * > |
+ MLX
+
+ |
+
#include <type_traits.h>
+Public Types | |
using | type = remove_cv_t<T> |
using metal::pointer_element< device T * >::type = remove_cv_t<T> | +
+ MLX
+
+ |
+
This is the complete list of members for metal::pointer_element< thread T * >, including all inherited members.
+type typedef | metal::pointer_element< thread T * > |
+ MLX
+
+ |
+
#include <type_traits.h>
+Public Types | |
using | type = remove_cv_t<T> |
using metal::pointer_element< thread T * >::type = remove_cv_t<T> | +
+ MLX
+
+ |
+
This is the complete list of members for metal::pointer_element< threadgroup T * >, including all inherited members.
+type typedef | metal::pointer_element< threadgroup T * > |
+ MLX
+
+ |
+
#include <type_traits.h>
+Public Types | |
using | type = remove_cv_t<T> |
using metal::pointer_element< threadgroup T * >::type = remove_cv_t<T> | +
+ MLX
+
+ |
+
#include <mma.h>
+ MLX
+
+ |
+
This is the complete list of members for mlx::steel::BaseMMAFrag< T, 8, 8 >, including all inherited members.
+frag_type typedef | mlx::steel::BaseMMAFrag< T, 8, 8 > | |
get_coord(ushort simd_lane_id) | mlx::steel::BaseMMAFrag< T, 8, 8 > | inlinestatic |
kElemCols | mlx::steel::BaseMMAFrag< T, 8, 8 > | |
kElemRows | mlx::steel::BaseMMAFrag< T, 8, 8 > | |
kElemsPerFrag | mlx::steel::BaseMMAFrag< T, 8, 8 > | |
kFragCols | mlx::steel::BaseMMAFrag< T, 8, 8 > | |
kFragRows | mlx::steel::BaseMMAFrag< T, 8, 8 > | |
load(thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y) | mlx::steel::BaseMMAFrag< T, 8, 8 > | inlinestatic |
load_safe(thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{}) | mlx::steel::BaseMMAFrag< T, 8, 8 > | inlinestatic |
mat_type typedef | mlx::steel::BaseMMAFrag< T, 8, 8 > | |
mma(thread frag_type &D, thread frag_type &A, thread frag_type &B, thread frag_type &C) | mlx::steel::BaseMMAFrag< T, 8, 8 > | inlinestatic |
mma(thread mat_type &D, thread mat_type &A, thread mat_type &B, thread mat_type &C) | mlx::steel::BaseMMAFrag< T, 8, 8 > | inlinestatic |
store(const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y) | mlx::steel::BaseMMAFrag< T, 8, 8 > | inlinestatic |
store_safe(const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{}) | mlx::steel::BaseMMAFrag< T, 8, 8 > | inlinestatic |
+ MLX
+
+ |
+
#include <mma.h>
+Public Types | |
typedef metal::simdgroup_matrix< T, kFragRows, kFragCols > | mat_type |
typedef metal::vec< T, kElemsPerFrag > | frag_type |
+Static Public Member Functions | |
static METAL_FUNC constexpr short2 | get_coord (ushort simd_lane_id) |
template<typename SrcPtrType , typename StrX , typename StrY > | |
static METAL_FUNC constexpr void | load (thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y) |
template<typename SrcPtrType , typename StrX , typename StrY , typename LimX , typename LimY , typename OffX , typename OffY > | |
static METAL_FUNC constexpr void | load_safe (thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{}) |
template<typename DstPtrType , typename StrX , typename StrY > | |
static METAL_FUNC constexpr void | store (const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y) |
template<typename DstPtrType , typename StrX , typename StrY , typename LimX , typename LimY , typename OffX , typename OffY > | |
static METAL_FUNC constexpr void | store_safe (const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{}) |
static METAL_FUNC constexpr void | mma (thread frag_type &D, thread frag_type &A, thread frag_type &B, thread frag_type &C) |
static METAL_FUNC constexpr void | mma (thread mat_type &D, thread mat_type &A, thread mat_type &B, thread mat_type &C) |
+Public Attributes | |
STEEL_CONST int | kFragRows = 8 |
STEEL_CONST int | kFragCols = 8 |
STEEL_CONST int | kElemsPerFrag = (kFragRows * kFragCols) / 32 |
STEEL_CONST int | kElemRows = 1 |
STEEL_CONST int | kElemCols = 2 |
metal::vec<T, kElemsPerFrag> mlx::steel::BaseMMAFrag< T, 8, 8 >::frag_type | +
metal::simdgroup_matrix<T, kFragRows, kFragCols> mlx::steel::BaseMMAFrag< T, 8, 8 >::mat_type | +
+
|
+ +inlinestaticconstexpr | +
+
|
+ +inlinestaticconstexpr | +
+
|
+ +inlinestaticconstexpr | +
+
|
+ +inlinestaticconstexpr | +
+
|
+ +inlinestaticconstexpr | +
+
|
+ +inlinestaticconstexpr | +
+
|
+ +inlinestaticconstexpr | +
STEEL_CONST int mlx::steel::BaseMMAFrag< T, 8, 8 >::kElemCols = 2 | +
STEEL_CONST int mlx::steel::BaseMMAFrag< T, 8, 8 >::kElemRows = 1 | +
STEEL_CONST int mlx::steel::BaseMMAFrag< T, 8, 8 >::kElemsPerFrag = (kFragRows * kFragCols) / 32 | +
STEEL_CONST int mlx::steel::BaseMMAFrag< T, 8, 8 >::kFragCols = 8 | +
STEEL_CONST int mlx::steel::BaseMMAFrag< T, 8, 8 >::kFragRows = 8 | +
This is the complete list of members for mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >, including all inherited members.
#include <mma.h>
+Public Types | |
using | MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize> |
Public Member Functions | |
METAL_FUNC | BlockMMA (ushort simd_group_id, ushort simd_lane_id) |
METAL_FUNC void | mma (const threadgroup T *As, const threadgroup T *Bs) |
METAL_FUNC void | store_result (device U *D, const int ldd) const |
METAL_FUNC void | store_result_safe (device U *D, const int ldd, short2 dst_tile_dims) const |
METAL_FUNC void | store_result (device U *D, const int ldd) |
METAL_FUNC void | store_result_safe (device U *D, const int ldd, short2 dst_tile_dims) |
template<typename UnaryEpilogue > | |
METAL_FUNC void | apply_epilogue (thread const UnaryEpilogue &epilogue_op) |
Public Attributes | |
STEEL_CONST short | TM_stride = 8 * WM |
STEEL_CONST short | kFragSize = 8 |
STEEL_CONST short | TM_stride = kFragSize * WM |
STEEL_CONST short | TN_stride = 8 * WN |
STEEL_CONST short | TN_stride = kFragSize * WN |
STEEL_CONST short | TM = BM / TM_stride |
STEEL_CONST short | TN = BN / TN_stride |
STEEL_CONST short | simd_stride_a |
STEEL_CONST short | simd_stride_b |
STEEL_CONST short | jump_a = {transpose_a ? lda_tgp : 1} |
STEEL_CONST short | jump_b = {transpose_b ? ldb_tgp : 1} |
STEEL_CONST short | tile_stride_a = {transpose_a ? 8 * lda_tgp : 8} |
STEEL_CONST short | A_str_m = transpose_a ? 1 : lda_tgp |
STEEL_CONST short | A_str_k = transpose_a ? lda_tgp : 1 |
STEEL_CONST short | B_str_k = transpose_b ? 1 : ldb_tgp |
STEEL_CONST short | B_str_n = transpose_b ? ldb_tgp : 1 |
STEEL_CONST short | tile_stride_a = kFragSize * A_str_k |
STEEL_CONST short | tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp} |
STEEL_CONST short | tile_stride_b = kFragSize * B_str_k |
simdgroup_matrix< AccumType, 8, 8 > | Asimd [TM] |
simdgroup_matrix< AccumType, 8, 8 > | Bsimd [TN] |
simdgroup_matrix< AccumType, 8, 8 > | results [TM *TN] |
const short | tm |
const short | tn |
MMATile< AccumType, TM, 1, MMAFrag_acc_t > | Atile |
MMATile< AccumType, 1, TN, MMAFrag_acc_t > | Btile |
MMATile< AccumType, TM, TN, MMAFrag_acc_t > | Ctile |
short | sm |
short | sn |
short | Bs_offset |
using mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize> | +
STEEL_CONST short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::A_str_k = transpose_a ? lda_tgp : 1 | +
STEEL_CONST short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::A_str_m = transpose_a ? 1 : lda_tgp | +
simdgroup_matrix<AccumType, 8, 8> mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::Asimd[TM] | +MMATile<AccumType, TM, 1, MMAFrag_acc_t> mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::Atile | +
STEEL_CONST short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::B_str_k = transpose_b ? 1 : ldb_tgp | +
STEEL_CONST short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::B_str_n = transpose_b ? ldb_tgp : 1 |
simdgroup_matrix<AccumType, 8, 8> mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::Bsimd[TN] | +MMATile<AccumType, 1, TN, MMAFrag_acc_t> mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::Btile |
STEEL_CONST short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::jump_a = {transpose_a ? lda_tgp : 1} | +MMATile<AccumType, TM, TN, MMAFrag_acc_t> mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::Ctile |
STEEL_CONST short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::jump_b = {transpose_b ? ldb_tgp : 1} | +STEEL_CONST short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::kFragSize = 8 |
simdgroup_matrix<AccumType, 8, 8> mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::results[TM *TN] | -
STEEL_CONST short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::simd_stride_a | -
STEEL_CONST short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::simd_stride_b | -
STEEL_CONST short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::tile_stride_a = {transpose_a ? 8 * lda_tgp : 8} | +STEEL_CONST short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::tile_stride_a = kFragSize * A_str_k |
STEEL_CONST short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp} | +STEEL_CONST short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::tile_stride_b = kFragSize * B_str_k |
const short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::tm | -
STEEL_CONST short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::TM_stride = 8 * WM | +STEEL_CONST short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::TM_stride = kFragSize * WM |
const short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::tn | -
STEEL_CONST short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::TN_stride = 8 * WN | +STEEL_CONST short mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >::TN_stride = kFragSize * WN |
+ MLX
+
+ |
+
This is the complete list of members for mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >, including all inherited members.
+
+ MLX
+
+ |
+
#include <mma.h>
+Public Types | |
using | MMAFrag_t = MMAFrag_ |
using | elem_type = T |
typedef MMAFrag_t::mat_type | mat_type |
typedef MMAFrag_t::frag_type | frag_type |
+Public Member Functions | |
METAL_FUNC | MMATile () thread |
METAL_FUNC constexpr void | clear () |
METAL_FUNC constexpr thread frag_type & | frag_at (const short i, const short j) |
METAL_FUNC constexpr const thread frag_type & | frag_at (const short i, const short j) const |
METAL_FUNC mat_type | mat_at (const short i, const short j) |
METAL_FUNC thread elem_type * | elems () |
METAL_FUNC const thread elem_type * | elems () const |
template<typename U , int w_x, int w_y, int str_x, int str_y> | |
METAL_FUNC void | load (const threadgroup U *src) |
template<typename U , int w_x, int w_y, int str_x, int str_y> | |
METAL_FUNC void | store (threadgroup U *dst) const |
template<typename U , int w_x, int w_y> | |
METAL_FUNC void | load (const device U *src, const int ld) |
template<typename U , int w_x, int w_y> | |
METAL_FUNC void | store (device U *dst, const int ld) const |
template<typename U , int w_x, int w_y> | |
METAL_FUNC void | load_safe (const device U *src, const int ld, const short2 src_tile_dims) |
template<typename U , int w_x, int w_y> | |
METAL_FUNC void | store_safe (device U *dst, const int ld, const short2 dst_tile_dims) const |
+Public Attributes | |
STEEL_CONST int | kFragRows = MMAFrag_t::kFragRows |
STEEL_CONST int | kFragCols = MMAFrag_t::kFragCols |
STEEL_CONST int | kElemsPerFrag = MMAFrag_t::kElemsPerFrag |
STEEL_CONST int | kTileRows = kTileRows_ |
STEEL_CONST int | kTileCols = kTileCols_ |
STEEL_CONST int | kRows = kTileRows * kFragRows |
STEEL_CONST int | kCols = kTileCols * kFragCols |
STEEL_CONST int | kNumFrags = kTileRows * kTileCols |
STEEL_CONST int | kElemsPerTile = kNumFrags * kElemsPerFrag |
frag_type | val_frags [kNumFrags] = {frag_type(0)} |
using mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >::elem_type = T | +
MMAFrag_t::frag_type mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >::frag_type | +
MMAFrag_t::mat_type mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >::mat_type | +
using mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >::MMAFrag_t = MMAFrag_ | +
+
|
+ +inline | +
+
|
+ +inlineconstexpr | +
+
|
+ +inline | +
+
|
+ +inline | +
+
|
+ +inlineconstexpr | +
+
|
+ +inlineconstexpr | +
+
|
+ +inline | +
+
|
+ +inline | +
+
|
+ +inline | +
+
|
+ +inline | +
+
|
+ +inline | +
+
|
+ +inline | +
+
|
+ +inline | +
STEEL_CONST int mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >::kCols = kTileCols * kFragCols | +
STEEL_CONST int mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >::kElemsPerFrag = MMAFrag_t::kElemsPerFrag | +
STEEL_CONST int mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >::kElemsPerTile = kNumFrags * kElemsPerFrag | +
STEEL_CONST int mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >::kFragCols = MMAFrag_t::kFragCols | +
STEEL_CONST int mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >::kFragRows = MMAFrag_t::kFragRows | +
STEEL_CONST int mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >::kNumFrags = kTileRows * kTileCols | +
STEEL_CONST int mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >::kRows = kTileRows * kFragRows | +
STEEL_CONST int mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >::kTileCols = kTileCols_ | +
STEEL_CONST int mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >::kTileRows = kTileRows_ | +
frag_type mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >::val_frags[kNumFrags] = {frag_type(0)} | +
+ MLX
+
+ |
+
This is the complete list of members for mlx::steel::integral_constant< T, v >, including all inherited members.
+operator value_type() const noexcept | mlx::steel::integral_constant< T, v > | inline |
type typedef | mlx::steel::integral_constant< T, v > | |
value | mlx::steel::integral_constant< T, v > | static |
value_type typedef | mlx::steel::integral_constant< T, v > |
+ MLX
+
+ |
+
#include <integral_constant.h>
+Public Types | |
using | value_type = T |
using | type = integral_constant |
+Public Member Functions | |
METAL_FUNC constexpr | operator value_type () const noexcept |
+Static Public Attributes | |
static constexpr constant T | value = v |
using mlx::steel::integral_constant< T, v >::type = integral_constant | +
using mlx::steel::integral_constant< T, v >::value_type = T | +
+
|
+ +inlineconstexprnoexcept | +
+
|
+ +staticconstexpr | +
+ MLX
+
+ |
+
This is the complete list of members for mlx::steel::is_integral< T >, including all inherited members.
+operator value_type() const noexcept | mlx::steel::integral_constant< T, v > | inline |
type typedef | mlx::steel::integral_constant< T, v > | |
value | mlx::steel::integral_constant< T, v > | static |
value_type typedef | mlx::steel::integral_constant< T, v > |
+ MLX
+
+ |
+
#include <integral_constant.h>
+Additional Inherited Members | |
![]() | |
using | value_type = T |
using | type = integral_constant |
![]() | |
METAL_FUNC constexpr | operator value_type () const noexcept |
![]() | |
static constexpr constant T | value = v |
+ MLX
+
+ |
+
This is the complete list of members for mlx::steel::is_integral< integral_constant< T, v > >, including all inherited members.
+operator value_type() const noexcept | mlx::steel::integral_constant< T, v > | inline |
type typedef | mlx::steel::integral_constant< T, v > | |
value | mlx::steel::integral_constant< T, v > | static |
value_type typedef | mlx::steel::integral_constant< T, v > |
+ MLX
+
+ |
+
#include <integral_constant.h>
+Additional Inherited Members | |
![]() | |
using | value_type = T |
using | type = integral_constant |
![]() | |
METAL_FUNC constexpr | operator value_type () const noexcept |
![]() | |
static constexpr constant T | value = v |
+ MLX
+
+ |
+
#include <metal_stdlib>
Go to the source code of this file.
++Classes | |
struct | metal::is_empty< T > |
struct | metal::make_void< Ts > |
struct | metal::is_static< T > |
struct | metal::pointer_element< T > |
struct | metal::pointer_element< thread T * > |
struct | metal::pointer_element< device T * > |
struct | metal::pointer_element< constant T * > |
struct | metal::pointer_element< threadgroup T * > |
+Namespaces | |
namespace | metal |
+Typedefs | |
template<typename... Ts> | |
using | metal::void_t = typename make_void<Ts...>::type |
template<typename T > | |
using | metal::pointer_element_t = typename pointer_element<remove_cv_t<T>>::type |
+ MLX
+
+ |
+