More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -130,7 +130,7 @@ std::string build_lib_name(
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape) {
const Shape& shape) {
bool contiguous = true;
bool all_contig = true;
bool all_row_contig = true;

View File

@@ -56,7 +56,7 @@ inline bool is_scalar(const array& x) {
// Check if we can use a contiguous operation given inputs and the output shape
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape);
const Shape& shape);
// Allocate space for the outputs possibly with input donation
void compiled_allocate_outputs(

View File

@@ -726,7 +726,7 @@ void explicit_gemm_conv_1D_cpu(
auto conv_dtype = float32;
// Pad input
std::vector<int> padded_shape = {N, iH + 2 * padding[0], C};
Shape padded_shape = {N, iH + 2 * padding[0], C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
@@ -765,7 +765,7 @@ void explicit_gemm_conv_1D_cpu(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N * oH, wH * C};
Shape strided_reshape = {N * oH, wH * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);
@@ -843,8 +843,7 @@ void explicit_gemm_conv_2D_cpu(
auto conv_dtype = out.dtype();
// Pad input
std::vector<int> padded_shape = {
N, iH + 2 * padding[0], iW + 2 * padding[1], C};
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
@@ -881,7 +880,7 @@ void explicit_gemm_conv_2D_cpu(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N * oH * oW, wH * wW * C};
Shape strided_reshape = {N * oH * oW, wH * wW * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);
@@ -934,19 +933,19 @@ void explicit_gemm_conv_ND_cpu(
const std::vector<int>& wt_dilation,
const bool flip) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const auto iDim = std::vector<int>(
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
const auto oDim = std::vector<int>(
const auto iDim =
Shape(in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
const auto oDim = Shape(
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(-1); // In channels
const auto wDim = std::vector<int>(
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
const auto wDim =
Shape(wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
auto conv_dtype = float32;
// Pad input
std::vector<int> padded_shape(in.shape().size());
Shape padded_shape(in.shape().size());
padded_shape.front() = N;
for (size_t i = 0; i < iDim.size(); i++) {
padded_shape[i + 1] = iDim[i] + 2 * padding[i];

View File

@@ -14,10 +14,10 @@ namespace mlx::core {
namespace {
template <typename T, typename IdxT = int32_t>
template <typename T>
struct StridedIterator {
using iterator_category = std::random_access_iterator_tag;
using difference_type = IdxT;
using difference_type = int32_t;
using value_type = T;
using reference = value_type&;
using pointer = value_type*;

View File

@@ -107,7 +107,7 @@ struct ContiguousIterator {
: shape_(a.shape()), strides_(a.strides()) {
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0);
pos_ = Shape(shape_.size(), 0);
}
}