Add gemv masked to JIT plus some fixes (#1310)

* add gemv masked to JIT plus some fixes

* some cleanup

* add utils

* fix

* fix 2

* more cleaning

* fix

* remove unused mps matmul support

* one more nit

* revert
This commit is contained in:
Awni Hannun
2024-08-07 13:38:07 -07:00
committed by GitHub
parent 635ccd9e25
commit 30bbea2f08
25 changed files with 1230 additions and 1702 deletions

View File

@@ -8,8 +8,6 @@
namespace mlx::core {
namespace {
using metal::CommandEncoder;
template <typename T>
@@ -27,82 +25,13 @@ set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
return set_vector_bytes(enc, vec, vec.size(), idx);
}
std::string type_to_name(const array& a) {
std::string tname;
switch (a.dtype()) {
case bool_:
tname = "bool_";
break;
case uint8:
tname = "uint8";
break;
case uint16:
tname = "uint16";
break;
case uint32:
tname = "uint32";
break;
case uint64:
tname = "uint64";
break;
case int8:
tname = "int8";
break;
case int16:
tname = "int16";
break;
case int32:
tname = "int32";
break;
case int64:
tname = "int64";
break;
case float16:
tname = "float16";
break;
case float32:
tname = "float32";
break;
case bfloat16:
tname = "bfloat16";
break;
case complex64:
tname = "complex64";
break;
}
return tname;
}
std::string type_to_name(const array& a);
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
int pows[3] = {0, 0, 0};
int sum = 0;
while (true) {
int presum = sum;
// Check all the pows
if (dim0 >= (1 << (pows[0] + 1))) {
pows[0]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim1 >= (1 << (pows[1] + 1))) {
pows[1]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim2 >= (1 << (pows[2] + 1))) {
pows[2]++;
sum++;
}
if (sum == presum || sum == 10) {
break;
}
}
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
}
// Compute the thread block dimensions which fit the given
// input dimensions.
// - The thread block dimensions will be powers of two
// - The thread block size will be less than 1024
MTL::Size get_block_dims(int dim0, int dim1, int dim2);
// Computes a 2D grid where each element is < UINT_MAX
// Assumes:
@@ -111,27 +40,7 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
// possibly broadcasted array
MTL::Size get_2d_grid_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
// Dims with strides of 0 are ignored as they
// correspond to broadcasted dimensions
size_t grid_x = 1;
size_t grid_y = 1;
for (int i = 0; i < shape.size(); ++i) {
if (strides[i] == 0) {
continue;
}
if (grid_x * shape[i] < UINT32_MAX) {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
throw std::runtime_error("Unable to safely factor shape.");
}
return MTL::Size(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
const std::vector<size_t>& strides);
inline NS::String* make_string(std::ostringstream& os) {
std::string string = os.str();
@@ -159,12 +68,6 @@ inline void debug_set_primitive_buffer_label(
#endif
}
std::string get_primitive_string(Primitive* primitive) {
std::ostringstream op_t;
primitive->print(op_t);
return op_t.str();
}
} // namespace
std::string get_primitive_string(Primitive* primitive);
} // namespace mlx::core