mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
44
mlx/random.h
44
mlx/random.h
@@ -42,12 +42,12 @@ void seed(uint64_t seed);
|
||||
|
||||
/** Generate an array with type uint32 filled with random bits. */
|
||||
array bits(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
int width,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
inline array bits(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return bits(shape, 4, key, s);
|
||||
@@ -63,7 +63,7 @@ array split(const array& key, int num, StreamOrDevice s = {});
|
||||
array uniform(
|
||||
const array& low,
|
||||
const array& high,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype = float32,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
@@ -72,7 +72,7 @@ template <typename T, typename U>
|
||||
array uniform(
|
||||
T low,
|
||||
U high,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype = float32,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
@@ -81,12 +81,12 @@ array uniform(
|
||||
|
||||
/** Generate uniform random numbers between 0 and 1. */
|
||||
array uniform(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
inline array uniform(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return uniform(shape, float32, key);
|
||||
@@ -94,14 +94,14 @@ inline array uniform(
|
||||
|
||||
/** Generate samples from the standard normal distribution. */
|
||||
array normal(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
inline array normal(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
@@ -109,14 +109,14 @@ inline array normal(
|
||||
return normal(shape, float32, loc, scale, key, s);
|
||||
}
|
||||
inline array normal(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const Dtype dtype,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, dtype, 0.0, 1.0, key, s);
|
||||
}
|
||||
inline array normal(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, float32, 0.0, 1.0, key, s);
|
||||
@@ -126,7 +126,7 @@ inline array normal(
|
||||
array multivariate_normal(
|
||||
const array& mean,
|
||||
const array& cov,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
@@ -135,7 +135,7 @@ array multivariate_normal(
|
||||
array randint(
|
||||
const array& low,
|
||||
const array& high,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype = int32,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
@@ -144,7 +144,7 @@ template <typename T, typename U>
|
||||
array randint(
|
||||
T low,
|
||||
U high,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype = int32,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
@@ -154,7 +154,7 @@ array randint(
|
||||
/** Generate binary variables with probability to be true equal to p */
|
||||
array bernoulli(
|
||||
const array& p,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
array bernoulli(
|
||||
@@ -173,7 +173,7 @@ array bernoulli(
|
||||
template <typename T>
|
||||
array bernoulli(
|
||||
T p,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return bernoulli(array(p), shape, key, s);
|
||||
@@ -186,7 +186,7 @@ array bernoulli(
|
||||
array truncated_normal(
|
||||
const array& lower,
|
||||
const array& upper,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype = float32,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
@@ -199,7 +199,7 @@ array truncated_normal(
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array gumbel(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype = float32,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
@@ -207,7 +207,7 @@ array gumbel(
|
||||
array categorical(
|
||||
const array& logits,
|
||||
int axis,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
@@ -226,14 +226,14 @@ array categorical(
|
||||
|
||||
/** Generate samples from the laplace distribution. */
|
||||
array laplace(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
inline array laplace(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
@@ -241,14 +241,14 @@ inline array laplace(
|
||||
return laplace(shape, float32, loc, scale, key, s);
|
||||
}
|
||||
inline array laplace(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const Dtype dtype,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return laplace(shape, dtype, 0.0, 1.0, key, s);
|
||||
}
|
||||
inline array laplace(
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return laplace(shape, float32, 0.0, 1.0, key, s);
|
||||
|
||||
Reference in New Issue
Block a user