Fix copy donation and add partial rope (#881)

This commit is contained in:
Angelos Katharopoulos 2024-03-22 17:28:26 -07:00 committed by GitHub
parent 8e5a5a1ccd
commit 6ee1112f30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 42 additions and 20 deletions

View File

@ -6,21 +6,21 @@ from time_utils import time_fn
def time_rope(): def time_rope():
rope = nn.RoPE(4096) rope = nn.RoPE(64)
# vec # vec
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16) x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
mx.eval(x) mx.eval(x)
def rope_vec(x): def rope_vec(x):
for _ in range(32): for _ in range(32):
x = rope(x) x = rope(x, offset=100)
return x return x
time_fn(rope_vec, x) time_fn(rope_vec, x)
# matrix # matrix
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16) x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
mx.eval(x) mx.eval(x)
def rope_mat(x): def rope_mat(x):

View File

@ -12,8 +12,15 @@ namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) { if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
if (in.is_donatable() && in.itemsize() == out.itemsize()) { if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in); out.move_shared_buffer(in);
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
if (in.dtype() == out.dtype()) {
return;
}
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()), allocator::malloc_or_wait(in.data_size() * out.itemsize()),

View File

@ -10,6 +10,7 @@ template <typename T, bool traditional>
const device T *in [[buffer(0)]], const device T *in [[buffer(0)]],
device T * out [[buffer(1)]], device T * out [[buffer(1)]],
constant const size_t strides[3], constant const size_t strides[3],
constant const size_t out_strides[3],
constant const int& offset, constant const int& offset,
constant const float& base, constant const float& base,
constant const float& scale, constant const float& scale,
@ -19,13 +20,13 @@ template <typename T, bool traditional>
uint in_index_1, in_index_2; uint in_index_1, in_index_2;
uint out_index_1, out_index_2; uint out_index_1, out_index_2;
if (traditional) { if (traditional) {
out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z)); out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
out_index_2 = out_index_1 + 1; out_index_2 = out_index_1 + 1;
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + strides[2]; in_index_2 = in_index_1 + strides[2];
} else { } else {
out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z)); out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
out_index_2 = out_index_1 + grid.x; out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2]; in_index_2 = in_index_1 + grid.x * strides[2];
} }
@ -54,6 +55,7 @@ template <typename T, bool traditional>
const device type* in [[buffer(0)]], \ const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \ device type* out [[buffer(1)]], \
constant const size_t strides[3], \ constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const int& offset, \ constant const int& offset, \
constant const float& base, \ constant const float& base, \
constant const float& scale, \ constant const float& scale, \

View File

@ -16,18 +16,24 @@ void RoPE::eval_gpu(
if (in.ndim() < 3) { if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions"); throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
} }
if (dims_ != in.shape(-1)) {
throw std::runtime_error("[RoPE] Partial RoPE application not supported");
}
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
size_t strides[3]; size_t strides[3];
size_t out_strides[3];
bool donated = false; bool donated = false;
int ndim = in.ndim(); int ndim = in.ndim();
size_t mat_size = in.shape()[ndim - 2] * in.shape()[ndim - 1]; size_t mat_size = in.shape(-2) * in.shape(-1);
if (in.flags().row_contiguous) { if (dims_ < in.shape(-1)) {
donated = true;
auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
copy_gpu(in, out, ctype, s);
strides[0] = mat_size;
strides[1] = out.strides()[ndim - 2];
strides[2] = out.strides()[ndim - 1];
} else if (in.flags().row_contiguous) {
if (in.is_donatable()) { if (in.is_donatable()) {
donated = true; donated = true;
out.move_shared_buffer(in); out.move_shared_buffer(in);
@ -52,6 +58,9 @@ void RoPE::eval_gpu(
strides[1] = out.strides()[ndim - 2]; strides[1] = out.strides()[ndim - 2];
strides[2] = out.strides()[ndim - 1]; strides[2] = out.strides()[ndim - 1];
} }
out_strides[0] = mat_size;
out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1];
std::ostringstream kname; std::ostringstream kname;
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in); kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
@ -63,12 +72,13 @@ void RoPE::eval_gpu(
set_array_buffer(compute_encoder, donated ? out : in, 0); set_array_buffer(compute_encoder, donated ? out : in, 0);
set_array_buffer(compute_encoder, out, 1); set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2); compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2);
compute_encoder->setBytes(&offset_, sizeof(int), 3); compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 3);
compute_encoder->setBytes(&base, sizeof(float), 4); compute_encoder->setBytes(&offset_, sizeof(int), 4);
compute_encoder->setBytes(&scale_, sizeof(float), 5); compute_encoder->setBytes(&base, sizeof(float), 5);
compute_encoder->setBytes(&scale_, sizeof(float), 6);
int dim0 = in.shape()[ndim - 1] / 2; int dim0 = dims_ / 2;
int dim1 = in.shape()[ndim - 2]; int dim1 = in.shape(-2);
int dim2 = in.size() / mat_size; int dim2 = in.size() / mat_size;
auto group_dims = get_block_dims(dim0, dim1, dim2); auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2); auto grid_dims = MTL::Size(dim0, dim1, dim2);

View File

@ -244,7 +244,7 @@ array rope(
} }
}; };
auto stream = to_stream(s); auto stream = to_stream(s);
if (stream.device == Device::gpu && x.shape(-1) == dims) { if (stream.device == Device::gpu) {
return array( return array(
x.shape(), x.shape(),
x.dtype(), x.dtype(),

View File

@ -2590,8 +2590,11 @@ std::vector<array> Scatter::vjp(
break; break;
case Scatter::Max: case Scatter::Max:
case Scatter::Min: { case Scatter::Min: {
auto mask = where(result == values, array({1}), array({0})); vjps.push_back(where(
vjps.push_back(multiply(cotangents[0], mask)); equal(result, values, stream()),
cotangents[0],
array(0, cotangents[0].dtype()),
stream()));
break; break;
} }
default: default: