Implement vjps for some primitives in the fast namespace (#883)

* Implement rope vjp in terms of rope
* RMSNormVJP primitive and kernel
* Add LayerNormVJP primitive and kernel
This commit is contained in:
Angelos Katharopoulos
2024-03-26 16:35:34 -07:00
committed by GitHub
parent a789685c63
commit 29221fa238
14 changed files with 1383 additions and 110 deletions

View File

@@ -0,0 +1,41 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def layer_norm(x, w, b, eps):
ot = x.dtype
x = x.astype(mx.float32)
mu = mx.mean(x, -1, keepdims=True)
v = mx.var(x, -1, keepdims=True)
return (x - mu) * mx.rsqrt(v + eps) * w + b
def time_layer_norm():
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1, 2))
g2 = mx.grad(f2, argnums=(0, 1, 2))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, b, y)
def layer_norm_loop(g, x, w, b):
gx, gw, gb = x, w, b
for _ in range(32):
gx, gw, gb = g(gx, gw, gb, y)
return gx, gw, gb
time_fn(layer_norm_loop, g1, x, w, b)
time_fn(layer_norm_loop, g2, x, w, b)
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
if __name__ == "__main__":
time_layer_norm()

View File

@@ -0,0 +1,39 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def rms_norm(x, w, eps):
ot = x.dtype
x = x.astype(mx.float32)
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return (x * n).astype(ot) * w
def time_rms_norm():
f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1))
g2 = mx.grad(f2, argnums=(0, 1))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, y)
def rms_norm_loop(g, x, w):
gx, gw = x, w
for _ in range(32):
gx, gw = g(gx, gw, y)
return gx, gw
time_fn(rms_norm_loop, g1, x, w)
time_fn(rms_norm_loop, g2, x, w)
time_fn(rms_norm_loop, mx.compile(g1), x, w)
time_fn(rms_norm_loop, mx.compile(g2), x, w)
if __name__ == "__main__":
time_rms_norm()

View File

@@ -6,8 +6,6 @@
namespace mlx::core { namespace mlx::core {
namespace {
enum ReductionOpType { enum ReductionOpType {
// Self-explanatory. Read everything and produce 1 output. // Self-explanatory. Read everything and produce 1 output.
ContiguousAllReduce, ContiguousAllReduce,
@@ -38,6 +36,21 @@ enum ReductionOpType {
GeneralReduce GeneralReduce
}; };
struct ReductionPlan {
ReductionOpType type;
std::vector<int> shape;
std::vector<size_t> strides;
ReductionPlan(
ReductionOpType type_,
std::vector<int> shape_,
std::vector<size_t> strides_)
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
namespace {
// Helper for the ndimensional strided loop // Helper for the ndimensional strided loop
// Should this be in utils? // Should this be in utils?
inline void nd_loop( inline void nd_loop(
@@ -110,19 +123,6 @@ struct DefaultContiguousReduce {
} }
}; };
struct ReductionPlan {
ReductionOpType type;
std::vector<int> shape;
std::vector<size_t> strides;
ReductionPlan(
ReductionOpType type_,
std::vector<int> shape_,
std::vector<size_t> strides_)
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) { ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// The data is all there and we are reducing over everything // The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() && if (x.size() == x.data_size() && axes.size() == x.ndim() &&

View File

@@ -205,39 +205,341 @@ template <typename T, int N_READS = RMS_N_READS>
} }
} }
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void vjp_layer_norm_single_row(
const device T* x,
const device T* w,
const device T* g,
device T* gx,
device T* gw,
constant float& eps,
constant uint& axis_size,
constant uint& w_stride,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Advance the input pointers
x += gid * axis_size + lid * N_READS;
g += gid * axis_size + lid * N_READS;
w += w_stride * lid * N_READS;
// Allocate registers for the computation and accumulators
float thread_x[N_READS];
float thread_w[N_READS];
float thread_g[N_READS];
float sumx = 0;
float sumx2 = 0;
float sumwg = 0;
float sumwgx = 0;
constexpr int SIMD_SIZE = 32;
threadgroup float local_sumx[SIMD_SIZE];
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_sumwg[SIMD_SIZE];
threadgroup float local_sumwgx[SIMD_SIZE];
threadgroup float local_mean[1];
threadgroup float local_normalizer[1];
threadgroup float local_meanwg[1];
threadgroup float local_meanwgx[1];
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = x[i];
thread_w[i] = w[i * w_stride];
thread_g[i] = g[i];
float wg = thread_w[i] * thread_g[i];
sumx += thread_x[i];
sumx2 += thread_x[i] * thread_x[i];
sumwg += wg;
sumwgx += wg * thread_x[i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = x[i];
thread_w[i] = w[i * w_stride];
thread_g[i] = g[i];
float wg = thread_w[i] * thread_g[i];
sumx += thread_x[i];
sumx2 += thread_x[i] * thread_x[i];
sumwg += wg;
sumwgx += wg * thread_x[i];
}
}
}
sumx = simd_sum(sumx);
sumx2 = simd_sum(sumx2);
sumwg = simd_sum(sumwg);
sumwgx = simd_sum(sumwgx);
// Initialize shared memory
if (simd_group_id == 0) {
local_sumx[simd_lane_id] = 0;
local_sumx2[simd_lane_id] = 0;
local_sumwg[simd_lane_id] = 0;
local_sumwgx[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sumx[simd_group_id] = sumx;
local_sumx2[simd_group_id] = sumx2;
local_sumwg[simd_group_id] = sumwg;
local_sumwgx[simd_group_id] = sumwgx;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
sumx = simd_sum(local_sumx[simd_lane_id]);
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
sumwg = simd_sum(local_sumwg[simd_lane_id]);
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
if (simd_lane_id == 0) {
float mean = sumx / axis_size;
float variance = sumx2 / axis_size - mean * mean;
local_mean[0] = mean;
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
local_meanwg[0] = sumwg / axis_size;
local_meanwgx[0] = sumwgx / axis_size;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float mean = local_mean[0];
float normalizer = local_normalizer[0];
float meanwg = local_meanwg[0];
float meanwgxc = local_meanwgx[0] - meanwg * mean;
float normalizer2 = normalizer * normalizer;
// Write the outputs
gx += gid * axis_size + lid * N_READS;
gw += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
}
}
}
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void vjp_layer_norm_looped(
const device T* x,
const device T* w,
const device T* g,
device T* gx,
device T* gw,
constant float& eps,
constant uint& axis_size,
constant uint& w_stride,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Advance the input pointers
x += gid * axis_size + lid * N_READS;
g += gid * axis_size + lid * N_READS;
w += w_stride * lid * N_READS;
// Allocate registers for the accumulators
float sumx = 0;
float sumx2 = 0;
float sumwg = 0;
float sumwgx = 0;
constexpr int SIMD_SIZE = 32;
threadgroup float local_sumx[SIMD_SIZE];
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_sumwg[SIMD_SIZE];
threadgroup float local_sumwgx[SIMD_SIZE];
threadgroup float local_mean[1];
threadgroup float local_normalizer[1];
threadgroup float local_meanwg[1];
threadgroup float local_meanwgx[1];
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = x[i + r];
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
float wg = wi * gi;
sumx += xi;
sumx2 += xi * xi;
sumwg += wg;
sumwgx += wg * xi;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = x[i + r];
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
float wg = wi * gi;
sumx += xi;
sumx2 += xi * xi;
sumwg += wg;
sumwgx += wg * xi;
}
}
}
}
sumx = simd_sum(sumx);
sumx2 = simd_sum(sumx2);
sumwg = simd_sum(sumwg);
sumwgx = simd_sum(sumwgx);
// Initialize shared memory
if (simd_group_id == 0) {
local_sumx[simd_lane_id] = 0;
local_sumx2[simd_lane_id] = 0;
local_sumwg[simd_lane_id] = 0;
local_sumwgx[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sumx[simd_group_id] = sumx;
local_sumx2[simd_group_id] = sumx2;
local_sumwg[simd_group_id] = sumwg;
local_sumwgx[simd_group_id] = sumwgx;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
sumx = simd_sum(local_sumx[simd_lane_id]);
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
sumwg = simd_sum(local_sumwg[simd_lane_id]);
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
if (simd_lane_id == 0) {
float mean = sumx / axis_size;
float variance = sumx2 / axis_size - mean * mean;
local_mean[0] = mean;
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
local_meanwg[0] = sumwg / axis_size;
local_meanwgx[0] = sumwgx / axis_size;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float mean = local_mean[0];
float normalizer = local_normalizer[0];
float meanwg = local_meanwg[0];
float meanwgxc = local_meanwgx[0] - meanwg * mean;
float normalizer2 = normalizer * normalizer;
// Write the outputs
gx += gid * axis_size + lid * N_READS;
gw += gid * axis_size + lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = (x[i + r] - mean) * normalizer;
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
xi * meanwgxc * normalizer2);
gw[i + r] = static_cast<T>(gi * xi);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = (x[i + r] - mean) * normalizer;
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
xi * meanwgxc * normalizer2);
gw[i + r] = static_cast<T>(gi * xi);
}
}
}
}
}
// clang-format off // clang-format off
#define instantiate_layer_norm_single_row(name, itype) \ #define instantiate_layer_norm_single_row(name, itype) \
template [[host_name("layer_norm" #name)]] [[kernel]] void \ template [[host_name("layer_norm" #name)]] [[kernel]] void \
layer_norm_single_row<itype>( \ layer_norm_single_row<itype>( \
const device itype* x, \ const device itype* x, \
const device itype* w, \ const device itype* w, \
const device itype* b, \ const device itype* b, \
device itype* out, \ device itype* out, \
constant float& eps, \ constant float& eps, \
constant uint& axis_size, \ constant uint& axis_size, \
constant uint& w_stride, \ constant uint& w_stride, \
constant uint& b_stride, \ constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \ uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \ uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \
vjp_layer_norm_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm_looped(name, itype) \ #define instantiate_layer_norm_looped(name, itype) \
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \ template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
layer_norm_looped<itype>( \ layer_norm_looped<itype>( \
const device itype* x, \ const device itype* x, \
const device itype* w, \ const device itype* w, \
const device itype* b, \ const device itype* b, \
device itype* out, \ device itype* out, \
constant float& eps, \ constant float& eps, \
constant uint& axis_size, \ constant uint& axis_size, \
constant uint& w_stride, \ constant uint& w_stride, \
constant uint& b_stride, \ constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \ uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \ uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \ uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \
vjp_layer_norm_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gb, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm(name, itype) \ #define instantiate_layer_norm(name, itype) \

View File

@@ -150,6 +150,216 @@ template <typename T, int N_READS = RMS_N_READS>
} }
} }
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void vjp_rms_single_row(
const device T* x,
const device T* w,
const device T* g,
device T* gx,
device T* gw,
constant float& eps,
constant uint& axis_size,
constant uint& w_stride,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Advance the input pointers
x += gid * axis_size + lid * N_READS;
g += gid * axis_size + lid * N_READS;
w += w_stride * lid * N_READS;
// Allocate registers for the computation and accumulators
float thread_x[N_READS];
float thread_w[N_READS];
float thread_g[N_READS];
float sumx2 = 0;
float sumgwx = 0;
// Allocate shared memory to implement the reduction
constexpr int SIMD_SIZE = 32;
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_sumgwx[SIMD_SIZE];
threadgroup float local_normalizer[1];
threadgroup float local_meangwx[1];
// Read and accumulate locally
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = x[i];
thread_w[i] = w[w_stride * i];
thread_g[i] = g[i];
sumx2 += thread_x[i] * thread_x[i];
sumgwx += thread_x[i] * thread_w[i] * thread_g[i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = x[i];
thread_w[i] = w[w_stride * i];
thread_g[i] = g[i];
sumx2 += thread_x[i] * thread_x[i];
sumgwx += thread_x[i] * thread_w[i] * thread_g[i];
}
}
}
// Accumulate across threads
sumx2 = simd_sum(sumx2);
sumgwx = simd_sum(sumgwx);
if (simd_group_id == 0) {
local_sumx2[simd_lane_id] = 0;
local_sumgwx[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_lane_id == 0) {
local_sumx2[simd_group_id] = sumx2;
local_sumgwx[simd_group_id] = sumgwx;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
sumgwx = simd_sum(local_sumgwx[simd_lane_id]);
if (simd_lane_id == 0) {
local_meangwx[0] = sumgwx / axis_size;
local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float meangwx = local_meangwx[0];
float normalizer = local_normalizer[0];
float normalizer3 = normalizer * normalizer * normalizer;
// Write the outputs
gx += gid * axis_size + lid * N_READS;
gw += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
}
}
}
}
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void vjp_rms_looped(
const device T* x,
const device T* w,
const device T* g,
device T* gx,
device T* gw,
constant float& eps,
constant uint& axis_size,
constant uint& w_stride,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Advance the input pointers
x += gid * axis_size + lid * N_READS;
g += gid * axis_size + lid * N_READS;
w += w_stride * lid * N_READS;
// Allocate registers for the accumulators
float sumx2 = 0;
float sumgwx = 0;
// Allocate shared memory to implement the reduction
constexpr int SIMD_SIZE = 32;
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_sumgwx[SIMD_SIZE];
threadgroup float local_normalizer[1];
threadgroup float local_meangwx[1];
// Read and accumulate locally
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = x[i + r];
float wi = w[w_stride * (i + r)];
float gi = g[i + r];
sumx2 += xi * xi;
sumgwx += xi * wi * gi;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = x[i + r];
float wi = w[w_stride * (i + r)];
float gi = g[i + r];
sumx2 += xi * xi;
sumgwx += xi * wi * gi;
}
}
}
}
// Accumulate across threads
sumx2 = simd_sum(sumx2);
sumgwx = simd_sum(sumgwx);
if (simd_group_id == 0) {
local_sumx2[simd_lane_id] = 0;
local_sumgwx[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_lane_id == 0) {
local_sumx2[simd_group_id] = sumx2;
local_sumgwx[simd_group_id] = sumgwx;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
sumgwx = simd_sum(local_sumgwx[simd_lane_id]);
if (simd_lane_id == 0) {
local_meangwx[0] = sumgwx / axis_size;
local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float meangwx = local_meangwx[0];
float normalizer = local_normalizer[0];
float normalizer3 = normalizer * normalizer * normalizer;
// Write the outputs
gx += gid * axis_size + lid * N_READS;
gw += gid * axis_size + lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = x[i + r];
float wi = w[w_stride * (i + r)];
float gi = g[i + r];
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gw[i + r] = static_cast<T>(gi * xi * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = x[i + r];
float wi = w[w_stride * (i + r)];
float gi = g[i + r];
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gw[i + r] = static_cast<T>(gi * xi * normalizer);
}
}
}
}
}
// clang-format off // clang-format off
#define instantiate_rms_single_row(name, itype) \ #define instantiate_rms_single_row(name, itype) \
template [[host_name("rms" #name)]] [[kernel]] void \ template [[host_name("rms" #name)]] [[kernel]] void \
@@ -165,25 +375,56 @@ template <typename T, int N_READS = RMS_N_READS>
uint gid [[thread_position_in_grid]], \ uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \ uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
\
#define instantiate_rms_looped(name, itype) \ template [[host_name("vjp_rms" #name)]] [[kernel]] void \
template [[host_name("rms_looped" #name)]] [[kernel]] void \ vjp_rms_single_row<itype>( \
rms_looped<itype>( \
const device itype* x, \ const device itype* x, \
const device itype* w, \ const device itype* w, \
device itype* out, \ const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \ constant float& eps, \
constant uint& axis_size, \ constant uint& axis_size, \
constant uint& w_stride, \ constant uint& w_stride, \
threadgroup float* local_inv_mean [[threadgroup(0)]], \
threadgroup float* local_sums [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \ uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \ uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms_looped(name, itype) \
template [[host_name("rms_looped" #name)]] [[kernel]] void \
rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
threadgroup float* local_inv_mean [[threadgroup(0)]], \
threadgroup float* local_sums [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
\
template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \
vjp_rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms(name, itype) \ #define instantiate_rms(name, itype) \
instantiate_rms_single_row(name, itype) \ instantiate_rms_single_row(name, itype) \
instantiate_rms_looped(name, itype) instantiate_rms_looped(name, itype)

View File

@@ -5,7 +5,7 @@
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional> template <typename T, bool traditional, bool forward>
[[kernel]] void rope( [[kernel]] void rope(
const device T *in [[buffer(0)]], const device T *in [[buffer(0)]],
device T * out [[buffer(1)]], device T * out [[buffer(1)]],
@@ -43,15 +43,22 @@ template <typename T, bool traditional>
// Read and write the output // Read and write the output
float x1 = static_cast<float>(in[in_index_1]); float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]); float x2 = static_cast<float>(in[in_index_2]);
float rx1 = x1 * costheta - x2 * sintheta; float rx1;
float rx2 = x1 * sintheta + x2 * costheta; float rx2;
if (forward) {
rx1 = x1 * costheta - x2 * sintheta;
rx2 = x1 * sintheta + x2 * costheta;
} else {
rx1 = x2 * sintheta + x1 * costheta;
rx2 = x2 * costheta - x1 * sintheta;
}
out[out_index_1] = static_cast<T>(rx1); out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2); out[out_index_2] = static_cast<T>(rx2);
} }
#define instantiate_rope(name, type, traditional) \ #define instantiate_rope(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] \ template [[host_name("rope_" #name)]] \
[[kernel]] void rope<type, traditional>( \ [[kernel]] void rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \ const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \ device type* out [[buffer(1)]], \
constant const size_t strides[3], \ constant const size_t strides[3], \
@@ -62,9 +69,15 @@ template <typename T, bool traditional>
uint3 pos [[thread_position_in_grid]], \ uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]); uint3 grid [[threads_per_grid]]);
instantiate_rope(traditional_float16, half, true) instantiate_rope(traditional_float16, half, true, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true) instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
instantiate_rope(traditional_float32, float, true) instantiate_rope(traditional_float32, float, true, true)
instantiate_rope(float16, half, false) instantiate_rope(float16, half, false, true)
instantiate_rope(bfloat16, bfloat16_t, false) instantiate_rope(bfloat16, bfloat16_t, false, true)
instantiate_rope(float32, float, false) instantiate_rope(float32, float, false, true)
instantiate_rope(vjp_traditional_float16, half, true, false)
instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
instantiate_rope(vjp_traditional_float32, float, true, false)
instantiate_rope(vjp_float16, half, false, false)
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
instantiate_rope(vjp_float32, float, false, false)

View File

@@ -4,6 +4,7 @@
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/reduce.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
@@ -95,6 +96,113 @@ void RMSNorm::eval_gpu(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); }); [copies](MTL::CommandBuffer*) mutable { copies.clear(); });
} }
void RMSNormVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& d = metal::device(s.device);
// Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler.
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) {
if (x.flags().row_contiguous) {
return x;
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
copies.push_back(x_copy);
return x_copy;
};
const array& x = check_input(inputs[0]);
const array& w = inputs[1];
const array& g = check_input(inputs[2]);
array& gx = outputs[0];
array& gw = outputs[1];
// Allocate space for the outputs
bool x_in_gx = false;
bool g_in_gx = false;
if (x.is_donatable()) {
gx.move_shared_buffer(x);
x_in_gx = true;
} else if (g.is_donatable()) {
gx.move_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
}
auto axis_size = static_cast<uint32_t>(x.shape().back());
int n_rows = x.data_size() / axis_size;
// Allocate a temporary to store the gradients for w and initialize the
// gradient accumulator to 0.
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
bool g_in_gw = false;
if (!g_in_gx && g.is_donatable()) {
gw_temp.move_shared_buffer(g);
g_in_gw = true;
} else {
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
}
copies.push_back(gw_temp);
array zero(0, gw.dtype());
copy_gpu(zero, gw, CopyType::Scalar, s);
const int simd_size = 32;
const int n_reads = RMS_N_READS;
const int looped_limit = RMS_LOOPED_LIMIT;
std::string op_name = "vjp_rms";
if (axis_size > looped_limit) {
op_name += "_looped";
}
op_name += type_to_name(gx);
auto compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
} else {
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
uint32_t w_stride = w.strides()[0];
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, x_in_gx ? gx : x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(
compute_encoder, g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
set_array_buffer(compute_encoder, gx, 3);
set_array_buffer(compute_encoder, gw_temp, 4);
compute_encoder->setBytes(&eps_, sizeof(float), 5);
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
strided_reduce_general_dispatch(
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
void LayerNorm::eval_gpu( void LayerNorm::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
@@ -182,4 +290,124 @@ void LayerNorm::eval_gpu(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); }); [copies](MTL::CommandBuffer*) mutable { copies.clear(); });
} }
void LayerNormVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& d = metal::device(s.device);
// Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler.
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) {
if (x.flags().row_contiguous) {
return x;
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
copies.push_back(x_copy);
return x_copy;
};
const array& x = check_input(inputs[0]);
const array& w = inputs[1];
const array& b = inputs[2];
const array& g = check_input(inputs[3]);
array& gx = outputs[0];
array& gw = outputs[1];
array& gb = outputs[2];
// Allocate space for the outputs
bool x_in_gx = false;
bool g_in_gx = false;
if (x.is_donatable()) {
gx.move_shared_buffer(x);
x_in_gx = true;
} else if (g.is_donatable()) {
gx.move_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
}
auto axis_size = static_cast<uint32_t>(x.shape().back());
int n_rows = x.data_size() / axis_size;
// Allocate a temporary to store the gradients for w and initialize the
// gradient accumulator to 0.
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
bool g_in_gw = false;
if (!g_in_gx && g.is_donatable()) {
gw_temp.move_shared_buffer(g);
g_in_gw = true;
} else {
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
}
copies.push_back(gw_temp);
array zero(0, gw.dtype());
copy_gpu(zero, gw, CopyType::Scalar, s);
copy_gpu(zero, gb, CopyType::Scalar, s);
// Finish with the gradient for b in case we had a b
auto compute_encoder = d.get_command_encoder(s.index);
if (gb.ndim() == 1 && gb.size() == axis_size) {
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
strided_reduce_general_dispatch(
g, gb, "sum", plan, {0}, compute_encoder, d, s);
}
const int simd_size = 32;
const int n_reads = RMS_N_READS;
const int looped_limit = RMS_LOOPED_LIMIT;
std::string op_name = "vjp_layer_norm";
if (axis_size > looped_limit) {
op_name += "_looped";
}
op_name += type_to_name(gx);
{
auto kernel = d.get_kernel(op_name);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
} else {
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
uint32_t w_stride = w.strides()[0];
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, x_in_gx ? gx : x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(
compute_encoder, g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
set_array_buffer(compute_encoder, gx, 3);
set_array_buffer(compute_encoder, gw_temp, 4);
compute_encoder->setBytes(&eps_, sizeof(float), 5);
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
if (gw.ndim() == 1 && gw.size() == axis_size) {
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
strided_reduce_general_dispatch(
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@@ -4,10 +4,10 @@
#include <cassert> #include <cassert>
#include <sstream> #include <sstream>
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/reduce.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@@ -18,8 +18,6 @@ namespace mlx::core {
// Case wise reduce dispatch // Case wise reduce dispatch
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
namespace {
inline auto safe_div(size_t n, size_t m) { inline auto safe_div(size_t n, size_t m) {
return m == 0 ? 0 : (n + m - 1) / m; return m == 0 ? 0 : (n + m - 1) / m;
} }
@@ -534,8 +532,6 @@ void strided_reduce_general_dispatch(
} }
} }
} // namespace
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
// Main reduce dispatch // Main reduce dispatch
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,39 @@
// Copyright @ 2023 - 2024 Apple Inc.
#pragma once
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/metal/device.h"
#include "mlx/stream.h"
namespace mlx::core {
void all_reduce_dispatch(
const array& in,
array& out,
const std::string& op_name,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d,
const Stream& s);
void row_reduce_general_dispatch(
const array& in,
array& out,
const std::string& op_name,
const ReductionPlan& plan,
const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d,
const Stream& s);
void strided_reduce_general_dispatch(
const array& in,
array& out,
const std::string& op_name,
const ReductionPlan& plan,
const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d,
const Stream& s);
} // namespace mlx::core

View File

@@ -63,7 +63,8 @@ void RoPE::eval_gpu(
out_strides[2] = out.strides()[ndim - 1]; out_strides[2] = out.strides()[ndim - 1];
std::ostringstream kname; std::ostringstream kname;
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in); kname << "rope_" << (forward_ ? "" : "vjp_")
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index); auto compute_encoder = d.get_command_encoder(s.index);

View File

@@ -103,7 +103,9 @@ NO_GPU(Inverse)
namespace fast { namespace fast {
NO_GPU_MULTI(LayerNorm) NO_GPU_MULTI(LayerNorm)
NO_GPU_MULTI(LayerNormVJP)
NO_GPU_MULTI(RMSNorm) NO_GPU_MULTI(RMSNorm)
NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE) NO_GPU_MULTI(RoPE)
NO_GPU(ScaledDotProductAttention) NO_GPU(ScaledDotProductAttention)
} // namespace fast } // namespace fast

View File

@@ -1,5 +1,8 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <numeric>
#include "mlx/fast.h" #include "mlx/fast.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/ops.h" #include "mlx/ops.h"
@@ -94,11 +97,69 @@ array rms_norm(
return fallback({x, weight})[0]; return fallback({x, weight})[0];
} }
std::vector<array> RMSNorm::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
assert(primals.size() == 2);
assert(outputs.size() == 1);
assert(cotangents.size() == 1);
auto s = stream();
auto fallback = [eps = eps_, s](const std::vector<array>& inputs) {
auto& x = inputs[0];
auto& w = inputs[1];
auto& g = inputs[2];
std::vector<array> vjps;
auto n = rsqrt(
add(mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s),
array(eps, x.dtype()),
s),
s);
auto n3 = power(n, array(3, x.dtype()), s);
// df/dx
auto gw = multiply(g, w, s);
auto t = mean(multiply(gw, x, s), /* axis= */ -1, /* keepdims= */ true, s);
t = multiply(multiply(x, t, s), n3, s);
vjps.push_back(subtract(multiply(gw, n, s), t, s));
// df/dw
std::vector<int> axes(g.ndim() - 1);
std::iota(axes.begin(), axes.end(), 0);
vjps.push_back(
sum(multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s));
return vjps;
};
auto vjps = array::make_arrays(
{primals[0].shape(), primals[1].shape()},
{primals[0].dtype(), primals[1].dtype()},
std::make_shared<RMSNormVJP>(s, fallback, eps_),
{primals[0], primals[1], cotangents[0]});
std::vector<array> returned_vjps;
for (auto& arg : argnums) {
returned_vjps.push_back(std::move(vjps[arg]));
}
return returned_vjps;
}
bool RMSNorm::is_equivalent(const Primitive& other) const { bool RMSNorm::is_equivalent(const Primitive& other) const {
const RMSNorm& a_other = static_cast<const RMSNorm&>(other); const RMSNorm& a_other = static_cast<const RMSNorm&>(other);
return eps_ == a_other.eps_; return eps_ == a_other.eps_;
} }
bool RMSNormVJP::is_equivalent(const Primitive& other) const {
const RMSNormVJP& a_other = static_cast<const RMSNormVJP&>(other);
return eps_ == a_other.eps_;
}
array layer_norm( array layer_norm(
const array& x, const array& x,
const std::optional<array>& weight, const std::optional<array>& weight,
@@ -176,11 +237,90 @@ array layer_norm(
return fallback({x, passed_weight, passed_bias})[0]; return fallback({x, passed_weight, passed_bias})[0];
} }
std::vector<array> LayerNorm::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
assert(primals.size() == 3);
assert(outputs.size() == 1);
assert(cotangents.size() == 1);
auto s = stream();
auto fallback = [eps = eps_, s](const std::vector<array>& inputs) {
auto& x = inputs[0];
auto& w = inputs[1];
auto& b = inputs[2];
auto& g = inputs[3];
std::vector<array> vjps;
auto norm = number_of_elements(x, {-1}, true, x.dtype(), s);
auto sumx = sum(x, /* axis= */ -1, /* keepdims= */ true, s);
auto sumx2 = sum(square(x, s), /* axis= */ -1, /* keepdims= */ true, s);
auto mu = multiply(sumx, norm, s);
auto mu2 = multiply(sumx2, norm, s);
auto var = subtract(mu2, square(mu, s), s);
auto n = rsqrt(add(var, array(eps, x.dtype()), s));
auto n3 = power(n, array(3, x.dtype()), s);
auto x_c = subtract(x, mu, s);
// df/dx
auto wg = multiply(w, g, s);
auto sumwg =
multiply(sum(wg, /* axis= */ -1, /* keepdims= */ true, s), norm, s);
auto sumwgxc = multiply(
sum(multiply(wg, x_c, s), /* axis= */ -1, /* keepdims= */ true, s),
norm,
s);
auto t1 = multiply(multiply(x_c, sumwgxc, s), n3, s);
auto t2 = multiply(subtract(wg, sumwg, s), n, s);
vjps.push_back(subtract(t2, t1, s));
// df/dw
std::vector<int> axes(g.ndim() - 1);
std::iota(axes.begin(), axes.end(), 0);
if (w.ndim() == 0) {
vjps.push_back(zeros_like(w, s));
} else {
vjps.push_back(sum(
multiply(g, multiply(x_c, n, s), s), axes, /* keepdims= */ false, s));
}
// df/db
if (b.ndim() == 0) {
vjps.push_back(zeros_like(w, s));
} else {
vjps.push_back(sum(g, axes, /* keepdims= */ false, s));
}
return vjps;
};
auto vjps = array::make_arrays(
{primals[0].shape(), primals[1].shape(), primals[2].shape()},
{primals[0].dtype(), primals[1].dtype(), primals[2].dtype()},
std::make_shared<LayerNormVJP>(s, fallback, eps_),
{primals[0], primals[1], primals[2], cotangents[0]});
std::vector<array> returned_vjps;
for (auto& arg : argnums) {
returned_vjps.push_back(std::move(vjps[arg]));
}
return returned_vjps;
}
bool LayerNorm::is_equivalent(const Primitive& other) const { bool LayerNorm::is_equivalent(const Primitive& other) const {
const LayerNorm& a_other = static_cast<const LayerNorm&>(other); const LayerNorm& a_other = static_cast<const LayerNorm&>(other);
return eps_ == a_other.eps_; return eps_ == a_other.eps_;
} }
bool LayerNormVJP::is_equivalent(const Primitive& other) const {
const LayerNormVJP& a_other = static_cast<const LayerNormVJP&>(other);
return eps_ == a_other.eps_;
}
array rope( array rope(
const array& x, const array& x,
int dims, int dims,
@@ -188,19 +328,16 @@ array rope(
float base, float base,
float scale, float scale,
int offset, int offset,
StreamOrDevice s /* = {} */) { bool forward,
StreamOrDevice s) {
if (x.ndim() < 3) { if (x.ndim() < 3) {
std::ostringstream msg; std::ostringstream msg;
msg << "[rope] Input must have at least 3 dimensions but got input with " msg << "[rope] Input must have at least 3 dimensions but got input with "
<< x.ndim() << " dimensions."; << x.ndim() << " dimensions.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (traditional && x.shape(-1) != dims) {
throw std::invalid_argument(
"[rope] Does not support partial traditional application.");
}
auto fallback = [dims, traditional, base, scale, offset, s]( auto fallback = [dims, traditional, base, scale, offset, forward, s](
const std::vector<array>& inputs) { const std::vector<array>& inputs) {
auto& shape = inputs[0].shape(); auto& shape = inputs[0].shape();
int ndim = shape.size(); int ndim = shape.size();
@@ -217,16 +354,39 @@ array rope(
auto coss = cos(theta, s); auto coss = cos(theta, s);
auto sins = sin(theta, s); auto sins = sin(theta, s);
if (traditional) { auto apply_rope = [forward, s](
auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s); const array& x1,
auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s); const array& x2,
const array& coss,
const array& sins) {
std::vector<array> outs; std::vector<array> outs;
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s)); if (forward) {
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s)); outs.push_back(
subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
} else {
outs.push_back(add(multiply(x2, sins, s), multiply(x1, coss, s), s));
outs.push_back(
subtract(multiply(x2, coss, s), multiply(x1, sins, s), s));
}
return outs;
};
if (traditional) {
auto x1 =
slice(x, {0, 0, 0}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
auto x2 =
slice(x, {0, 0, 1}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
auto outs = apply_rope(x1, x2, coss, sins);
for (auto& o : outs) { for (auto& o : outs) {
o = expand_dims(o, 3, s); o = expand_dims(o, 3, s);
} }
return std::vector<array>{reshape(concatenate(outs, 3, s), shape, s)}; auto out = concatenate(outs, 3, s);
if (dims < x.shape(-1)) {
out = reshape(out, {x.shape(0), x.shape(1), dims});
out = concatenate({out, slice(x, {0, 0, dims}, x.shape(), s)}, 2, s);
}
return std::vector<array>{reshape(out, shape, s)};
} else { } else {
auto out_s = x.shape(); auto out_s = x.shape();
out_s.back() = half_dims; out_s.back() = half_dims;
@@ -234,9 +394,7 @@ array rope(
out_s.back() = dims; out_s.back() = dims;
auto x2 = slice(x, {0, 0, half_dims}, out_s, s); auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
std::vector<array> outs; auto outs = apply_rope(x1, x2, coss, sins);
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
if (dims < x.shape(-1)) { if (dims < x.shape(-1)) {
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s)); outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
} }
@@ -249,18 +407,54 @@ array rope(
x.shape(), x.shape(),
x.dtype(), x.dtype(),
std::make_shared<RoPE>( std::make_shared<RoPE>(
stream, fallback, dims, traditional, base, scale, offset), stream, fallback, dims, traditional, base, scale, offset, forward),
{x}); {x});
} }
return fallback({x})[0]; return fallback({x})[0];
} }
array rope(
const array& x,
int dims,
bool traditional,
float base,
float scale,
int offset,
StreamOrDevice s /* = {} */) {
return rope(x, dims, traditional, base, scale, offset, true, s);
}
std::vector<array> RoPE::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
auto s = stream();
auto fallback = [dims = dims_,
traditional = traditional_,
base = base_,
scale = scale_,
offset = offset_,
forward = forward_,
s](std::vector<array> inputs) {
return std::vector<array>{
rope(inputs[0], dims, traditional, base, scale, offset, !forward, s)};
};
return {array(
cotangents[0].shape(),
cotangents[0].dtype(),
std::make_shared<RoPE>(
s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_),
cotangents)};
}
bool RoPE::is_equivalent(const Primitive& other) const { bool RoPE::is_equivalent(const Primitive& other) const {
const RoPE& a_other = static_cast<const RoPE&>(other); const RoPE& a_other = static_cast<const RoPE&>(other);
return ( return (
dims_ == a_other.dims_ && base_ == a_other.base_ && dims_ == a_other.dims_ && base_ == a_other.base_ &&
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ && scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&
offset_ == a_other.offset_); offset_ == a_other.offset_ && forward_ == a_other.forward_);
} }
/** Computes: O = softmax(Q @ K.T) @ V **/ /** Computes: O = softmax(Q @ K.T) @ V **/

View File

@@ -48,6 +48,12 @@ class RMSNorm : public Custom {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(RMSNorm) DEFINE_PRINT(RMSNorm)
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
@@ -56,6 +62,29 @@ class RMSNorm : public Custom {
float eps_; float eps_;
}; };
class RMSNormVJP : public Custom {
public:
RMSNormVJP(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
};
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(RMSNormVJP)
bool is_equivalent(const Primitive& other) const override;
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_;
};
class LayerNorm : public Custom { class LayerNorm : public Custom {
public: public:
LayerNorm( LayerNorm(
@@ -71,6 +100,12 @@ class LayerNorm : public Custom {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(LayerNorm) DEFINE_PRINT(LayerNorm)
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
@@ -79,6 +114,29 @@ class LayerNorm : public Custom {
float eps_; float eps_;
}; };
class LayerNormVJP : public Custom {
public:
LayerNormVJP(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
};
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(LayerNormVJP)
bool is_equivalent(const Primitive& other) const override;
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_;
};
class RoPE : public Custom { class RoPE : public Custom {
public: public:
RoPE( RoPE(
@@ -88,13 +146,15 @@ class RoPE : public Custom {
bool traditional, bool traditional,
float base, float base,
float scale, float scale,
int offset) int offset,
bool forward)
: Custom(stream, fallback), : Custom(stream, fallback),
dims_(dims), dims_(dims),
traditional_(traditional), traditional_(traditional),
base_(base), base_(base),
scale_(scale), scale_(scale),
offset_(offset){}; offset_(offset),
forward_(forward){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
@@ -103,6 +163,12 @@ class RoPE : public Custom {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(RoPE) DEFINE_PRINT(RoPE)
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
@@ -113,6 +179,7 @@ class RoPE : public Custom {
float base_; float base_;
float scale_; float scale_;
int offset_; int offset_;
bool forward_;
}; };
class ScaledDotProductAttention : public Custom { class ScaledDotProductAttention : public Custom {
@@ -126,7 +193,7 @@ class ScaledDotProductAttention : public Custom {
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
outputs[0] = fallback_(inputs)[0]; throw std::runtime_error("NYI");
}; };
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)

View File

@@ -16,11 +16,14 @@ def rope_orig(x, dims, traditional, base, scale, offset):
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
costheta, sintheta = mx.cos(theta), mx.sin(theta) costheta, sintheta = mx.cos(theta), mx.sin(theta)
if traditional: if traditional:
x1 = x[..., ::2] x1 = x[..., :dims:2]
x2 = x[..., 1::2] x2 = x[..., 1:dims:2]
rx1 = x1 * costheta - x2 * sintheta rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta rx2 = x1 * sintheta + x2 * costheta
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1) rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
if dims < x.shape[-1]:
rx = mx.reshape(rx, (*x.shape[:-1], dims))
rx = mx.concatenate([rx, x[..., dims:]], axis=-1)
return mx.reshape(rx, x.shape) return mx.reshape(rx, x.shape)
else: else:
x1 = x[..., : dims // 2] x1 = x[..., : dims // 2]
@@ -34,6 +37,26 @@ def rope_orig(x, dims, traditional, base, scale, offset):
return rx return rx
def rms_norm(x, weight, eps):
x = x.astype(mx.float32)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return weight * x.astype(weight.dtype)
def layer_norm(x, weight, bias, eps):
ot = x.dtype
x = x.astype(mx.float32)
mean = x.mean(axis=-1, keepdims=True)
var = x.var(axis=-1, keepdims=True)
x = (x - mean) * mx.rsqrt(var + eps)
x = x.astype(ot)
if weight is not None:
x = x * weight
if bias is not None:
x = x + bias
return x
class TestFast(mlx_tests.MLXTestCase): class TestFast(mlx_tests.MLXTestCase):
def test_rope(self): def test_rope(self):
T = 4 T = 4
@@ -115,12 +138,34 @@ class TestFast(mlx_tests.MLXTestCase):
) )
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
def test_rms_norm(self): def test_rope_grad(self):
def rms_norm(x, weight, eps): D = 32
x = x.astype(mx.float32) defaults = (D, 10000.0, 1.0, 0, False)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps) for dims in (D, D // 2):
return weight * x.astype(weight.dtype) for traditional in (True, False):
_, base, scale, offset, _ = defaults
f1 = lambda x, y: (
rope_orig(x, dims, traditional, base, scale, offset) * y
).sum()
f2 = lambda x, y: (
mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
* y
).sum()
x = mx.random.uniform(shape=(2, 100, D))
y = mx.random.uniform(shape=(2, 100, D))
g1 = mx.grad(f1)(x, y)
g2 = mx.grad(f2)(x, y)
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
def test_rms_norm(self):
# Per dtype absolute tolerance # Per dtype absolute tolerance
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2} tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
@@ -166,20 +211,42 @@ class TestFast(mlx_tests.MLXTestCase):
rx_fast = mx.fast.rms_norm(x, weight, eps) rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6) self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
def test_layer_norm(self): def test_rms_norm_grad(self):
def layer_norm(x, weight, bias, eps): D = 32
ot = x.dtype eps = 1e-5
x = x.astype(mx.float32) f1 = lambda x, w, y: (rms_norm(x, w, eps) * y).sum()
mean = x.mean(axis=-1, keepdims=True) f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, eps) * y).sum()
var = x.var(axis=-1, keepdims=True)
x = (x - mean) * mx.rsqrt(var + eps)
x = x.astype(ot)
if weight is not None:
x = x * weight
if bias is not None:
x = x + bias
return x
x = mx.random.uniform(shape=(8, 100, D))
w = mx.random.uniform(shape=(D,))
y = mx.random.uniform(shape=(8, 100, D))
gx1, gw1 = mx.grad(f1, argnums=(0, 1))(x, w, y)
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
D = 8192
x = mx.random.uniform(shape=(2, 2, D))
w = mx.random.uniform(shape=(D,))
y = mx.random.uniform(shape=(2, 2, D))
gx1, gw1 = mx.grad(f1, argnums=(0, 1))(x, w, y)
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
def gf(f):
def inner(x, w, y):
gx, gw = mx.grad(f, argnums=(0, 1))(x, w, y)
return (gx + gw).sum()
return inner
gx1, gw1 = mx.grad(gf(f1), argnums=(0, 1))(x, w, y)
gx2, gw2 = mx.grad(gf(f2), argnums=(0, 1))(x, w, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
def test_layer_norm(self):
# Per dtype absolute tolerance # Per dtype absolute tolerance
tolerances = {mx.float32: 3e-6, mx.float16: 3e-3, mx.bfloat16: 3e-2} tolerances = {mx.float32: 3e-6, mx.float16: 3e-3, mx.bfloat16: 3e-2}
@@ -265,6 +332,49 @@ class TestFast(mlx_tests.MLXTestCase):
rx_fast = mx.fast.layer_norm(x, None, None, eps) rx_fast = mx.fast.layer_norm(x, None, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
def test_layer_norm_grad(self):
D = 32
eps = 1e-5
f1 = lambda x, w, b, y: (layer_norm(x, w, b, eps) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, eps) * y).sum()
x = mx.random.uniform(shape=(8, 100, D))
w = mx.random.uniform(shape=(D,))
b = mx.random.uniform(shape=(D,))
y = mx.random.uniform(shape=(8, 100, D))
gx1, gw1, gb1 = mx.grad(f1, argnums=(0, 1, 2))(x, w, b, y)
gx2, gw2, gb2 = mx.grad(f2, argnums=(0, 1, 2))(x, w, b, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
self.assertLess(mx.abs(gb1 - gb2).max() / mx.abs(gb1).mean(), 1e-5)
D = 8192
x = mx.random.uniform(shape=(8, 100, D))
w = mx.random.uniform(shape=(D,))
b = mx.random.uniform(shape=(D,))
y = mx.random.uniform(shape=(8, 100, D))
gx1, gw1, gb1 = mx.grad(f1, argnums=(0, 1, 2))(x, w, b, y)
gx2, gw2, gb2 = mx.grad(f2, argnums=(0, 1, 2))(x, w, b, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
self.assertLess(mx.abs(gb1 - gb2).max() / mx.abs(gb1).mean(), 1e-5)
def gf(f):
def inner(x, w, b, y):
gx, gw, gb = mx.grad(f, argnums=(0, 1, 2))(x, w, b, y)
return ((gx + gw + gb) * y).sum()
return inner
gx1, gw1, gb1 = mx.grad(gf(f1), argnums=(0, 1, 2))(x, w, b, y)
gx2, gw2, gb2 = mx.grad(gf(f2), argnums=(0, 1, 2))(x, w, b, y)
self.assertLess(mx.abs(gx1 - gx2).max() / mx.abs(gx1).mean(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
self.assertLess(mx.abs(gb1).max(), 1e-9)
self.assertLess(mx.abs(gb2).max(), 1e-9)
def test_fast_transforms(self): def test_fast_transforms(self):
x = mx.random.uniform(shape=(2, 2, 8)) x = mx.random.uniform(shape=(2, 2, 8))