mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix cumulative operations when axis=None (#2653)
This commit is contained in:
35
mlx/ops.h
35
mlx/ops.h
@@ -716,6 +716,13 @@ array topk(const array& a, int k, StreamOrDevice s = {});
|
||||
array topk(const array& a, int k, int axis, StreamOrDevice s = {});
|
||||
|
||||
/** Cumulative logsumexp of an array. */
|
||||
array logcumsumexp(
|
||||
const array& a,
|
||||
bool reverse = false,
|
||||
bool inclusive = true,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Cumulative logsumexp of an array along the given axis. */
|
||||
array logcumsumexp(
|
||||
const array& a,
|
||||
int axis,
|
||||
@@ -1186,6 +1193,13 @@ softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
|
||||
array power(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
/** Cumulative sum of an array. */
|
||||
array cumsum(
|
||||
const array& a,
|
||||
bool reverse = false,
|
||||
bool inclusive = true,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Cumulative sum of an array along the given axis. */
|
||||
array cumsum(
|
||||
const array& a,
|
||||
int axis,
|
||||
@@ -1194,6 +1208,13 @@ array cumsum(
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Cumulative product of an array. */
|
||||
array cumprod(
|
||||
const array& a,
|
||||
bool reverse = false,
|
||||
bool inclusive = true,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Cumulative product of an array along the given axis. */
|
||||
array cumprod(
|
||||
const array& a,
|
||||
int axis,
|
||||
@@ -1202,6 +1223,13 @@ array cumprod(
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Cumulative max of an array. */
|
||||
array cummax(
|
||||
const array& a,
|
||||
bool reverse = false,
|
||||
bool inclusive = true,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Cumulative max of an array along the given axis. */
|
||||
array cummax(
|
||||
const array& a,
|
||||
int axis,
|
||||
@@ -1210,6 +1238,13 @@ array cummax(
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Cumulative min of an array. */
|
||||
array cummin(
|
||||
const array& a,
|
||||
bool reverse = false,
|
||||
bool inclusive = true,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Cumulative min of an array along the given axis. */
|
||||
array cummin(
|
||||
const array& a,
|
||||
int axis,
|
||||
|
||||
Reference in New Issue
Block a user