Fix copy donation and add partial rope (#881)

This commit is contained in:
Angelos Katharopoulos 2024-03-22 17:28:26 -07:00 committed by GitHub
parent 8e5a5a1ccd
commit 6ee1112f30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 42 additions and 20 deletions

View File

@ -6,21 +6,21 @@ from time_utils import time_fn
def time_rope():
rope = nn.RoPE(4096)
rope = nn.RoPE(64)
# vec
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
mx.eval(x)
def rope_vec(x):
for _ in range(32):
x = rope(x)
x = rope(x, offset=100)
return x
time_fn(rope_vec, x)
# matrix
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
mx.eval(x)
def rope_mat(x):

View File

@ -12,8 +12,15 @@ namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in);
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
if (in.dtype() == out.dtype()) {
return;
}
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),

View File

@ -10,6 +10,7 @@ template <typename T, bool traditional>
const device T *in [[buffer(0)]],
device T * out [[buffer(1)]],
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const int& offset,
constant const float& base,
constant const float& scale,
@ -19,13 +20,13 @@ template <typename T, bool traditional>
uint in_index_1, in_index_2;
uint out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z));
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z));
out_index_2 = out_index_1 + grid.x;
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
}
@ -54,6 +55,7 @@ template <typename T, bool traditional>
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \

View File

@ -16,18 +16,24 @@ void RoPE::eval_gpu(
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
if (dims_ != in.shape(-1)) {
throw std::runtime_error("[RoPE] Partial RoPE application not supported");
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
size_t strides[3];
size_t out_strides[3];
bool donated = false;
int ndim = in.ndim();
size_t mat_size = in.shape()[ndim - 2] * in.shape()[ndim - 1];
if (in.flags().row_contiguous) {
size_t mat_size = in.shape(-2) * in.shape(-1);
if (dims_ < in.shape(-1)) {
donated = true;
auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
copy_gpu(in, out, ctype, s);
strides[0] = mat_size;
strides[1] = out.strides()[ndim - 2];
strides[2] = out.strides()[ndim - 1];
} else if (in.flags().row_contiguous) {
if (in.is_donatable()) {
donated = true;
out.move_shared_buffer(in);
@ -52,6 +58,9 @@ void RoPE::eval_gpu(
strides[1] = out.strides()[ndim - 2];
strides[2] = out.strides()[ndim - 1];
}
out_strides[0] = mat_size;
out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1];
std::ostringstream kname;
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
@ -63,12 +72,13 @@ void RoPE::eval_gpu(
set_array_buffer(compute_encoder, donated ? out : in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2);
compute_encoder->setBytes(&offset_, sizeof(int), 3);
compute_encoder->setBytes(&base, sizeof(float), 4);
compute_encoder->setBytes(&scale_, sizeof(float), 5);
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 3);
compute_encoder->setBytes(&offset_, sizeof(int), 4);
compute_encoder->setBytes(&base, sizeof(float), 5);
compute_encoder->setBytes(&scale_, sizeof(float), 6);
int dim0 = in.shape()[ndim - 1] / 2;
int dim1 = in.shape()[ndim - 2];
int dim0 = dims_ / 2;
int dim1 = in.shape(-2);
int dim2 = in.size() / mat_size;
auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2);

View File

@ -244,7 +244,7 @@ array rope(
}
};
auto stream = to_stream(s);
if (stream.device == Device::gpu && x.shape(-1) == dims) {
if (stream.device == Device::gpu) {
return array(
x.shape(),
x.dtype(),

View File

@ -2590,8 +2590,11 @@ std::vector<array> Scatter::vjp(
break;
case Scatter::Max:
case Scatter::Min: {
auto mask = where(result == values, array({1}), array({0}));
vjps.push_back(multiply(cotangents[0], mask));
vjps.push_back(where(
equal(result, values, stream()),
cotangents[0],
array(0, cotangents[0].dtype()),
stream()));
break;
}
default: