mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
parent
f17536af9c
commit
e03f0372b1
@ -25,7 +25,7 @@ bool retain_graph() {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
||||||
auto cval = static_cast<complex64_t>(val);
|
auto cval = static_cast<complex64_t>(val);
|
||||||
init(&cval);
|
init(&cval);
|
||||||
}
|
}
|
||||||
@ -61,14 +61,14 @@ std::vector<array> array::make_arrays(
|
|||||||
|
|
||||||
array::array(std::initializer_list<float> data)
|
array::array(std::initializer_list<float> data)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(
|
: array_desc_(std::make_shared<ArrayDesc>(
|
||||||
std::vector<int>{static_cast<int>(data.size())},
|
Shape{static_cast<ShapeElem>(data.size())},
|
||||||
float32)) {
|
float32)) {
|
||||||
init(data.begin());
|
init(data.begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
array::array(std::initializer_list<int> data, Dtype dtype)
|
array::array(std::initializer_list<int> data, Dtype dtype)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(
|
: array_desc_(std::make_shared<ArrayDesc>(
|
||||||
std::vector<int>{static_cast<int>(data.size())},
|
Shape{static_cast<ShapeElem>(data.size())},
|
||||||
dtype)) {
|
dtype)) {
|
||||||
init(data.begin());
|
init(data.begin());
|
||||||
}
|
}
|
||||||
@ -322,7 +322,7 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
|||||||
}
|
}
|
||||||
|
|
||||||
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||||
auto start = std::vector<int>(arr.ndim(), 0);
|
auto start = Shape(arr.ndim(), 0);
|
||||||
auto end = arr.shape();
|
auto end = arr.shape();
|
||||||
auto shape = arr.shape();
|
auto shape = arr.shape();
|
||||||
shape.erase(shape.begin());
|
shape.erase(shape.begin());
|
||||||
|
@ -17,7 +17,8 @@ namespace mlx::core {
|
|||||||
class Primitive;
|
class Primitive;
|
||||||
|
|
||||||
using Deleter = std::function<void(allocator::Buffer)>;
|
using Deleter = std::function<void(allocator::Buffer)>;
|
||||||
using Shape = std::vector<int32_t>;
|
using ShapeElem = int32_t;
|
||||||
|
using Shape = std::vector<ShapeElem>;
|
||||||
using Strides = std::vector<int64_t>;
|
using Strides = std::vector<int64_t>;
|
||||||
|
|
||||||
class array {
|
class array {
|
||||||
@ -498,7 +499,7 @@ class array {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
||||||
init(&val);
|
init(&val);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -516,7 +517,7 @@ array::array(
|
|||||||
std::initializer_list<T> data,
|
std::initializer_list<T> data,
|
||||||
Dtype dtype /* = TypeToDtype<T>() */)
|
Dtype dtype /* = TypeToDtype<T>() */)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(
|
: array_desc_(std::make_shared<ArrayDesc>(
|
||||||
std::vector<int>{static_cast<int>(data.size())},
|
Shape{static_cast<ShapeElem>(data.size())},
|
||||||
dtype)) {
|
dtype)) {
|
||||||
init(data.begin());
|
init(data.begin());
|
||||||
}
|
}
|
||||||
|
@ -130,7 +130,7 @@ std::string build_lib_name(
|
|||||||
|
|
||||||
bool compiled_check_contiguity(
|
bool compiled_check_contiguity(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& shape) {
|
const Shape& shape) {
|
||||||
bool contiguous = true;
|
bool contiguous = true;
|
||||||
bool all_contig = true;
|
bool all_contig = true;
|
||||||
bool all_row_contig = true;
|
bool all_row_contig = true;
|
||||||
|
@ -56,7 +56,7 @@ inline bool is_scalar(const array& x) {
|
|||||||
// Check if we can use a contiguous operation given inputs and the output shape
|
// Check if we can use a contiguous operation given inputs and the output shape
|
||||||
bool compiled_check_contiguity(
|
bool compiled_check_contiguity(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& shape);
|
const Shape& shape);
|
||||||
|
|
||||||
// Allocate space for the outputs possibly with input donation
|
// Allocate space for the outputs possibly with input donation
|
||||||
void compiled_allocate_outputs(
|
void compiled_allocate_outputs(
|
||||||
|
@ -726,7 +726,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
auto conv_dtype = float32;
|
auto conv_dtype = float32;
|
||||||
|
|
||||||
// Pad input
|
// Pad input
|
||||||
std::vector<int> padded_shape = {N, iH + 2 * padding[0], C};
|
Shape padded_shape = {N, iH + 2 * padding[0], C};
|
||||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||||
|
|
||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
@ -765,7 +765,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||||
|
|
||||||
// Materialize strided view
|
// Materialize strided view
|
||||||
std::vector<int> strided_reshape = {N * oH, wH * C};
|
Shape strided_reshape = {N * oH, wH * C};
|
||||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
copy(in_strided_view, in_strided, CopyType::General);
|
copy(in_strided_view, in_strided, CopyType::General);
|
||||||
|
|
||||||
@ -843,8 +843,7 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
auto conv_dtype = out.dtype();
|
auto conv_dtype = out.dtype();
|
||||||
|
|
||||||
// Pad input
|
// Pad input
|
||||||
std::vector<int> padded_shape = {
|
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
|
||||||
N, iH + 2 * padding[0], iW + 2 * padding[1], C};
|
|
||||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||||
|
|
||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
@ -881,7 +880,7 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||||
|
|
||||||
// Materialize strided view
|
// Materialize strided view
|
||||||
std::vector<int> strided_reshape = {N * oH * oW, wH * wW * C};
|
Shape strided_reshape = {N * oH * oW, wH * wW * C};
|
||||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
copy(in_strided_view, in_strided, CopyType::General);
|
copy(in_strided_view, in_strided, CopyType::General);
|
||||||
|
|
||||||
@ -934,19 +933,19 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const bool flip) {
|
const bool flip) {
|
||||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||||
const auto iDim = std::vector<int>(
|
const auto iDim =
|
||||||
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
|
Shape(in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
|
||||||
const auto oDim = std::vector<int>(
|
const auto oDim = Shape(
|
||||||
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
|
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
|
||||||
const int O = wt.shape(0); // Out channels
|
const int O = wt.shape(0); // Out channels
|
||||||
const int C = wt.shape(-1); // In channels
|
const int C = wt.shape(-1); // In channels
|
||||||
const auto wDim = std::vector<int>(
|
const auto wDim =
|
||||||
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
|
Shape(wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
|
||||||
|
|
||||||
auto conv_dtype = float32;
|
auto conv_dtype = float32;
|
||||||
|
|
||||||
// Pad input
|
// Pad input
|
||||||
std::vector<int> padded_shape(in.shape().size());
|
Shape padded_shape(in.shape().size());
|
||||||
padded_shape.front() = N;
|
padded_shape.front() = N;
|
||||||
for (size_t i = 0; i < iDim.size(); i++) {
|
for (size_t i = 0; i < iDim.size(); i++) {
|
||||||
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
|
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
|
||||||
|
@ -14,10 +14,10 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename T, typename IdxT = int32_t>
|
template <typename T>
|
||||||
struct StridedIterator {
|
struct StridedIterator {
|
||||||
using iterator_category = std::random_access_iterator_tag;
|
using iterator_category = std::random_access_iterator_tag;
|
||||||
using difference_type = IdxT;
|
using difference_type = int32_t;
|
||||||
using value_type = T;
|
using value_type = T;
|
||||||
using reference = value_type&;
|
using reference = value_type&;
|
||||||
using pointer = value_type*;
|
using pointer = value_type*;
|
||||||
|
@ -107,7 +107,7 @@ struct ContiguousIterator {
|
|||||||
: shape_(a.shape()), strides_(a.strides()) {
|
: shape_(a.shape()), strides_(a.strides()) {
|
||||||
if (!shape_.empty()) {
|
if (!shape_.empty()) {
|
||||||
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
||||||
pos_ = std::vector<int>(shape_.size(), 0);
|
pos_ = Shape(shape_.size(), 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ void explicit_gemm_conv_ND_gpu(
|
|||||||
int implicit_K = wt.size() / conv_params.O;
|
int implicit_K = wt.size() / conv_params.O;
|
||||||
int implicit_N = conv_params.O;
|
int implicit_N = conv_params.O;
|
||||||
// Prepare unfolding array
|
// Prepare unfolding array
|
||||||
std::vector<int> unfolded_shape{implicit_M, implicit_K};
|
Shape unfolded_shape{implicit_M, implicit_K};
|
||||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||||
|
|
||||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||||
@ -113,7 +113,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Prepare unfolding array
|
// Prepare unfolding array
|
||||||
std::vector<int> unfolded_shape{implicit_M, implicit_K * groups};
|
Shape unfolded_shape{implicit_M, implicit_K * groups};
|
||||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||||
|
|
||||||
@ -192,12 +192,12 @@ void conv_1D_gpu(
|
|||||||
bool flip) {
|
bool flip) {
|
||||||
// Make conv params
|
// Make conv params
|
||||||
MLXConvParams<1> conv_params{
|
MLXConvParams<1> conv_params{
|
||||||
/* const int N = */ in.shape(0),
|
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||||
/* const int C = */ in.shape(2),
|
/* const int C = */ static_cast<int>(in.shape(2)),
|
||||||
/* const int O = */ wt.shape(0),
|
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||||
/* const int iS[NDIM] = */ {in.shape(1)},
|
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
|
||||||
/* const int wS[NDIM] = */ {wt.shape(1)},
|
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
|
||||||
/* const int oS[NDIM] = */ {out.shape(1)},
|
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
|
||||||
/* const int str[NDIM] = */ {wt_strides[0]},
|
/* const int str[NDIM] = */ {wt_strides[0]},
|
||||||
/* const int pad[NDIM] = */ {padding[0]},
|
/* const int pad[NDIM] = */ {padding[0]},
|
||||||
/* const int kdil[NDIM] = */ {wt_dilation[0]},
|
/* const int kdil[NDIM] = */ {wt_dilation[0]},
|
||||||
@ -541,7 +541,7 @@ void winograd_conv_2D_gpu(
|
|||||||
array out,
|
array out,
|
||||||
const MLXConvParams<2>& conv_params,
|
const MLXConvParams<2>& conv_params,
|
||||||
std::vector<array>& copies_w) {
|
std::vector<array>& copies_w) {
|
||||||
std::vector<int> padded_shape = {
|
Shape padded_shape = {
|
||||||
conv_params.N,
|
conv_params.N,
|
||||||
conv_params.iS[0] + 2 * conv_params.pad[0],
|
conv_params.iS[0] + 2 * conv_params.pad[0],
|
||||||
conv_params.iS[1] + 2 * conv_params.pad[1],
|
conv_params.iS[1] + 2 * conv_params.pad[1],
|
||||||
@ -550,7 +550,7 @@ void winograd_conv_2D_gpu(
|
|||||||
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
|
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
|
||||||
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
|
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
|
||||||
|
|
||||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
|
||||||
|
|
||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
array zero_arr = array(0, in.dtype());
|
array zero_arr = array(0, in.dtype());
|
||||||
@ -575,12 +575,16 @@ void winograd_conv_2D_gpu(
|
|||||||
copies_w.push_back(in_padded);
|
copies_w.push_back(in_padded);
|
||||||
|
|
||||||
MLXConvParams<2> conv_params_updated{
|
MLXConvParams<2> conv_params_updated{
|
||||||
/* const int N = */ in_padded.shape(0),
|
/* const int N = */ static_cast<int>(in_padded.shape(0)),
|
||||||
/* const int C = */ in_padded.shape(3),
|
/* const int C = */ static_cast<int>(in_padded.shape(3)),
|
||||||
/* const int O = */ wt.shape(0),
|
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||||
/* const int iS[NDIM] = */ {in_padded.shape(1), in_padded.shape(2)},
|
/* const int iS[NDIM] = */
|
||||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
|
{static_cast<int>(in_padded.shape(1)),
|
||||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
static_cast<int>(in_padded.shape(2))},
|
||||||
|
/* const int wS[NDIM] = */
|
||||||
|
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
|
||||||
|
/* const int oS[NDIM] = */
|
||||||
|
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
|
||||||
/* const int str[NDIM] = */ {1, 1},
|
/* const int str[NDIM] = */ {1, 1},
|
||||||
/* const int pad[NDIM] = */ {0, 0},
|
/* const int pad[NDIM] = */ {0, 0},
|
||||||
/* const int kdil[NDIM] = */ {1, 1},
|
/* const int kdil[NDIM] = */ {1, 1},
|
||||||
@ -607,8 +611,8 @@ void winograd_conv_2D_gpu(
|
|||||||
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
|
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
|
||||||
|
|
||||||
// Do filter transform
|
// Do filter transform
|
||||||
std::vector<int> filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
||||||
array filt_wg(filt_wg_shape, wt.dtype(), nullptr, {});
|
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
|
||||||
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
|
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
|
||||||
copies_w.push_back(filt_wg);
|
copies_w.push_back(filt_wg);
|
||||||
{
|
{
|
||||||
@ -634,8 +638,8 @@ void winograd_conv_2D_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Do input transform
|
// Do input transform
|
||||||
std::vector<int> inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
||||||
array inp_wg(inp_wg_shape, in.dtype(), nullptr, {});
|
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
|
||||||
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
|
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
|
||||||
copies_w.push_back(inp_wg);
|
copies_w.push_back(inp_wg);
|
||||||
{
|
{
|
||||||
@ -661,8 +665,8 @@ void winograd_conv_2D_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Do batched gemm
|
// Do batched gemm
|
||||||
std::vector<int> out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
||||||
array out_wg(out_wg_shape, in.dtype(), nullptr, {});
|
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
|
||||||
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
|
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
|
||||||
copies_w.push_back(out_wg);
|
copies_w.push_back(out_wg);
|
||||||
{
|
{
|
||||||
@ -723,12 +727,15 @@ void conv_2D_gpu(
|
|||||||
std::vector<array>& copies) {
|
std::vector<array>& copies) {
|
||||||
// Make conv params
|
// Make conv params
|
||||||
MLXConvParams<2> conv_params{
|
MLXConvParams<2> conv_params{
|
||||||
/* const int N = */ in.shape(0),
|
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||||
/* const int C = */ in.shape(3),
|
/* const int C = */ static_cast<int>(in.shape(3)),
|
||||||
/* const int O = */ wt.shape(0),
|
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||||
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2)},
|
/* const int iS[NDIM] = */
|
||||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
|
{static_cast<int>(in.shape(1)), static_cast<int>(in.shape(2))},
|
||||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
/* const int wS[NDIM] = */
|
||||||
|
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
|
||||||
|
/* const int oS[NDIM] = */
|
||||||
|
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
|
||||||
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
|
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
|
||||||
/* const int pad[NDIM] = */ {padding[0], padding[1]},
|
/* const int pad[NDIM] = */ {padding[0], padding[1]},
|
||||||
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
||||||
@ -800,12 +807,21 @@ void conv_3D_gpu(
|
|||||||
std::vector<array>& copies) {
|
std::vector<array>& copies) {
|
||||||
// Make conv params
|
// Make conv params
|
||||||
MLXConvParams<3> conv_params{
|
MLXConvParams<3> conv_params{
|
||||||
/* const int N = */ in.shape(0),
|
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||||
/* const int C = */ in.shape(4),
|
/* const int C = */ static_cast<int>(in.shape(4)),
|
||||||
/* const int O = */ wt.shape(0),
|
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||||
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)},
|
/* const int iS[NDIM] = */
|
||||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)},
|
{static_cast<int>(in.shape(1)),
|
||||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)},
|
static_cast<int>(in.shape(2)),
|
||||||
|
static_cast<int>(in.shape(3))},
|
||||||
|
/* const int wS[NDIM] = */
|
||||||
|
{static_cast<int>(wt.shape(1)),
|
||||||
|
static_cast<int>(wt.shape(2)),
|
||||||
|
static_cast<int>(wt.shape(3))},
|
||||||
|
/* const int oS[NDIM] = */
|
||||||
|
{static_cast<int>(out.shape(1)),
|
||||||
|
static_cast<int>(out.shape(2)),
|
||||||
|
static_cast<int>(out.shape(3))},
|
||||||
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
|
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
|
||||||
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
|
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
|
||||||
/* const int kdil[NDIM] = */
|
/* const int kdil[NDIM] = */
|
||||||
|
@ -635,7 +635,7 @@ void strided_reduce_longcolumn(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Prepare the temporary accumulator
|
// Prepare the temporary accumulator
|
||||||
std::vector<int> intermediate_shape;
|
Shape intermediate_shape;
|
||||||
intermediate_shape.reserve(out.ndim() + 1);
|
intermediate_shape.reserve(out.ndim() + 1);
|
||||||
intermediate_shape.push_back(outer_blocks);
|
intermediate_shape.push_back(outer_blocks);
|
||||||
intermediate_shape.insert(
|
intermediate_shape.insert(
|
||||||
@ -806,7 +806,7 @@ void strided_reduce_2pass(
|
|||||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||||
|
|
||||||
// Prepare the temporary accumulator
|
// Prepare the temporary accumulator
|
||||||
std::vector<int> intermediate_shape;
|
Shape intermediate_shape;
|
||||||
intermediate_shape.reserve(out.ndim() + 1);
|
intermediate_shape.reserve(out.ndim() + 1);
|
||||||
intermediate_shape.push_back(32);
|
intermediate_shape.push_back(32);
|
||||||
intermediate_shape.insert(
|
intermediate_shape.insert(
|
||||||
|
@ -63,8 +63,8 @@ void pad_gpu(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& val,
|
const array& val,
|
||||||
array& out,
|
array& out,
|
||||||
std::vector<int> axes,
|
const std::vector<int>& axes,
|
||||||
std::vector<int> low_pad_size,
|
const Shape& low_pad_size,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
// Fill output with val
|
// Fill output with val
|
||||||
fill_gpu(val, out, s);
|
fill_gpu(val, out, s);
|
||||||
|
@ -23,8 +23,8 @@ void pad_gpu(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& val,
|
const array& val,
|
||||||
array& out,
|
array& out,
|
||||||
std::vector<int> axes,
|
const std::vector<int>& axes,
|
||||||
std::vector<int> low_pad_size,
|
const Shape& low_pad_size,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -82,7 +82,7 @@ array send(
|
|||||||
}
|
}
|
||||||
|
|
||||||
array recv(
|
array recv(
|
||||||
std::vector<int> shape,
|
Shape shape,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
int src,
|
int src,
|
||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
|
@ -26,7 +26,7 @@ array send(
|
|||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
array recv(
|
array recv(
|
||||||
std::vector<int> shape,
|
Shape shape,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
int src,
|
int src,
|
||||||
std::optional<Group> group = std::nullopt,
|
std::optional<Group> group = std::nullopt,
|
||||||
|
@ -91,7 +91,7 @@ std::vector<array> AllGather::vjp(
|
|||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) {
|
const std::vector<array>& outputs) {
|
||||||
auto g = group();
|
auto g = group();
|
||||||
std::vector<int> starts(primals[0].ndim(), 0);
|
Shape starts(primals[0].ndim(), 0);
|
||||||
auto stops = primals[0].shape();
|
auto stops = primals[0].shape();
|
||||||
starts[0] = g.rank() * stops[0];
|
starts[0] = g.rank() * stops[0];
|
||||||
stops[0] += starts[0];
|
stops[0] += starts[0];
|
||||||
|
@ -108,7 +108,7 @@ bool disjoint(const CharSet& x, const CharSet& y) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
size_t term_size(const T& term, std::unordered_map<char, int> dict) {
|
size_t term_size(const T& term, std::unordered_map<char, ShapeElem> dict) {
|
||||||
size_t size = 1;
|
size_t size = 1;
|
||||||
for (auto c : term) {
|
for (auto c : term) {
|
||||||
size *= dict[c];
|
size *= dict[c];
|
||||||
@ -120,7 +120,7 @@ size_t flop_count(
|
|||||||
const CharSet& term,
|
const CharSet& term,
|
||||||
bool inner,
|
bool inner,
|
||||||
int num_terms,
|
int num_terms,
|
||||||
std::unordered_map<char, int> dict) {
|
std::unordered_map<char, ShapeElem> dict) {
|
||||||
size_t size = term_size(term, dict);
|
size_t size = term_size(term, dict);
|
||||||
auto op_factor = 1;
|
auto op_factor = 1;
|
||||||
if ((num_terms - 1) > op_factor) {
|
if ((num_terms - 1) > op_factor) {
|
||||||
@ -135,7 +135,7 @@ size_t flop_count(
|
|||||||
std::pair<size_t, int> compute_cost_and_scaling(
|
std::pair<size_t, int> compute_cost_and_scaling(
|
||||||
const std::vector<Subscript>& inputs,
|
const std::vector<Subscript>& inputs,
|
||||||
const Subscript& output,
|
const Subscript& output,
|
||||||
std::unordered_map<char, int> dim_map) {
|
std::unordered_map<char, ShapeElem> dim_map) {
|
||||||
CharSet contractions;
|
CharSet contractions;
|
||||||
for (auto& in : inputs) {
|
for (auto& in : inputs) {
|
||||||
contractions.insert(in.set.begin(), in.set.end());
|
contractions.insert(in.set.begin(), in.set.end());
|
||||||
@ -155,7 +155,7 @@ std::pair<size_t, int> compute_cost_and_scaling(
|
|||||||
std::tuple<std::vector<PathNode>, size_t, int> greedy_path(
|
std::tuple<std::vector<PathNode>, size_t, int> greedy_path(
|
||||||
std::vector<Subscript> inputs,
|
std::vector<Subscript> inputs,
|
||||||
const Subscript& output,
|
const Subscript& output,
|
||||||
std::unordered_map<char, int> dim_map,
|
std::unordered_map<char, ShapeElem> dim_map,
|
||||||
size_t cost_limit,
|
size_t cost_limit,
|
||||||
size_t memory_limit) {
|
size_t memory_limit) {
|
||||||
// Helper struct for building the greedy path
|
// Helper struct for building the greedy path
|
||||||
@ -457,7 +457,8 @@ array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) {
|
|||||||
}
|
}
|
||||||
Shape idx_shape(n_expand--, 1);
|
Shape idx_shape(n_expand--, 1);
|
||||||
idx_shape[0] = in.shape(axes.back());
|
idx_shape[0] = in.shape(axes.back());
|
||||||
auto idx = reshape(arange(in.shape(axes.back()), s), idx_shape, s);
|
auto idx = reshape(
|
||||||
|
arange(static_cast<ShapeElem>(in.shape(axes.back())), s), idx_shape, s);
|
||||||
for (int i = 0; i < v; ++i) {
|
for (int i = 0; i < v; ++i) {
|
||||||
indices.push_back(idx);
|
indices.push_back(idx);
|
||||||
}
|
}
|
||||||
@ -663,7 +664,7 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
|
|||||||
}
|
}
|
||||||
Subscript output(out_subscript, std::move(out_set));
|
Subscript output(out_subscript, std::move(out_set));
|
||||||
|
|
||||||
std::unordered_map<char, int> dim_map;
|
std::unordered_map<char, ShapeElem> dim_map;
|
||||||
std::vector<Subscript> inputs;
|
std::vector<Subscript> inputs;
|
||||||
for (int i = 0; i < in_subscripts.size(); ++i) {
|
for (int i = 0; i < in_subscripts.size(); ++i) {
|
||||||
auto& in = in_subscripts[i];
|
auto& in = in_subscripts[i];
|
||||||
@ -680,7 +681,7 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
|
|||||||
|
|
||||||
// Check repeat subscripts are valid
|
// Check repeat subscripts are valid
|
||||||
if (in_set.size() < in.size()) {
|
if (in_set.size() < in.size()) {
|
||||||
std::unordered_map<char, int> local_dims;
|
std::unordered_map<char, ShapeElem> local_dims;
|
||||||
for (int j = 0; j < in.size(); ++j) {
|
for (int j = 0; j < in.size(); ++j) {
|
||||||
auto dim = operands[i].shape(j);
|
auto dim = operands[i].shape(j);
|
||||||
auto inserted = local_dims.insert({in[j], dim});
|
auto inserted = local_dims.insert({in[j], dim});
|
||||||
|
@ -670,8 +670,7 @@ array scaled_dot_product_attention(
|
|||||||
supports_sdpa_full || supports_sdpa_vector;
|
supports_sdpa_full || supports_sdpa_vector;
|
||||||
|
|
||||||
if (implementation_supports_use_case) {
|
if (implementation_supports_use_case) {
|
||||||
auto out_shape =
|
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||||
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
|
|
||||||
return array(
|
return array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
final_type,
|
final_type,
|
||||||
|
@ -59,7 +59,7 @@ typedef std::variant<int, bool, Dtype> TemplateArg;
|
|||||||
|
|
||||||
typedef std::function<std::vector<array>(
|
typedef std::function<std::vector<array>(
|
||||||
const std::vector<array>&,
|
const std::vector<array>&,
|
||||||
const std::vector<std::vector<int>>&,
|
const std::vector<Shape>&,
|
||||||
const std::vector<Dtype>&,
|
const std::vector<Dtype>&,
|
||||||
std::tuple<int, int, int>,
|
std::tuple<int, int, int>,
|
||||||
std::tuple<int, int, int>,
|
std::tuple<int, int, int>,
|
||||||
|
@ -47,8 +47,8 @@ std::optional<Dtype> gguf_type_to_dtype(const uint32_t& gguf_type) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> get_shape(const gguf_tensor& tensor) {
|
Shape get_shape(const gguf_tensor& tensor) {
|
||||||
std::vector<int> shape;
|
Shape shape;
|
||||||
// The dimension order in GGML is the reverse of the order used in MLX.
|
// The dimension order in GGML is the reverse of the order used in MLX.
|
||||||
for (int i = tensor.ndim - 1; i >= 0; i--) {
|
for (int i = tensor.ndim - 1; i >= 0; i--) {
|
||||||
shape.push_back(tensor.dim[i]);
|
shape.push_back(tensor.dim[i]);
|
||||||
|
@ -12,7 +12,7 @@ extern "C" {
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::vector<int> get_shape(const gguf_tensor& tensor);
|
Shape get_shape(const gguf_tensor& tensor);
|
||||||
void gguf_load_quantized(
|
void gguf_load_quantized(
|
||||||
std::unordered_map<std::string, array>& a,
|
std::unordered_map<std::string, array>& a,
|
||||||
const gguf_tensor& tensor);
|
const gguf_tensor& tensor);
|
||||||
|
@ -109,7 +109,7 @@ void gguf_load_quantized(
|
|||||||
|
|
||||||
std::string name(tensor.name, tensor.namelen);
|
std::string name(tensor.name, tensor.namelen);
|
||||||
|
|
||||||
std::vector<int> shape = get_shape(tensor);
|
auto shape = get_shape(tensor);
|
||||||
const uint64_t weights_per_block = 32;
|
const uint64_t weights_per_block = 32;
|
||||||
if (shape[shape.size() - 1] % weights_per_block != 0) {
|
if (shape[shape.size() - 1] % weights_per_block != 0) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -118,7 +118,7 @@ void gguf_load_quantized(
|
|||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> weights_shape = shape;
|
auto weights_shape = shape;
|
||||||
weights_shape.back() /= (weights_per_byte * 4);
|
weights_shape.back() /= (weights_per_byte * 4);
|
||||||
auto w_nbytes = uint32.size() *
|
auto w_nbytes = uint32.size() *
|
||||||
std::accumulate(weights_shape.begin(),
|
std::accumulate(weights_shape.begin(),
|
||||||
|
@ -271,7 +271,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
|||||||
bool col_contiguous = header[34] == 'T';
|
bool col_contiguous = header[34] == 'T';
|
||||||
|
|
||||||
// Read array shape from header
|
// Read array shape from header
|
||||||
std::vector<int> shape;
|
Shape shape;
|
||||||
|
|
||||||
size_t st = header.find_last_of('(') + 1;
|
size_t st = header.find_last_of('(') + 1;
|
||||||
size_t ed = header.find_last_of(')');
|
size_t ed = header.find_last_of(')');
|
||||||
|
@ -219,15 +219,15 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
const auto n = a.shape(-1);
|
const auto n = a.shape(-1);
|
||||||
const auto rank = a.ndim();
|
const auto rank = a.ndim();
|
||||||
|
|
||||||
std::vector<int> u_shape = a.shape();
|
auto u_shape = a.shape();
|
||||||
u_shape[rank - 2] = m;
|
u_shape[rank - 2] = m;
|
||||||
u_shape[rank - 1] = m;
|
u_shape[rank - 1] = m;
|
||||||
|
|
||||||
std::vector<int> s_shape = a.shape();
|
auto s_shape = a.shape();
|
||||||
s_shape.pop_back();
|
s_shape.pop_back();
|
||||||
s_shape[rank - 2] = std::min(m, n);
|
s_shape[rank - 2] = std::min(m, n);
|
||||||
|
|
||||||
std::vector<int> vt_shape = a.shape();
|
auto vt_shape = a.shape();
|
||||||
vt_shape[rank - 2] = n;
|
vt_shape[rank - 2] = n;
|
||||||
vt_shape[rank - 1] = n;
|
vt_shape[rank - 1] = n;
|
||||||
|
|
||||||
@ -328,8 +328,8 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
array S = outs[1];
|
array S = outs[1];
|
||||||
array V = outs[2];
|
array V = outs[2];
|
||||||
|
|
||||||
std::vector<int> starts(a.ndim(), 0);
|
Shape starts(a.ndim(), 0);
|
||||||
std::vector<int> ends = a.shape();
|
auto ends = a.shape();
|
||||||
int i = a.ndim() - 2;
|
int i = a.ndim() - 2;
|
||||||
int j = a.ndim() - 1;
|
int j = a.ndim() - 1;
|
||||||
|
|
||||||
@ -479,7 +479,7 @@ array eigvalsh(
|
|||||||
std::string UPLO /* = "L" */,
|
std::string UPLO /* = "L" */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
validate_eigh(a, "[linalg::eigvalsh]");
|
validate_eigh(a, "[linalg::eigvalsh]");
|
||||||
std::vector<int> out_shape(a.shape().begin(), a.shape().end() - 1);
|
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
|
||||||
return array(
|
return array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
@ -493,7 +493,7 @@ std::pair<array, array> eigh(
|
|||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
validate_eigh(a, "[linalg::eigh]");
|
validate_eigh(a, "[linalg::eigh]");
|
||||||
auto out = array::make_arrays(
|
auto out = array::make_arrays(
|
||||||
{std::vector<int>(a.shape().begin(), a.shape().end() - 1), a.shape()},
|
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
|
||||||
{a.dtype(), a.dtype()},
|
{a.dtype(), a.dtype()},
|
||||||
std::make_shared<Eigh>(to_stream(s), UPLO, true),
|
std::make_shared<Eigh>(to_stream(s), UPLO, true),
|
||||||
{a});
|
{a});
|
||||||
|
42
mlx/ops.cpp
42
mlx/ops.cpp
@ -649,7 +649,7 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
|
|||||||
|
|
||||||
// Clamp to bounds
|
// Clamp to bounds
|
||||||
auto st = std::min(s, n - 1);
|
auto st = std::min(s, n - 1);
|
||||||
auto ed = std::max(-1, e);
|
auto ed = e > -1 ? e : -1;
|
||||||
|
|
||||||
start[i] = st;
|
start[i] = st;
|
||||||
stop[i] = ed > st ? st : ed;
|
stop[i] = ed > st ? st : ed;
|
||||||
@ -659,8 +659,8 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
// Clamp to bounds
|
// Clamp to bounds
|
||||||
auto st = std::max(0, std::min(s, n));
|
auto st = std::max(static_cast<ShapeElem>(0), std::min(s, n));
|
||||||
auto ed = std::max(0, std::min(e, n));
|
auto ed = std::max(static_cast<ShapeElem>(0), std::min(e, n));
|
||||||
|
|
||||||
start[i] = st;
|
start[i] = st;
|
||||||
stop[i] = ed < st ? st : ed;
|
stop[i] = ed < st ? st : ed;
|
||||||
@ -765,7 +765,7 @@ array slice_update(
|
|||||||
|
|
||||||
std::vector<array> split(
|
std::vector<array> split(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::vector<int>& indices,
|
const Shape& indices,
|
||||||
int axis,
|
int axis,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto ax = axis < 0 ? axis + a.ndim() : axis;
|
auto ax = axis < 0 ? axis + a.ndim() : axis;
|
||||||
@ -809,10 +809,8 @@ std::vector<array> split(
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> split(
|
std::vector<array>
|
||||||
const array& a,
|
split(const array& a, const Shape& indices, StreamOrDevice s /* = {} */) {
|
||||||
const std::vector<int>& indices,
|
|
||||||
StreamOrDevice s /* = {} */) {
|
|
||||||
return split(a, indices, 0, s);
|
return split(a, indices, 0, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -834,7 +832,7 @@ split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) {
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
auto split_size = q_and_r.quot;
|
auto split_size = q_and_r.quot;
|
||||||
std::vector<int> indices(num_splits - 1);
|
Shape indices(num_splits - 1);
|
||||||
for (int i = 0; i < indices.size(); ++i) {
|
for (int i = 0; i < indices.size(); ++i) {
|
||||||
indices[i] = (i + 1) * split_size;
|
indices[i] = (i + 1) * split_size;
|
||||||
}
|
}
|
||||||
@ -1104,7 +1102,7 @@ array edge_pad(
|
|||||||
/** Pad an array with a constant value */
|
/** Pad an array with a constant value */
|
||||||
array pad(
|
array pad(
|
||||||
const array& a,
|
const array& a,
|
||||||
const Shape& axes,
|
const std::vector<int>& axes,
|
||||||
const Shape& low_pad_size,
|
const Shape& low_pad_size,
|
||||||
const Shape& high_pad_size,
|
const Shape& high_pad_size,
|
||||||
const array& pad_value /*= array(0)*/,
|
const array& pad_value /*= array(0)*/,
|
||||||
@ -1904,9 +1902,11 @@ array min(
|
|||||||
|
|
||||||
array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
|
array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
|
||||||
int size = a.size();
|
int size = a.size();
|
||||||
auto result = argmin(reshape(a, {size}, s), 0, true, s);
|
auto result = argmin(flatten(a, s), 0, true, s);
|
||||||
if (keepdims) {
|
if (keepdims) {
|
||||||
result = reshape(result, std::vector<int>(a.shape().size(), 1), s);
|
std::vector<int> axes(a.ndim() - 1);
|
||||||
|
std::iota(axes.begin(), axes.end(), 0);
|
||||||
|
result = expand_dims(result, axes, s);
|
||||||
} else {
|
} else {
|
||||||
result = squeeze(result, s);
|
result = squeeze(result, s);
|
||||||
}
|
}
|
||||||
@ -1940,9 +1940,11 @@ array argmin(
|
|||||||
|
|
||||||
array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
|
array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
|
||||||
int size = a.size();
|
int size = a.size();
|
||||||
auto result = argmax(reshape(a, {size}, s), 0, true, s);
|
auto result = argmax(flatten(a, s), 0, true, s);
|
||||||
if (keepdims) {
|
if (keepdims) {
|
||||||
result = reshape(result, Shape(a.shape().size(), 1), s);
|
std::vector<int> axes(a.ndim() - 1);
|
||||||
|
std::iota(axes.begin(), axes.end(), 0);
|
||||||
|
result = expand_dims(result, axes, s);
|
||||||
} else {
|
} else {
|
||||||
result = squeeze(result, s);
|
result = squeeze(result, s);
|
||||||
}
|
}
|
||||||
@ -3238,8 +3240,8 @@ inline int dilate_size(int dim, int dil) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Shape conv_out_shape(
|
Shape conv_out_shape(
|
||||||
const std::vector<int>& in_shape,
|
const Shape& in_shape,
|
||||||
const std::vector<int>& wt_shape,
|
const Shape& wt_shape,
|
||||||
const std::vector<int>& strides,
|
const std::vector<int>& strides,
|
||||||
const std::vector<int>& pads_lo,
|
const std::vector<int>& pads_lo,
|
||||||
const std::vector<int>& pads_hi,
|
const std::vector<int>& pads_hi,
|
||||||
@ -4329,16 +4331,16 @@ array diagonal(
|
|||||||
"[diagonal] axis1 and axis2 cannot be the same axis");
|
"[diagonal] axis1 and axis2 cannot be the same axis");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto off1 = std::max(-offset, 0);
|
ShapeElem off1 = std::max(-offset, 0);
|
||||||
auto off2 = std::max(offset, 0);
|
ShapeElem off2 = std::max(offset, 0);
|
||||||
|
|
||||||
auto diag_size = std::min(a.shape(ax1) - off1, a.shape(ax2) - off2);
|
auto diag_size = std::min(a.shape(ax1) - off1, a.shape(ax2) - off2);
|
||||||
diag_size = std::max(diag_size, 0);
|
diag_size = diag_size < 0 ? 0 : diag_size;
|
||||||
|
|
||||||
std::vector<array> indices = {
|
std::vector<array> indices = {
|
||||||
arange(off1, off1 + diag_size, s), arange(off2, off2 + diag_size, s)};
|
arange(off1, off1 + diag_size, s), arange(off2, off2 + diag_size, s)};
|
||||||
|
|
||||||
std::vector<int> slice_sizes = a.shape();
|
Shape slice_sizes = a.shape();
|
||||||
slice_sizes[ax1] = 1;
|
slice_sizes[ax1] = 1;
|
||||||
slice_sizes[ax2] = 1;
|
slice_sizes[ax2] = 1;
|
||||||
|
|
||||||
|
19
mlx/ops.h
19
mlx/ops.h
@ -189,13 +189,10 @@ array slice_update(
|
|||||||
std::vector<array>
|
std::vector<array>
|
||||||
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
|
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
|
||||||
std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
|
std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
|
||||||
std::vector<array> split(
|
|
||||||
const array& a,
|
|
||||||
const std::vector<int>& indices,
|
|
||||||
int axis,
|
|
||||||
StreamOrDevice s = {});
|
|
||||||
std::vector<array>
|
std::vector<array>
|
||||||
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
|
split(const array& a, const Shape& indices, int axis, StreamOrDevice s = {});
|
||||||
|
std::vector<array>
|
||||||
|
split(const array& a, const Shape& indices, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** A vector of coordinate arrays from coordinate vectors. */
|
/** A vector of coordinate arrays from coordinate vectors. */
|
||||||
std::vector<array> meshgrid(
|
std::vector<array> meshgrid(
|
||||||
@ -253,8 +250,8 @@ array moveaxis(
|
|||||||
array pad(
|
array pad(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const std::vector<int>& low_pad_size,
|
const Shape& low_pad_size,
|
||||||
const std::vector<int>& high_pad_size,
|
const Shape& high_pad_size,
|
||||||
const array& pad_value = array(0),
|
const array& pad_value = array(0),
|
||||||
const std::string mode = "constant",
|
const std::string mode = "constant",
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
@ -1453,7 +1450,11 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
|
|||||||
array roll(const array& a, int shift, StreamOrDevice s = {});
|
array roll(const array& a, int shift, StreamOrDevice s = {});
|
||||||
array roll(const array& a, const Shape& shift, StreamOrDevice s = {});
|
array roll(const array& a, const Shape& shift, StreamOrDevice s = {});
|
||||||
array roll(const array& a, int shift, int axis, StreamOrDevice s = {});
|
array roll(const array& a, int shift, int axis, StreamOrDevice s = {});
|
||||||
array roll(const array& a, int shift, const Shape& axes, StreamOrDevice s = {});
|
array roll(
|
||||||
|
const array& a,
|
||||||
|
int shift,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {});
|
array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {});
|
||||||
array roll(
|
array roll(
|
||||||
const array& a,
|
const array& a,
|
||||||
|
@ -817,10 +817,10 @@ std::vector<array> Concatenate::vjp(
|
|||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>&) {
|
const std::vector<array>&) {
|
||||||
auto& cotan = cotangents[0];
|
auto& cotan = cotangents[0];
|
||||||
std::vector<int> start(cotan.ndim(), 0);
|
Shape start(cotan.ndim(), 0);
|
||||||
std::vector<int> stop = cotan.shape();
|
Shape stop = cotan.shape();
|
||||||
|
|
||||||
std::vector<int> sizes;
|
Shape sizes;
|
||||||
sizes.push_back(0);
|
sizes.push_back(0);
|
||||||
for (auto& p : primals) {
|
for (auto& p : primals) {
|
||||||
sizes.push_back(p.shape(axis_));
|
sizes.push_back(p.shape(axis_));
|
||||||
@ -956,9 +956,9 @@ array conv_weight_backward_patches(
|
|||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
// Resolve Padded input shapes and strides
|
// Resolve Padded input shapes and strides
|
||||||
std::vector<int> padding_starts(in.ndim(), 0);
|
Shape padding_starts(in.ndim(), 0);
|
||||||
std::vector<int> padding_ends = in.shape();
|
auto padding_ends = in.shape();
|
||||||
std::vector<int> in_padded_shape = in.shape();
|
auto in_padded_shape = in.shape();
|
||||||
|
|
||||||
// padded shape
|
// padded shape
|
||||||
for (int i = 1; i < in.ndim() - 1; i++) {
|
for (int i = 1; i < in.ndim() - 1; i++) {
|
||||||
@ -976,8 +976,9 @@ array conv_weight_backward_patches(
|
|||||||
// Pad input
|
// Pad input
|
||||||
std::vector<int> padded_axes(in.ndim() - 2, 0);
|
std::vector<int> padded_axes(in.ndim() - 2, 0);
|
||||||
std::iota(padded_axes.begin(), padded_axes.end(), 1);
|
std::iota(padded_axes.begin(), padded_axes.end(), 1);
|
||||||
|
Shape padding_(padding.begin(), padding.end());
|
||||||
auto in_padded = pad(
|
auto in_padded = pad(
|
||||||
in, padded_axes, padding, padding, array(0, in.dtype()), "constant", s);
|
in, padded_axes, padding_, padding_, array(0, in.dtype()), "constant", s);
|
||||||
|
|
||||||
// Resolve strided patches
|
// Resolve strided patches
|
||||||
|
|
||||||
@ -1797,7 +1798,7 @@ std::vector<array> FFT::vjp(
|
|||||||
std::vector<int> axes(axes_.begin(), axes_.end());
|
std::vector<int> axes(axes_.begin(), axes_.end());
|
||||||
if (real_ && inverse_) {
|
if (real_ && inverse_) {
|
||||||
auto out = fft::fftn(cotangents[0], axes, stream());
|
auto out = fft::fftn(cotangents[0], axes, stream());
|
||||||
auto start = std::vector<int>(out.ndim(), 0);
|
auto start = Shape(out.ndim(), 0);
|
||||||
auto stop = in.shape();
|
auto stop = in.shape();
|
||||||
out = slice(out, start, stop, stream());
|
out = slice(out, start, stop, stream());
|
||||||
auto mask_shape = out.shape();
|
auto mask_shape = out.shape();
|
||||||
@ -1809,7 +1810,7 @@ std::vector<array> FFT::vjp(
|
|||||||
mask = concatenate({pad, mask, pad}, axes_.back(), stream());
|
mask = concatenate({pad, mask, pad}, axes_.back(), stream());
|
||||||
return {multiply(mask, out, stream())};
|
return {multiply(mask, out, stream())};
|
||||||
} else if (real_) {
|
} else if (real_) {
|
||||||
std::vector<int> n;
|
Shape n;
|
||||||
for (auto ax : axes_) {
|
for (auto ax : axes_) {
|
||||||
n.push_back(in.shape()[ax]);
|
n.push_back(in.shape()[ax]);
|
||||||
}
|
}
|
||||||
@ -1934,10 +1935,11 @@ std::pair<std::vector<array>, std::vector<int>> Gather::vmap(
|
|||||||
}
|
}
|
||||||
if (indices_vmapped) {
|
if (indices_vmapped) {
|
||||||
// Make a new index array for the vmapped dimension
|
// Make a new index array for the vmapped dimension
|
||||||
auto vmap_inds = arange(0, src.shape(axes[0]), stream());
|
auto vmap_inds =
|
||||||
|
arange(static_cast<ShapeElem>(0), src.shape(axes[0]), stream());
|
||||||
// Reshape it so it broadcasts with other index arrays
|
// Reshape it so it broadcasts with other index arrays
|
||||||
{
|
{
|
||||||
auto shape = std::vector<int>(idx_dims, 1);
|
auto shape = Shape(idx_dims, 1);
|
||||||
shape[out_ax] = vmap_inds.size();
|
shape[out_ax] = vmap_inds.size();
|
||||||
vmap_inds = reshape(vmap_inds, std::move(shape), stream());
|
vmap_inds = reshape(vmap_inds, std::move(shape), stream());
|
||||||
}
|
}
|
||||||
@ -2628,8 +2630,8 @@ std::vector<array> Pad::vjp(
|
|||||||
assert(argnums.size() == 1 && argnums[0] == 0);
|
assert(argnums.size() == 1 && argnums[0] == 0);
|
||||||
|
|
||||||
auto& cotan = cotangents[0];
|
auto& cotan = cotangents[0];
|
||||||
std::vector<int> start(cotan.ndim(), 0);
|
Shape start(cotan.ndim(), 0);
|
||||||
std::vector<int> stop = cotan.shape();
|
auto stop = cotan.shape();
|
||||||
|
|
||||||
for (auto i : axes_) {
|
for (auto i : axes_) {
|
||||||
start[i] = low_pad_size_[i];
|
start[i] = low_pad_size_[i];
|
||||||
@ -3019,7 +3021,7 @@ std::vector<array> Reduce::vjp(
|
|||||||
const std::vector<array>& outputs) {
|
const std::vector<array>& outputs) {
|
||||||
auto in = primals[0];
|
auto in = primals[0];
|
||||||
|
|
||||||
std::vector<int> shape = in.shape();
|
auto shape = in.shape();
|
||||||
for (auto ax : axes_) {
|
for (auto ax : axes_) {
|
||||||
shape[ax] = 1;
|
shape[ax] = 1;
|
||||||
}
|
}
|
||||||
@ -3044,7 +3046,7 @@ std::vector<array> Reduce::vjp(
|
|||||||
if (axes_.size() > 1) {
|
if (axes_.size() > 1) {
|
||||||
std::vector<int> transpose_to;
|
std::vector<int> transpose_to;
|
||||||
std::vector<int> transpose_back;
|
std::vector<int> transpose_back;
|
||||||
std::vector<int> shape_flat;
|
Shape shape_flat;
|
||||||
{
|
{
|
||||||
// Find the transpose needed to move axes_ to the back and the shape
|
// Find the transpose needed to move axes_ to the back and the shape
|
||||||
// except the reduced over axes.
|
// except the reduced over axes.
|
||||||
@ -3422,7 +3424,7 @@ std::pair<std::vector<array>, std::vector<int>> Scatter::vmap(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto vmap_inds = arange(vmap_size, inputs[1].dtype(), stream());
|
auto vmap_inds = arange(vmap_size, inputs[1].dtype(), stream());
|
||||||
auto vmap_inds_shape = std::vector<int>(inputs[1].ndim(), 1);
|
auto vmap_inds_shape = Shape(inputs[1].ndim(), 1);
|
||||||
vmap_inds_shape[0] = vmap_inds.size();
|
vmap_inds_shape[0] = vmap_inds.size();
|
||||||
vmap_inds = reshape(vmap_inds, std::move(vmap_inds_shape), stream());
|
vmap_inds = reshape(vmap_inds, std::move(vmap_inds_shape), stream());
|
||||||
inputs.insert(
|
inputs.insert(
|
||||||
@ -3607,7 +3609,7 @@ std::vector<array> Slice::vjp(
|
|||||||
// Transpose and reshape cotangents
|
// Transpose and reshape cotangents
|
||||||
auto cotan = cotangents[0];
|
auto cotan = cotangents[0];
|
||||||
if (!ind_axes.empty()) {
|
if (!ind_axes.empty()) {
|
||||||
std::vector<int> cotan_shape;
|
Shape cotan_shape;
|
||||||
for (auto ax : ind_axes) {
|
for (auto ax : ind_axes) {
|
||||||
cotan_shape.push_back(cotan.shape(ax));
|
cotan_shape.push_back(cotan.shape(ax));
|
||||||
}
|
}
|
||||||
@ -3626,7 +3628,7 @@ std::vector<array> Slice::vjp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Make indices broadcastable
|
// Make indices broadcastable
|
||||||
std::vector<int> inds_shape(inds.size(), 1);
|
Shape inds_shape(inds.size(), 1);
|
||||||
for (int i = 0; i < inds.size(); ++i) {
|
for (int i = 0; i < inds.size(); ++i) {
|
||||||
inds_shape[i] = inds[i].size();
|
inds_shape[i] = inds[i].size();
|
||||||
inds[i] = reshape(inds[i], inds_shape, stream());
|
inds[i] = reshape(inds[i], inds_shape, stream());
|
||||||
@ -4184,7 +4186,7 @@ std::vector<array> BlockMaskedMM::vjp(
|
|||||||
// Slice mask
|
// Slice mask
|
||||||
mask_reshape[mask_ndim - 2] = Y;
|
mask_reshape[mask_ndim - 2] = Y;
|
||||||
mask_reshape[mask_ndim - 1] = X;
|
mask_reshape[mask_ndim - 1] = X;
|
||||||
mask = slice(mask, std::vector<int>(mask_ndim, 0), mask_reshape, stream());
|
mask = slice(mask, Shape(mask_ndim, 0), mask_reshape, stream());
|
||||||
|
|
||||||
return mask;
|
return mask;
|
||||||
};
|
};
|
||||||
@ -4202,7 +4204,7 @@ std::vector<array> BlockMaskedMM::vjp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Reshape
|
// Reshape
|
||||||
std::vector<int> r_reshape(r.shape().begin(), r.shape().end() - 2);
|
Shape r_reshape(r.shape().begin(), r.shape().end() - 2);
|
||||||
r_reshape.push_back(r.shape(-2) / block_size_);
|
r_reshape.push_back(r.shape(-2) / block_size_);
|
||||||
r_reshape.push_back(block_size_);
|
r_reshape.push_back(block_size_);
|
||||||
r_reshape.push_back(r.shape(-1) / block_size_);
|
r_reshape.push_back(r.shape(-1) / block_size_);
|
||||||
@ -4492,7 +4494,7 @@ std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
|
|||||||
}
|
}
|
||||||
|
|
||||||
array out = array(
|
array out = array(
|
||||||
std::vector<int>{},
|
{},
|
||||||
dtype_,
|
dtype_,
|
||||||
std::make_shared<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
|
std::make_shared<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
|
||||||
inputs);
|
inputs);
|
||||||
|
@ -1088,10 +1088,7 @@ class Full : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Gather : public UnaryPrimitive {
|
class Gather : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Gather(
|
explicit Gather(Stream stream, std::vector<int> axes, Shape slice_sizes)
|
||||||
Stream stream,
|
|
||||||
std::vector<int> axes,
|
|
||||||
std::vector<int> slice_sizes)
|
|
||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
axes_(std::move(axes)),
|
axes_(std::move(axes)),
|
||||||
slice_sizes_(std::move(slice_sizes)) {}
|
slice_sizes_(std::move(slice_sizes)) {}
|
||||||
@ -1108,7 +1105,7 @@ class Gather : public UnaryPrimitive {
|
|||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
std::vector<int> axes_;
|
std::vector<int> axes_;
|
||||||
std::vector<int> slice_sizes_;
|
Shape slice_sizes_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Greater : public UnaryPrimitive {
|
class Greater : public UnaryPrimitive {
|
||||||
@ -1503,8 +1500,8 @@ class Pad : public UnaryPrimitive {
|
|||||||
explicit Pad(
|
explicit Pad(
|
||||||
Stream stream,
|
Stream stream,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const std::vector<int>& low_pad_size,
|
const Shape& low_pad_size,
|
||||||
const std::vector<int>& high_pad_size)
|
const Shape& high_pad_size)
|
||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
axes_(axes),
|
axes_(axes),
|
||||||
low_pad_size_(low_pad_size),
|
low_pad_size_(low_pad_size),
|
||||||
@ -1520,8 +1517,8 @@ class Pad : public UnaryPrimitive {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int> axes_;
|
std::vector<int> axes_;
|
||||||
std::vector<int> low_pad_size_;
|
Shape low_pad_size_;
|
||||||
std::vector<int> high_pad_size_;
|
Shape high_pad_size_;
|
||||||
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
@ -1903,9 +1900,9 @@ class Slice : public UnaryPrimitive {
|
|||||||
public:
|
public:
|
||||||
explicit Slice(
|
explicit Slice(
|
||||||
Stream stream,
|
Stream stream,
|
||||||
const std::vector<int>& start_indices,
|
const Shape& start_indices,
|
||||||
const std::vector<int>& end_indices,
|
const Shape& end_indices,
|
||||||
const std::vector<int>& strides)
|
const Shape& strides)
|
||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
start_indices_(start_indices),
|
start_indices_(start_indices),
|
||||||
end_indices_(end_indices),
|
end_indices_(end_indices),
|
||||||
@ -1920,9 +1917,9 @@ class Slice : public UnaryPrimitive {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int> start_indices_;
|
Shape start_indices_;
|
||||||
std::vector<int> end_indices_;
|
Shape end_indices_;
|
||||||
std::vector<int> strides_;
|
Shape strides_;
|
||||||
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
@ -1931,9 +1928,9 @@ class SliceUpdate : public UnaryPrimitive {
|
|||||||
public:
|
public:
|
||||||
explicit SliceUpdate(
|
explicit SliceUpdate(
|
||||||
Stream stream,
|
Stream stream,
|
||||||
const std::vector<int>& start_indices,
|
const Shape& start_indices,
|
||||||
const std::vector<int>& end_indices,
|
const Shape& end_indices,
|
||||||
const std::vector<int>& strides)
|
const Shape& strides)
|
||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
start_indices_(start_indices),
|
start_indices_(start_indices),
|
||||||
end_indices_(end_indices),
|
end_indices_(end_indices),
|
||||||
@ -1948,9 +1945,9 @@ class SliceUpdate : public UnaryPrimitive {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int> start_indices_;
|
Shape start_indices_;
|
||||||
std::vector<int> end_indices_;
|
Shape end_indices_;
|
||||||
std::vector<int> strides_;
|
Shape strides_;
|
||||||
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
@ -1997,7 +1994,7 @@ class Sort : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Split : public Primitive {
|
class Split : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit Split(Stream stream, const std::vector<int>& indices, int axis)
|
explicit Split(Stream stream, const Shape& indices, int axis)
|
||||||
: Primitive(stream), indices_(indices), axis_(axis) {}
|
: Primitive(stream), indices_(indices), axis_(axis) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
@ -2013,7 +2010,7 @@ class Split : public Primitive {
|
|||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||||
|
|
||||||
std::vector<int> indices_;
|
Shape indices_;
|
||||||
int axis_;
|
int axis_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -296,7 +296,7 @@ std::ostream& operator<<(std::ostream& os, array a) {
|
|||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const Shape& v) {
|
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
|
||||||
os << "(";
|
os << "(";
|
||||||
for (int i = 0; i < v.size(); ++i) {
|
for (int i = 0; i < v.size(); ++i) {
|
||||||
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||||
@ -305,7 +305,7 @@ std::ostream& operator<<(std::ostream& os, const Shape& v) {
|
|||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const Strides& v) {
|
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
|
||||||
os << "(";
|
os << "(";
|
||||||
for (int i = 0; i < v.size(); ++i) {
|
for (int i = 0; i < v.size(); ++i) {
|
||||||
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||||
|
@ -77,8 +77,8 @@ std::ostream& operator<<(std::ostream& os, const Stream& s);
|
|||||||
std::ostream& operator<<(std::ostream& os, const Dtype& d);
|
std::ostream& operator<<(std::ostream& os, const Dtype& d);
|
||||||
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
|
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
|
||||||
std::ostream& operator<<(std::ostream& os, array a);
|
std::ostream& operator<<(std::ostream& os, array a);
|
||||||
std::ostream& operator<<(std::ostream& os, const Shape& v);
|
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
|
||||||
std::ostream& operator<<(std::ostream& os, const Strides& v);
|
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
|
||||||
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
|
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
|
||||||
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
|
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
|
||||||
}
|
}
|
||||||
|
@ -889,13 +889,13 @@ void init_array(nb::module_& m) {
|
|||||||
.def(
|
.def(
|
||||||
"reshape",
|
"reshape",
|
||||||
[](const mx::array& a, nb::args shape_, mx::StreamOrDevice s) {
|
[](const mx::array& a, nb::args shape_, mx::StreamOrDevice s) {
|
||||||
std::vector<int> shape;
|
mx::Shape shape;
|
||||||
if (!nb::isinstance<int>(shape_[0])) {
|
if (!nb::isinstance<int>(shape_[0])) {
|
||||||
shape = nb::cast<std::vector<int>>(shape_[0]);
|
shape = nb::cast<mx::Shape>(shape_[0]);
|
||||||
} else {
|
} else {
|
||||||
shape = nb::cast<std::vector<int>>(shape_);
|
shape = nb::cast<mx::Shape>(shape_);
|
||||||
}
|
}
|
||||||
return mx::reshape(a, shape, s);
|
return mx::reshape(a, std::move(shape), s);
|
||||||
},
|
},
|
||||||
"shape"_a,
|
"shape"_a,
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1182,14 +1182,14 @@ void init_array(nb::module_& m) {
|
|||||||
.def(
|
.def(
|
||||||
"split",
|
"split",
|
||||||
[](const mx::array& a,
|
[](const mx::array& a,
|
||||||
const std::variant<int, std::vector<int>>& indices_or_sections,
|
const std::variant<int, mx::Shape>& indices_or_sections,
|
||||||
int axis,
|
int axis,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
|
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
|
||||||
return mx::split(a, *pv, axis, s);
|
return mx::split(a, *pv, axis, s);
|
||||||
} else {
|
} else {
|
||||||
return mx::split(
|
return mx::split(
|
||||||
a, std::get<std::vector<int>>(indices_or_sections), axis, s);
|
a, std::get<mx::Shape>(indices_or_sections), axis, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"indices_or_sections"_a,
|
"indices_or_sections"_a,
|
||||||
|
@ -181,7 +181,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
return nb::cpp_function(
|
return nb::cpp_function(
|
||||||
[kernel = std::move(kernel)](
|
[kernel = std::move(kernel)](
|
||||||
const std::vector<ScalarOrArray>& inputs_,
|
const std::vector<ScalarOrArray>& inputs_,
|
||||||
const std::vector<std::vector<int>>& output_shapes,
|
const std::vector<mx::Shape>& output_shapes,
|
||||||
const std::vector<mx::Dtype>& output_dtypes,
|
const std::vector<mx::Dtype>& output_dtypes,
|
||||||
std::tuple<int, int, int> grid,
|
std::tuple<int, int, int> grid,
|
||||||
std::tuple<int, int, int> threadgroup,
|
std::tuple<int, int, int> threadgroup,
|
||||||
|
@ -79,7 +79,7 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
m.def(
|
m.def(
|
||||||
"fft2",
|
"fft2",
|
||||||
[](const mx::array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<std::vector<int>>& n,
|
const std::optional<mx::Shape>& n,
|
||||||
const std::optional<std::vector<int>>& axes,
|
const std::optional<std::vector<int>>& axes,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (axes.has_value() && n.has_value()) {
|
if (axes.has_value() && n.has_value()) {
|
||||||
@ -115,7 +115,7 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
m.def(
|
m.def(
|
||||||
"ifft2",
|
"ifft2",
|
||||||
[](const mx::array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<std::vector<int>>& n,
|
const std::optional<mx::Shape>& n,
|
||||||
const std::optional<std::vector<int>>& axes,
|
const std::optional<std::vector<int>>& axes,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (axes.has_value() && n.has_value()) {
|
if (axes.has_value() && n.has_value()) {
|
||||||
@ -151,7 +151,7 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
m.def(
|
m.def(
|
||||||
"fftn",
|
"fftn",
|
||||||
[](const mx::array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<std::vector<int>>& n,
|
const std::optional<mx::Shape>& n,
|
||||||
const std::optional<std::vector<int>>& axes,
|
const std::optional<std::vector<int>>& axes,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (axes.has_value() && n.has_value()) {
|
if (axes.has_value() && n.has_value()) {
|
||||||
@ -188,7 +188,7 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
m.def(
|
m.def(
|
||||||
"ifftn",
|
"ifftn",
|
||||||
[](const mx::array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<std::vector<int>>& n,
|
const std::optional<mx::Shape>& n,
|
||||||
const std::optional<std::vector<int>>& axes,
|
const std::optional<std::vector<int>>& axes,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (axes.has_value() && n.has_value()) {
|
if (axes.has_value() && n.has_value()) {
|
||||||
@ -294,7 +294,7 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
m.def(
|
m.def(
|
||||||
"rfft2",
|
"rfft2",
|
||||||
[](const mx::array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<std::vector<int>>& n,
|
const std::optional<mx::Shape>& n,
|
||||||
const std::optional<std::vector<int>>& axes,
|
const std::optional<std::vector<int>>& axes,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (axes.has_value() && n.has_value()) {
|
if (axes.has_value() && n.has_value()) {
|
||||||
@ -336,7 +336,7 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
m.def(
|
m.def(
|
||||||
"irfft2",
|
"irfft2",
|
||||||
[](const mx::array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<std::vector<int>>& n,
|
const std::optional<mx::Shape>& n,
|
||||||
const std::optional<std::vector<int>>& axes,
|
const std::optional<std::vector<int>>& axes,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (axes.has_value() && n.has_value()) {
|
if (axes.has_value() && n.has_value()) {
|
||||||
@ -378,7 +378,7 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
m.def(
|
m.def(
|
||||||
"rfftn",
|
"rfftn",
|
||||||
[](const mx::array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<std::vector<int>>& n,
|
const std::optional<mx::Shape>& n,
|
||||||
const std::optional<std::vector<int>>& axes,
|
const std::optional<std::vector<int>>& axes,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (axes.has_value() && n.has_value()) {
|
if (axes.has_value() && n.has_value()) {
|
||||||
@ -420,7 +420,7 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
m.def(
|
m.def(
|
||||||
"irfftn",
|
"irfftn",
|
||||||
[](const mx::array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<std::vector<int>>& n,
|
const std::optional<mx::Shape>& n,
|
||||||
const std::optional<std::vector<int>>& axes,
|
const std::optional<std::vector<int>>& axes,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (axes.has_value() && n.has_value()) {
|
if (axes.has_value() && n.has_value()) {
|
||||||
|
@ -25,9 +25,9 @@ int get_slice_int(nb::object obj, int default_val) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void get_slice_params(
|
void get_slice_params(
|
||||||
int& starts,
|
mx::ShapeElem& starts,
|
||||||
int& ends,
|
mx::ShapeElem& ends,
|
||||||
int& strides,
|
mx::ShapeElem& strides,
|
||||||
const nb::slice& in_slice,
|
const nb::slice& in_slice,
|
||||||
int axis_size) {
|
int axis_size) {
|
||||||
// Following numpy's convention
|
// Following numpy's convention
|
||||||
@ -68,9 +68,9 @@ mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) {
|
|||||||
return src;
|
return src;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> starts(src.ndim(), 0);
|
mx::Shape starts(src.ndim(), 0);
|
||||||
std::vector<int> ends = src.shape();
|
auto ends = src.shape();
|
||||||
std::vector<int> strides(src.ndim(), 1);
|
mx::Shape strides(src.ndim(), 1);
|
||||||
|
|
||||||
// Check and update slice params
|
// Check and update slice params
|
||||||
get_slice_params(starts[0], ends[0], strides[0], in_slice, ends[0]);
|
get_slice_params(starts[0], ends[0], strides[0], in_slice, ends[0]);
|
||||||
@ -119,7 +119,7 @@ mx::array mlx_gather_nd(
|
|||||||
auto& idx = indices[i];
|
auto& idx = indices[i];
|
||||||
|
|
||||||
if (nb::isinstance<nb::slice>(idx)) {
|
if (nb::isinstance<nb::slice>(idx)) {
|
||||||
int start, end, stride;
|
mx::ShapeElem start, end, stride;
|
||||||
get_slice_params(
|
get_slice_params(
|
||||||
start, end, stride, nb::cast<nb::slice>(idx), src.shape(i));
|
start, end, stride, nb::cast<nb::slice>(idx), src.shape(i));
|
||||||
|
|
||||||
@ -168,7 +168,7 @@ mx::array mlx_gather_nd(
|
|||||||
// Do the gather
|
// Do the gather
|
||||||
std::vector<int> axes(indices.size());
|
std::vector<int> axes(indices.size());
|
||||||
std::iota(axes.begin(), axes.end(), 0);
|
std::iota(axes.begin(), axes.end(), 0);
|
||||||
std::vector<int> slice_sizes = src.shape();
|
auto slice_sizes = src.shape();
|
||||||
std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
|
std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
|
||||||
src = gather(src, gather_indices, axes, slice_sizes);
|
src = gather(src, gather_indices, axes, slice_sizes);
|
||||||
|
|
||||||
@ -179,9 +179,7 @@ mx::array mlx_gather_nd(
|
|||||||
return mx::squeeze(src, axes);
|
return mx::squeeze(src, axes);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto mlx_expand_ellipsis(
|
auto mlx_expand_ellipsis(const mx::Shape& shape, const nb::tuple& entries) {
|
||||||
const std::vector<int>& shape,
|
|
||||||
const nb::tuple& entries) {
|
|
||||||
std::vector<nb::object> indices;
|
std::vector<nb::object> indices;
|
||||||
|
|
||||||
// Go over all entries and note the position of ellipsis
|
// Go over all entries and note the position of ellipsis
|
||||||
@ -230,7 +228,8 @@ auto mlx_expand_ellipsis(
|
|||||||
for (int axis = non_none_indices_before;
|
for (int axis = non_none_indices_before;
|
||||||
axis < shape.size() - non_none_indices_after;
|
axis < shape.size() - non_none_indices_after;
|
||||||
axis++) {
|
axis++) {
|
||||||
indices.push_back(nb::slice(0, shape[axis], 1));
|
indices.push_back(
|
||||||
|
nb::slice(mx::ShapeElem{0}, shape[axis], mx::ShapeElem{1}));
|
||||||
non_none_indices++;
|
non_none_indices++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -371,9 +370,9 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
|
|||||||
|
|
||||||
// Slice handling
|
// Slice handling
|
||||||
{
|
{
|
||||||
std::vector<int> starts(src.ndim(), 0);
|
mx::Shape starts(src.ndim(), 0);
|
||||||
std::vector<int> ends = src.shape();
|
auto ends = src.shape();
|
||||||
std::vector<int> strides(src.ndim(), 1);
|
mx::Shape strides(src.ndim(), 1);
|
||||||
int axis = 0;
|
int axis = 0;
|
||||||
for (auto& idx : remaining_indices) {
|
for (auto& idx : remaining_indices) {
|
||||||
if (!idx.is_none()) {
|
if (!idx.is_none()) {
|
||||||
@ -461,8 +460,7 @@ mlx_scatter_args_int(
|
|||||||
int s = 0;
|
int s = 0;
|
||||||
for (; s < update.ndim() && update.shape(s) == 1; s++)
|
for (; s < update.ndim() && update.shape(s) == 1; s++)
|
||||||
;
|
;
|
||||||
auto up_shape =
|
auto up_shape = mx::Shape(update.shape().begin() + s, update.shape().end());
|
||||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
|
||||||
auto shape = src.shape();
|
auto shape = src.shape();
|
||||||
shape[0] = 1;
|
shape[0] = 1;
|
||||||
|
|
||||||
@ -521,9 +519,9 @@ mlx_scatter_args_slice(
|
|||||||
{}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}};
|
{}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}};
|
||||||
}
|
}
|
||||||
|
|
||||||
int start = 0;
|
mx::ShapeElem start = 0;
|
||||||
int end = src.shape(0);
|
auto end = src.shape(0);
|
||||||
int stride = 1;
|
mx::ShapeElem stride = 1;
|
||||||
|
|
||||||
// Check and update slice params
|
// Check and update slice params
|
||||||
get_slice_params(start, end, stride, in_slice, end);
|
get_slice_params(start, end, stride, in_slice, end);
|
||||||
@ -645,7 +643,7 @@ mlx_scatter_args_nd(
|
|||||||
for (int i = 0; i < indices.size(); ++i) {
|
for (int i = 0; i < indices.size(); ++i) {
|
||||||
auto& pyidx = indices[i];
|
auto& pyidx = indices[i];
|
||||||
if (nb::isinstance<nb::slice>(pyidx)) {
|
if (nb::isinstance<nb::slice>(pyidx)) {
|
||||||
int start, end, stride;
|
mx::ShapeElem start, end, stride;
|
||||||
auto axis_size = src.shape(ax++);
|
auto axis_size = src.shape(ax++);
|
||||||
get_slice_params(
|
get_slice_params(
|
||||||
start, end, stride, nb::cast<nb::slice>(pyidx), axis_size);
|
start, end, stride, nb::cast<nb::slice>(pyidx), axis_size);
|
||||||
@ -654,7 +652,7 @@ mlx_scatter_args_nd(
|
|||||||
start = (start < 0) ? start + axis_size : start;
|
start = (start < 0) ? start + axis_size : start;
|
||||||
end = (end < 0) ? end + axis_size : end;
|
end = (end < 0) ? end + axis_size : end;
|
||||||
|
|
||||||
std::vector<int> idx_shape(idx_ndim, 1);
|
mx::Shape idx_shape(idx_ndim, 1);
|
||||||
|
|
||||||
// If it's a simple slice, we only need to add the start index
|
// If it's a simple slice, we only need to add the start index
|
||||||
if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {
|
if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {
|
||||||
|
@ -1571,15 +1571,14 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"full",
|
"full",
|
||||||
[](const std::variant<int, std::vector<int>>& shape,
|
[](const std::variant<int, mx::Shape>& shape,
|
||||||
const ScalarOrArray& vals,
|
const ScalarOrArray& vals,
|
||||||
std::optional<mx::Dtype> dtype,
|
std::optional<mx::Dtype> dtype,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (auto pv = std::get_if<int>(&shape); pv) {
|
if (auto pv = std::get_if<int>(&shape); pv) {
|
||||||
return mx::full({*pv}, to_array(vals, dtype), s);
|
return mx::full({*pv}, to_array(vals, dtype), s);
|
||||||
} else {
|
} else {
|
||||||
return mx::full(
|
return mx::full(std::get<mx::Shape>(shape), to_array(vals, dtype), s);
|
||||||
std::get<std::vector<int>>(shape), to_array(vals, dtype), s);
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"shape"_a,
|
"shape"_a,
|
||||||
@ -1606,14 +1605,14 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"zeros",
|
"zeros",
|
||||||
[](const std::variant<int, std::vector<int>>& shape,
|
[](const std::variant<int, mx::Shape>& shape,
|
||||||
std::optional<mx::Dtype> dtype,
|
std::optional<mx::Dtype> dtype,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
auto t = dtype.value_or(mx::float32);
|
auto t = dtype.value_or(mx::float32);
|
||||||
if (auto pv = std::get_if<int>(&shape); pv) {
|
if (auto pv = std::get_if<int>(&shape); pv) {
|
||||||
return mx::zeros({*pv}, t, s);
|
return mx::zeros({*pv}, t, s);
|
||||||
} else {
|
} else {
|
||||||
return mx::zeros(std::get<std::vector<int>>(shape), t, s);
|
return mx::zeros(std::get<mx::Shape>(shape), t, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"shape"_a,
|
"shape"_a,
|
||||||
@ -1652,14 +1651,14 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"ones",
|
"ones",
|
||||||
[](const std::variant<int, std::vector<int>>& shape,
|
[](const std::variant<int, mx::Shape>& shape,
|
||||||
std::optional<mx::Dtype> dtype,
|
std::optional<mx::Dtype> dtype,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
auto t = dtype.value_or(mx::float32);
|
auto t = dtype.value_or(mx::float32);
|
||||||
if (auto pv = std::get_if<int>(&shape); pv) {
|
if (auto pv = std::get_if<int>(&shape); pv) {
|
||||||
return mx::ones({*pv}, t, s);
|
return mx::ones({*pv}, t, s);
|
||||||
} else {
|
} else {
|
||||||
return mx::ones(std::get<std::vector<int>>(shape), t, s);
|
return mx::ones(std::get<mx::Shape>(shape), t, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"shape"_a,
|
"shape"_a,
|
||||||
@ -2481,14 +2480,14 @@ void init_ops(nb::module_& m) {
|
|||||||
m.def(
|
m.def(
|
||||||
"split",
|
"split",
|
||||||
[](const mx::array& a,
|
[](const mx::array& a,
|
||||||
const std::variant<int, std::vector<int>>& indices_or_sections,
|
const std::variant<int, mx::Shape>& indices_or_sections,
|
||||||
int axis,
|
int axis,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
|
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
|
||||||
return mx::split(a, *pv, axis, s);
|
return mx::split(a, *pv, axis, s);
|
||||||
} else {
|
} else {
|
||||||
return mx::split(
|
return mx::split(
|
||||||
a, std::get<std::vector<int>>(indices_or_sections), axis, s);
|
a, std::get<mx::Shape>(indices_or_sections), axis, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
@ -2744,9 +2743,7 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"broadcast_to",
|
"broadcast_to",
|
||||||
[](const ScalarOrArray& a,
|
[](const ScalarOrArray& a, const mx::Shape& shape, mx::StreamOrDevice s) {
|
||||||
const std::vector<int>& shape,
|
|
||||||
mx::StreamOrDevice s) {
|
|
||||||
return mx::broadcast_to(to_array(a), shape, s);
|
return mx::broadcast_to(to_array(a), shape, s);
|
||||||
},
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
@ -4895,23 +4892,15 @@ void init_ops(nb::module_& m) {
|
|||||||
m.def(
|
m.def(
|
||||||
"roll",
|
"roll",
|
||||||
[](const mx::array& a,
|
[](const mx::array& a,
|
||||||
const IntOrVec& shift,
|
const std::variant<int, mx::Shape>& shift,
|
||||||
const IntOrVec& axis,
|
const IntOrVec& axis,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
return std::visit(
|
return std::visit(
|
||||||
[&](auto sh, auto ax) -> mx::array {
|
[&](auto sh, auto ax) -> mx::array {
|
||||||
using T = decltype(ax);
|
if constexpr (std::is_same_v<decltype(ax), std::monostate>) {
|
||||||
using V = decltype(sh);
|
return mx::roll(a, sh, s);
|
||||||
|
|
||||||
if constexpr (std::is_same_v<V, std::monostate>) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[roll] Expected two arguments but only one was given.");
|
|
||||||
} else {
|
} else {
|
||||||
if constexpr (std::is_same_v<T, std::monostate>) {
|
return mx::roll(a, sh, ax, s);
|
||||||
return mx::roll(a, sh, s);
|
|
||||||
} else {
|
|
||||||
return mx::roll(a, sh, ax, s);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
shift,
|
shift,
|
||||||
|
@ -108,7 +108,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
"uniform",
|
"uniform",
|
||||||
[](const ScalarOrArray& low,
|
[](const ScalarOrArray& low,
|
||||||
const ScalarOrArray& high,
|
const ScalarOrArray& high,
|
||||||
const std::vector<int>& shape,
|
const mx::Shape& shape,
|
||||||
std::optional<mx::Dtype> type,
|
std::optional<mx::Dtype> type,
|
||||||
const std::optional<mx::array>& key_,
|
const std::optional<mx::array>& key_,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
@ -123,7 +123,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
},
|
},
|
||||||
"low"_a = 0,
|
"low"_a = 0,
|
||||||
"high"_a = 1,
|
"high"_a = 1,
|
||||||
"shape"_a = std::vector<int>{},
|
"shape"_a = mx::Shape{},
|
||||||
"dtype"_a.none() = mx::float32,
|
"dtype"_a.none() = mx::float32,
|
||||||
"key"_a = nb::none(),
|
"key"_a = nb::none(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -150,7 +150,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"normal",
|
"normal",
|
||||||
[](const std::vector<int>& shape,
|
[](const mx::Shape& shape,
|
||||||
std::optional<mx::Dtype> type,
|
std::optional<mx::Dtype> type,
|
||||||
float loc,
|
float loc,
|
||||||
float scale,
|
float scale,
|
||||||
@ -160,7 +160,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
return mx::random::normal(
|
return mx::random::normal(
|
||||||
shape, type.value_or(mx::float32), loc, scale, key, s);
|
shape, type.value_or(mx::float32), loc, scale, key, s);
|
||||||
},
|
},
|
||||||
"shape"_a = std::vector<int>{},
|
"shape"_a = mx::Shape{},
|
||||||
"dtype"_a.none() = mx::float32,
|
"dtype"_a.none() = mx::float32,
|
||||||
"loc"_a = 0.0,
|
"loc"_a = 0.0,
|
||||||
"scale"_a = 1.0,
|
"scale"_a = 1.0,
|
||||||
@ -185,7 +185,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
"multivariate_normal",
|
"multivariate_normal",
|
||||||
[](const mx::array& mean,
|
[](const mx::array& mean,
|
||||||
const mx::array& cov,
|
const mx::array& cov,
|
||||||
const std::vector<int>& shape,
|
const mx::Shape& shape,
|
||||||
std::optional<mx::Dtype> type,
|
std::optional<mx::Dtype> type,
|
||||||
const std::optional<mx::array>& key_,
|
const std::optional<mx::array>& key_,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
@ -195,7 +195,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
},
|
},
|
||||||
"mean"_a,
|
"mean"_a,
|
||||||
"cov"_a,
|
"cov"_a,
|
||||||
"shape"_a = std::vector<int>{},
|
"shape"_a = mx::Shape{},
|
||||||
"dtype"_a.none() = mx::float32,
|
"dtype"_a.none() = mx::float32,
|
||||||
"key"_a = nb::none(),
|
"key"_a = nb::none(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -227,7 +227,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
"randint",
|
"randint",
|
||||||
[](const ScalarOrArray& low,
|
[](const ScalarOrArray& low,
|
||||||
const ScalarOrArray& high,
|
const ScalarOrArray& high,
|
||||||
const std::vector<int>& shape,
|
const mx::Shape& shape,
|
||||||
std::optional<mx::Dtype> type,
|
std::optional<mx::Dtype> type,
|
||||||
const std::optional<mx::array>& key_,
|
const std::optional<mx::array>& key_,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
@ -242,7 +242,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
},
|
},
|
||||||
"low"_a,
|
"low"_a,
|
||||||
"high"_a,
|
"high"_a,
|
||||||
"shape"_a = std::vector<int>{},
|
"shape"_a = mx::Shape{},
|
||||||
"dtype"_a.none() = mx::int32,
|
"dtype"_a.none() = mx::int32,
|
||||||
"key"_a = nb::none(),
|
"key"_a = nb::none(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -268,7 +268,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
m.def(
|
m.def(
|
||||||
"bernoulli",
|
"bernoulli",
|
||||||
[](const ScalarOrArray& p_,
|
[](const ScalarOrArray& p_,
|
||||||
const std::optional<std::vector<int>> shape,
|
const std::optional<mx::Shape> shape,
|
||||||
const std::optional<mx::array>& key_,
|
const std::optional<mx::array>& key_,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
auto key = key_ ? key_.value() : default_key().next();
|
auto key = key_ ? key_.value() : default_key().next();
|
||||||
@ -306,7 +306,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
"truncated_normal",
|
"truncated_normal",
|
||||||
[](const ScalarOrArray& lower_,
|
[](const ScalarOrArray& lower_,
|
||||||
const ScalarOrArray& upper_,
|
const ScalarOrArray& upper_,
|
||||||
const std::optional<std::vector<int>> shape_,
|
const std::optional<mx::Shape> shape_,
|
||||||
std::optional<mx::Dtype> type,
|
std::optional<mx::Dtype> type,
|
||||||
const std::optional<mx::array>& key_,
|
const std::optional<mx::array>& key_,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
@ -350,14 +350,14 @@ void init_random(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"gumbel",
|
"gumbel",
|
||||||
[](const std::vector<int>& shape,
|
[](const mx::Shape& shape,
|
||||||
std::optional<mx::Dtype> type,
|
std::optional<mx::Dtype> type,
|
||||||
const std::optional<mx::array>& key_,
|
const std::optional<mx::array>& key_,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
auto key = key_ ? key_.value() : default_key().next();
|
auto key = key_ ? key_.value() : default_key().next();
|
||||||
return mx::random::gumbel(shape, type.value_or(mx::float32), key, s);
|
return mx::random::gumbel(shape, type.value_or(mx::float32), key, s);
|
||||||
},
|
},
|
||||||
"shape"_a = std::vector<int>{},
|
"shape"_a = mx::Shape{},
|
||||||
"dtype"_a.none() = mx::float32,
|
"dtype"_a.none() = mx::float32,
|
||||||
"key"_a = nb::none(),
|
"key"_a = nb::none(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -384,7 +384,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
"categorical",
|
"categorical",
|
||||||
[](const mx::array& logits,
|
[](const mx::array& logits,
|
||||||
int axis,
|
int axis,
|
||||||
const std::optional<std::vector<int>> shape,
|
const std::optional<mx::Shape> shape,
|
||||||
const std::optional<int> num_samples,
|
const std::optional<int> num_samples,
|
||||||
const std::optional<mx::array>& key_,
|
const std::optional<mx::array>& key_,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
@ -434,7 +434,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"laplace",
|
"laplace",
|
||||||
[](const std::vector<int>& shape,
|
[](const mx::Shape& shape,
|
||||||
std::optional<mx::Dtype> type,
|
std::optional<mx::Dtype> type,
|
||||||
float loc,
|
float loc,
|
||||||
float scale,
|
float scale,
|
||||||
@ -444,7 +444,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
return mx::random::laplace(
|
return mx::random::laplace(
|
||||||
shape, type.value_or(mx::float32), loc, scale, key, s);
|
shape, type.value_or(mx::float32), loc, scale, key, s);
|
||||||
},
|
},
|
||||||
"shape"_a = std::vector<int>{},
|
"shape"_a = mx::Shape{},
|
||||||
"dtype"_a.none() = mx::float32,
|
"dtype"_a.none() = mx::float32,
|
||||||
"loc"_a = 0.0,
|
"loc"_a = 0.0,
|
||||||
"scale"_a = 1.0,
|
"scale"_a = 1.0,
|
||||||
@ -479,7 +479,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
return mx::random::permutation(std::get<mx::array>(x), axis, key, s);
|
return mx::random::permutation(std::get<mx::array>(x), axis, key, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"shape"_a = std::vector<int>{},
|
"x"_a,
|
||||||
"axis"_a = 0,
|
"axis"_a = 0,
|
||||||
"key"_a = nb::none(),
|
"key"_a = nb::none(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <nanobind/nanobind.h>
|
#include <nanobind/nanobind.h>
|
||||||
#include <nanobind/stl/optional.h>
|
#include <nanobind/stl/optional.h>
|
||||||
#include <nanobind/stl/pair.h>
|
#include <nanobind/stl/pair.h>
|
||||||
|
@ -395,7 +395,7 @@ TEST_CASE("test split") {
|
|||||||
CHECK_EQ(out[1].shape(), Shape{8, 4});
|
CHECK_EQ(out[1].shape(), Shape{8, 4});
|
||||||
CHECK_EQ(out[2].shape(), Shape{8, 4});
|
CHECK_EQ(out[2].shape(), Shape{8, 4});
|
||||||
|
|
||||||
out = split(x, std::vector<int>{});
|
out = split(x, Shape{});
|
||||||
CHECK_EQ(out.size(), 1);
|
CHECK_EQ(out.size(), 1);
|
||||||
CHECK_EQ(out[0].shape(), x.shape());
|
CHECK_EQ(out[0].shape(), x.shape());
|
||||||
|
|
||||||
@ -405,25 +405,25 @@ TEST_CASE("test split") {
|
|||||||
CHECK_EQ(out[1].shape(), Shape{4, 12});
|
CHECK_EQ(out[1].shape(), Shape{4, 12});
|
||||||
CHECK_EQ(out[2].shape(), Shape{1, 12});
|
CHECK_EQ(out[2].shape(), Shape{1, 12});
|
||||||
|
|
||||||
out = split(x, std::vector<int>{20});
|
out = split(x, Shape{20});
|
||||||
CHECK_EQ(out.size(), 2);
|
CHECK_EQ(out.size(), 2);
|
||||||
CHECK_EQ(out[0].shape(), Shape{8, 12});
|
CHECK_EQ(out[0].shape(), Shape{8, 12});
|
||||||
CHECK_EQ(out[1].shape(), Shape{0, 12});
|
CHECK_EQ(out[1].shape(), Shape{0, 12});
|
||||||
|
|
||||||
// Negative indices
|
// Negative indices
|
||||||
out = split(x, std::vector<int>{-5});
|
out = split(x, Shape{-5});
|
||||||
CHECK_EQ(out[0].shape(), Shape{3, 12});
|
CHECK_EQ(out[0].shape(), Shape{3, 12});
|
||||||
CHECK_EQ(out[1].shape(), Shape{5, 12});
|
CHECK_EQ(out[1].shape(), Shape{5, 12});
|
||||||
|
|
||||||
// Different axis
|
// Different axis
|
||||||
out = split(x, std::vector<int>{2, 8}, 1);
|
out = split(x, {2, 8}, 1);
|
||||||
CHECK_EQ(out[0].shape(), Shape{8, 2});
|
CHECK_EQ(out[0].shape(), Shape{8, 2});
|
||||||
CHECK_EQ(out[1].shape(), Shape{8, 6});
|
CHECK_EQ(out[1].shape(), Shape{8, 6});
|
||||||
CHECK_EQ(out[2].shape(), Shape{8, 4});
|
CHECK_EQ(out[2].shape(), Shape{8, 4});
|
||||||
|
|
||||||
// Out of order indices
|
// Out of order indices
|
||||||
x = arange(5);
|
x = arange(5);
|
||||||
out = split(x, std::vector<int>{2, 1, 2});
|
out = split(x, {2, 1, 2});
|
||||||
CHECK(array_equal(out[0], array({0, 1})).item<bool>());
|
CHECK(array_equal(out[0], array({0, 1})).item<bool>());
|
||||||
CHECK(array_equal(out[1], array({})).item<bool>());
|
CHECK(array_equal(out[1], array({})).item<bool>());
|
||||||
CHECK(array_equal(out[2], array({1})).item<bool>());
|
CHECK(array_equal(out[2], array({1})).item<bool>());
|
||||||
|
@ -611,8 +611,8 @@ TEST_CASE("test categorical") {
|
|||||||
CHECK_THROWS(categorical(logits, -3));
|
CHECK_THROWS(categorical(logits, -3));
|
||||||
|
|
||||||
// Invalid requested shapes
|
// Invalid requested shapes
|
||||||
CHECK_THROWS(categorical(logits, 1, std::vector<int>{1}));
|
CHECK_THROWS(categorical(logits, 1, Shape{1}));
|
||||||
CHECK_THROWS(categorical(logits, 1, std::vector<int>{11}));
|
CHECK_THROWS(categorical(logits, 1, Shape{11}));
|
||||||
CHECK_THROWS(categorical(logits, 1, {10, 1}));
|
CHECK_THROWS(categorical(logits, 1, {10, 1}));
|
||||||
|
|
||||||
CHECK_EQ(categorical(logits, -1).shape(), Shape{10});
|
CHECK_EQ(categorical(logits, -1).shape(), Shape{10});
|
||||||
|
@ -335,8 +335,7 @@ TEST_CASE("test vmap gather") {
|
|||||||
auto fun = [](std::vector<array> inputs) {
|
auto fun = [](std::vector<array> inputs) {
|
||||||
auto src = inputs[0];
|
auto src = inputs[0];
|
||||||
auto indices = inputs[1];
|
auto indices = inputs[1];
|
||||||
std::vector<int> slice_sizes = {1, 2, 2};
|
auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 2);
|
||||||
auto out = squeeze(gather(src, indices, 0, slice_sizes), 2);
|
|
||||||
return std::vector<array>{out};
|
return std::vector<array>{out};
|
||||||
};
|
};
|
||||||
auto x = zeros({2, 2, 2, 2});
|
auto x = zeros({2, 2, 2, 2});
|
||||||
@ -351,8 +350,7 @@ TEST_CASE("test vmap gather") {
|
|||||||
auto fun = [](std::vector<array> inputs) {
|
auto fun = [](std::vector<array> inputs) {
|
||||||
auto src = inputs[0];
|
auto src = inputs[0];
|
||||||
auto indices = inputs[1];
|
auto indices = inputs[1];
|
||||||
std::vector<int> slice_sizes = {1, 2, 2};
|
auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 1);
|
||||||
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
|
|
||||||
return std::vector<array>{out};
|
return std::vector<array>{out};
|
||||||
};
|
};
|
||||||
auto x = zeros({2, 2, 2, 2});
|
auto x = zeros({2, 2, 2, 2});
|
||||||
@ -365,8 +363,7 @@ TEST_CASE("test vmap gather") {
|
|||||||
auto fun = [](std::vector<array> inputs) {
|
auto fun = [](std::vector<array> inputs) {
|
||||||
auto src = inputs[0];
|
auto src = inputs[0];
|
||||||
auto indices = inputs[1];
|
auto indices = inputs[1];
|
||||||
std::vector<int> slice_sizes = {1, 2, 2, 2};
|
auto out = squeeze(gather(src, indices, 0, {1, 2, 2, 2}), 1);
|
||||||
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
|
|
||||||
return std::vector<array>{out};
|
return std::vector<array>{out};
|
||||||
};
|
};
|
||||||
auto x = zeros({2, 2, 2, 2});
|
auto x = zeros({2, 2, 2, 2});
|
||||||
@ -380,8 +377,7 @@ TEST_CASE("test vmap gather") {
|
|||||||
auto fun = [](std::vector<array> inputs) {
|
auto fun = [](std::vector<array> inputs) {
|
||||||
auto src = inputs[0];
|
auto src = inputs[0];
|
||||||
auto indices = std::vector<array>(inputs.begin() + 1, inputs.end());
|
auto indices = std::vector<array>(inputs.begin() + 1, inputs.end());
|
||||||
std::vector<int> slice_sizes = {1, 1, 2, 2};
|
auto out = squeeze(gather(src, indices, {0, 1}, {1, 1, 2, 2}), {1, 2});
|
||||||
auto out = squeeze(gather(src, indices, {0, 1}, slice_sizes), {1, 2});
|
|
||||||
return std::vector<array>{out};
|
return std::vector<array>{out};
|
||||||
};
|
};
|
||||||
auto x = zeros({2, 2, 2, 2});
|
auto x = zeros({2, 2, 2, 2});
|
||||||
|
Loading…
Reference in New Issue
Block a user