More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -130,7 +130,7 @@ std::string build_lib_name(
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape) {
const Shape& shape) {
bool contiguous = true;
bool all_contig = true;
bool all_row_contig = true;

View File

@@ -56,7 +56,7 @@ inline bool is_scalar(const array& x) {
// Check if we can use a contiguous operation given inputs and the output shape
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape);
const Shape& shape);
// Allocate space for the outputs possibly with input donation
void compiled_allocate_outputs(

View File

@@ -726,7 +726,7 @@ void explicit_gemm_conv_1D_cpu(
auto conv_dtype = float32;
// Pad input
std::vector<int> padded_shape = {N, iH + 2 * padding[0], C};
Shape padded_shape = {N, iH + 2 * padding[0], C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
@@ -765,7 +765,7 @@ void explicit_gemm_conv_1D_cpu(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N * oH, wH * C};
Shape strided_reshape = {N * oH, wH * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);
@@ -843,8 +843,7 @@ void explicit_gemm_conv_2D_cpu(
auto conv_dtype = out.dtype();
// Pad input
std::vector<int> padded_shape = {
N, iH + 2 * padding[0], iW + 2 * padding[1], C};
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
@@ -881,7 +880,7 @@ void explicit_gemm_conv_2D_cpu(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N * oH * oW, wH * wW * C};
Shape strided_reshape = {N * oH * oW, wH * wW * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);
@@ -934,19 +933,19 @@ void explicit_gemm_conv_ND_cpu(
const std::vector<int>& wt_dilation,
const bool flip) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const auto iDim = std::vector<int>(
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
const auto oDim = std::vector<int>(
const auto iDim =
Shape(in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
const auto oDim = Shape(
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(-1); // In channels
const auto wDim = std::vector<int>(
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
const auto wDim =
Shape(wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
auto conv_dtype = float32;
// Pad input
std::vector<int> padded_shape(in.shape().size());
Shape padded_shape(in.shape().size());
padded_shape.front() = N;
for (size_t i = 0; i < iDim.size(); i++) {
padded_shape[i + 1] = iDim[i] + 2 * padding[i];

View File

@@ -14,10 +14,10 @@ namespace mlx::core {
namespace {
template <typename T, typename IdxT = int32_t>
template <typename T>
struct StridedIterator {
using iterator_category = std::random_access_iterator_tag;
using difference_type = IdxT;
using difference_type = int32_t;
using value_type = T;
using reference = value_type&;
using pointer = value_type*;

View File

@@ -107,7 +107,7 @@ struct ContiguousIterator {
: shape_(a.shape()), strides_(a.strides()) {
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0);
pos_ = Shape(shape_.size(), 0);
}
}

View File

@@ -34,7 +34,7 @@ void explicit_gemm_conv_ND_gpu(
int implicit_K = wt.size() / conv_params.O;
int implicit_N = conv_params.O;
// Prepare unfolding array
std::vector<int> unfolded_shape{implicit_M, implicit_K};
Shape unfolded_shape{implicit_M, implicit_K};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
@@ -113,7 +113,7 @@ void explicit_gemm_conv_group_ND_gpu(
}
// Prepare unfolding array
std::vector<int> unfolded_shape{implicit_M, implicit_K * groups};
Shape unfolded_shape{implicit_M, implicit_K * groups};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
@@ -192,12 +192,12 @@ void conv_1D_gpu(
bool flip) {
// Make conv params
MLXConvParams<1> conv_params{
/* const int N = */ in.shape(0),
/* const int C = */ in.shape(2),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in.shape(1)},
/* const int wS[NDIM] = */ {wt.shape(1)},
/* const int oS[NDIM] = */ {out.shape(1)},
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ static_cast<int>(in.shape(2)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
/* const int str[NDIM] = */ {wt_strides[0]},
/* const int pad[NDIM] = */ {padding[0]},
/* const int kdil[NDIM] = */ {wt_dilation[0]},
@@ -541,7 +541,7 @@ void winograd_conv_2D_gpu(
array out,
const MLXConvParams<2>& conv_params,
std::vector<array>& copies_w) {
std::vector<int> padded_shape = {
Shape padded_shape = {
conv_params.N,
conv_params.iS[0] + 2 * conv_params.pad[0],
conv_params.iS[1] + 2 * conv_params.pad[1],
@@ -550,7 +550,7 @@ void winograd_conv_2D_gpu(
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
array in_padded(padded_shape, in.dtype(), nullptr, {});
array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
// Fill with zeros
array zero_arr = array(0, in.dtype());
@@ -575,12 +575,16 @@ void winograd_conv_2D_gpu(
copies_w.push_back(in_padded);
MLXConvParams<2> conv_params_updated{
/* const int N = */ in_padded.shape(0),
/* const int C = */ in_padded.shape(3),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in_padded.shape(1), in_padded.shape(2)},
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
/* const int N = */ static_cast<int>(in_padded.shape(0)),
/* const int C = */ static_cast<int>(in_padded.shape(3)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */
{static_cast<int>(in_padded.shape(1)),
static_cast<int>(in_padded.shape(2))},
/* const int wS[NDIM] = */
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
/* const int oS[NDIM] = */
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
/* const int str[NDIM] = */ {1, 1},
/* const int pad[NDIM] = */ {0, 0},
/* const int kdil[NDIM] = */ {1, 1},
@@ -607,8 +611,8 @@ void winograd_conv_2D_gpu(
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
// Do filter transform
std::vector<int> filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
array filt_wg(filt_wg_shape, wt.dtype(), nullptr, {});
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
copies_w.push_back(filt_wg);
{
@@ -634,8 +638,8 @@ void winograd_conv_2D_gpu(
}
// Do input transform
std::vector<int> inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
array inp_wg(inp_wg_shape, in.dtype(), nullptr, {});
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
copies_w.push_back(inp_wg);
{
@@ -661,8 +665,8 @@ void winograd_conv_2D_gpu(
}
// Do batched gemm
std::vector<int> out_wg_shape = {8 * 8, N_tiles, conv_params.O};
array out_wg(out_wg_shape, in.dtype(), nullptr, {});
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
copies_w.push_back(out_wg);
{
@@ -723,12 +727,15 @@ void conv_2D_gpu(
std::vector<array>& copies) {
// Make conv params
MLXConvParams<2> conv_params{
/* const int N = */ in.shape(0),
/* const int C = */ in.shape(3),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2)},
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ static_cast<int>(in.shape(3)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */
{static_cast<int>(in.shape(1)), static_cast<int>(in.shape(2))},
/* const int wS[NDIM] = */
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
/* const int oS[NDIM] = */
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
/* const int pad[NDIM] = */ {padding[0], padding[1]},
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
@@ -800,12 +807,21 @@ void conv_3D_gpu(
std::vector<array>& copies) {
// Make conv params
MLXConvParams<3> conv_params{
/* const int N = */ in.shape(0),
/* const int C = */ in.shape(4),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)},
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)},
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ static_cast<int>(in.shape(4)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */
{static_cast<int>(in.shape(1)),
static_cast<int>(in.shape(2)),
static_cast<int>(in.shape(3))},
/* const int wS[NDIM] = */
{static_cast<int>(wt.shape(1)),
static_cast<int>(wt.shape(2)),
static_cast<int>(wt.shape(3))},
/* const int oS[NDIM] = */
{static_cast<int>(out.shape(1)),
static_cast<int>(out.shape(2)),
static_cast<int>(out.shape(3))},
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
/* const int kdil[NDIM] = */

View File

@@ -635,7 +635,7 @@ void strided_reduce_longcolumn(
}
// Prepare the temporary accumulator
std::vector<int> intermediate_shape;
Shape intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(outer_blocks);
intermediate_shape.insert(
@@ -806,7 +806,7 @@ void strided_reduce_2pass(
auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Prepare the temporary accumulator
std::vector<int> intermediate_shape;
Shape intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(32);
intermediate_shape.insert(

View File

@@ -63,8 +63,8 @@ void pad_gpu(
const array& in,
const array& val,
array& out,
std::vector<int> axes,
std::vector<int> low_pad_size,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Stream& s) {
// Fill output with val
fill_gpu(val, out, s);

View File

@@ -23,8 +23,8 @@ void pad_gpu(
const array& in,
const array& val,
array& out,
std::vector<int> axes,
std::vector<int> low_pad_size,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Stream& s);
} // namespace mlx::core