mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-29 13:55:29 +08:00
Donation bug (#933)
* donation * buf * fix bug in softmax * comment * remove print
This commit is contained in:
parent
f48bc496c7
commit
8915901966
@ -19,7 +19,7 @@ void RMSNorm::eval_gpu(
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
@ -28,10 +28,9 @@ void RMSNorm::eval_gpu(
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
@ -106,15 +105,13 @@ void RMSNormVJP::eval_gpu(
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
}
|
||||
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
@ -149,8 +146,11 @@ void RMSNormVJP::eval_gpu(
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
}
|
||||
copies.push_back(gw_temp);
|
||||
array zero(0, gw.dtype());
|
||||
copy_gpu(zero, gw, CopyType::Scalar, s);
|
||||
{
|
||||
array zero(0, gw.dtype());
|
||||
copy_gpu(zero, gw, CopyType::Scalar, s);
|
||||
copies.push_back(std::move(zero));
|
||||
}
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
@ -212,7 +212,7 @@ void LayerNorm::eval_gpu(
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
@ -221,10 +221,9 @@ void LayerNorm::eval_gpu(
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
@ -300,15 +299,13 @@ void LayerNormVJP::eval_gpu(
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
}
|
||||
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
@ -345,9 +342,12 @@ void LayerNormVJP::eval_gpu(
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
}
|
||||
copies.push_back(gw_temp);
|
||||
array zero(0, gw.dtype());
|
||||
copy_gpu(zero, gw, CopyType::Scalar, s);
|
||||
copy_gpu(zero, gb, CopyType::Scalar, s);
|
||||
{
|
||||
array zero(0, gw.dtype());
|
||||
copy_gpu(zero, gw, CopyType::Scalar, s);
|
||||
copy_gpu(zero, gb, CopyType::Scalar, s);
|
||||
copies.push_back(std::move(zero));
|
||||
}
|
||||
|
||||
// Finish with the gradient for b in case we had a b
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@ -21,7 +21,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
@ -30,10 +30,9 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
const array& in = check_input(inputs[0]);
|
||||
@ -81,7 +80,6 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(
|
||||
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
|
||||
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 0);
|
||||
|
Loading…
Reference in New Issue
Block a user