MLX
 
Loading...
Searching...
No Matches
utils.h File Reference
#include <metal_math>
#include "bf16.h"
#include "mlx/backend/metal/kernels/bf16_math.h"
#include "mlx/backend/metal/kernels/complex.h"
#include "mlx/backend/metal/kernels/defines.h"

Go to the source code of this file.

Classes

struct  Limits< U >
 
struct  Limits< uint8_t >
 
struct  Limits< uint16_t >
 
struct  Limits< uint32_t >
 
struct  Limits< uint64_t >
 
struct  Limits< int8_t >
 
struct  Limits< int16_t >
 
struct  Limits< int32_t >
 
struct  Limits< int64_t >
 
struct  Limits< half >
 
struct  Limits< float >
 
struct  Limits< bfloat16_t >
 
struct  Limits< bool >
 
struct  Limits< complex64_t >
 
struct  LoopedElemToLoc< DIM, OffsetT, General >
 
struct  LoopedElemToLoc< 1, OffsetT, true >
 
struct  LoopedElemToLoc< 1, OffsetT, false >
 
struct  ConditionalType< condition, T, U >
 
struct  ConditionalType< true, T, U >
 

Macros

#define instantiate_default_limit(type)
 
#define instantiate_float_limit(type)
 
#define MLX_MTL_PRAGMA_UNROLL   _Pragma("clang loop unroll(full)")
 

Typedefs

typedef half float16_t
 

Functions

template<typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc (IdxT elem, constant const int *shape, constant const int64_t *strides, int ndim)
 
template<typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc (uint3 elem, constant const int *shape, constant const int64_t *strides, int ndim)
 
template<typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc_1 (uint elem, constant const int64_t &stride)
 
template<typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc_2 (uint2 elem, constant const int64_t strides[2])
 
template<typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc_3 (uint3 elem, constant const int64_t strides[3])
 
template<typename IdxT = int64_t>
METAL_FUNC vec< IdxT, 2 > elem_to_loc_2_nd (uint3 elem, constant const int *shape, constant const int64_t *a_strides, constant const int64_t *b_strides, int ndim)
 
template<typename IdxT = int64_t>
METAL_FUNC vec< IdxT, 3 > elem_to_loc_3_nd (uint3 elem, constant const int *shape, constant const int64_t *a_strides, constant const int64_t *b_strides, constant const int64_t *c_strides, int ndim)
 
template<typename T, typename U>
ceildiv (T N, U M)
 Compute ceil((float)N/(float)M)
 
float log1p (float x)
 
bfloat16_t log1p (bfloat16_t x)
 
uint64_t simd_shuffle_down (uint64_t data, uint16_t delta)
 
int64_t simd_shuffle_down (int64_t data, uint16_t delta)
 
bool simd_shuffle_down (bool data, uint16_t delta)
 
complex64_t simd_shuffle_down (complex64_t data, uint16_t delta)
 
uint64_t simd_shuffle_up (uint64_t data, uint16_t delta)
 
int64_t simd_shuffle_up (int64_t data, uint16_t delta)
 
bool simd_shuffle_up (bool data, uint16_t delta)
 
complex64_t simd_shuffle_up (complex64_t data, uint16_t delta)
 
uint64_t simd_shuffle_and_fill_up (uint64_t data, uint64_t filling, uint16_t delta)
 
int64_t simd_shuffle_and_fill_up (int64_t data, int64_t filling, uint16_t delta)
 
bool simd_shuffle_and_fill_up (bool data, bool filling, uint16_t delta)
 
complex64_t simd_shuffle_and_fill_up (complex64_t data, complex64_t filling, uint16_t delta)
 
uint64_t simd_shuffle (uint64_t data, uint16_t lane)
 
int64_t simd_shuffle (int64_t data, uint16_t lane)
 
bool simd_shuffle (bool data, uint16_t lane)
 
complex64_t simd_shuffle (complex64_t data, uint16_t lane)
 

Macro Definition Documentation

◆ instantiate_default_limit

#define instantiate_default_limit ( type)
Value:
template <> \
struct Limits<type> { \
static constexpr constant type max = metal::numeric_limits<type>::max(); \
static constexpr constant type min = metal::numeric_limits<type>::min(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
metal::numeric_limits<type>::min(); \
};
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
Definition utils.h:23

◆ instantiate_float_limit

#define instantiate_float_limit ( type)
Value:
template <> \
struct Limits<type> { \
static constexpr constant type max = \
metal::numeric_limits<type>::infinity(); \
static constexpr constant type min = \
-metal::numeric_limits<type>::infinity(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
-metal::numeric_limits<type>::max(); \
};

◆ MLX_MTL_PRAGMA_UNROLL

#define MLX_MTL_PRAGMA_UNROLL   _Pragma("clang loop unroll(full)")

Typedef Documentation

◆ float16_t

typedef half float16_t

Function Documentation

◆ ceildiv()

template<typename T, typename U>
T ceildiv ( T N,
U M )
inline

Compute ceil((float)N/(float)M)

◆ elem_to_loc() [1/2]

template<typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc ( IdxT elem,
constant const int * shape,
constant const int64_t * strides,
int ndim )

◆ elem_to_loc() [2/2]

template<typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc ( uint3 elem,
constant const int * shape,
constant const int64_t * strides,
int ndim )

◆ elem_to_loc_1()

template<typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc_1 ( uint elem,
constant const int64_t & stride )

◆ elem_to_loc_2()

template<typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc_2 ( uint2 elem,
constant const int64_t strides[2] )

◆ elem_to_loc_2_nd()

template<typename IdxT = int64_t>
METAL_FUNC vec< IdxT, 2 > elem_to_loc_2_nd ( uint3 elem,
constant const int * shape,
constant const int64_t * a_strides,
constant const int64_t * b_strides,
int ndim )

◆ elem_to_loc_3()

template<typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc_3 ( uint3 elem,
constant const int64_t strides[3] )

◆ elem_to_loc_3_nd()

template<typename IdxT = int64_t>
METAL_FUNC vec< IdxT, 3 > elem_to_loc_3_nd ( uint3 elem,
constant const int * shape,
constant const int64_t * a_strides,
constant const int64_t * b_strides,
constant const int64_t * c_strides,
int ndim )

◆ log1p() [1/2]

bfloat16_t log1p ( bfloat16_t x)
inline

◆ log1p() [2/2]

float log1p ( float x)
inline

◆ simd_shuffle() [1/4]

bool simd_shuffle ( bool data,
uint16_t lane )
inline

◆ simd_shuffle() [2/4]

complex64_t simd_shuffle ( complex64_t data,
uint16_t lane )
inline

◆ simd_shuffle() [3/4]

int64_t simd_shuffle ( int64_t data,
uint16_t lane )
inline

◆ simd_shuffle() [4/4]

uint64_t simd_shuffle ( uint64_t data,
uint16_t lane )
inline

◆ simd_shuffle_and_fill_up() [1/4]

bool simd_shuffle_and_fill_up ( bool data,
bool filling,
uint16_t delta )
inline

◆ simd_shuffle_and_fill_up() [2/4]

complex64_t simd_shuffle_and_fill_up ( complex64_t data,
complex64_t filling,
uint16_t delta )
inline

◆ simd_shuffle_and_fill_up() [3/4]

int64_t simd_shuffle_and_fill_up ( int64_t data,
int64_t filling,
uint16_t delta )
inline

◆ simd_shuffle_and_fill_up() [4/4]

uint64_t simd_shuffle_and_fill_up ( uint64_t data,
uint64_t filling,
uint16_t delta )
inline

◆ simd_shuffle_down() [1/4]

bool simd_shuffle_down ( bool data,
uint16_t delta )
inline

◆ simd_shuffle_down() [2/4]

complex64_t simd_shuffle_down ( complex64_t data,
uint16_t delta )
inline

◆ simd_shuffle_down() [3/4]

int64_t simd_shuffle_down ( int64_t data,
uint16_t delta )
inline

◆ simd_shuffle_down() [4/4]

uint64_t simd_shuffle_down ( uint64_t data,
uint16_t delta )
inline

◆ simd_shuffle_up() [1/4]

bool simd_shuffle_up ( bool data,
uint16_t delta )
inline

◆ simd_shuffle_up() [2/4]

complex64_t simd_shuffle_up ( complex64_t data,
uint16_t delta )
inline

◆ simd_shuffle_up() [3/4]

int64_t simd_shuffle_up ( int64_t data,
uint16_t delta )
inline

◆ simd_shuffle_up() [4/4]

uint64_t simd_shuffle_up ( uint64_t data,
uint16_t delta )
inline