mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fix copy donation and add partial rope (#881)
This commit is contained in:
parent
8e5a5a1ccd
commit
6ee1112f30
@ -6,21 +6,21 @@ from time_utils import time_fn
|
|||||||
|
|
||||||
|
|
||||||
def time_rope():
|
def time_rope():
|
||||||
rope = nn.RoPE(4096)
|
rope = nn.RoPE(64)
|
||||||
|
|
||||||
# vec
|
# vec
|
||||||
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
|
x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
|
||||||
mx.eval(x)
|
mx.eval(x)
|
||||||
|
|
||||||
def rope_vec(x):
|
def rope_vec(x):
|
||||||
for _ in range(32):
|
for _ in range(32):
|
||||||
x = rope(x)
|
x = rope(x, offset=100)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
time_fn(rope_vec, x)
|
time_fn(rope_vec, x)
|
||||||
|
|
||||||
# matrix
|
# matrix
|
||||||
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
|
x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
|
||||||
mx.eval(x)
|
mx.eval(x)
|
||||||
|
|
||||||
def rope_mat(x):
|
def rope_mat(x):
|
||||||
|
@ -12,8 +12,15 @@ namespace mlx::core {
|
|||||||
|
|
||||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||||
if (ctype == CopyType::Vector) {
|
if (ctype == CopyType::Vector) {
|
||||||
|
// If the input is donateable, we are doing a vector copy and the types
|
||||||
|
// have the same size, then the input buffer can hold the output.
|
||||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||||
out.move_shared_buffer(in);
|
out.move_shared_buffer(in);
|
||||||
|
// If the output has the same type as the input then there is nothing to
|
||||||
|
// copy, just use the buffer.
|
||||||
|
if (in.dtype() == out.dtype()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||||
|
@ -10,6 +10,7 @@ template <typename T, bool traditional>
|
|||||||
const device T *in [[buffer(0)]],
|
const device T *in [[buffer(0)]],
|
||||||
device T * out [[buffer(1)]],
|
device T * out [[buffer(1)]],
|
||||||
constant const size_t strides[3],
|
constant const size_t strides[3],
|
||||||
|
constant const size_t out_strides[3],
|
||||||
constant const int& offset,
|
constant const int& offset,
|
||||||
constant const float& base,
|
constant const float& base,
|
||||||
constant const float& scale,
|
constant const float& scale,
|
||||||
@ -19,13 +20,13 @@ template <typename T, bool traditional>
|
|||||||
uint in_index_1, in_index_2;
|
uint in_index_1, in_index_2;
|
||||||
uint out_index_1, out_index_2;
|
uint out_index_1, out_index_2;
|
||||||
if (traditional) {
|
if (traditional) {
|
||||||
out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z));
|
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
|
||||||
out_index_2 = out_index_1 + 1;
|
out_index_2 = out_index_1 + 1;
|
||||||
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||||
in_index_2 = in_index_1 + strides[2];
|
in_index_2 = in_index_1 + strides[2];
|
||||||
} else {
|
} else {
|
||||||
out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z));
|
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
|
||||||
out_index_2 = out_index_1 + grid.x;
|
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
||||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||||
}
|
}
|
||||||
@ -54,6 +55,7 @@ template <typename T, bool traditional>
|
|||||||
const device type* in [[buffer(0)]], \
|
const device type* in [[buffer(0)]], \
|
||||||
device type* out [[buffer(1)]], \
|
device type* out [[buffer(1)]], \
|
||||||
constant const size_t strides[3], \
|
constant const size_t strides[3], \
|
||||||
|
constant const size_t out_strides[3], \
|
||||||
constant const int& offset, \
|
constant const int& offset, \
|
||||||
constant const float& base, \
|
constant const float& base, \
|
||||||
constant const float& scale, \
|
constant const float& scale, \
|
||||||
|
@ -16,18 +16,24 @@ void RoPE::eval_gpu(
|
|||||||
if (in.ndim() < 3) {
|
if (in.ndim() < 3) {
|
||||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
||||||
}
|
}
|
||||||
if (dims_ != in.shape(-1)) {
|
|
||||||
throw std::runtime_error("[RoPE] Partial RoPE application not supported");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& s = out.primitive().stream();
|
auto& s = out.primitive().stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
size_t strides[3];
|
size_t strides[3];
|
||||||
|
size_t out_strides[3];
|
||||||
bool donated = false;
|
bool donated = false;
|
||||||
int ndim = in.ndim();
|
int ndim = in.ndim();
|
||||||
size_t mat_size = in.shape()[ndim - 2] * in.shape()[ndim - 1];
|
size_t mat_size = in.shape(-2) * in.shape(-1);
|
||||||
if (in.flags().row_contiguous) {
|
if (dims_ < in.shape(-1)) {
|
||||||
|
donated = true;
|
||||||
|
auto ctype =
|
||||||
|
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||||
|
copy_gpu(in, out, ctype, s);
|
||||||
|
strides[0] = mat_size;
|
||||||
|
strides[1] = out.strides()[ndim - 2];
|
||||||
|
strides[2] = out.strides()[ndim - 1];
|
||||||
|
} else if (in.flags().row_contiguous) {
|
||||||
if (in.is_donatable()) {
|
if (in.is_donatable()) {
|
||||||
donated = true;
|
donated = true;
|
||||||
out.move_shared_buffer(in);
|
out.move_shared_buffer(in);
|
||||||
@ -52,6 +58,9 @@ void RoPE::eval_gpu(
|
|||||||
strides[1] = out.strides()[ndim - 2];
|
strides[1] = out.strides()[ndim - 2];
|
||||||
strides[2] = out.strides()[ndim - 1];
|
strides[2] = out.strides()[ndim - 1];
|
||||||
}
|
}
|
||||||
|
out_strides[0] = mat_size;
|
||||||
|
out_strides[1] = out.strides()[ndim - 2];
|
||||||
|
out_strides[2] = out.strides()[ndim - 1];
|
||||||
|
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
|
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||||
@ -63,12 +72,13 @@ void RoPE::eval_gpu(
|
|||||||
set_array_buffer(compute_encoder, donated ? out : in, 0);
|
set_array_buffer(compute_encoder, donated ? out : in, 0);
|
||||||
set_array_buffer(compute_encoder, out, 1);
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2);
|
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2);
|
||||||
compute_encoder->setBytes(&offset_, sizeof(int), 3);
|
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 3);
|
||||||
compute_encoder->setBytes(&base, sizeof(float), 4);
|
compute_encoder->setBytes(&offset_, sizeof(int), 4);
|
||||||
compute_encoder->setBytes(&scale_, sizeof(float), 5);
|
compute_encoder->setBytes(&base, sizeof(float), 5);
|
||||||
|
compute_encoder->setBytes(&scale_, sizeof(float), 6);
|
||||||
|
|
||||||
int dim0 = in.shape()[ndim - 1] / 2;
|
int dim0 = dims_ / 2;
|
||||||
int dim1 = in.shape()[ndim - 2];
|
int dim1 = in.shape(-2);
|
||||||
int dim2 = in.size() / mat_size;
|
int dim2 = in.size() / mat_size;
|
||||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||||
auto grid_dims = MTL::Size(dim0, dim1, dim2);
|
auto grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||||
|
@ -244,7 +244,7 @@ array rope(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
auto stream = to_stream(s);
|
auto stream = to_stream(s);
|
||||||
if (stream.device == Device::gpu && x.shape(-1) == dims) {
|
if (stream.device == Device::gpu) {
|
||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
|
@ -2590,8 +2590,11 @@ std::vector<array> Scatter::vjp(
|
|||||||
break;
|
break;
|
||||||
case Scatter::Max:
|
case Scatter::Max:
|
||||||
case Scatter::Min: {
|
case Scatter::Min: {
|
||||||
auto mask = where(result == values, array({1}), array({0}));
|
vjps.push_back(where(
|
||||||
vjps.push_back(multiply(cotangents[0], mask));
|
equal(result, values, stream()),
|
||||||
|
cotangents[0],
|
||||||
|
array(0, cotangents[0].dtype()),
|
||||||
|
stream()));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
Loading…
Reference in New Issue
Block a user