MLX
 
Loading...
Searching...
No Matches
mma.h File Reference
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"

Go to the source code of this file.

Classes

struct  mlx::steel::Shape2D< RInt, CInt >
 
struct  mlx::steel::Layout2D< Shape, Layout >
 
struct  mlx::steel::BaseMMAFrag< T, kFragRows_, kFragCols_ >
 
struct  mlx::steel::BaseMMAFrag< T, 8, 8 >
 
struct  mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >
 
struct  mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, lda_tgp, ldb_tgp, AccumType, Epilogue >
 

Namespaces

namespace  mlx
 
namespace  mlx::steel
 

Functions

template<typename Dtype, typename Atype, typename Btype, typename Ctype, int M, int N, int K, class MMAFragD, class MMAFragA, class MMAFragB, class MMAFragC>
METAL_FUNC void mlx::steel::tile_matmad (thread MMATile< Dtype, M, N, MMAFragD > &D, thread MMATile< Atype, M, K, MMAFragA > &A, thread MMATile< Btype, K, N, MMAFragB > &B, thread MMATile< Ctype, M, N, MMAFragC > &C)