#include <metal_common>
#include <metal_compute>
#include "mlx/backend/metal/kernels/steel/defines.h"
Go to the source code of this file.
|
template<short R> |
METAL_FUNC void | radix_func (thread float *x) |
|
template<typename T , int N, int max_radix, int read_width> |
void | hadamard_n (const device T *in, device T *out, constant const float &scale, uint3 elem, uint3 grid) |
|
template<typename T , int N, int M, int read_width> |
void | hadamard_m (const device T *in, device T *out, constant const float &scale, uint3 elem, uint3 grid) |
|
◆ hadamard_m()
template<typename T , int N, int M, int read_width>
void hadamard_m |
( |
const device T * | in, |
|
|
device T * | out, |
|
|
constant const float & | scale, |
|
|
uint3 | elem, |
|
|
uint3 | grid ) |
◆ hadamard_n()
template<typename T , int N, int max_radix, int read_width>
void hadamard_n |
( |
const device T * | in, |
|
|
device T * | out, |
|
|
constant const float & | scale, |
|
|
uint3 | elem, |
|
|
uint3 | grid ) |
◆ radix_func()
template<short R>
METAL_FUNC void radix_func |
( |
thread float * | x | ) |
|