mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
parent
f17536af9c
commit
e03f0372b1
@ -25,7 +25,7 @@ bool retain_graph() {
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
||||
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
||||
auto cval = static_cast<complex64_t>(val);
|
||||
init(&cval);
|
||||
}
|
||||
@ -61,14 +61,14 @@ std::vector<array> array::make_arrays(
|
||||
|
||||
array::array(std::initializer_list<float> data)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
Shape{static_cast<ShapeElem>(data.size())},
|
||||
float32)) {
|
||||
init(data.begin());
|
||||
}
|
||||
|
||||
array::array(std::initializer_list<int> data, Dtype dtype)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
Shape{static_cast<ShapeElem>(data.size())},
|
||||
dtype)) {
|
||||
init(data.begin());
|
||||
}
|
||||
@ -322,7 +322,7 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||
}
|
||||
|
||||
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||
auto start = std::vector<int>(arr.ndim(), 0);
|
||||
auto start = Shape(arr.ndim(), 0);
|
||||
auto end = arr.shape();
|
||||
auto shape = arr.shape();
|
||||
shape.erase(shape.begin());
|
||||
|
@ -17,7 +17,8 @@ namespace mlx::core {
|
||||
class Primitive;
|
||||
|
||||
using Deleter = std::function<void(allocator::Buffer)>;
|
||||
using Shape = std::vector<int32_t>;
|
||||
using ShapeElem = int32_t;
|
||||
using Shape = std::vector<ShapeElem>;
|
||||
using Strides = std::vector<int64_t>;
|
||||
|
||||
class array {
|
||||
@ -498,7 +499,7 @@ class array {
|
||||
|
||||
template <typename T>
|
||||
array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
||||
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
||||
init(&val);
|
||||
}
|
||||
|
||||
@ -516,7 +517,7 @@ array::array(
|
||||
std::initializer_list<T> data,
|
||||
Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
Shape{static_cast<ShapeElem>(data.size())},
|
||||
dtype)) {
|
||||
init(data.begin());
|
||||
}
|
||||
|
@ -130,7 +130,7 @@ std::string build_lib_name(
|
||||
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape) {
|
||||
const Shape& shape) {
|
||||
bool contiguous = true;
|
||||
bool all_contig = true;
|
||||
bool all_row_contig = true;
|
||||
|
@ -56,7 +56,7 @@ inline bool is_scalar(const array& x) {
|
||||
// Check if we can use a contiguous operation given inputs and the output shape
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape);
|
||||
const Shape& shape);
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
void compiled_allocate_outputs(
|
||||
|
@ -726,7 +726,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
auto conv_dtype = float32;
|
||||
|
||||
// Pad input
|
||||
std::vector<int> padded_shape = {N, iH + 2 * padding[0], C};
|
||||
Shape padded_shape = {N, iH + 2 * padding[0], C};
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
@ -765,7 +765,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {N * oH, wH * C};
|
||||
Shape strided_reshape = {N * oH, wH * C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy(in_strided_view, in_strided, CopyType::General);
|
||||
|
||||
@ -843,8 +843,7 @@ void explicit_gemm_conv_2D_cpu(
|
||||
auto conv_dtype = out.dtype();
|
||||
|
||||
// Pad input
|
||||
std::vector<int> padded_shape = {
|
||||
N, iH + 2 * padding[0], iW + 2 * padding[1], C};
|
||||
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
@ -881,7 +880,7 @@ void explicit_gemm_conv_2D_cpu(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {N * oH * oW, wH * wW * C};
|
||||
Shape strided_reshape = {N * oH * oW, wH * wW * C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy(in_strided_view, in_strided, CopyType::General);
|
||||
|
||||
@ -934,19 +933,19 @@ void explicit_gemm_conv_ND_cpu(
|
||||
const std::vector<int>& wt_dilation,
|
||||
const bool flip) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const auto iDim = std::vector<int>(
|
||||
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
|
||||
const auto oDim = std::vector<int>(
|
||||
const auto iDim =
|
||||
Shape(in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
|
||||
const auto oDim = Shape(
|
||||
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(-1); // In channels
|
||||
const auto wDim = std::vector<int>(
|
||||
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
|
||||
const auto wDim =
|
||||
Shape(wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
|
||||
|
||||
auto conv_dtype = float32;
|
||||
|
||||
// Pad input
|
||||
std::vector<int> padded_shape(in.shape().size());
|
||||
Shape padded_shape(in.shape().size());
|
||||
padded_shape.front() = N;
|
||||
for (size_t i = 0; i < iDim.size(); i++) {
|
||||
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
|
||||
|
@ -14,10 +14,10 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename IdxT = int32_t>
|
||||
template <typename T>
|
||||
struct StridedIterator {
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
using difference_type = IdxT;
|
||||
using difference_type = int32_t;
|
||||
using value_type = T;
|
||||
using reference = value_type&;
|
||||
using pointer = value_type*;
|
||||
|
@ -107,7 +107,7 @@ struct ContiguousIterator {
|
||||
: shape_(a.shape()), strides_(a.strides()) {
|
||||
if (!shape_.empty()) {
|
||||
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
||||
pos_ = std::vector<int>(shape_.size(), 0);
|
||||
pos_ = Shape(shape_.size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -34,7 +34,7 @@ void explicit_gemm_conv_ND_gpu(
|
||||
int implicit_K = wt.size() / conv_params.O;
|
||||
int implicit_N = conv_params.O;
|
||||
// Prepare unfolding array
|
||||
std::vector<int> unfolded_shape{implicit_M, implicit_K};
|
||||
Shape unfolded_shape{implicit_M, implicit_K};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
@ -113,7 +113,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
}
|
||||
|
||||
// Prepare unfolding array
|
||||
std::vector<int> unfolded_shape{implicit_M, implicit_K * groups};
|
||||
Shape unfolded_shape{implicit_M, implicit_K * groups};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
|
||||
@ -192,12 +192,12 @@ void conv_1D_gpu(
|
||||
bool flip) {
|
||||
// Make conv params
|
||||
MLXConvParams<1> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
/* const int C = */ in.shape(2),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in.shape(1)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1)},
|
||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||
/* const int C = */ static_cast<int>(in.shape(2)),
|
||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
|
||||
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
|
||||
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
|
||||
/* const int str[NDIM] = */ {wt_strides[0]},
|
||||
/* const int pad[NDIM] = */ {padding[0]},
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0]},
|
||||
@ -541,7 +541,7 @@ void winograd_conv_2D_gpu(
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params,
|
||||
std::vector<array>& copies_w) {
|
||||
std::vector<int> padded_shape = {
|
||||
Shape padded_shape = {
|
||||
conv_params.N,
|
||||
conv_params.iS[0] + 2 * conv_params.pad[0],
|
||||
conv_params.iS[1] + 2 * conv_params.pad[1],
|
||||
@ -550,7 +550,7 @@ void winograd_conv_2D_gpu(
|
||||
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
|
||||
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
|
||||
|
||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||
array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
array zero_arr = array(0, in.dtype());
|
||||
@ -575,12 +575,16 @@ void winograd_conv_2D_gpu(
|
||||
copies_w.push_back(in_padded);
|
||||
|
||||
MLXConvParams<2> conv_params_updated{
|
||||
/* const int N = */ in_padded.shape(0),
|
||||
/* const int C = */ in_padded.shape(3),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in_padded.shape(1), in_padded.shape(2)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
||||
/* const int N = */ static_cast<int>(in_padded.shape(0)),
|
||||
/* const int C = */ static_cast<int>(in_padded.shape(3)),
|
||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||
/* const int iS[NDIM] = */
|
||||
{static_cast<int>(in_padded.shape(1)),
|
||||
static_cast<int>(in_padded.shape(2))},
|
||||
/* const int wS[NDIM] = */
|
||||
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
|
||||
/* const int oS[NDIM] = */
|
||||
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
|
||||
/* const int str[NDIM] = */ {1, 1},
|
||||
/* const int pad[NDIM] = */ {0, 0},
|
||||
/* const int kdil[NDIM] = */ {1, 1},
|
||||
@ -607,8 +611,8 @@ void winograd_conv_2D_gpu(
|
||||
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
|
||||
|
||||
// Do filter transform
|
||||
std::vector<int> filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
||||
array filt_wg(filt_wg_shape, wt.dtype(), nullptr, {});
|
||||
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
||||
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
|
||||
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
|
||||
copies_w.push_back(filt_wg);
|
||||
{
|
||||
@ -634,8 +638,8 @@ void winograd_conv_2D_gpu(
|
||||
}
|
||||
|
||||
// Do input transform
|
||||
std::vector<int> inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
||||
array inp_wg(inp_wg_shape, in.dtype(), nullptr, {});
|
||||
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
||||
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
|
||||
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
|
||||
copies_w.push_back(inp_wg);
|
||||
{
|
||||
@ -661,8 +665,8 @@ void winograd_conv_2D_gpu(
|
||||
}
|
||||
|
||||
// Do batched gemm
|
||||
std::vector<int> out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
||||
array out_wg(out_wg_shape, in.dtype(), nullptr, {});
|
||||
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
||||
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
|
||||
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
|
||||
copies_w.push_back(out_wg);
|
||||
{
|
||||
@ -723,12 +727,15 @@ void conv_2D_gpu(
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
MLXConvParams<2> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
/* const int C = */ in.shape(3),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||
/* const int C = */ static_cast<int>(in.shape(3)),
|
||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||
/* const int iS[NDIM] = */
|
||||
{static_cast<int>(in.shape(1)), static_cast<int>(in.shape(2))},
|
||||
/* const int wS[NDIM] = */
|
||||
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
|
||||
/* const int oS[NDIM] = */
|
||||
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
|
||||
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
|
||||
/* const int pad[NDIM] = */ {padding[0], padding[1]},
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
||||
@ -800,12 +807,21 @@ void conv_3D_gpu(
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
MLXConvParams<3> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
/* const int C = */ in.shape(4),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)},
|
||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||
/* const int C = */ static_cast<int>(in.shape(4)),
|
||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||
/* const int iS[NDIM] = */
|
||||
{static_cast<int>(in.shape(1)),
|
||||
static_cast<int>(in.shape(2)),
|
||||
static_cast<int>(in.shape(3))},
|
||||
/* const int wS[NDIM] = */
|
||||
{static_cast<int>(wt.shape(1)),
|
||||
static_cast<int>(wt.shape(2)),
|
||||
static_cast<int>(wt.shape(3))},
|
||||
/* const int oS[NDIM] = */
|
||||
{static_cast<int>(out.shape(1)),
|
||||
static_cast<int>(out.shape(2)),
|
||||
static_cast<int>(out.shape(3))},
|
||||
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
|
||||
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
|
||||
/* const int kdil[NDIM] = */
|
||||
|
@ -635,7 +635,7 @@ void strided_reduce_longcolumn(
|
||||
}
|
||||
|
||||
// Prepare the temporary accumulator
|
||||
std::vector<int> intermediate_shape;
|
||||
Shape intermediate_shape;
|
||||
intermediate_shape.reserve(out.ndim() + 1);
|
||||
intermediate_shape.push_back(outer_blocks);
|
||||
intermediate_shape.insert(
|
||||
@ -806,7 +806,7 @@ void strided_reduce_2pass(
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
|
||||
// Prepare the temporary accumulator
|
||||
std::vector<int> intermediate_shape;
|
||||
Shape intermediate_shape;
|
||||
intermediate_shape.reserve(out.ndim() + 1);
|
||||
intermediate_shape.push_back(32);
|
||||
intermediate_shape.insert(
|
||||
|
@ -63,8 +63,8 @@ void pad_gpu(
|
||||
const array& in,
|
||||
const array& val,
|
||||
array& out,
|
||||
std::vector<int> axes,
|
||||
std::vector<int> low_pad_size,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& low_pad_size,
|
||||
const Stream& s) {
|
||||
// Fill output with val
|
||||
fill_gpu(val, out, s);
|
||||
|
@ -23,8 +23,8 @@ void pad_gpu(
|
||||
const array& in,
|
||||
const array& val,
|
||||
array& out,
|
||||
std::vector<int> axes,
|
||||
std::vector<int> low_pad_size,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& low_pad_size,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -82,7 +82,7 @@ array send(
|
||||
}
|
||||
|
||||
array recv(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
int src,
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
|
@ -26,7 +26,7 @@ array send(
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array recv(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
int src,
|
||||
std::optional<Group> group = std::nullopt,
|
||||
|
@ -91,7 +91,7 @@ std::vector<array> AllGather::vjp(
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
auto g = group();
|
||||
std::vector<int> starts(primals[0].ndim(), 0);
|
||||
Shape starts(primals[0].ndim(), 0);
|
||||
auto stops = primals[0].shape();
|
||||
starts[0] = g.rank() * stops[0];
|
||||
stops[0] += starts[0];
|
||||
|
@ -108,7 +108,7 @@ bool disjoint(const CharSet& x, const CharSet& y) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t term_size(const T& term, std::unordered_map<char, int> dict) {
|
||||
size_t term_size(const T& term, std::unordered_map<char, ShapeElem> dict) {
|
||||
size_t size = 1;
|
||||
for (auto c : term) {
|
||||
size *= dict[c];
|
||||
@ -120,7 +120,7 @@ size_t flop_count(
|
||||
const CharSet& term,
|
||||
bool inner,
|
||||
int num_terms,
|
||||
std::unordered_map<char, int> dict) {
|
||||
std::unordered_map<char, ShapeElem> dict) {
|
||||
size_t size = term_size(term, dict);
|
||||
auto op_factor = 1;
|
||||
if ((num_terms - 1) > op_factor) {
|
||||
@ -135,7 +135,7 @@ size_t flop_count(
|
||||
std::pair<size_t, int> compute_cost_and_scaling(
|
||||
const std::vector<Subscript>& inputs,
|
||||
const Subscript& output,
|
||||
std::unordered_map<char, int> dim_map) {
|
||||
std::unordered_map<char, ShapeElem> dim_map) {
|
||||
CharSet contractions;
|
||||
for (auto& in : inputs) {
|
||||
contractions.insert(in.set.begin(), in.set.end());
|
||||
@ -155,7 +155,7 @@ std::pair<size_t, int> compute_cost_and_scaling(
|
||||
std::tuple<std::vector<PathNode>, size_t, int> greedy_path(
|
||||
std::vector<Subscript> inputs,
|
||||
const Subscript& output,
|
||||
std::unordered_map<char, int> dim_map,
|
||||
std::unordered_map<char, ShapeElem> dim_map,
|
||||
size_t cost_limit,
|
||||
size_t memory_limit) {
|
||||
// Helper struct for building the greedy path
|
||||
@ -457,7 +457,8 @@ array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) {
|
||||
}
|
||||
Shape idx_shape(n_expand--, 1);
|
||||
idx_shape[0] = in.shape(axes.back());
|
||||
auto idx = reshape(arange(in.shape(axes.back()), s), idx_shape, s);
|
||||
auto idx = reshape(
|
||||
arange(static_cast<ShapeElem>(in.shape(axes.back())), s), idx_shape, s);
|
||||
for (int i = 0; i < v; ++i) {
|
||||
indices.push_back(idx);
|
||||
}
|
||||
@ -663,7 +664,7 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
|
||||
}
|
||||
Subscript output(out_subscript, std::move(out_set));
|
||||
|
||||
std::unordered_map<char, int> dim_map;
|
||||
std::unordered_map<char, ShapeElem> dim_map;
|
||||
std::vector<Subscript> inputs;
|
||||
for (int i = 0; i < in_subscripts.size(); ++i) {
|
||||
auto& in = in_subscripts[i];
|
||||
@ -680,7 +681,7 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
|
||||
|
||||
// Check repeat subscripts are valid
|
||||
if (in_set.size() < in.size()) {
|
||||
std::unordered_map<char, int> local_dims;
|
||||
std::unordered_map<char, ShapeElem> local_dims;
|
||||
for (int j = 0; j < in.size(); ++j) {
|
||||
auto dim = operands[i].shape(j);
|
||||
auto inserted = local_dims.insert({in[j], dim});
|
||||
|
@ -670,8 +670,7 @@ array scaled_dot_product_attention(
|
||||
supports_sdpa_full || supports_sdpa_vector;
|
||||
|
||||
if (implementation_supports_use_case) {
|
||||
auto out_shape =
|
||||
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
|
||||
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
final_type,
|
||||
|
@ -59,7 +59,7 @@ typedef std::variant<int, bool, Dtype> TemplateArg;
|
||||
|
||||
typedef std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<std::vector<int>>&,
|
||||
const std::vector<Shape>&,
|
||||
const std::vector<Dtype>&,
|
||||
std::tuple<int, int, int>,
|
||||
std::tuple<int, int, int>,
|
||||
|
@ -47,8 +47,8 @@ std::optional<Dtype> gguf_type_to_dtype(const uint32_t& gguf_type) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> get_shape(const gguf_tensor& tensor) {
|
||||
std::vector<int> shape;
|
||||
Shape get_shape(const gguf_tensor& tensor) {
|
||||
Shape shape;
|
||||
// The dimension order in GGML is the reverse of the order used in MLX.
|
||||
for (int i = tensor.ndim - 1; i >= 0; i--) {
|
||||
shape.push_back(tensor.dim[i]);
|
||||
|
@ -12,7 +12,7 @@ extern "C" {
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::vector<int> get_shape(const gguf_tensor& tensor);
|
||||
Shape get_shape(const gguf_tensor& tensor);
|
||||
void gguf_load_quantized(
|
||||
std::unordered_map<std::string, array>& a,
|
||||
const gguf_tensor& tensor);
|
||||
|
@ -109,7 +109,7 @@ void gguf_load_quantized(
|
||||
|
||||
std::string name(tensor.name, tensor.namelen);
|
||||
|
||||
std::vector<int> shape = get_shape(tensor);
|
||||
auto shape = get_shape(tensor);
|
||||
const uint64_t weights_per_block = 32;
|
||||
if (shape[shape.size() - 1] % weights_per_block != 0) {
|
||||
std::ostringstream msg;
|
||||
@ -118,7 +118,7 @@ void gguf_load_quantized(
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
std::vector<int> weights_shape = shape;
|
||||
auto weights_shape = shape;
|
||||
weights_shape.back() /= (weights_per_byte * 4);
|
||||
auto w_nbytes = uint32.size() *
|
||||
std::accumulate(weights_shape.begin(),
|
||||
|
@ -271,7 +271,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
||||
bool col_contiguous = header[34] == 'T';
|
||||
|
||||
// Read array shape from header
|
||||
std::vector<int> shape;
|
||||
Shape shape;
|
||||
|
||||
size_t st = header.find_last_of('(') + 1;
|
||||
size_t ed = header.find_last_of(')');
|
||||
|
@ -219,15 +219,15 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
||||
const auto n = a.shape(-1);
|
||||
const auto rank = a.ndim();
|
||||
|
||||
std::vector<int> u_shape = a.shape();
|
||||
auto u_shape = a.shape();
|
||||
u_shape[rank - 2] = m;
|
||||
u_shape[rank - 1] = m;
|
||||
|
||||
std::vector<int> s_shape = a.shape();
|
||||
auto s_shape = a.shape();
|
||||
s_shape.pop_back();
|
||||
s_shape[rank - 2] = std::min(m, n);
|
||||
|
||||
std::vector<int> vt_shape = a.shape();
|
||||
auto vt_shape = a.shape();
|
||||
vt_shape[rank - 2] = n;
|
||||
vt_shape[rank - 1] = n;
|
||||
|
||||
@ -328,8 +328,8 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) {
|
||||
array S = outs[1];
|
||||
array V = outs[2];
|
||||
|
||||
std::vector<int> starts(a.ndim(), 0);
|
||||
std::vector<int> ends = a.shape();
|
||||
Shape starts(a.ndim(), 0);
|
||||
auto ends = a.shape();
|
||||
int i = a.ndim() - 2;
|
||||
int j = a.ndim() - 1;
|
||||
|
||||
@ -479,7 +479,7 @@ array eigvalsh(
|
||||
std::string UPLO /* = "L" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_eigh(a, "[linalg::eigvalsh]");
|
||||
std::vector<int> out_shape(a.shape().begin(), a.shape().end() - 1);
|
||||
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
@ -493,7 +493,7 @@ std::pair<array, array> eigh(
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_eigh(a, "[linalg::eigh]");
|
||||
auto out = array::make_arrays(
|
||||
{std::vector<int>(a.shape().begin(), a.shape().end() - 1), a.shape()},
|
||||
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
|
||||
{a.dtype(), a.dtype()},
|
||||
std::make_shared<Eigh>(to_stream(s), UPLO, true),
|
||||
{a});
|
||||
|
42
mlx/ops.cpp
42
mlx/ops.cpp
@ -649,7 +649,7 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
|
||||
|
||||
// Clamp to bounds
|
||||
auto st = std::min(s, n - 1);
|
||||
auto ed = std::max(-1, e);
|
||||
auto ed = e > -1 ? e : -1;
|
||||
|
||||
start[i] = st;
|
||||
stop[i] = ed > st ? st : ed;
|
||||
@ -659,8 +659,8 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
|
||||
|
||||
} else {
|
||||
// Clamp to bounds
|
||||
auto st = std::max(0, std::min(s, n));
|
||||
auto ed = std::max(0, std::min(e, n));
|
||||
auto st = std::max(static_cast<ShapeElem>(0), std::min(s, n));
|
||||
auto ed = std::max(static_cast<ShapeElem>(0), std::min(e, n));
|
||||
|
||||
start[i] = st;
|
||||
stop[i] = ed < st ? st : ed;
|
||||
@ -765,7 +765,7 @@ array slice_update(
|
||||
|
||||
std::vector<array> split(
|
||||
const array& a,
|
||||
const std::vector<int>& indices,
|
||||
const Shape& indices,
|
||||
int axis,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto ax = axis < 0 ? axis + a.ndim() : axis;
|
||||
@ -809,10 +809,8 @@ std::vector<array> split(
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<array> split(
|
||||
const array& a,
|
||||
const std::vector<int>& indices,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<array>
|
||||
split(const array& a, const Shape& indices, StreamOrDevice s /* = {} */) {
|
||||
return split(a, indices, 0, s);
|
||||
}
|
||||
|
||||
@ -834,7 +832,7 @@ split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) {
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto split_size = q_and_r.quot;
|
||||
std::vector<int> indices(num_splits - 1);
|
||||
Shape indices(num_splits - 1);
|
||||
for (int i = 0; i < indices.size(); ++i) {
|
||||
indices[i] = (i + 1) * split_size;
|
||||
}
|
||||
@ -1104,7 +1102,7 @@ array edge_pad(
|
||||
/** Pad an array with a constant value */
|
||||
array pad(
|
||||
const array& a,
|
||||
const Shape& axes,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& low_pad_size,
|
||||
const Shape& high_pad_size,
|
||||
const array& pad_value /*= array(0)*/,
|
||||
@ -1904,9 +1902,11 @@ array min(
|
||||
|
||||
array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
|
||||
int size = a.size();
|
||||
auto result = argmin(reshape(a, {size}, s), 0, true, s);
|
||||
auto result = argmin(flatten(a, s), 0, true, s);
|
||||
if (keepdims) {
|
||||
result = reshape(result, std::vector<int>(a.shape().size(), 1), s);
|
||||
std::vector<int> axes(a.ndim() - 1);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
result = expand_dims(result, axes, s);
|
||||
} else {
|
||||
result = squeeze(result, s);
|
||||
}
|
||||
@ -1940,9 +1940,11 @@ array argmin(
|
||||
|
||||
array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
|
||||
int size = a.size();
|
||||
auto result = argmax(reshape(a, {size}, s), 0, true, s);
|
||||
auto result = argmax(flatten(a, s), 0, true, s);
|
||||
if (keepdims) {
|
||||
result = reshape(result, Shape(a.shape().size(), 1), s);
|
||||
std::vector<int> axes(a.ndim() - 1);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
result = expand_dims(result, axes, s);
|
||||
} else {
|
||||
result = squeeze(result, s);
|
||||
}
|
||||
@ -3238,8 +3240,8 @@ inline int dilate_size(int dim, int dil) {
|
||||
}
|
||||
|
||||
Shape conv_out_shape(
|
||||
const std::vector<int>& in_shape,
|
||||
const std::vector<int>& wt_shape,
|
||||
const Shape& in_shape,
|
||||
const Shape& wt_shape,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& pads_lo,
|
||||
const std::vector<int>& pads_hi,
|
||||
@ -4329,16 +4331,16 @@ array diagonal(
|
||||
"[diagonal] axis1 and axis2 cannot be the same axis");
|
||||
}
|
||||
|
||||
auto off1 = std::max(-offset, 0);
|
||||
auto off2 = std::max(offset, 0);
|
||||
ShapeElem off1 = std::max(-offset, 0);
|
||||
ShapeElem off2 = std::max(offset, 0);
|
||||
|
||||
auto diag_size = std::min(a.shape(ax1) - off1, a.shape(ax2) - off2);
|
||||
diag_size = std::max(diag_size, 0);
|
||||
diag_size = diag_size < 0 ? 0 : diag_size;
|
||||
|
||||
std::vector<array> indices = {
|
||||
arange(off1, off1 + diag_size, s), arange(off2, off2 + diag_size, s)};
|
||||
|
||||
std::vector<int> slice_sizes = a.shape();
|
||||
Shape slice_sizes = a.shape();
|
||||
slice_sizes[ax1] = 1;
|
||||
slice_sizes[ax2] = 1;
|
||||
|
||||
|
19
mlx/ops.h
19
mlx/ops.h
@ -189,13 +189,10 @@ array slice_update(
|
||||
std::vector<array>
|
||||
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
|
||||
std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
|
||||
std::vector<array> split(
|
||||
const array& a,
|
||||
const std::vector<int>& indices,
|
||||
int axis,
|
||||
StreamOrDevice s = {});
|
||||
std::vector<array>
|
||||
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
|
||||
split(const array& a, const Shape& indices, int axis, StreamOrDevice s = {});
|
||||
std::vector<array>
|
||||
split(const array& a, const Shape& indices, StreamOrDevice s = {});
|
||||
|
||||
/** A vector of coordinate arrays from coordinate vectors. */
|
||||
std::vector<array> meshgrid(
|
||||
@ -253,8 +250,8 @@ array moveaxis(
|
||||
array pad(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& low_pad_size,
|
||||
const std::vector<int>& high_pad_size,
|
||||
const Shape& low_pad_size,
|
||||
const Shape& high_pad_size,
|
||||
const array& pad_value = array(0),
|
||||
const std::string mode = "constant",
|
||||
StreamOrDevice s = {});
|
||||
@ -1453,7 +1450,11 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
|
||||
array roll(const array& a, int shift, StreamOrDevice s = {});
|
||||
array roll(const array& a, const Shape& shift, StreamOrDevice s = {});
|
||||
array roll(const array& a, int shift, int axis, StreamOrDevice s = {});
|
||||
array roll(const array& a, int shift, const Shape& axes, StreamOrDevice s = {});
|
||||
array roll(
|
||||
const array& a,
|
||||
int shift,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {});
|
||||
array roll(
|
||||
const array& a,
|
||||
|
@ -817,10 +817,10 @@ std::vector<array> Concatenate::vjp(
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
auto& cotan = cotangents[0];
|
||||
std::vector<int> start(cotan.ndim(), 0);
|
||||
std::vector<int> stop = cotan.shape();
|
||||
Shape start(cotan.ndim(), 0);
|
||||
Shape stop = cotan.shape();
|
||||
|
||||
std::vector<int> sizes;
|
||||
Shape sizes;
|
||||
sizes.push_back(0);
|
||||
for (auto& p : primals) {
|
||||
sizes.push_back(p.shape(axis_));
|
||||
@ -956,9 +956,9 @@ array conv_weight_backward_patches(
|
||||
const std::vector<int>& padding,
|
||||
StreamOrDevice s) {
|
||||
// Resolve Padded input shapes and strides
|
||||
std::vector<int> padding_starts(in.ndim(), 0);
|
||||
std::vector<int> padding_ends = in.shape();
|
||||
std::vector<int> in_padded_shape = in.shape();
|
||||
Shape padding_starts(in.ndim(), 0);
|
||||
auto padding_ends = in.shape();
|
||||
auto in_padded_shape = in.shape();
|
||||
|
||||
// padded shape
|
||||
for (int i = 1; i < in.ndim() - 1; i++) {
|
||||
@ -976,8 +976,9 @@ array conv_weight_backward_patches(
|
||||
// Pad input
|
||||
std::vector<int> padded_axes(in.ndim() - 2, 0);
|
||||
std::iota(padded_axes.begin(), padded_axes.end(), 1);
|
||||
Shape padding_(padding.begin(), padding.end());
|
||||
auto in_padded = pad(
|
||||
in, padded_axes, padding, padding, array(0, in.dtype()), "constant", s);
|
||||
in, padded_axes, padding_, padding_, array(0, in.dtype()), "constant", s);
|
||||
|
||||
// Resolve strided patches
|
||||
|
||||
@ -1797,7 +1798,7 @@ std::vector<array> FFT::vjp(
|
||||
std::vector<int> axes(axes_.begin(), axes_.end());
|
||||
if (real_ && inverse_) {
|
||||
auto out = fft::fftn(cotangents[0], axes, stream());
|
||||
auto start = std::vector<int>(out.ndim(), 0);
|
||||
auto start = Shape(out.ndim(), 0);
|
||||
auto stop = in.shape();
|
||||
out = slice(out, start, stop, stream());
|
||||
auto mask_shape = out.shape();
|
||||
@ -1809,7 +1810,7 @@ std::vector<array> FFT::vjp(
|
||||
mask = concatenate({pad, mask, pad}, axes_.back(), stream());
|
||||
return {multiply(mask, out, stream())};
|
||||
} else if (real_) {
|
||||
std::vector<int> n;
|
||||
Shape n;
|
||||
for (auto ax : axes_) {
|
||||
n.push_back(in.shape()[ax]);
|
||||
}
|
||||
@ -1934,10 +1935,11 @@ std::pair<std::vector<array>, std::vector<int>> Gather::vmap(
|
||||
}
|
||||
if (indices_vmapped) {
|
||||
// Make a new index array for the vmapped dimension
|
||||
auto vmap_inds = arange(0, src.shape(axes[0]), stream());
|
||||
auto vmap_inds =
|
||||
arange(static_cast<ShapeElem>(0), src.shape(axes[0]), stream());
|
||||
// Reshape it so it broadcasts with other index arrays
|
||||
{
|
||||
auto shape = std::vector<int>(idx_dims, 1);
|
||||
auto shape = Shape(idx_dims, 1);
|
||||
shape[out_ax] = vmap_inds.size();
|
||||
vmap_inds = reshape(vmap_inds, std::move(shape), stream());
|
||||
}
|
||||
@ -2628,8 +2630,8 @@ std::vector<array> Pad::vjp(
|
||||
assert(argnums.size() == 1 && argnums[0] == 0);
|
||||
|
||||
auto& cotan = cotangents[0];
|
||||
std::vector<int> start(cotan.ndim(), 0);
|
||||
std::vector<int> stop = cotan.shape();
|
||||
Shape start(cotan.ndim(), 0);
|
||||
auto stop = cotan.shape();
|
||||
|
||||
for (auto i : axes_) {
|
||||
start[i] = low_pad_size_[i];
|
||||
@ -3019,7 +3021,7 @@ std::vector<array> Reduce::vjp(
|
||||
const std::vector<array>& outputs) {
|
||||
auto in = primals[0];
|
||||
|
||||
std::vector<int> shape = in.shape();
|
||||
auto shape = in.shape();
|
||||
for (auto ax : axes_) {
|
||||
shape[ax] = 1;
|
||||
}
|
||||
@ -3044,7 +3046,7 @@ std::vector<array> Reduce::vjp(
|
||||
if (axes_.size() > 1) {
|
||||
std::vector<int> transpose_to;
|
||||
std::vector<int> transpose_back;
|
||||
std::vector<int> shape_flat;
|
||||
Shape shape_flat;
|
||||
{
|
||||
// Find the transpose needed to move axes_ to the back and the shape
|
||||
// except the reduced over axes.
|
||||
@ -3422,7 +3424,7 @@ std::pair<std::vector<array>, std::vector<int>> Scatter::vmap(
|
||||
}
|
||||
|
||||
auto vmap_inds = arange(vmap_size, inputs[1].dtype(), stream());
|
||||
auto vmap_inds_shape = std::vector<int>(inputs[1].ndim(), 1);
|
||||
auto vmap_inds_shape = Shape(inputs[1].ndim(), 1);
|
||||
vmap_inds_shape[0] = vmap_inds.size();
|
||||
vmap_inds = reshape(vmap_inds, std::move(vmap_inds_shape), stream());
|
||||
inputs.insert(
|
||||
@ -3607,7 +3609,7 @@ std::vector<array> Slice::vjp(
|
||||
// Transpose and reshape cotangents
|
||||
auto cotan = cotangents[0];
|
||||
if (!ind_axes.empty()) {
|
||||
std::vector<int> cotan_shape;
|
||||
Shape cotan_shape;
|
||||
for (auto ax : ind_axes) {
|
||||
cotan_shape.push_back(cotan.shape(ax));
|
||||
}
|
||||
@ -3626,7 +3628,7 @@ std::vector<array> Slice::vjp(
|
||||
}
|
||||
|
||||
// Make indices broadcastable
|
||||
std::vector<int> inds_shape(inds.size(), 1);
|
||||
Shape inds_shape(inds.size(), 1);
|
||||
for (int i = 0; i < inds.size(); ++i) {
|
||||
inds_shape[i] = inds[i].size();
|
||||
inds[i] = reshape(inds[i], inds_shape, stream());
|
||||
@ -4184,7 +4186,7 @@ std::vector<array> BlockMaskedMM::vjp(
|
||||
// Slice mask
|
||||
mask_reshape[mask_ndim - 2] = Y;
|
||||
mask_reshape[mask_ndim - 1] = X;
|
||||
mask = slice(mask, std::vector<int>(mask_ndim, 0), mask_reshape, stream());
|
||||
mask = slice(mask, Shape(mask_ndim, 0), mask_reshape, stream());
|
||||
|
||||
return mask;
|
||||
};
|
||||
@ -4202,7 +4204,7 @@ std::vector<array> BlockMaskedMM::vjp(
|
||||
}
|
||||
|
||||
// Reshape
|
||||
std::vector<int> r_reshape(r.shape().begin(), r.shape().end() - 2);
|
||||
Shape r_reshape(r.shape().begin(), r.shape().end() - 2);
|
||||
r_reshape.push_back(r.shape(-2) / block_size_);
|
||||
r_reshape.push_back(block_size_);
|
||||
r_reshape.push_back(r.shape(-1) / block_size_);
|
||||
@ -4492,7 +4494,7 @@ std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
|
||||
}
|
||||
|
||||
array out = array(
|
||||
std::vector<int>{},
|
||||
{},
|
||||
dtype_,
|
||||
std::make_shared<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
|
||||
inputs);
|
||||
|
@ -1088,10 +1088,7 @@ class Full : public UnaryPrimitive {
|
||||
|
||||
class Gather : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Gather(
|
||||
Stream stream,
|
||||
std::vector<int> axes,
|
||||
std::vector<int> slice_sizes)
|
||||
explicit Gather(Stream stream, std::vector<int> axes, Shape slice_sizes)
|
||||
: UnaryPrimitive(stream),
|
||||
axes_(std::move(axes)),
|
||||
slice_sizes_(std::move(slice_sizes)) {}
|
||||
@ -1108,7 +1105,7 @@ class Gather : public UnaryPrimitive {
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
std::vector<int> axes_;
|
||||
std::vector<int> slice_sizes_;
|
||||
Shape slice_sizes_;
|
||||
};
|
||||
|
||||
class Greater : public UnaryPrimitive {
|
||||
@ -1503,8 +1500,8 @@ class Pad : public UnaryPrimitive {
|
||||
explicit Pad(
|
||||
Stream stream,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& low_pad_size,
|
||||
const std::vector<int>& high_pad_size)
|
||||
const Shape& low_pad_size,
|
||||
const Shape& high_pad_size)
|
||||
: UnaryPrimitive(stream),
|
||||
axes_(axes),
|
||||
low_pad_size_(low_pad_size),
|
||||
@ -1520,8 +1517,8 @@ class Pad : public UnaryPrimitive {
|
||||
|
||||
private:
|
||||
std::vector<int> axes_;
|
||||
std::vector<int> low_pad_size_;
|
||||
std::vector<int> high_pad_size_;
|
||||
Shape low_pad_size_;
|
||||
Shape high_pad_size_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
@ -1903,9 +1900,9 @@ class Slice : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Slice(
|
||||
Stream stream,
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& end_indices,
|
||||
const std::vector<int>& strides)
|
||||
const Shape& start_indices,
|
||||
const Shape& end_indices,
|
||||
const Shape& strides)
|
||||
: UnaryPrimitive(stream),
|
||||
start_indices_(start_indices),
|
||||
end_indices_(end_indices),
|
||||
@ -1920,9 +1917,9 @@ class Slice : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::vector<int> start_indices_;
|
||||
std::vector<int> end_indices_;
|
||||
std::vector<int> strides_;
|
||||
Shape start_indices_;
|
||||
Shape end_indices_;
|
||||
Shape strides_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
@ -1931,9 +1928,9 @@ class SliceUpdate : public UnaryPrimitive {
|
||||
public:
|
||||
explicit SliceUpdate(
|
||||
Stream stream,
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& end_indices,
|
||||
const std::vector<int>& strides)
|
||||
const Shape& start_indices,
|
||||
const Shape& end_indices,
|
||||
const Shape& strides)
|
||||
: UnaryPrimitive(stream),
|
||||
start_indices_(start_indices),
|
||||
end_indices_(end_indices),
|
||||
@ -1948,9 +1945,9 @@ class SliceUpdate : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::vector<int> start_indices_;
|
||||
std::vector<int> end_indices_;
|
||||
std::vector<int> strides_;
|
||||
Shape start_indices_;
|
||||
Shape end_indices_;
|
||||
Shape strides_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
@ -1997,7 +1994,7 @@ class Sort : public UnaryPrimitive {
|
||||
|
||||
class Split : public Primitive {
|
||||
public:
|
||||
explicit Split(Stream stream, const std::vector<int>& indices, int axis)
|
||||
explicit Split(Stream stream, const Shape& indices, int axis)
|
||||
: Primitive(stream), indices_(indices), axis_(axis) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
@ -2013,7 +2010,7 @@ class Split : public Primitive {
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
|
||||
std::vector<int> indices_;
|
||||
Shape indices_;
|
||||
int axis_;
|
||||
};
|
||||
|
||||
|
@ -296,7 +296,7 @@ std::ostream& operator<<(std::ostream& os, array a) {
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Shape& v) {
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
|
||||
os << "(";
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||
@ -305,7 +305,7 @@ std::ostream& operator<<(std::ostream& os, const Shape& v) {
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Strides& v) {
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
|
||||
os << "(";
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||
|
@ -77,8 +77,8 @@ std::ostream& operator<<(std::ostream& os, const Stream& s);
|
||||
std::ostream& operator<<(std::ostream& os, const Dtype& d);
|
||||
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
|
||||
std::ostream& operator<<(std::ostream& os, array a);
|
||||
std::ostream& operator<<(std::ostream& os, const Shape& v);
|
||||
std::ostream& operator<<(std::ostream& os, const Strides& v);
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
|
||||
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
|
||||
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
|
||||
}
|
||||
|
@ -889,13 +889,13 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"reshape",
|
||||
[](const mx::array& a, nb::args shape_, mx::StreamOrDevice s) {
|
||||
std::vector<int> shape;
|
||||
mx::Shape shape;
|
||||
if (!nb::isinstance<int>(shape_[0])) {
|
||||
shape = nb::cast<std::vector<int>>(shape_[0]);
|
||||
shape = nb::cast<mx::Shape>(shape_[0]);
|
||||
} else {
|
||||
shape = nb::cast<std::vector<int>>(shape_);
|
||||
shape = nb::cast<mx::Shape>(shape_);
|
||||
}
|
||||
return mx::reshape(a, shape, s);
|
||||
return mx::reshape(a, std::move(shape), s);
|
||||
},
|
||||
"shape"_a,
|
||||
"stream"_a = nb::none(),
|
||||
@ -1182,14 +1182,14 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"split",
|
||||
[](const mx::array& a,
|
||||
const std::variant<int, std::vector<int>>& indices_or_sections,
|
||||
const std::variant<int, mx::Shape>& indices_or_sections,
|
||||
int axis,
|
||||
mx::StreamOrDevice s) {
|
||||
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
|
||||
return mx::split(a, *pv, axis, s);
|
||||
} else {
|
||||
return mx::split(
|
||||
a, std::get<std::vector<int>>(indices_or_sections), axis, s);
|
||||
a, std::get<mx::Shape>(indices_or_sections), axis, s);
|
||||
}
|
||||
},
|
||||
"indices_or_sections"_a,
|
||||
|
@ -181,7 +181,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
return nb::cpp_function(
|
||||
[kernel = std::move(kernel)](
|
||||
const std::vector<ScalarOrArray>& inputs_,
|
||||
const std::vector<std::vector<int>>& output_shapes,
|
||||
const std::vector<mx::Shape>& output_shapes,
|
||||
const std::vector<mx::Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
|
@ -79,7 +79,7 @@ void init_fft(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"fft2",
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<mx::Shape>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
@ -115,7 +115,7 @@ void init_fft(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"ifft2",
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<mx::Shape>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
@ -151,7 +151,7 @@ void init_fft(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"fftn",
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<mx::Shape>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
@ -188,7 +188,7 @@ void init_fft(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"ifftn",
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<mx::Shape>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
@ -294,7 +294,7 @@ void init_fft(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"rfft2",
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<mx::Shape>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
@ -336,7 +336,7 @@ void init_fft(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"irfft2",
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<mx::Shape>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
@ -378,7 +378,7 @@ void init_fft(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"rfftn",
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<mx::Shape>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
@ -420,7 +420,7 @@ void init_fft(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"irfftn",
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<mx::Shape>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
|
@ -25,9 +25,9 @@ int get_slice_int(nb::object obj, int default_val) {
|
||||
}
|
||||
|
||||
void get_slice_params(
|
||||
int& starts,
|
||||
int& ends,
|
||||
int& strides,
|
||||
mx::ShapeElem& starts,
|
||||
mx::ShapeElem& ends,
|
||||
mx::ShapeElem& strides,
|
||||
const nb::slice& in_slice,
|
||||
int axis_size) {
|
||||
// Following numpy's convention
|
||||
@ -68,9 +68,9 @@ mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) {
|
||||
return src;
|
||||
}
|
||||
|
||||
std::vector<int> starts(src.ndim(), 0);
|
||||
std::vector<int> ends = src.shape();
|
||||
std::vector<int> strides(src.ndim(), 1);
|
||||
mx::Shape starts(src.ndim(), 0);
|
||||
auto ends = src.shape();
|
||||
mx::Shape strides(src.ndim(), 1);
|
||||
|
||||
// Check and update slice params
|
||||
get_slice_params(starts[0], ends[0], strides[0], in_slice, ends[0]);
|
||||
@ -119,7 +119,7 @@ mx::array mlx_gather_nd(
|
||||
auto& idx = indices[i];
|
||||
|
||||
if (nb::isinstance<nb::slice>(idx)) {
|
||||
int start, end, stride;
|
||||
mx::ShapeElem start, end, stride;
|
||||
get_slice_params(
|
||||
start, end, stride, nb::cast<nb::slice>(idx), src.shape(i));
|
||||
|
||||
@ -168,7 +168,7 @@ mx::array mlx_gather_nd(
|
||||
// Do the gather
|
||||
std::vector<int> axes(indices.size());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
std::vector<int> slice_sizes = src.shape();
|
||||
auto slice_sizes = src.shape();
|
||||
std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
|
||||
src = gather(src, gather_indices, axes, slice_sizes);
|
||||
|
||||
@ -179,9 +179,7 @@ mx::array mlx_gather_nd(
|
||||
return mx::squeeze(src, axes);
|
||||
}
|
||||
|
||||
auto mlx_expand_ellipsis(
|
||||
const std::vector<int>& shape,
|
||||
const nb::tuple& entries) {
|
||||
auto mlx_expand_ellipsis(const mx::Shape& shape, const nb::tuple& entries) {
|
||||
std::vector<nb::object> indices;
|
||||
|
||||
// Go over all entries and note the position of ellipsis
|
||||
@ -230,7 +228,8 @@ auto mlx_expand_ellipsis(
|
||||
for (int axis = non_none_indices_before;
|
||||
axis < shape.size() - non_none_indices_after;
|
||||
axis++) {
|
||||
indices.push_back(nb::slice(0, shape[axis], 1));
|
||||
indices.push_back(
|
||||
nb::slice(mx::ShapeElem{0}, shape[axis], mx::ShapeElem{1}));
|
||||
non_none_indices++;
|
||||
}
|
||||
}
|
||||
@ -371,9 +370,9 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
|
||||
|
||||
// Slice handling
|
||||
{
|
||||
std::vector<int> starts(src.ndim(), 0);
|
||||
std::vector<int> ends = src.shape();
|
||||
std::vector<int> strides(src.ndim(), 1);
|
||||
mx::Shape starts(src.ndim(), 0);
|
||||
auto ends = src.shape();
|
||||
mx::Shape strides(src.ndim(), 1);
|
||||
int axis = 0;
|
||||
for (auto& idx : remaining_indices) {
|
||||
if (!idx.is_none()) {
|
||||
@ -461,8 +460,7 @@ mlx_scatter_args_int(
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++)
|
||||
;
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
auto up_shape = mx::Shape(update.shape().begin() + s, update.shape().end());
|
||||
auto shape = src.shape();
|
||||
shape[0] = 1;
|
||||
|
||||
@ -521,9 +519,9 @@ mlx_scatter_args_slice(
|
||||
{}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}};
|
||||
}
|
||||
|
||||
int start = 0;
|
||||
int end = src.shape(0);
|
||||
int stride = 1;
|
||||
mx::ShapeElem start = 0;
|
||||
auto end = src.shape(0);
|
||||
mx::ShapeElem stride = 1;
|
||||
|
||||
// Check and update slice params
|
||||
get_slice_params(start, end, stride, in_slice, end);
|
||||
@ -645,7 +643,7 @@ mlx_scatter_args_nd(
|
||||
for (int i = 0; i < indices.size(); ++i) {
|
||||
auto& pyidx = indices[i];
|
||||
if (nb::isinstance<nb::slice>(pyidx)) {
|
||||
int start, end, stride;
|
||||
mx::ShapeElem start, end, stride;
|
||||
auto axis_size = src.shape(ax++);
|
||||
get_slice_params(
|
||||
start, end, stride, nb::cast<nb::slice>(pyidx), axis_size);
|
||||
@ -654,7 +652,7 @@ mlx_scatter_args_nd(
|
||||
start = (start < 0) ? start + axis_size : start;
|
||||
end = (end < 0) ? end + axis_size : end;
|
||||
|
||||
std::vector<int> idx_shape(idx_ndim, 1);
|
||||
mx::Shape idx_shape(idx_ndim, 1);
|
||||
|
||||
// If it's a simple slice, we only need to add the start index
|
||||
if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {
|
||||
|
@ -1571,15 +1571,14 @@ void init_ops(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"full",
|
||||
[](const std::variant<int, std::vector<int>>& shape,
|
||||
[](const std::variant<int, mx::Shape>& shape,
|
||||
const ScalarOrArray& vals,
|
||||
std::optional<mx::Dtype> dtype,
|
||||
mx::StreamOrDevice s) {
|
||||
if (auto pv = std::get_if<int>(&shape); pv) {
|
||||
return mx::full({*pv}, to_array(vals, dtype), s);
|
||||
} else {
|
||||
return mx::full(
|
||||
std::get<std::vector<int>>(shape), to_array(vals, dtype), s);
|
||||
return mx::full(std::get<mx::Shape>(shape), to_array(vals, dtype), s);
|
||||
}
|
||||
},
|
||||
"shape"_a,
|
||||
@ -1606,14 +1605,14 @@ void init_ops(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"zeros",
|
||||
[](const std::variant<int, std::vector<int>>& shape,
|
||||
[](const std::variant<int, mx::Shape>& shape,
|
||||
std::optional<mx::Dtype> dtype,
|
||||
mx::StreamOrDevice s) {
|
||||
auto t = dtype.value_or(mx::float32);
|
||||
if (auto pv = std::get_if<int>(&shape); pv) {
|
||||
return mx::zeros({*pv}, t, s);
|
||||
} else {
|
||||
return mx::zeros(std::get<std::vector<int>>(shape), t, s);
|
||||
return mx::zeros(std::get<mx::Shape>(shape), t, s);
|
||||
}
|
||||
},
|
||||
"shape"_a,
|
||||
@ -1652,14 +1651,14 @@ void init_ops(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ones",
|
||||
[](const std::variant<int, std::vector<int>>& shape,
|
||||
[](const std::variant<int, mx::Shape>& shape,
|
||||
std::optional<mx::Dtype> dtype,
|
||||
mx::StreamOrDevice s) {
|
||||
auto t = dtype.value_or(mx::float32);
|
||||
if (auto pv = std::get_if<int>(&shape); pv) {
|
||||
return mx::ones({*pv}, t, s);
|
||||
} else {
|
||||
return mx::ones(std::get<std::vector<int>>(shape), t, s);
|
||||
return mx::ones(std::get<mx::Shape>(shape), t, s);
|
||||
}
|
||||
},
|
||||
"shape"_a,
|
||||
@ -2481,14 +2480,14 @@ void init_ops(nb::module_& m) {
|
||||
m.def(
|
||||
"split",
|
||||
[](const mx::array& a,
|
||||
const std::variant<int, std::vector<int>>& indices_or_sections,
|
||||
const std::variant<int, mx::Shape>& indices_or_sections,
|
||||
int axis,
|
||||
mx::StreamOrDevice s) {
|
||||
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
|
||||
return mx::split(a, *pv, axis, s);
|
||||
} else {
|
||||
return mx::split(
|
||||
a, std::get<std::vector<int>>(indices_or_sections), axis, s);
|
||||
a, std::get<mx::Shape>(indices_or_sections), axis, s);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
@ -2744,9 +2743,7 @@ void init_ops(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"broadcast_to",
|
||||
[](const ScalarOrArray& a,
|
||||
const std::vector<int>& shape,
|
||||
mx::StreamOrDevice s) {
|
||||
[](const ScalarOrArray& a, const mx::Shape& shape, mx::StreamOrDevice s) {
|
||||
return mx::broadcast_to(to_array(a), shape, s);
|
||||
},
|
||||
nb::arg(),
|
||||
@ -4895,24 +4892,16 @@ void init_ops(nb::module_& m) {
|
||||
m.def(
|
||||
"roll",
|
||||
[](const mx::array& a,
|
||||
const IntOrVec& shift,
|
||||
const std::variant<int, mx::Shape>& shift,
|
||||
const IntOrVec& axis,
|
||||
mx::StreamOrDevice s) {
|
||||
return std::visit(
|
||||
[&](auto sh, auto ax) -> mx::array {
|
||||
using T = decltype(ax);
|
||||
using V = decltype(sh);
|
||||
|
||||
if constexpr (std::is_same_v<V, std::monostate>) {
|
||||
throw std::invalid_argument(
|
||||
"[roll] Expected two arguments but only one was given.");
|
||||
} else {
|
||||
if constexpr (std::is_same_v<T, std::monostate>) {
|
||||
if constexpr (std::is_same_v<decltype(ax), std::monostate>) {
|
||||
return mx::roll(a, sh, s);
|
||||
} else {
|
||||
return mx::roll(a, sh, ax, s);
|
||||
}
|
||||
}
|
||||
},
|
||||
shift,
|
||||
axis);
|
||||
|
@ -108,7 +108,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"uniform",
|
||||
[](const ScalarOrArray& low,
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
@ -123,7 +123,7 @@ void init_random(nb::module_& parent_module) {
|
||||
},
|
||||
"low"_a = 0,
|
||||
"high"_a = 1,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
@ -150,7 +150,7 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"normal",
|
||||
[](const std::vector<int>& shape,
|
||||
[](const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
float loc,
|
||||
float scale,
|
||||
@ -160,7 +160,7 @@ void init_random(nb::module_& parent_module) {
|
||||
return mx::random::normal(
|
||||
shape, type.value_or(mx::float32), loc, scale, key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"loc"_a = 0.0,
|
||||
"scale"_a = 1.0,
|
||||
@ -185,7 +185,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"multivariate_normal",
|
||||
[](const mx::array& mean,
|
||||
const mx::array& cov,
|
||||
const std::vector<int>& shape,
|
||||
const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
@ -195,7 +195,7 @@ void init_random(nb::module_& parent_module) {
|
||||
},
|
||||
"mean"_a,
|
||||
"cov"_a,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
@ -227,7 +227,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"randint",
|
||||
[](const ScalarOrArray& low,
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
@ -242,7 +242,7 @@ void init_random(nb::module_& parent_module) {
|
||||
},
|
||||
"low"_a,
|
||||
"high"_a,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::int32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
@ -268,7 +268,7 @@ void init_random(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"bernoulli",
|
||||
[](const ScalarOrArray& p_,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<mx::Shape> shape,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
@ -306,7 +306,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"truncated_normal",
|
||||
[](const ScalarOrArray& lower_,
|
||||
const ScalarOrArray& upper_,
|
||||
const std::optional<std::vector<int>> shape_,
|
||||
const std::optional<mx::Shape> shape_,
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
@ -350,14 +350,14 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"gumbel",
|
||||
[](const std::vector<int>& shape,
|
||||
[](const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return mx::random::gumbel(shape, type.value_or(mx::float32), key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
@ -384,7 +384,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"categorical",
|
||||
[](const mx::array& logits,
|
||||
int axis,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<mx::Shape> shape,
|
||||
const std::optional<int> num_samples,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
@ -434,7 +434,7 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"laplace",
|
||||
[](const std::vector<int>& shape,
|
||||
[](const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
float loc,
|
||||
float scale,
|
||||
@ -444,7 +444,7 @@ void init_random(nb::module_& parent_module) {
|
||||
return mx::random::laplace(
|
||||
shape, type.value_or(mx::float32), loc, scale, key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"loc"_a = 0.0,
|
||||
"scale"_a = 1.0,
|
||||
@ -479,7 +479,7 @@ void init_random(nb::module_& parent_module) {
|
||||
return mx::random::permutation(std::get<mx::array>(x), axis, key, s);
|
||||
}
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"x"_a,
|
||||
"axis"_a = 0,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
|
@ -1,4 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/pair.h>
|
||||
|
@ -395,7 +395,7 @@ TEST_CASE("test split") {
|
||||
CHECK_EQ(out[1].shape(), Shape{8, 4});
|
||||
CHECK_EQ(out[2].shape(), Shape{8, 4});
|
||||
|
||||
out = split(x, std::vector<int>{});
|
||||
out = split(x, Shape{});
|
||||
CHECK_EQ(out.size(), 1);
|
||||
CHECK_EQ(out[0].shape(), x.shape());
|
||||
|
||||
@ -405,25 +405,25 @@ TEST_CASE("test split") {
|
||||
CHECK_EQ(out[1].shape(), Shape{4, 12});
|
||||
CHECK_EQ(out[2].shape(), Shape{1, 12});
|
||||
|
||||
out = split(x, std::vector<int>{20});
|
||||
out = split(x, Shape{20});
|
||||
CHECK_EQ(out.size(), 2);
|
||||
CHECK_EQ(out[0].shape(), Shape{8, 12});
|
||||
CHECK_EQ(out[1].shape(), Shape{0, 12});
|
||||
|
||||
// Negative indices
|
||||
out = split(x, std::vector<int>{-5});
|
||||
out = split(x, Shape{-5});
|
||||
CHECK_EQ(out[0].shape(), Shape{3, 12});
|
||||
CHECK_EQ(out[1].shape(), Shape{5, 12});
|
||||
|
||||
// Different axis
|
||||
out = split(x, std::vector<int>{2, 8}, 1);
|
||||
out = split(x, {2, 8}, 1);
|
||||
CHECK_EQ(out[0].shape(), Shape{8, 2});
|
||||
CHECK_EQ(out[1].shape(), Shape{8, 6});
|
||||
CHECK_EQ(out[2].shape(), Shape{8, 4});
|
||||
|
||||
// Out of order indices
|
||||
x = arange(5);
|
||||
out = split(x, std::vector<int>{2, 1, 2});
|
||||
out = split(x, {2, 1, 2});
|
||||
CHECK(array_equal(out[0], array({0, 1})).item<bool>());
|
||||
CHECK(array_equal(out[1], array({})).item<bool>());
|
||||
CHECK(array_equal(out[2], array({1})).item<bool>());
|
||||
|
@ -611,8 +611,8 @@ TEST_CASE("test categorical") {
|
||||
CHECK_THROWS(categorical(logits, -3));
|
||||
|
||||
// Invalid requested shapes
|
||||
CHECK_THROWS(categorical(logits, 1, std::vector<int>{1}));
|
||||
CHECK_THROWS(categorical(logits, 1, std::vector<int>{11}));
|
||||
CHECK_THROWS(categorical(logits, 1, Shape{1}));
|
||||
CHECK_THROWS(categorical(logits, 1, Shape{11}));
|
||||
CHECK_THROWS(categorical(logits, 1, {10, 1}));
|
||||
|
||||
CHECK_EQ(categorical(logits, -1).shape(), Shape{10});
|
||||
|
@ -335,8 +335,7 @@ TEST_CASE("test vmap gather") {
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto src = inputs[0];
|
||||
auto indices = inputs[1];
|
||||
std::vector<int> slice_sizes = {1, 2, 2};
|
||||
auto out = squeeze(gather(src, indices, 0, slice_sizes), 2);
|
||||
auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 2);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
@ -351,8 +350,7 @@ TEST_CASE("test vmap gather") {
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto src = inputs[0];
|
||||
auto indices = inputs[1];
|
||||
std::vector<int> slice_sizes = {1, 2, 2};
|
||||
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
|
||||
auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 1);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
@ -365,8 +363,7 @@ TEST_CASE("test vmap gather") {
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto src = inputs[0];
|
||||
auto indices = inputs[1];
|
||||
std::vector<int> slice_sizes = {1, 2, 2, 2};
|
||||
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
|
||||
auto out = squeeze(gather(src, indices, 0, {1, 2, 2, 2}), 1);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
@ -380,8 +377,7 @@ TEST_CASE("test vmap gather") {
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto src = inputs[0];
|
||||
auto indices = std::vector<array>(inputs.begin() + 1, inputs.end());
|
||||
std::vector<int> slice_sizes = {1, 1, 2, 2};
|
||||
auto out = squeeze(gather(src, indices, {0, 1}, slice_sizes), {1, 2});
|
||||
auto out = squeeze(gather(src, indices, {0, 1}, {1, 1, 2, 2}), {1, 2});
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
|
Loading…
Reference in New Issue
Block a user