MLX
|
#include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/complex.h"
#include "mlx/backend/metal/kernels/defines.h"
Go to the source code of this file.
Classes | |
struct | Limits< U > |
struct | Limits< uint8_t > |
struct | Limits< uint16_t > |
struct | Limits< uint32_t > |
struct | Limits< uint64_t > |
struct | Limits< int8_t > |
struct | Limits< int16_t > |
struct | Limits< int32_t > |
struct | Limits< int64_t > |
struct | Limits< half > |
struct | Limits< float > |
struct | Limits< bfloat16_t > |
struct | Limits< bool > |
struct | Limits< complex64_t > |
struct | looped_elem_to_loc< dim, offset_t > |
struct | looped_elem_to_loc< 1, offset_t > |
struct | looped_elem_to_loc< 0, offset_t > |
Macros | |
#define | instantiate_default_limit(type) |
#define | instantiate_float_limit(type) |
#define | MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") |
Typedefs | |
typedef half | float16_t |
Functions | |
template<typename stride_t > | |
METAL_FUNC stride_t | elem_to_loc (uint elem, device const int *shape, device const stride_t *strides, int ndim) |
template<typename stride_t > | |
METAL_FUNC stride_t | elem_to_loc (uint elem, constant const int *shape, constant const stride_t *strides, int ndim) |
template<typename stride_t > | |
METAL_FUNC stride_t | elem_to_loc (stride_t elem, device const int *shape, device const stride_t *strides, int ndim) |
template<typename stride_t > | |
METAL_FUNC stride_t | elem_to_loc (stride_t elem, constant const int *shape, constant const stride_t *strides, int ndim) |
template<typename stride_t > | |
METAL_FUNC stride_t | elem_to_loc (uint3 elem, constant const int *shape, constant const stride_t *strides, int ndim) |
template<typename stride_t > | |
METAL_FUNC stride_t | elem_to_loc_1 (uint elem, constant const stride_t &stride) |
template<typename stride_t > | |
METAL_FUNC stride_t | elem_to_loc_2 (uint2 elem, constant const stride_t strides[2]) |
template<typename stride_t > | |
METAL_FUNC stride_t | elem_to_loc_3 (uint3 elem, constant const stride_t strides[3]) |
template<int NDIM> | |
METAL_FUNC size_t | elem_to_loc_nd (uint elem, device const int *shape, device const size_t *strides) |
template<int NDIM> | |
METAL_FUNC size_t | elem_to_loc_nd (uint3 elem, constant const int shape[NDIM], constant const size_t strides[NDIM]) |
template<int NDIM> | |
METAL_FUNC int64_t | elem_to_loc_nd (uint elem, constant const int shape[NDIM], constant const int64_t strides[NDIM]) |
template<int NDIM> | |
METAL_FUNC int64_t | elem_to_loc_nd (uint3 elem, constant const int shape[NDIM], constant const int64_t strides[NDIM]) |
METAL_FUNC uint2 | elem_to_loc_2_nd (uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim) |
METAL_FUNC uint3 | elem_to_loc_3_nd (uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const size_t *c_strides, int ndim) |
template<int NDIM> | |
METAL_FUNC uint2 | elem_to_loc_2_nd (uint3 elem, constant const int shape[NDIM], constant const size_t a_strides[NDIM], constant const size_t b_strides[NDIM]) |
template<int NDIM> | |
METAL_FUNC uint3 | elem_to_loc_3_nd (uint3 elem, constant const int shape[NDIM], constant const size_t a_strides[NDIM], constant const size_t b_strides[NDIM], constant const size_t c_strides[NDIM]) |
template<typename T , typename U > | |
T | ceildiv (T N, U M) |
Compute ceil((float)N/(float)M) | |
float | log1p (float x) |
bfloat16_t | log1p (bfloat16_t x) |
uint64_t | simd_shuffle_down (uint64_t data, uint16_t delta) |
int64_t | simd_shuffle_down (int64_t data, uint16_t delta) |
bool | simd_shuffle_down (bool data, uint16_t delta) |
complex64_t | simd_shuffle_down (complex64_t data, uint16_t delta) |
#define instantiate_default_limit | ( | type | ) |
#define instantiate_float_limit | ( | type | ) |
#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") |
typedef half float16_t |
|
inline |
Compute ceil((float)N/(float)M)
METAL_FUNC stride_t elem_to_loc | ( | stride_t | elem, |
constant const int * | shape, | ||
constant const stride_t * | strides, | ||
int | ndim ) |
METAL_FUNC stride_t elem_to_loc | ( | stride_t | elem, |
device const int * | shape, | ||
device const stride_t * | strides, | ||
int | ndim ) |
METAL_FUNC stride_t elem_to_loc | ( | uint | elem, |
constant const int * | shape, | ||
constant const stride_t * | strides, | ||
int | ndim ) |
METAL_FUNC stride_t elem_to_loc | ( | uint | elem, |
device const int * | shape, | ||
device const stride_t * | strides, | ||
int | ndim ) |
METAL_FUNC stride_t elem_to_loc | ( | uint3 | elem, |
constant const int * | shape, | ||
constant const stride_t * | strides, | ||
int | ndim ) |
METAL_FUNC stride_t elem_to_loc_1 | ( | uint | elem, |
constant const stride_t & | stride ) |
METAL_FUNC stride_t elem_to_loc_2 | ( | uint2 | elem, |
constant const stride_t | strides[2] ) |
METAL_FUNC uint2 elem_to_loc_2_nd | ( | uint3 | elem, |
constant const int * | shape, | ||
constant const size_t * | a_strides, | ||
constant const size_t * | b_strides, | ||
int | ndim ) |
METAL_FUNC uint2 elem_to_loc_2_nd | ( | uint3 | elem, |
constant const int | shape[NDIM], | ||
constant const size_t | a_strides[NDIM], | ||
constant const size_t | b_strides[NDIM] ) |
METAL_FUNC stride_t elem_to_loc_3 | ( | uint3 | elem, |
constant const stride_t | strides[3] ) |
METAL_FUNC uint3 elem_to_loc_3_nd | ( | uint3 | elem, |
constant const int * | shape, | ||
constant const size_t * | a_strides, | ||
constant const size_t * | b_strides, | ||
constant const size_t * | c_strides, | ||
int | ndim ) |
METAL_FUNC uint3 elem_to_loc_3_nd | ( | uint3 | elem, |
constant const int | shape[NDIM], | ||
constant const size_t | a_strides[NDIM], | ||
constant const size_t | b_strides[NDIM], | ||
constant const size_t | c_strides[NDIM] ) |
METAL_FUNC int64_t elem_to_loc_nd | ( | uint | elem, |
constant const int | shape[NDIM], | ||
constant const int64_t | strides[NDIM] ) |
METAL_FUNC size_t elem_to_loc_nd | ( | uint | elem, |
device const int * | shape, | ||
device const size_t * | strides ) |
METAL_FUNC int64_t elem_to_loc_nd | ( | uint3 | elem, |
constant const int | shape[NDIM], | ||
constant const int64_t | strides[NDIM] ) |
METAL_FUNC size_t elem_to_loc_nd | ( | uint3 | elem, |
constant const int | shape[NDIM], | ||
constant const size_t | strides[NDIM] ) |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |