mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Implement vjps for some primitives in the fast namespace (#883)
* Implement rope vjp in terms of rope * RMSNormVJP primitive and kernel * Add LayerNormVJP primitive and kernel
This commit is contained in:

committed by
GitHub

parent
a789685c63
commit
29221fa238
@@ -205,39 +205,341 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_layer_norm_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the computation and accumulators
|
||||
float thread_x[N_READS];
|
||||
float thread_w[N_READS];
|
||||
float thread_g[N_READS];
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
float sumwg = 0;
|
||||
float sumwgx = 0;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumwg[SIMD_SIZE];
|
||||
threadgroup float local_sumwgx[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meanwg[1];
|
||||
threadgroup float local_meanwgx[1];
|
||||
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[i * w_stride];
|
||||
thread_g[i] = g[i];
|
||||
float wg = thread_w[i] * thread_g[i];
|
||||
sumx += thread_x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumwg += wg;
|
||||
sumwgx += wg * thread_x[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[i * w_stride];
|
||||
thread_g[i] = g[i];
|
||||
float wg = thread_w[i] * thread_g[i];
|
||||
sumx += thread_x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumwg += wg;
|
||||
sumwgx += wg * thread_x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumwg = simd_sum(sumwg);
|
||||
sumwgx = simd_sum(sumwgx);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumwg[simd_lane_id] = 0;
|
||||
local_sumwgx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumwg[simd_group_id] = sumwg;
|
||||
local_sumwgx[simd_group_id] = sumwgx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumwg = simd_sum(local_sumwg[simd_lane_id]);
|
||||
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
local_meanwg[0] = sumwg / axis_size;
|
||||
local_meanwgx[0] = sumwgx / axis_size;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float meanwg = local_meanwg[0];
|
||||
float meanwgxc = local_meanwgx[0] - meanwg * mean;
|
||||
float normalizer2 = normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_layer_norm_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the accumulators
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
float sumwg = 0;
|
||||
float sumwgx = 0;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumwg[SIMD_SIZE];
|
||||
threadgroup float local_sumwgx[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meanwg[1];
|
||||
threadgroup float local_meanwgx[1];
|
||||
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
float wg = wi * gi;
|
||||
sumx += xi;
|
||||
sumx2 += xi * xi;
|
||||
sumwg += wg;
|
||||
sumwgx += wg * xi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
float wg = wi * gi;
|
||||
sumx += xi;
|
||||
sumx2 += xi * xi;
|
||||
sumwg += wg;
|
||||
sumwgx += wg * xi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumwg = simd_sum(sumwg);
|
||||
sumwgx = simd_sum(sumwgx);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumwg[simd_lane_id] = 0;
|
||||
local_sumwgx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumwg[simd_group_id] = sumwg;
|
||||
local_sumwgx[simd_group_id] = sumwgx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumwg = simd_sum(local_sumwg[simd_lane_id]);
|
||||
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
local_meanwg[0] = sumwg / axis_size;
|
||||
local_meanwgx[0] = sumwgx / axis_size;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float meanwg = local_meanwg[0];
|
||||
float meanwgxc = local_meanwgx[0] - meanwg * mean;
|
||||
float normalizer2 = normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = (x[i + r] - mean) * normalizer;
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
|
||||
xi * meanwgxc * normalizer2);
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = (x[i + r] - mean) * normalizer;
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
|
||||
xi * meanwgxc * normalizer2);
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_layer_norm_single_row(name, itype) \
|
||||
template [[host_name("layer_norm" #name)]] [[kernel]] void \
|
||||
layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
#define instantiate_layer_norm_single_row(name, itype) \
|
||||
template [[host_name("layer_norm" #name)]] [[kernel]] void \
|
||||
layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \
|
||||
vjp_layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_layer_norm_looped(name, itype) \
|
||||
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
|
||||
layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
#define instantiate_layer_norm_looped(name, itype) \
|
||||
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
|
||||
layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \
|
||||
vjp_layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gb, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_layer_norm(name, itype) \
|
||||
|
@@ -150,6 +150,216 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_rms_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the computation and accumulators
|
||||
float thread_x[N_READS];
|
||||
float thread_w[N_READS];
|
||||
float thread_g[N_READS];
|
||||
float sumx2 = 0;
|
||||
float sumgwx = 0;
|
||||
|
||||
// Allocate shared memory to implement the reduction
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumgwx[SIMD_SIZE];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meangwx[1];
|
||||
|
||||
// Read and accumulate locally
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[w_stride * i];
|
||||
thread_g[i] = g[i];
|
||||
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumgwx += thread_x[i] * thread_w[i] * thread_g[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[w_stride * i];
|
||||
thread_g[i] = g[i];
|
||||
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumgwx += thread_x[i] * thread_w[i] * thread_g[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate across threads
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumgwx = simd_sum(sumgwx);
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumgwx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumgwx[simd_group_id] = sumgwx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumgwx = simd_sum(local_sumgwx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_meangwx[0] = sumgwx / axis_size;
|
||||
local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float meangwx = local_meangwx[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float normalizer3 = normalizer * normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_rms_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the accumulators
|
||||
float sumx2 = 0;
|
||||
float sumgwx = 0;
|
||||
|
||||
// Allocate shared memory to implement the reduction
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumgwx[SIMD_SIZE];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meangwx[1];
|
||||
|
||||
// Read and accumulate locally
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
sumx2 += xi * xi;
|
||||
sumgwx += xi * wi * gi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
sumx2 += xi * xi;
|
||||
sumgwx += xi * wi * gi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate across threads
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumgwx = simd_sum(sumgwx);
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumgwx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumgwx[simd_group_id] = sumgwx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumgwx = simd_sum(local_sumgwx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_meangwx[0] = sumgwx / axis_size;
|
||||
local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float meangwx = local_meangwx[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float normalizer3 = normalizer * normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_rms_single_row(name, itype) \
|
||||
template [[host_name("rms" #name)]] [[kernel]] void \
|
||||
@@ -165,25 +375,56 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms_looped(name, itype) \
|
||||
template [[host_name("rms_looped" #name)]] [[kernel]] void \
|
||||
rms_looped<itype>( \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
\
|
||||
template [[host_name("vjp_rms" #name)]] [[kernel]] void \
|
||||
vjp_rms_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
device itype* out, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms_looped(name, itype) \
|
||||
template [[host_name("rms_looped" #name)]] [[kernel]] void \
|
||||
rms_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
\
|
||||
template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \
|
||||
vjp_rms_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms(name, itype) \
|
||||
instantiate_rms_single_row(name, itype) \
|
||||
instantiate_rms_looped(name, itype)
|
||||
|
@@ -5,7 +5,7 @@
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, bool traditional>
|
||||
template <typename T, bool traditional, bool forward>
|
||||
[[kernel]] void rope(
|
||||
const device T *in [[buffer(0)]],
|
||||
device T * out [[buffer(1)]],
|
||||
@@ -43,15 +43,22 @@ template <typename T, bool traditional>
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
float rx1 = x1 * costheta - x2 * sintheta;
|
||||
float rx2 = x1 * sintheta + x2 * costheta;
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
rx1 = x1 * costheta - x2 * sintheta;
|
||||
rx2 = x1 * sintheta + x2 * costheta;
|
||||
} else {
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
#define instantiate_rope(name, type, traditional) \
|
||||
#define instantiate_rope(name, type, traditional, forward) \
|
||||
template [[host_name("rope_" #name)]] \
|
||||
[[kernel]] void rope<type, traditional>( \
|
||||
[[kernel]] void rope<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const size_t strides[3], \
|
||||
@@ -62,9 +69,15 @@ template <typename T, bool traditional>
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
instantiate_rope(traditional_float16, half, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true)
|
||||
instantiate_rope(traditional_float32, float, true)
|
||||
instantiate_rope(float16, half, false)
|
||||
instantiate_rope(bfloat16, bfloat16_t, false)
|
||||
instantiate_rope(float32, float, false)
|
||||
instantiate_rope(traditional_float16, half, true, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
||||
instantiate_rope(traditional_float32, float, true, true)
|
||||
instantiate_rope(float16, half, false, true)
|
||||
instantiate_rope(bfloat16, bfloat16_t, false, true)
|
||||
instantiate_rope(float32, float, false, true)
|
||||
instantiate_rope(vjp_traditional_float16, half, true, false)
|
||||
instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
|
||||
instantiate_rope(vjp_traditional_float32, float, true, false)
|
||||
instantiate_rope(vjp_float16, half, false, false)
|
||||
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
|
||||
instantiate_rope(vjp_float32, float, false, false)
|
||||
|
@@ -4,6 +4,7 @@
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
@@ -95,6 +96,113 @@ void RMSNorm::eval_gpu(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void RMSNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
}
|
||||
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
const array& g = check_input(inputs[2]);
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
|
||||
// Allocate space for the outputs
|
||||
bool x_in_gx = false;
|
||||
bool g_in_gx = false;
|
||||
if (x.is_donatable()) {
|
||||
gx.move_shared_buffer(x);
|
||||
x_in_gx = true;
|
||||
} else if (g.is_donatable()) {
|
||||
gx.move_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and initialize the
|
||||
// gradient accumulator to 0.
|
||||
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
|
||||
bool g_in_gw = false;
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
gw_temp.move_shared_buffer(g);
|
||||
g_in_gw = true;
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
}
|
||||
copies.push_back(gw_temp);
|
||||
array zero(0, gw.dtype());
|
||||
copy_gpu(zero, gw, CopyType::Scalar, s);
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
const int looped_limit = RMS_LOOPED_LIMIT;
|
||||
std::string op_name = "vjp_rms";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(gx);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
} else {
|
||||
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
uint32_t w_stride = w.strides()[0];
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, x_in_gx ? gx : x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(
|
||||
compute_encoder, g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
|
||||
set_array_buffer(compute_encoder, gx, 3);
|
||||
set_array_buffer(compute_encoder, gw_temp, 4);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 5);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
strided_reduce_general_dispatch(
|
||||
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void LayerNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
@@ -182,4 +290,124 @@ void LayerNorm::eval_gpu(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void LayerNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
}
|
||||
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
const array& g = check_input(inputs[3]);
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
array& gb = outputs[2];
|
||||
|
||||
// Allocate space for the outputs
|
||||
bool x_in_gx = false;
|
||||
bool g_in_gx = false;
|
||||
if (x.is_donatable()) {
|
||||
gx.move_shared_buffer(x);
|
||||
x_in_gx = true;
|
||||
} else if (g.is_donatable()) {
|
||||
gx.move_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and initialize the
|
||||
// gradient accumulator to 0.
|
||||
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
|
||||
bool g_in_gw = false;
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
gw_temp.move_shared_buffer(g);
|
||||
g_in_gw = true;
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
}
|
||||
copies.push_back(gw_temp);
|
||||
array zero(0, gw.dtype());
|
||||
copy_gpu(zero, gw, CopyType::Scalar, s);
|
||||
copy_gpu(zero, gb, CopyType::Scalar, s);
|
||||
|
||||
// Finish with the gradient for b in case we had a b
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
if (gb.ndim() == 1 && gb.size() == axis_size) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
strided_reduce_general_dispatch(
|
||||
g, gb, "sum", plan, {0}, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
const int looped_limit = RMS_LOOPED_LIMIT;
|
||||
std::string op_name = "vjp_layer_norm";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(gx);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
} else {
|
||||
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
uint32_t w_stride = w.strides()[0];
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, x_in_gx ? gx : x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(
|
||||
compute_encoder, g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
|
||||
set_array_buffer(compute_encoder, gx, 3);
|
||||
set_array_buffer(compute_encoder, gw_temp, 4);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 5);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
if (gw.ndim() == 1 && gw.size() == axis_size) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
strided_reduce_general_dispatch(
|
||||
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
@@ -4,10 +4,10 @@
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -18,8 +18,6 @@ namespace mlx::core {
|
||||
// Case wise reduce dispatch
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
inline auto safe_div(size_t n, size_t m) {
|
||||
return m == 0 ? 0 : (n + m - 1) / m;
|
||||
}
|
||||
@@ -534,8 +532,6 @@ void strided_reduce_general_dispatch(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
// Main reduce dispatch
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
39
mlx/backend/metal/reduce.h
Normal file
39
mlx/backend/metal/reduce.h
Normal file
@@ -0,0 +1,39 @@
|
||||
// Copyright @ 2023 - 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void all_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
void row_reduce_general_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
void strided_reduce_general_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
@@ -63,7 +63,8 @@ void RoPE::eval_gpu(
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
kname << "rope_" << (forward_ ? "" : "vjp_")
|
||||
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
|
Reference in New Issue
Block a user