Compare commits

..

2 Commits

Author SHA1 Message Date
AN Long
85a8824a8c Fix cumulative operations when axis=None (#2653) 2025-10-08 15:25:38 -07:00
Awni Hannun
f5d4397e5c Fix fast synch when fence is waited before a command buffer is created (#2657) 2025-10-08 11:23:46 -07:00
4 changed files with 85 additions and 19 deletions

View File

@@ -471,6 +471,10 @@ void Device::end_encoding(int index) {
CommandEncoder& Device::get_command_encoder(int index) {
auto& stream = get_stream_(index);
if (stream.encoder == nullptr) {
// Ensure there is an active command buffer
if (stream.buffer == nullptr) {
get_command_buffer(index);
}
stream.encoder = std::make_unique<CommandEncoder>(stream);
stream.fence = std::make_shared<Fence>(device_->newFence());
}

View File

@@ -3468,6 +3468,14 @@ array cumsum(
{a});
}
array cumsum(
const array& a,
bool reverse /* = false*/,
bool inclusive /* = true*/,
StreamOrDevice s /* = {}*/) {
return cumsum(flatten(a, to_stream(s)), 0, reverse, inclusive, to_stream(s));
}
array cumprod(
const array& a,
int axis,
@@ -3490,6 +3498,14 @@ array cumprod(
{a});
}
array cumprod(
const array& a,
bool reverse /* = false*/,
bool inclusive /* = true*/,
StreamOrDevice s /* = {}*/) {
return cumprod(flatten(a, s), 0, reverse, inclusive, s);
}
array cummax(
const array& a,
int axis,
@@ -3512,6 +3528,14 @@ array cummax(
{a});
}
array cummax(
const array& a,
bool reverse /* = false*/,
bool inclusive /* = true*/,
StreamOrDevice s /* = {}*/) {
return cummax(flatten(a, s), 0, reverse, inclusive, s);
}
array cummin(
const array& a,
int axis,
@@ -3534,6 +3558,14 @@ array cummin(
{a});
}
array cummin(
const array& a,
bool reverse /* = false*/,
bool inclusive /* = true*/,
StreamOrDevice s /* = {}*/) {
return cummin(flatten(a, s), 0, reverse, inclusive, s);
}
array logcumsumexp(
const array& a,
int axis,
@@ -3556,6 +3588,15 @@ array logcumsumexp(
{a});
}
array logcumsumexp(
const array& a,
bool reverse /* = false*/,
bool inclusive /* = true*/,
StreamOrDevice s /* = {}*/) {
return logcumsumexp(
flatten(a, to_stream(s)), 0, reverse, inclusive, to_stream(s));
}
/** Convolution operations */
namespace {

View File

@@ -716,6 +716,13 @@ array topk(const array& a, int k, StreamOrDevice s = {});
array topk(const array& a, int k, int axis, StreamOrDevice s = {});
/** Cumulative logsumexp of an array. */
array logcumsumexp(
const array& a,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Cumulative logsumexp of an array along the given axis. */
array logcumsumexp(
const array& a,
int axis,
@@ -1186,6 +1193,13 @@ softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
array power(const array& a, const array& b, StreamOrDevice s = {});
/** Cumulative sum of an array. */
array cumsum(
const array& a,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Cumulative sum of an array along the given axis. */
array cumsum(
const array& a,
int axis,
@@ -1194,6 +1208,13 @@ array cumsum(
StreamOrDevice s = {});
/** Cumulative product of an array. */
array cumprod(
const array& a,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Cumulative product of an array along the given axis. */
array cumprod(
const array& a,
int axis,
@@ -1202,6 +1223,13 @@ array cumprod(
StreamOrDevice s = {});
/** Cumulative max of an array. */
array cummax(
const array& a,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Cumulative max of an array along the given axis. */
array cummax(
const array& a,
int axis,
@@ -1210,6 +1238,13 @@ array cummax(
StreamOrDevice s = {});
/** Cumulative min of an array. */
array cummin(
const array& a,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Cumulative min of an array along the given axis. */
array cummin(
const array& a,
int axis,

View File

@@ -1275,10 +1275,7 @@ void init_array(nb::module_& m) {
if (axis) {
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
} else {
// TODO: Implement that in the C++ API as well. See concatenate
// above.
return mx::logcumsumexp(
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
return mx::logcumsumexp(a, reverse, inclusive, s);
}
},
"axis"_a = nb::none(),
@@ -1408,9 +1405,7 @@ void init_array(nb::module_& m) {
if (axis) {
return mx::cumsum(a, *axis, reverse, inclusive, s);
} else {
// TODO: Implement that in the C++ API as well. See concatenate
// above.
return mx::cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s);
return mx::cumsum(a, reverse, inclusive, s);
}
},
"axis"_a = nb::none(),
@@ -1429,10 +1424,7 @@ void init_array(nb::module_& m) {
if (axis) {
return mx::cumprod(a, *axis, reverse, inclusive, s);
} else {
// TODO: Implement that in the C++ API as well. See concatenate
// above.
return mx::cumprod(
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
return mx::cumprod(a, reverse, inclusive, s);
}
},
"axis"_a = nb::none(),
@@ -1451,10 +1443,7 @@ void init_array(nb::module_& m) {
if (axis) {
return mx::cummax(a, *axis, reverse, inclusive, s);
} else {
// TODO: Implement that in the C++ API as well. See concatenate
// above.
return mx::cummax(
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
return mx::cummax(a, reverse, inclusive, s);
}
},
"axis"_a = nb::none(),
@@ -1473,10 +1462,7 @@ void init_array(nb::module_& m) {
if (axis) {
return mx::cummin(a, *axis, reverse, inclusive, s);
} else {
// TODO: Implement that in the C++ API as well. See concatenate
// above.
return mx::cummin(
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
return mx::cummin(a, reverse, inclusive, s);
}
},
"axis"_a = nb::none(),