mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Implement vjps for some primitives in the fast namespace (#883)
* Implement rope vjp in terms of rope * RMSNormVJP primitive and kernel * Add LayerNormVJP primitive and kernel
This commit is contained in:

committed by
GitHub

parent
a789685c63
commit
29221fa238
@@ -6,8 +6,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
enum ReductionOpType {
|
||||
// Self-explanatory. Read everything and produce 1 output.
|
||||
ContiguousAllReduce,
|
||||
@@ -38,6 +36,21 @@ enum ReductionOpType {
|
||||
GeneralReduce
|
||||
};
|
||||
|
||||
struct ReductionPlan {
|
||||
ReductionOpType type;
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
|
||||
ReductionPlan(
|
||||
ReductionOpType type_,
|
||||
std::vector<int> shape_,
|
||||
std::vector<size_t> strides_)
|
||||
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
// Helper for the ndimensional strided loop
|
||||
// Should this be in utils?
|
||||
inline void nd_loop(
|
||||
@@ -110,19 +123,6 @@ struct DefaultContiguousReduce {
|
||||
}
|
||||
};
|
||||
|
||||
struct ReductionPlan {
|
||||
ReductionOpType type;
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
|
||||
ReductionPlan(
|
||||
ReductionOpType type_,
|
||||
std::vector<int> shape_,
|
||||
std::vector<size_t> strides_)
|
||||
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
|
@@ -205,39 +205,341 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_layer_norm_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the computation and accumulators
|
||||
float thread_x[N_READS];
|
||||
float thread_w[N_READS];
|
||||
float thread_g[N_READS];
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
float sumwg = 0;
|
||||
float sumwgx = 0;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumwg[SIMD_SIZE];
|
||||
threadgroup float local_sumwgx[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meanwg[1];
|
||||
threadgroup float local_meanwgx[1];
|
||||
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[i * w_stride];
|
||||
thread_g[i] = g[i];
|
||||
float wg = thread_w[i] * thread_g[i];
|
||||
sumx += thread_x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumwg += wg;
|
||||
sumwgx += wg * thread_x[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[i * w_stride];
|
||||
thread_g[i] = g[i];
|
||||
float wg = thread_w[i] * thread_g[i];
|
||||
sumx += thread_x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumwg += wg;
|
||||
sumwgx += wg * thread_x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumwg = simd_sum(sumwg);
|
||||
sumwgx = simd_sum(sumwgx);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumwg[simd_lane_id] = 0;
|
||||
local_sumwgx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumwg[simd_group_id] = sumwg;
|
||||
local_sumwgx[simd_group_id] = sumwgx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumwg = simd_sum(local_sumwg[simd_lane_id]);
|
||||
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
local_meanwg[0] = sumwg / axis_size;
|
||||
local_meanwgx[0] = sumwgx / axis_size;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float meanwg = local_meanwg[0];
|
||||
float meanwgxc = local_meanwgx[0] - meanwg * mean;
|
||||
float normalizer2 = normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_layer_norm_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the accumulators
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
float sumwg = 0;
|
||||
float sumwgx = 0;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumwg[SIMD_SIZE];
|
||||
threadgroup float local_sumwgx[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meanwg[1];
|
||||
threadgroup float local_meanwgx[1];
|
||||
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
float wg = wi * gi;
|
||||
sumx += xi;
|
||||
sumx2 += xi * xi;
|
||||
sumwg += wg;
|
||||
sumwgx += wg * xi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
float wg = wi * gi;
|
||||
sumx += xi;
|
||||
sumx2 += xi * xi;
|
||||
sumwg += wg;
|
||||
sumwgx += wg * xi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumwg = simd_sum(sumwg);
|
||||
sumwgx = simd_sum(sumwgx);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumwg[simd_lane_id] = 0;
|
||||
local_sumwgx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumwg[simd_group_id] = sumwg;
|
||||
local_sumwgx[simd_group_id] = sumwgx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumwg = simd_sum(local_sumwg[simd_lane_id]);
|
||||
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
local_meanwg[0] = sumwg / axis_size;
|
||||
local_meanwgx[0] = sumwgx / axis_size;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float meanwg = local_meanwg[0];
|
||||
float meanwgxc = local_meanwgx[0] - meanwg * mean;
|
||||
float normalizer2 = normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = (x[i + r] - mean) * normalizer;
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
|
||||
xi * meanwgxc * normalizer2);
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = (x[i + r] - mean) * normalizer;
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
|
||||
xi * meanwgxc * normalizer2);
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_layer_norm_single_row(name, itype) \
|
||||
template [[host_name("layer_norm" #name)]] [[kernel]] void \
|
||||
layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
#define instantiate_layer_norm_single_row(name, itype) \
|
||||
template [[host_name("layer_norm" #name)]] [[kernel]] void \
|
||||
layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \
|
||||
vjp_layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_layer_norm_looped(name, itype) \
|
||||
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
|
||||
layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
#define instantiate_layer_norm_looped(name, itype) \
|
||||
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
|
||||
layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \
|
||||
vjp_layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gb, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_layer_norm(name, itype) \
|
||||
|
@@ -150,6 +150,216 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_rms_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the computation and accumulators
|
||||
float thread_x[N_READS];
|
||||
float thread_w[N_READS];
|
||||
float thread_g[N_READS];
|
||||
float sumx2 = 0;
|
||||
float sumgwx = 0;
|
||||
|
||||
// Allocate shared memory to implement the reduction
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumgwx[SIMD_SIZE];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meangwx[1];
|
||||
|
||||
// Read and accumulate locally
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[w_stride * i];
|
||||
thread_g[i] = g[i];
|
||||
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumgwx += thread_x[i] * thread_w[i] * thread_g[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[w_stride * i];
|
||||
thread_g[i] = g[i];
|
||||
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumgwx += thread_x[i] * thread_w[i] * thread_g[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate across threads
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumgwx = simd_sum(sumgwx);
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumgwx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumgwx[simd_group_id] = sumgwx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumgwx = simd_sum(local_sumgwx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_meangwx[0] = sumgwx / axis_size;
|
||||
local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float meangwx = local_meangwx[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float normalizer3 = normalizer * normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_rms_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the accumulators
|
||||
float sumx2 = 0;
|
||||
float sumgwx = 0;
|
||||
|
||||
// Allocate shared memory to implement the reduction
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumgwx[SIMD_SIZE];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meangwx[1];
|
||||
|
||||
// Read and accumulate locally
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
sumx2 += xi * xi;
|
||||
sumgwx += xi * wi * gi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
sumx2 += xi * xi;
|
||||
sumgwx += xi * wi * gi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate across threads
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumgwx = simd_sum(sumgwx);
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumgwx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumgwx[simd_group_id] = sumgwx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumgwx = simd_sum(local_sumgwx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_meangwx[0] = sumgwx / axis_size;
|
||||
local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float meangwx = local_meangwx[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float normalizer3 = normalizer * normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_rms_single_row(name, itype) \
|
||||
template [[host_name("rms" #name)]] [[kernel]] void \
|
||||
@@ -165,25 +375,56 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms_looped(name, itype) \
|
||||
template [[host_name("rms_looped" #name)]] [[kernel]] void \
|
||||
rms_looped<itype>( \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
\
|
||||
template [[host_name("vjp_rms" #name)]] [[kernel]] void \
|
||||
vjp_rms_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
device itype* out, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms_looped(name, itype) \
|
||||
template [[host_name("rms_looped" #name)]] [[kernel]] void \
|
||||
rms_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
\
|
||||
template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \
|
||||
vjp_rms_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms(name, itype) \
|
||||
instantiate_rms_single_row(name, itype) \
|
||||
instantiate_rms_looped(name, itype)
|
||||
|
@@ -5,7 +5,7 @@
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, bool traditional>
|
||||
template <typename T, bool traditional, bool forward>
|
||||
[[kernel]] void rope(
|
||||
const device T *in [[buffer(0)]],
|
||||
device T * out [[buffer(1)]],
|
||||
@@ -43,15 +43,22 @@ template <typename T, bool traditional>
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
float rx1 = x1 * costheta - x2 * sintheta;
|
||||
float rx2 = x1 * sintheta + x2 * costheta;
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
rx1 = x1 * costheta - x2 * sintheta;
|
||||
rx2 = x1 * sintheta + x2 * costheta;
|
||||
} else {
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
#define instantiate_rope(name, type, traditional) \
|
||||
#define instantiate_rope(name, type, traditional, forward) \
|
||||
template [[host_name("rope_" #name)]] \
|
||||
[[kernel]] void rope<type, traditional>( \
|
||||
[[kernel]] void rope<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const size_t strides[3], \
|
||||
@@ -62,9 +69,15 @@ template <typename T, bool traditional>
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
instantiate_rope(traditional_float16, half, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true)
|
||||
instantiate_rope(traditional_float32, float, true)
|
||||
instantiate_rope(float16, half, false)
|
||||
instantiate_rope(bfloat16, bfloat16_t, false)
|
||||
instantiate_rope(float32, float, false)
|
||||
instantiate_rope(traditional_float16, half, true, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
||||
instantiate_rope(traditional_float32, float, true, true)
|
||||
instantiate_rope(float16, half, false, true)
|
||||
instantiate_rope(bfloat16, bfloat16_t, false, true)
|
||||
instantiate_rope(float32, float, false, true)
|
||||
instantiate_rope(vjp_traditional_float16, half, true, false)
|
||||
instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
|
||||
instantiate_rope(vjp_traditional_float32, float, true, false)
|
||||
instantiate_rope(vjp_float16, half, false, false)
|
||||
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
|
||||
instantiate_rope(vjp_float32, float, false, false)
|
||||
|
@@ -4,6 +4,7 @@
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
@@ -95,6 +96,113 @@ void RMSNorm::eval_gpu(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void RMSNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
}
|
||||
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
const array& g = check_input(inputs[2]);
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
|
||||
// Allocate space for the outputs
|
||||
bool x_in_gx = false;
|
||||
bool g_in_gx = false;
|
||||
if (x.is_donatable()) {
|
||||
gx.move_shared_buffer(x);
|
||||
x_in_gx = true;
|
||||
} else if (g.is_donatable()) {
|
||||
gx.move_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and initialize the
|
||||
// gradient accumulator to 0.
|
||||
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
|
||||
bool g_in_gw = false;
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
gw_temp.move_shared_buffer(g);
|
||||
g_in_gw = true;
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
}
|
||||
copies.push_back(gw_temp);
|
||||
array zero(0, gw.dtype());
|
||||
copy_gpu(zero, gw, CopyType::Scalar, s);
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
const int looped_limit = RMS_LOOPED_LIMIT;
|
||||
std::string op_name = "vjp_rms";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(gx);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
} else {
|
||||
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
uint32_t w_stride = w.strides()[0];
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, x_in_gx ? gx : x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(
|
||||
compute_encoder, g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
|
||||
set_array_buffer(compute_encoder, gx, 3);
|
||||
set_array_buffer(compute_encoder, gw_temp, 4);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 5);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
strided_reduce_general_dispatch(
|
||||
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void LayerNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
@@ -182,4 +290,124 @@ void LayerNorm::eval_gpu(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void LayerNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
}
|
||||
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
const array& g = check_input(inputs[3]);
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
array& gb = outputs[2];
|
||||
|
||||
// Allocate space for the outputs
|
||||
bool x_in_gx = false;
|
||||
bool g_in_gx = false;
|
||||
if (x.is_donatable()) {
|
||||
gx.move_shared_buffer(x);
|
||||
x_in_gx = true;
|
||||
} else if (g.is_donatable()) {
|
||||
gx.move_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and initialize the
|
||||
// gradient accumulator to 0.
|
||||
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
|
||||
bool g_in_gw = false;
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
gw_temp.move_shared_buffer(g);
|
||||
g_in_gw = true;
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
}
|
||||
copies.push_back(gw_temp);
|
||||
array zero(0, gw.dtype());
|
||||
copy_gpu(zero, gw, CopyType::Scalar, s);
|
||||
copy_gpu(zero, gb, CopyType::Scalar, s);
|
||||
|
||||
// Finish with the gradient for b in case we had a b
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
if (gb.ndim() == 1 && gb.size() == axis_size) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
strided_reduce_general_dispatch(
|
||||
g, gb, "sum", plan, {0}, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
const int looped_limit = RMS_LOOPED_LIMIT;
|
||||
std::string op_name = "vjp_layer_norm";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(gx);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
} else {
|
||||
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
uint32_t w_stride = w.strides()[0];
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, x_in_gx ? gx : x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(
|
||||
compute_encoder, g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
|
||||
set_array_buffer(compute_encoder, gx, 3);
|
||||
set_array_buffer(compute_encoder, gw_temp, 4);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 5);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
if (gw.ndim() == 1 && gw.size() == axis_size) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
strided_reduce_general_dispatch(
|
||||
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
@@ -4,10 +4,10 @@
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -18,8 +18,6 @@ namespace mlx::core {
|
||||
// Case wise reduce dispatch
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
inline auto safe_div(size_t n, size_t m) {
|
||||
return m == 0 ? 0 : (n + m - 1) / m;
|
||||
}
|
||||
@@ -534,8 +532,6 @@ void strided_reduce_general_dispatch(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
// Main reduce dispatch
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
39
mlx/backend/metal/reduce.h
Normal file
39
mlx/backend/metal/reduce.h
Normal file
@@ -0,0 +1,39 @@
|
||||
// Copyright @ 2023 - 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void all_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
void row_reduce_general_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
void strided_reduce_general_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
@@ -63,7 +63,8 @@ void RoPE::eval_gpu(
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
kname << "rope_" << (forward_ ? "" : "vjp_")
|
||||
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
|
@@ -103,7 +103,9 @@ NO_GPU(Inverse)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(LayerNorm)
|
||||
NO_GPU_MULTI(LayerNormVJP)
|
||||
NO_GPU_MULTI(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_MULTI(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
} // namespace fast
|
||||
|
228
mlx/fast.cpp
228
mlx/fast.cpp
@@ -1,5 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
@@ -94,11 +97,69 @@ array rms_norm(
|
||||
return fallback({x, weight})[0];
|
||||
}
|
||||
|
||||
std::vector<array> RMSNorm::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
assert(primals.size() == 2);
|
||||
assert(outputs.size() == 1);
|
||||
assert(cotangents.size() == 1);
|
||||
|
||||
auto s = stream();
|
||||
auto fallback = [eps = eps_, s](const std::vector<array>& inputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& w = inputs[1];
|
||||
auto& g = inputs[2];
|
||||
|
||||
std::vector<array> vjps;
|
||||
|
||||
auto n = rsqrt(
|
||||
add(mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s),
|
||||
array(eps, x.dtype()),
|
||||
s),
|
||||
s);
|
||||
auto n3 = power(n, array(3, x.dtype()), s);
|
||||
|
||||
// df/dx
|
||||
auto gw = multiply(g, w, s);
|
||||
auto t = mean(multiply(gw, x, s), /* axis= */ -1, /* keepdims= */ true, s);
|
||||
t = multiply(multiply(x, t, s), n3, s);
|
||||
vjps.push_back(subtract(multiply(gw, n, s), t, s));
|
||||
|
||||
// df/dw
|
||||
std::vector<int> axes(g.ndim() - 1);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
vjps.push_back(
|
||||
sum(multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s));
|
||||
|
||||
return vjps;
|
||||
};
|
||||
|
||||
auto vjps = array::make_arrays(
|
||||
{primals[0].shape(), primals[1].shape()},
|
||||
{primals[0].dtype(), primals[1].dtype()},
|
||||
std::make_shared<RMSNormVJP>(s, fallback, eps_),
|
||||
{primals[0], primals[1], cotangents[0]});
|
||||
|
||||
std::vector<array> returned_vjps;
|
||||
for (auto& arg : argnums) {
|
||||
returned_vjps.push_back(std::move(vjps[arg]));
|
||||
}
|
||||
|
||||
return returned_vjps;
|
||||
}
|
||||
|
||||
bool RMSNorm::is_equivalent(const Primitive& other) const {
|
||||
const RMSNorm& a_other = static_cast<const RMSNorm&>(other);
|
||||
return eps_ == a_other.eps_;
|
||||
}
|
||||
|
||||
bool RMSNormVJP::is_equivalent(const Primitive& other) const {
|
||||
const RMSNormVJP& a_other = static_cast<const RMSNormVJP&>(other);
|
||||
return eps_ == a_other.eps_;
|
||||
}
|
||||
|
||||
array layer_norm(
|
||||
const array& x,
|
||||
const std::optional<array>& weight,
|
||||
@@ -176,11 +237,90 @@ array layer_norm(
|
||||
return fallback({x, passed_weight, passed_bias})[0];
|
||||
}
|
||||
|
||||
std::vector<array> LayerNorm::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
assert(primals.size() == 3);
|
||||
assert(outputs.size() == 1);
|
||||
assert(cotangents.size() == 1);
|
||||
|
||||
auto s = stream();
|
||||
auto fallback = [eps = eps_, s](const std::vector<array>& inputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& w = inputs[1];
|
||||
auto& b = inputs[2];
|
||||
auto& g = inputs[3];
|
||||
|
||||
std::vector<array> vjps;
|
||||
|
||||
auto norm = number_of_elements(x, {-1}, true, x.dtype(), s);
|
||||
auto sumx = sum(x, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
auto sumx2 = sum(square(x, s), /* axis= */ -1, /* keepdims= */ true, s);
|
||||
auto mu = multiply(sumx, norm, s);
|
||||
auto mu2 = multiply(sumx2, norm, s);
|
||||
auto var = subtract(mu2, square(mu, s), s);
|
||||
auto n = rsqrt(add(var, array(eps, x.dtype()), s));
|
||||
auto n3 = power(n, array(3, x.dtype()), s);
|
||||
auto x_c = subtract(x, mu, s);
|
||||
|
||||
// df/dx
|
||||
auto wg = multiply(w, g, s);
|
||||
auto sumwg =
|
||||
multiply(sum(wg, /* axis= */ -1, /* keepdims= */ true, s), norm, s);
|
||||
auto sumwgxc = multiply(
|
||||
sum(multiply(wg, x_c, s), /* axis= */ -1, /* keepdims= */ true, s),
|
||||
norm,
|
||||
s);
|
||||
auto t1 = multiply(multiply(x_c, sumwgxc, s), n3, s);
|
||||
auto t2 = multiply(subtract(wg, sumwg, s), n, s);
|
||||
vjps.push_back(subtract(t2, t1, s));
|
||||
|
||||
// df/dw
|
||||
std::vector<int> axes(g.ndim() - 1);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
if (w.ndim() == 0) {
|
||||
vjps.push_back(zeros_like(w, s));
|
||||
} else {
|
||||
vjps.push_back(sum(
|
||||
multiply(g, multiply(x_c, n, s), s), axes, /* keepdims= */ false, s));
|
||||
}
|
||||
|
||||
// df/db
|
||||
if (b.ndim() == 0) {
|
||||
vjps.push_back(zeros_like(w, s));
|
||||
} else {
|
||||
vjps.push_back(sum(g, axes, /* keepdims= */ false, s));
|
||||
}
|
||||
|
||||
return vjps;
|
||||
};
|
||||
|
||||
auto vjps = array::make_arrays(
|
||||
{primals[0].shape(), primals[1].shape(), primals[2].shape()},
|
||||
{primals[0].dtype(), primals[1].dtype(), primals[2].dtype()},
|
||||
std::make_shared<LayerNormVJP>(s, fallback, eps_),
|
||||
{primals[0], primals[1], primals[2], cotangents[0]});
|
||||
|
||||
std::vector<array> returned_vjps;
|
||||
for (auto& arg : argnums) {
|
||||
returned_vjps.push_back(std::move(vjps[arg]));
|
||||
}
|
||||
|
||||
return returned_vjps;
|
||||
}
|
||||
|
||||
bool LayerNorm::is_equivalent(const Primitive& other) const {
|
||||
const LayerNorm& a_other = static_cast<const LayerNorm&>(other);
|
||||
return eps_ == a_other.eps_;
|
||||
}
|
||||
|
||||
bool LayerNormVJP::is_equivalent(const Primitive& other) const {
|
||||
const LayerNormVJP& a_other = static_cast<const LayerNormVJP&>(other);
|
||||
return eps_ == a_other.eps_;
|
||||
}
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
@@ -188,19 +328,16 @@ array rope(
|
||||
float base,
|
||||
float scale,
|
||||
int offset,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
bool forward,
|
||||
StreamOrDevice s) {
|
||||
if (x.ndim() < 3) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] Input must have at least 3 dimensions but got input with "
|
||||
<< x.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (traditional && x.shape(-1) != dims) {
|
||||
throw std::invalid_argument(
|
||||
"[rope] Does not support partial traditional application.");
|
||||
}
|
||||
|
||||
auto fallback = [dims, traditional, base, scale, offset, s](
|
||||
auto fallback = [dims, traditional, base, scale, offset, forward, s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto& shape = inputs[0].shape();
|
||||
int ndim = shape.size();
|
||||
@@ -217,16 +354,39 @@ array rope(
|
||||
auto coss = cos(theta, s);
|
||||
auto sins = sin(theta, s);
|
||||
|
||||
if (traditional) {
|
||||
auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s);
|
||||
auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s);
|
||||
auto apply_rope = [forward, s](
|
||||
const array& x1,
|
||||
const array& x2,
|
||||
const array& coss,
|
||||
const array& sins) {
|
||||
std::vector<array> outs;
|
||||
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
|
||||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
|
||||
if (forward) {
|
||||
outs.push_back(
|
||||
subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
|
||||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
|
||||
} else {
|
||||
outs.push_back(add(multiply(x2, sins, s), multiply(x1, coss, s), s));
|
||||
outs.push_back(
|
||||
subtract(multiply(x2, coss, s), multiply(x1, sins, s), s));
|
||||
}
|
||||
return outs;
|
||||
};
|
||||
|
||||
if (traditional) {
|
||||
auto x1 =
|
||||
slice(x, {0, 0, 0}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
|
||||
auto x2 =
|
||||
slice(x, {0, 0, 1}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
|
||||
auto outs = apply_rope(x1, x2, coss, sins);
|
||||
for (auto& o : outs) {
|
||||
o = expand_dims(o, 3, s);
|
||||
}
|
||||
return std::vector<array>{reshape(concatenate(outs, 3, s), shape, s)};
|
||||
auto out = concatenate(outs, 3, s);
|
||||
if (dims < x.shape(-1)) {
|
||||
out = reshape(out, {x.shape(0), x.shape(1), dims});
|
||||
out = concatenate({out, slice(x, {0, 0, dims}, x.shape(), s)}, 2, s);
|
||||
}
|
||||
return std::vector<array>{reshape(out, shape, s)};
|
||||
} else {
|
||||
auto out_s = x.shape();
|
||||
out_s.back() = half_dims;
|
||||
@@ -234,9 +394,7 @@ array rope(
|
||||
out_s.back() = dims;
|
||||
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
|
||||
|
||||
std::vector<array> outs;
|
||||
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
|
||||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
|
||||
auto outs = apply_rope(x1, x2, coss, sins);
|
||||
if (dims < x.shape(-1)) {
|
||||
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
|
||||
}
|
||||
@@ -249,18 +407,54 @@ array rope(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<RoPE>(
|
||||
stream, fallback, dims, traditional, base, scale, offset),
|
||||
stream, fallback, dims, traditional, base, scale, offset, forward),
|
||||
{x});
|
||||
}
|
||||
return fallback({x})[0];
|
||||
}
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
float base,
|
||||
float scale,
|
||||
int offset,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return rope(x, dims, traditional, base, scale, offset, true, s);
|
||||
}
|
||||
|
||||
std::vector<array> RoPE::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
auto s = stream();
|
||||
auto fallback = [dims = dims_,
|
||||
traditional = traditional_,
|
||||
base = base_,
|
||||
scale = scale_,
|
||||
offset = offset_,
|
||||
forward = forward_,
|
||||
s](std::vector<array> inputs) {
|
||||
return std::vector<array>{
|
||||
rope(inputs[0], dims, traditional, base, scale, offset, !forward, s)};
|
||||
};
|
||||
|
||||
return {array(
|
||||
cotangents[0].shape(),
|
||||
cotangents[0].dtype(),
|
||||
std::make_shared<RoPE>(
|
||||
s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_),
|
||||
cotangents)};
|
||||
}
|
||||
|
||||
bool RoPE::is_equivalent(const Primitive& other) const {
|
||||
const RoPE& a_other = static_cast<const RoPE&>(other);
|
||||
return (
|
||||
dims_ == a_other.dims_ && base_ == a_other.base_ &&
|
||||
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&
|
||||
offset_ == a_other.offset_);
|
||||
offset_ == a_other.offset_ && forward_ == a_other.forward_);
|
||||
}
|
||||
|
||||
/** Computes: O = softmax(Q @ K.T) @ V **/
|
||||
|
@@ -48,6 +48,12 @@ class RMSNorm : public Custom {
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(RMSNorm)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
@@ -56,6 +62,29 @@ class RMSNorm : public Custom {
|
||||
float eps_;
|
||||
};
|
||||
|
||||
class RMSNormVJP : public Custom {
|
||||
public:
|
||||
RMSNormVJP(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(RMSNormVJP)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float eps_;
|
||||
};
|
||||
|
||||
class LayerNorm : public Custom {
|
||||
public:
|
||||
LayerNorm(
|
||||
@@ -71,6 +100,12 @@ class LayerNorm : public Custom {
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(LayerNorm)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
@@ -79,6 +114,29 @@ class LayerNorm : public Custom {
|
||||
float eps_;
|
||||
};
|
||||
|
||||
class LayerNormVJP : public Custom {
|
||||
public:
|
||||
LayerNormVJP(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(LayerNormVJP)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float eps_;
|
||||
};
|
||||
|
||||
class RoPE : public Custom {
|
||||
public:
|
||||
RoPE(
|
||||
@@ -88,13 +146,15 @@ class RoPE : public Custom {
|
||||
bool traditional,
|
||||
float base,
|
||||
float scale,
|
||||
int offset)
|
||||
int offset,
|
||||
bool forward)
|
||||
: Custom(stream, fallback),
|
||||
dims_(dims),
|
||||
traditional_(traditional),
|
||||
base_(base),
|
||||
scale_(scale),
|
||||
offset_(offset){};
|
||||
offset_(offset),
|
||||
forward_(forward){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
@@ -103,6 +163,12 @@ class RoPE : public Custom {
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(RoPE)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
@@ -113,6 +179,7 @@ class RoPE : public Custom {
|
||||
float base_;
|
||||
float scale_;
|
||||
int offset_;
|
||||
bool forward_;
|
||||
};
|
||||
|
||||
class ScaledDotProductAttention : public Custom {
|
||||
@@ -126,7 +193,7 @@ class ScaledDotProductAttention : public Custom {
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
outputs[0] = fallback_(inputs)[0];
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
|
Reference in New Issue
Block a user