mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -130,7 +130,7 @@ std::string build_lib_name(
|
||||
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape) {
|
||||
const Shape& shape) {
|
||||
bool contiguous = true;
|
||||
bool all_contig = true;
|
||||
bool all_row_contig = true;
|
||||
|
||||
@@ -56,7 +56,7 @@ inline bool is_scalar(const array& x) {
|
||||
// Check if we can use a contiguous operation given inputs and the output shape
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape);
|
||||
const Shape& shape);
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
void compiled_allocate_outputs(
|
||||
|
||||
@@ -726,7 +726,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
auto conv_dtype = float32;
|
||||
|
||||
// Pad input
|
||||
std::vector<int> padded_shape = {N, iH + 2 * padding[0], C};
|
||||
Shape padded_shape = {N, iH + 2 * padding[0], C};
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
@@ -765,7 +765,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {N * oH, wH * C};
|
||||
Shape strided_reshape = {N * oH, wH * C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy(in_strided_view, in_strided, CopyType::General);
|
||||
|
||||
@@ -843,8 +843,7 @@ void explicit_gemm_conv_2D_cpu(
|
||||
auto conv_dtype = out.dtype();
|
||||
|
||||
// Pad input
|
||||
std::vector<int> padded_shape = {
|
||||
N, iH + 2 * padding[0], iW + 2 * padding[1], C};
|
||||
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
@@ -881,7 +880,7 @@ void explicit_gemm_conv_2D_cpu(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {N * oH * oW, wH * wW * C};
|
||||
Shape strided_reshape = {N * oH * oW, wH * wW * C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy(in_strided_view, in_strided, CopyType::General);
|
||||
|
||||
@@ -934,19 +933,19 @@ void explicit_gemm_conv_ND_cpu(
|
||||
const std::vector<int>& wt_dilation,
|
||||
const bool flip) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const auto iDim = std::vector<int>(
|
||||
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
|
||||
const auto oDim = std::vector<int>(
|
||||
const auto iDim =
|
||||
Shape(in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
|
||||
const auto oDim = Shape(
|
||||
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(-1); // In channels
|
||||
const auto wDim = std::vector<int>(
|
||||
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
|
||||
const auto wDim =
|
||||
Shape(wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
|
||||
|
||||
auto conv_dtype = float32;
|
||||
|
||||
// Pad input
|
||||
std::vector<int> padded_shape(in.shape().size());
|
||||
Shape padded_shape(in.shape().size());
|
||||
padded_shape.front() = N;
|
||||
for (size_t i = 0; i < iDim.size(); i++) {
|
||||
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
|
||||
|
||||
@@ -14,10 +14,10 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename IdxT = int32_t>
|
||||
template <typename T>
|
||||
struct StridedIterator {
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
using difference_type = IdxT;
|
||||
using difference_type = int32_t;
|
||||
using value_type = T;
|
||||
using reference = value_type&;
|
||||
using pointer = value_type*;
|
||||
|
||||
@@ -107,7 +107,7 @@ struct ContiguousIterator {
|
||||
: shape_(a.shape()), strides_(a.strides()) {
|
||||
if (!shape_.empty()) {
|
||||
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
||||
pos_ = std::vector<int>(shape_.size(), 0);
|
||||
pos_ = Shape(shape_.size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ void explicit_gemm_conv_ND_gpu(
|
||||
int implicit_K = wt.size() / conv_params.O;
|
||||
int implicit_N = conv_params.O;
|
||||
// Prepare unfolding array
|
||||
std::vector<int> unfolded_shape{implicit_M, implicit_K};
|
||||
Shape unfolded_shape{implicit_M, implicit_K};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
@@ -113,7 +113,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
}
|
||||
|
||||
// Prepare unfolding array
|
||||
std::vector<int> unfolded_shape{implicit_M, implicit_K * groups};
|
||||
Shape unfolded_shape{implicit_M, implicit_K * groups};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
|
||||
@@ -192,12 +192,12 @@ void conv_1D_gpu(
|
||||
bool flip) {
|
||||
// Make conv params
|
||||
MLXConvParams<1> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
/* const int C = */ in.shape(2),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in.shape(1)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1)},
|
||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||
/* const int C = */ static_cast<int>(in.shape(2)),
|
||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
|
||||
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
|
||||
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
|
||||
/* const int str[NDIM] = */ {wt_strides[0]},
|
||||
/* const int pad[NDIM] = */ {padding[0]},
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0]},
|
||||
@@ -541,7 +541,7 @@ void winograd_conv_2D_gpu(
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params,
|
||||
std::vector<array>& copies_w) {
|
||||
std::vector<int> padded_shape = {
|
||||
Shape padded_shape = {
|
||||
conv_params.N,
|
||||
conv_params.iS[0] + 2 * conv_params.pad[0],
|
||||
conv_params.iS[1] + 2 * conv_params.pad[1],
|
||||
@@ -550,7 +550,7 @@ void winograd_conv_2D_gpu(
|
||||
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
|
||||
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
|
||||
|
||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||
array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
array zero_arr = array(0, in.dtype());
|
||||
@@ -575,12 +575,16 @@ void winograd_conv_2D_gpu(
|
||||
copies_w.push_back(in_padded);
|
||||
|
||||
MLXConvParams<2> conv_params_updated{
|
||||
/* const int N = */ in_padded.shape(0),
|
||||
/* const int C = */ in_padded.shape(3),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in_padded.shape(1), in_padded.shape(2)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
||||
/* const int N = */ static_cast<int>(in_padded.shape(0)),
|
||||
/* const int C = */ static_cast<int>(in_padded.shape(3)),
|
||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||
/* const int iS[NDIM] = */
|
||||
{static_cast<int>(in_padded.shape(1)),
|
||||
static_cast<int>(in_padded.shape(2))},
|
||||
/* const int wS[NDIM] = */
|
||||
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
|
||||
/* const int oS[NDIM] = */
|
||||
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
|
||||
/* const int str[NDIM] = */ {1, 1},
|
||||
/* const int pad[NDIM] = */ {0, 0},
|
||||
/* const int kdil[NDIM] = */ {1, 1},
|
||||
@@ -607,8 +611,8 @@ void winograd_conv_2D_gpu(
|
||||
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
|
||||
|
||||
// Do filter transform
|
||||
std::vector<int> filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
||||
array filt_wg(filt_wg_shape, wt.dtype(), nullptr, {});
|
||||
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
||||
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
|
||||
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
|
||||
copies_w.push_back(filt_wg);
|
||||
{
|
||||
@@ -634,8 +638,8 @@ void winograd_conv_2D_gpu(
|
||||
}
|
||||
|
||||
// Do input transform
|
||||
std::vector<int> inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
||||
array inp_wg(inp_wg_shape, in.dtype(), nullptr, {});
|
||||
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
||||
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
|
||||
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
|
||||
copies_w.push_back(inp_wg);
|
||||
{
|
||||
@@ -661,8 +665,8 @@ void winograd_conv_2D_gpu(
|
||||
}
|
||||
|
||||
// Do batched gemm
|
||||
std::vector<int> out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
||||
array out_wg(out_wg_shape, in.dtype(), nullptr, {});
|
||||
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
||||
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
|
||||
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
|
||||
copies_w.push_back(out_wg);
|
||||
{
|
||||
@@ -723,12 +727,15 @@ void conv_2D_gpu(
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
MLXConvParams<2> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
/* const int C = */ in.shape(3),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||
/* const int C = */ static_cast<int>(in.shape(3)),
|
||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||
/* const int iS[NDIM] = */
|
||||
{static_cast<int>(in.shape(1)), static_cast<int>(in.shape(2))},
|
||||
/* const int wS[NDIM] = */
|
||||
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
|
||||
/* const int oS[NDIM] = */
|
||||
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
|
||||
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
|
||||
/* const int pad[NDIM] = */ {padding[0], padding[1]},
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
||||
@@ -800,12 +807,21 @@ void conv_3D_gpu(
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
MLXConvParams<3> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
/* const int C = */ in.shape(4),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)},
|
||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||
/* const int C = */ static_cast<int>(in.shape(4)),
|
||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||
/* const int iS[NDIM] = */
|
||||
{static_cast<int>(in.shape(1)),
|
||||
static_cast<int>(in.shape(2)),
|
||||
static_cast<int>(in.shape(3))},
|
||||
/* const int wS[NDIM] = */
|
||||
{static_cast<int>(wt.shape(1)),
|
||||
static_cast<int>(wt.shape(2)),
|
||||
static_cast<int>(wt.shape(3))},
|
||||
/* const int oS[NDIM] = */
|
||||
{static_cast<int>(out.shape(1)),
|
||||
static_cast<int>(out.shape(2)),
|
||||
static_cast<int>(out.shape(3))},
|
||||
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
|
||||
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
|
||||
/* const int kdil[NDIM] = */
|
||||
|
||||
@@ -635,7 +635,7 @@ void strided_reduce_longcolumn(
|
||||
}
|
||||
|
||||
// Prepare the temporary accumulator
|
||||
std::vector<int> intermediate_shape;
|
||||
Shape intermediate_shape;
|
||||
intermediate_shape.reserve(out.ndim() + 1);
|
||||
intermediate_shape.push_back(outer_blocks);
|
||||
intermediate_shape.insert(
|
||||
@@ -806,7 +806,7 @@ void strided_reduce_2pass(
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
|
||||
// Prepare the temporary accumulator
|
||||
std::vector<int> intermediate_shape;
|
||||
Shape intermediate_shape;
|
||||
intermediate_shape.reserve(out.ndim() + 1);
|
||||
intermediate_shape.push_back(32);
|
||||
intermediate_shape.insert(
|
||||
|
||||
@@ -63,8 +63,8 @@ void pad_gpu(
|
||||
const array& in,
|
||||
const array& val,
|
||||
array& out,
|
||||
std::vector<int> axes,
|
||||
std::vector<int> low_pad_size,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& low_pad_size,
|
||||
const Stream& s) {
|
||||
// Fill output with val
|
||||
fill_gpu(val, out, s);
|
||||
|
||||
@@ -23,8 +23,8 @@ void pad_gpu(
|
||||
const array& in,
|
||||
const array& val,
|
||||
array& out,
|
||||
std::vector<int> axes,
|
||||
std::vector<int> low_pad_size,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& low_pad_size,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user