MLX
 
Loading...
Searching...
No Matches
steel_attention.h File Reference

Go to the source code of this file.

Classes

struct  TransformScale< T >
 
struct  MaxOp
 
struct  SumOp
 
struct  MulOp
 
struct  SubOp
 
struct  ExpSubOp
 
struct  DivOp
 

Functions

template<typename T, int BQ, int BK, int BD, int WM, int WN, typename MaskType = float, typename AccumType = float>
void attention (const device T *Q, const device T *K, const device T *V, device T *O, const constant AttnParams *params, const constant AttnMaskParams *mask_params, const device MaskType *mask, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)
 

Variables

constant bool align_Q
 
constant bool align_K
 
constant bool has_mask
 
constant bool do_causal
 

Function Documentation

◆ attention()

template<typename T, int BQ, int BK, int BD, int WM, int WN, typename MaskType = float, typename AccumType = float>
void attention ( const device T * Q,
const device T * K,
const device T * V,
device T * O,
const constant AttnParams * params,
const constant AttnMaskParams * mask_params,
const device MaskType * mask,
uint simd_lane_id,
uint simd_group_id,
uint3 tid,
uint3 lid )

Variable Documentation

◆ align_K

constant bool align_K

◆ align_Q

constant bool align_Q

◆ do_causal

constant bool do_causal

◆ has_mask

constant bool has_mask