MLX
Loading...
Searching...
No Matches
utils.h File Reference

Go to the source code of this file.

Classes

struct  Limits< U >
 
struct  Limits< uint8_t >
 
struct  Limits< uint16_t >
 
struct  Limits< uint32_t >
 
struct  Limits< uint64_t >
 
struct  Limits< int8_t >
 
struct  Limits< int16_t >
 
struct  Limits< int32_t >
 
struct  Limits< int64_t >
 
struct  Limits< half >
 
struct  Limits< float >
 
struct  Limits< bfloat16_t >
 
struct  Limits< bool >
 
struct  Limits< complex64_t >
 
struct  looped_elem_to_loc< dim, offset_t >
 
struct  looped_elem_to_loc< 1, offset_t >
 
struct  looped_elem_to_loc< 0, offset_t >
 

Macros

#define instantiate_default_limit(type)
 
#define instantiate_float_limit(type)
 
#define MLX_MTL_PRAGMA_UNROLL   _Pragma("clang loop unroll(full)")
 

Typedefs

typedef half float16_t
 

Functions

template<typename stride_t >
METAL_FUNC stride_t elem_to_loc (uint elem, constant const int *shape, constant const stride_t *strides, int ndim)
 
template<typename stride_t >
METAL_FUNC stride_t elem_to_loc (stride_t elem, constant const int *shape, constant const stride_t *strides, int ndim)
 
template<typename stride_t >
METAL_FUNC stride_t elem_to_loc (uint3 elem, constant const int *shape, constant const stride_t *strides, int ndim)
 
template<typename stride_t >
METAL_FUNC stride_t elem_to_loc_1 (uint elem, constant const stride_t &stride)
 
template<typename stride_t >
METAL_FUNC stride_t elem_to_loc_2 (uint2 elem, constant const stride_t strides[2])
 
template<typename stride_t >
METAL_FUNC stride_t elem_to_loc_3 (uint3 elem, constant const stride_t strides[3])
 
template<typename stride_t >
METAL_FUNC ulong2 elem_to_loc_2_nd (uint3 elem, constant const int *shape, constant const stride_t *a_strides, constant const stride_t *b_strides, int ndim)
 
METAL_FUNC ulong3 elem_to_loc_3_nd (uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const size_t *c_strides, int ndim)
 
template<typename T , typename U >
ceildiv (T N, U M)
 Compute ceil((float)N/(float)M)
 
float log1p (float x)
 
bfloat16_t log1p (bfloat16_t x)
 
uint64_t simd_shuffle_down (uint64_t data, uint16_t delta)
 
int64_t simd_shuffle_down (int64_t data, uint16_t delta)
 
bool simd_shuffle_down (bool data, uint16_t delta)
 
complex64_t simd_shuffle_down (complex64_t data, uint16_t delta)
 

Macro Definition Documentation

◆ instantiate_default_limit

#define instantiate_default_limit ( type)
Value:
template <> \
struct Limits<type> { \
static constexpr constant type max = metal::numeric_limits<type>::max(); \
static constexpr constant type min = metal::numeric_limits<type>::min(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
metal::numeric_limits<type>::min(); \
};
Definition utils.h:17
static const constant U max
Definition utils.h:18
static const constant U finite_max
Definition utils.h:20
static const constant U min
Definition utils.h:19
static const constant U finite_min
Definition utils.h:21

◆ instantiate_float_limit

#define instantiate_float_limit ( type)
Value:
template <> \
struct Limits<type> { \
static constexpr constant type max = \
metal::numeric_limits<type>::infinity(); \
static constexpr constant type min = \
-metal::numeric_limits<type>::infinity(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
-metal::numeric_limits<type>::max(); \
};

◆ MLX_MTL_PRAGMA_UNROLL

#define MLX_MTL_PRAGMA_UNROLL   _Pragma("clang loop unroll(full)")

Typedef Documentation

◆ float16_t

typedef half float16_t

Function Documentation

◆ ceildiv()

template<typename T , typename U >
T ceildiv ( T N,
U M )
inline

Compute ceil((float)N/(float)M)

◆ elem_to_loc() [1/3]

template<typename stride_t >
METAL_FUNC stride_t elem_to_loc ( stride_t elem,
constant const int * shape,
constant const stride_t * strides,
int ndim )

◆ elem_to_loc() [2/3]

template<typename stride_t >
METAL_FUNC stride_t elem_to_loc ( uint elem,
constant const int * shape,
constant const stride_t * strides,
int ndim )

◆ elem_to_loc() [3/3]

template<typename stride_t >
METAL_FUNC stride_t elem_to_loc ( uint3 elem,
constant const int * shape,
constant const stride_t * strides,
int ndim )

◆ elem_to_loc_1()

template<typename stride_t >
METAL_FUNC stride_t elem_to_loc_1 ( uint elem,
constant const stride_t & stride )

◆ elem_to_loc_2()

template<typename stride_t >
METAL_FUNC stride_t elem_to_loc_2 ( uint2 elem,
constant const stride_t strides[2] )

◆ elem_to_loc_2_nd()

template<typename stride_t >
METAL_FUNC ulong2 elem_to_loc_2_nd ( uint3 elem,
constant const int * shape,
constant const stride_t * a_strides,
constant const stride_t * b_strides,
int ndim )

◆ elem_to_loc_3()

template<typename stride_t >
METAL_FUNC stride_t elem_to_loc_3 ( uint3 elem,
constant const stride_t strides[3] )

◆ elem_to_loc_3_nd()

METAL_FUNC ulong3 elem_to_loc_3_nd ( uint3 elem,
constant const int * shape,
constant const size_t * a_strides,
constant const size_t * b_strides,
constant const size_t * c_strides,
int ndim )

◆ log1p() [1/2]

bfloat16_t log1p ( bfloat16_t x)
inline

◆ log1p() [2/2]

float log1p ( float x)
inline

◆ simd_shuffle_down() [1/4]

bool simd_shuffle_down ( bool data,
uint16_t delta )
inline

◆ simd_shuffle_down() [2/4]

complex64_t simd_shuffle_down ( complex64_t data,
uint16_t delta )
inline

◆ simd_shuffle_down() [3/4]

int64_t simd_shuffle_down ( int64_t data,
uint16_t delta )
inline

◆ simd_shuffle_down() [4/4]

uint64_t simd_shuffle_down ( uint64_t data,
uint16_t delta )
inline