mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fully wrap the command encoder (#1572)
* fully wrap the command encoder * use consistent style + fix extensions
This commit is contained in:
@@ -17,12 +17,15 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]],
|
||||
threadgroup float* local_sums [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_inv_mean[1];
|
||||
threadgroup float local_sums[SIMD_SIZE];
|
||||
|
||||
float acc = 0;
|
||||
x += gid * size_t(axis_size) + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
@@ -84,13 +87,15 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]],
|
||||
threadgroup float* local_sums [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
threadgroup float local_inv_mean[1];
|
||||
threadgroup float local_sums[SIMD_SIZE];
|
||||
|
||||
float acc = 0;
|
||||
x += gid * size_t(axis_size) + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
@@ -376,8 +381,6 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
@@ -407,8 +410,6 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
|
||||
Reference in New Issue
Block a user