Fix cumulative operations when axis=None (#2653)

This commit is contained in:
AN Long
2025-10-09 07:25:38 +09:00
committed by GitHub
parent f5d4397e5c
commit 85a8824a8c
3 changed files with 81 additions and 19 deletions

View File

@@ -716,6 +716,13 @@ array topk(const array& a, int k, StreamOrDevice s = {});
array topk(const array& a, int k, int axis, StreamOrDevice s = {});
/** Cumulative logsumexp of an array. */
array logcumsumexp(
const array& a,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Cumulative logsumexp of an array along the given axis. */
array logcumsumexp(
const array& a,
int axis,
@@ -1186,6 +1193,13 @@ softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
array power(const array& a, const array& b, StreamOrDevice s = {});
/** Cumulative sum of an array. */
array cumsum(
const array& a,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Cumulative sum of an array along the given axis. */
array cumsum(
const array& a,
int axis,
@@ -1194,6 +1208,13 @@ array cumsum(
StreamOrDevice s = {});
/** Cumulative product of an array. */
array cumprod(
const array& a,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Cumulative product of an array along the given axis. */
array cumprod(
const array& a,
int axis,
@@ -1202,6 +1223,13 @@ array cumprod(
StreamOrDevice s = {});
/** Cumulative max of an array. */
array cummax(
const array& a,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Cumulative max of an array along the given axis. */
array cummax(
const array& a,
int axis,
@@ -1210,6 +1238,13 @@ array cummax(
StreamOrDevice s = {});
/** Cumulative min of an array. */
array cummin(
const array& a,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Cumulative min of an array along the given axis. */
array cummin(
const array& a,
int axis,