mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-28 14:18:13 +08:00
awni's commit files
This commit is contained in:
87
.clang-format
Normal file
87
.clang-format
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
---
|
||||||
|
AccessModifierOffset: -1
|
||||||
|
AlignAfterOpenBracket: AlwaysBreak
|
||||||
|
AlignConsecutiveAssignments: false
|
||||||
|
AlignConsecutiveDeclarations: false
|
||||||
|
AlignEscapedNewlinesLeft: true
|
||||||
|
AlignOperands: false
|
||||||
|
AlignTrailingComments: false
|
||||||
|
AllowAllParametersOfDeclarationOnNextLine: false
|
||||||
|
AllowShortBlocksOnASingleLine: false
|
||||||
|
AllowShortCaseLabelsOnASingleLine: false
|
||||||
|
AllowShortFunctionsOnASingleLine: Empty
|
||||||
|
AllowShortIfStatementsOnASingleLine: false
|
||||||
|
AllowShortLoopsOnASingleLine: false
|
||||||
|
AlwaysBreakAfterReturnType: None
|
||||||
|
AlwaysBreakBeforeMultilineStrings: true
|
||||||
|
AlwaysBreakTemplateDeclarations: true
|
||||||
|
BinPackArguments: false
|
||||||
|
BinPackParameters: false
|
||||||
|
BraceWrapping:
|
||||||
|
AfterClass: false
|
||||||
|
AfterControlStatement: false
|
||||||
|
AfterEnum: false
|
||||||
|
AfterFunction: false
|
||||||
|
AfterNamespace: false
|
||||||
|
AfterObjCDeclaration: false
|
||||||
|
AfterStruct: false
|
||||||
|
AfterUnion: false
|
||||||
|
BeforeCatch: false
|
||||||
|
BeforeElse: false
|
||||||
|
IndentBraces: false
|
||||||
|
BreakBeforeBinaryOperators: None
|
||||||
|
BreakBeforeBraces: Attach
|
||||||
|
BreakBeforeTernaryOperators: true
|
||||||
|
BreakConstructorInitializersBeforeComma: false
|
||||||
|
BreakAfterJavaFieldAnnotations: false
|
||||||
|
BreakStringLiterals: false
|
||||||
|
ColumnLimit: 80
|
||||||
|
CommentPragmas: '^ IWYU pragma:'
|
||||||
|
ConstructorInitializerAllOnOneLineOrOnePerLine: true
|
||||||
|
ConstructorInitializerIndentWidth: 4
|
||||||
|
ContinuationIndentWidth: 4
|
||||||
|
Cpp11BracedListStyle: true
|
||||||
|
DerivePointerAlignment: false
|
||||||
|
DisableFormat: false
|
||||||
|
ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ]
|
||||||
|
IncludeCategories:
|
||||||
|
- Regex: '^<.*\.h(pp)?>'
|
||||||
|
Priority: 1
|
||||||
|
- Regex: '^<.*'
|
||||||
|
Priority: 2
|
||||||
|
- Regex: '.*'
|
||||||
|
Priority: 3
|
||||||
|
IndentCaseLabels: true
|
||||||
|
IndentWidth: 2
|
||||||
|
IndentWrappedFunctionNames: false
|
||||||
|
KeepEmptyLinesAtTheStartOfBlocks: false
|
||||||
|
MacroBlockBegin: ''
|
||||||
|
MacroBlockEnd: ''
|
||||||
|
MaxEmptyLinesToKeep: 1
|
||||||
|
NamespaceIndentation: None
|
||||||
|
ObjCBlockIndentWidth: 2
|
||||||
|
ObjCSpaceAfterProperty: false
|
||||||
|
ObjCSpaceBeforeProtocolList: false
|
||||||
|
PenaltyBreakBeforeFirstCallParameter: 1
|
||||||
|
PenaltyBreakComment: 300
|
||||||
|
PenaltyBreakFirstLessLess: 120
|
||||||
|
PenaltyBreakString: 1000
|
||||||
|
PenaltyExcessCharacter: 1000000
|
||||||
|
PenaltyReturnTypeOnItsOwnLine: 200
|
||||||
|
PointerAlignment: Left
|
||||||
|
ReflowComments: true
|
||||||
|
SortIncludes: true
|
||||||
|
SpaceAfterCStyleCast: false
|
||||||
|
SpaceBeforeAssignmentOperators: true
|
||||||
|
SpaceBeforeParens: ControlStatements
|
||||||
|
SpaceInEmptyParentheses: false
|
||||||
|
SpacesBeforeTrailingComments: 1
|
||||||
|
SpacesInAngles: false
|
||||||
|
SpacesInContainerLiterals: true
|
||||||
|
SpacesInCStyleCastParentheses: false
|
||||||
|
SpacesInParentheses: false
|
||||||
|
SpacesInSquareBrackets: false
|
||||||
|
Standard: Cpp11
|
||||||
|
TabWidth: 8
|
||||||
|
UseTab: Never
|
||||||
|
...
|
||||||
3
MANIFEST.in
Normal file
3
MANIFEST.in
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
include CMakeLists.txt
|
||||||
|
recursive-include mlx/ *
|
||||||
|
include python/src/*
|
||||||
198
benchmarks/cpp/irregular_strides.cpp
Normal file
198
benchmarks/cpp/irregular_strides.cpp
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
#include "time_utils.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
void time_irregular_binary_ops_1D() {
|
||||||
|
auto device = default_device();
|
||||||
|
int size = 1000000;
|
||||||
|
int step = 2;
|
||||||
|
auto a = random::uniform({size});
|
||||||
|
auto b = random::uniform({size});
|
||||||
|
eval(a, b);
|
||||||
|
a = slice(a, {0}, {size}, {step});
|
||||||
|
b = slice(b, {0}, {size}, {step});
|
||||||
|
TIMEM("1D strided", add, a, b, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
void time_irregular_binary_ops_2D() {
|
||||||
|
auto device = default_device();
|
||||||
|
int size = 2048;
|
||||||
|
auto a = random::uniform({size, size});
|
||||||
|
auto b = random::uniform({size, size});
|
||||||
|
eval(a, b);
|
||||||
|
TIMEM("2D regular", add, a, b, device);
|
||||||
|
|
||||||
|
b = transpose(b);
|
||||||
|
eval(b);
|
||||||
|
TIMEM("2D transpose", add, a, b, device);
|
||||||
|
|
||||||
|
b = random::uniform({size});
|
||||||
|
eval(b);
|
||||||
|
TIMEM("2D broadcast dim 0", add, a, b, device);
|
||||||
|
|
||||||
|
b = reshape(b, {size, 1});
|
||||||
|
eval(b);
|
||||||
|
TIMEM("2D broadcast dim 1", add, a, b, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
void time_irregular_binary_ops_3D() {
|
||||||
|
auto device = default_device();
|
||||||
|
int d0 = 32;
|
||||||
|
int d1 = 512;
|
||||||
|
int d2 = 512;
|
||||||
|
auto a = random::uniform({d0, d1, d2});
|
||||||
|
auto b = random::uniform({d0, d1, d2});
|
||||||
|
TIMEM("3D regular", add, a, b, device);
|
||||||
|
|
||||||
|
b = transpose(b, {0, 2, 1});
|
||||||
|
TIMEM("3D transpose", add, a, b, device);
|
||||||
|
|
||||||
|
b = random::uniform({d1, d2});
|
||||||
|
TIMEM("3D broadcast dim 0", add, a, b, device);
|
||||||
|
|
||||||
|
b = random::uniform({d0, 1, d2});
|
||||||
|
TIMEM("3D broadcast dim 1", add, a, b, device);
|
||||||
|
|
||||||
|
b = random::uniform({d0, d1, 1});
|
||||||
|
TIMEM("3D broadcast dim 2", add, a, b, device);
|
||||||
|
|
||||||
|
b = random::uniform({d2});
|
||||||
|
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
|
||||||
|
|
||||||
|
b = random::uniform({d1, 1});
|
||||||
|
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
|
||||||
|
|
||||||
|
b = random::uniform({d0, 1, 1});
|
||||||
|
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
void time_irregular_binary_ops_4D() {
|
||||||
|
auto device = default_device();
|
||||||
|
std::vector<int> shape = {8, 8, 512, 512};
|
||||||
|
auto a = random::uniform(shape);
|
||||||
|
auto b = random::uniform(shape);
|
||||||
|
|
||||||
|
TIMEM("4D regular", add, a, b, device);
|
||||||
|
|
||||||
|
b = transpose(b, {0, 1, 3, 2});
|
||||||
|
TIMEM("4D transpose", add, a, b, device);
|
||||||
|
|
||||||
|
std::string om = "4D broadcast dims ";
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
shape[i] = 1;
|
||||||
|
b = random::uniform(shape);
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << om << i;
|
||||||
|
TIMEM(msg.str(), add, a, b, device);
|
||||||
|
|
||||||
|
for (int j = i + 1; j < shape.size(); ++j) {
|
||||||
|
shape[j] = 1;
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << om << i << ", " << j;
|
||||||
|
b = random::uniform(shape);
|
||||||
|
TIMEM(msg.str(), add, a, b, device);
|
||||||
|
shape[j] = a.shape(j);
|
||||||
|
|
||||||
|
for (int k = j + 1; k < shape.size(); ++k) {
|
||||||
|
shape[k] = 1;
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << om << i << ", " << j << ", " << k;
|
||||||
|
b = random::uniform(shape);
|
||||||
|
TIMEM(msg.str(), add, a, b, device);
|
||||||
|
shape[k] = a.shape(k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
shape[i] = a.shape(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void time_irregular_reshape() {
|
||||||
|
auto device = default_device();
|
||||||
|
std::vector<int> shape;
|
||||||
|
auto reshape_fn = [&shape, device](const array& a) {
|
||||||
|
return reshape(a, shape, device);
|
||||||
|
};
|
||||||
|
|
||||||
|
int size = 64;
|
||||||
|
int d = 2 * size;
|
||||||
|
|
||||||
|
auto a = random::uniform({d, d, d});
|
||||||
|
|
||||||
|
shape = {8 * size, size, size};
|
||||||
|
TIMEM("3D contiguous", reshape_fn, a);
|
||||||
|
|
||||||
|
a = transpose(a);
|
||||||
|
shape = {8 * size, size, size};
|
||||||
|
TIMEM("3D transpose", reshape_fn, a);
|
||||||
|
|
||||||
|
a = transpose(a, {1, 2, 0});
|
||||||
|
shape = {8 * size, size, size};
|
||||||
|
TIMEM("3D transpose dims 1 2", reshape_fn, a);
|
||||||
|
|
||||||
|
a = broadcast_to(random::uniform({d, d}), {d, d, d});
|
||||||
|
TIMEM("3D broadcast dim 0", reshape_fn, a);
|
||||||
|
|
||||||
|
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
|
||||||
|
TIMEM("3D broadcast dim 1", reshape_fn, a);
|
||||||
|
|
||||||
|
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
|
||||||
|
TIMEM("3D broadcast dim 2", reshape_fn, a);
|
||||||
|
|
||||||
|
a = broadcast_to(random::uniform({d}), {d, d, d});
|
||||||
|
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
|
||||||
|
|
||||||
|
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
|
||||||
|
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
|
||||||
|
|
||||||
|
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
|
||||||
|
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
|
||||||
|
|
||||||
|
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
|
||||||
|
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
void time_irregular_astype_1D() {
|
||||||
|
auto device = default_device();
|
||||||
|
int size = 1000000;
|
||||||
|
int step = 2;
|
||||||
|
auto a = random::uniform({size});
|
||||||
|
a = slice(a, {0}, {size}, {step});
|
||||||
|
TIMEM("1D strided", astype, a, int32, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
void time_irregular_astype_2D() {
|
||||||
|
auto device = default_device();
|
||||||
|
int size = 2048;
|
||||||
|
std::vector<int> shape = {size, size};
|
||||||
|
|
||||||
|
auto a = random::uniform(shape);
|
||||||
|
TIMEM("2D regular", astype, a, int32, device);
|
||||||
|
|
||||||
|
a = transpose(a);
|
||||||
|
TIMEM("2D transpose", astype, a, int32, device);
|
||||||
|
|
||||||
|
a = broadcast_to(random::uniform({size}), shape);
|
||||||
|
TIMEM("2D broadcast dim 0", astype, a, int32, device);
|
||||||
|
|
||||||
|
a = broadcast_to(random::uniform({size, 1}), shape);
|
||||||
|
TIMEM("2D broadcast dim 1", astype, a, int32, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
if (argc > 1) {
|
||||||
|
bool use_gpu = !strcmp(argv[1], "gpu");
|
||||||
|
set_default_device(use_gpu ? Device::gpu : Device::cpu);
|
||||||
|
}
|
||||||
|
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||||
|
time_irregular_binary_ops_1D();
|
||||||
|
time_irregular_binary_ops_2D();
|
||||||
|
time_irregular_binary_ops_3D();
|
||||||
|
time_irregular_binary_ops_4D();
|
||||||
|
time_irregular_reshape();
|
||||||
|
time_irregular_astype_1D();
|
||||||
|
time_irregular_astype_2D();
|
||||||
|
}
|
||||||
247
benchmarks/cpp/single_ops.cpp
Normal file
247
benchmarks/cpp/single_ops.cpp
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
#include "mlx/mlx.h"
|
||||||
|
#include "time_utils.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
void time_creation_ops() {
|
||||||
|
int M = 2000;
|
||||||
|
int N = 500;
|
||||||
|
auto shape = {M, N};
|
||||||
|
auto full_fp32 = [&]() { return full(shape, 3.3f); };
|
||||||
|
TIME(full_fp32);
|
||||||
|
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
|
||||||
|
TIME(zeros_fp32);
|
||||||
|
auto ones_fp32 = [&]() { return ones(shape, float32); };
|
||||||
|
TIME(ones_fp32);
|
||||||
|
|
||||||
|
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
|
||||||
|
TIME(arange_fp32);
|
||||||
|
}
|
||||||
|
|
||||||
|
void time_type_conversions() {
|
||||||
|
int M = 2000;
|
||||||
|
int N = 500;
|
||||||
|
auto shape = {M, N};
|
||||||
|
auto device = default_device();
|
||||||
|
|
||||||
|
auto a = zeros(shape, float32);
|
||||||
|
eval(a);
|
||||||
|
TIMEM("float32 to int32", astype, a, int32, device);
|
||||||
|
TIMEM("float32 to uint32", astype, a, uint32, device);
|
||||||
|
|
||||||
|
a = zeros(shape, int32);
|
||||||
|
eval(a);
|
||||||
|
TIMEM("int32 to float32", astype, a, float32, device);
|
||||||
|
|
||||||
|
a = zeros(shape, bool_);
|
||||||
|
eval(a);
|
||||||
|
TIMEM("bool to float32", astype, a, float32, device);
|
||||||
|
TIMEM("bool to int32", astype, a, int32, device);
|
||||||
|
TIMEM("bool to uint32", astype, a, uint32, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
void time_random_generation() {
|
||||||
|
int M = 2000;
|
||||||
|
int N = 500;
|
||||||
|
|
||||||
|
auto uniform = [&]() { return random::uniform({M, N}, float32); };
|
||||||
|
TIME(uniform);
|
||||||
|
auto normal = [&]() { return random::normal({M, N}, float32); };
|
||||||
|
TIME(normal);
|
||||||
|
}
|
||||||
|
|
||||||
|
void time_unary_ops() {
|
||||||
|
int M = 2000;
|
||||||
|
int N = 500;
|
||||||
|
auto device = default_device();
|
||||||
|
|
||||||
|
auto a = random::normal({M, N});
|
||||||
|
eval(a);
|
||||||
|
TIME(mlx::core::abs, a, device);
|
||||||
|
TIME(negative, a, device);
|
||||||
|
TIME(sign, a, device);
|
||||||
|
TIME(square, a, device);
|
||||||
|
TIME(mlx::core::sqrt, a, device);
|
||||||
|
TIME(rsqrt, a, device);
|
||||||
|
TIME(mlx::core::exp, a, device);
|
||||||
|
|
||||||
|
a = random::uniform({M, N});
|
||||||
|
TIME(mlx::core::log, a, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
void time_binary_ops() {
|
||||||
|
int M = 1000, N = 100, K = 10;
|
||||||
|
auto a = random::uniform({M, N, K});
|
||||||
|
auto b = random::uniform({M, N, K});
|
||||||
|
auto device = default_device();
|
||||||
|
eval(a, b);
|
||||||
|
|
||||||
|
TIME(add, a, b, device);
|
||||||
|
TIME(subtract, a, b, device);
|
||||||
|
TIME(multiply, a, b, device);
|
||||||
|
TIME(divide, a, b, device);
|
||||||
|
TIME(maximum, a, b, device);
|
||||||
|
TIME(minimum, a, b, device);
|
||||||
|
|
||||||
|
b = random::uniform({1});
|
||||||
|
eval(b);
|
||||||
|
TIMEM("scalar", add, a, b, device);
|
||||||
|
TIMEM("vector-scalar", subtract, a, b, device);
|
||||||
|
TIMEM("scalar-vector", subtract, b, a, device);
|
||||||
|
TIMEM("scalar", multiply, a, b, device);
|
||||||
|
TIMEM("vector-scalar", divide, a, b, device);
|
||||||
|
TIMEM("scalar-vector", divide, b, a, device);
|
||||||
|
|
||||||
|
a = broadcast_to(random::uniform({1}), {1000, 100});
|
||||||
|
b = broadcast_to(random::uniform({1}), {1000, 100});
|
||||||
|
eval(a, b);
|
||||||
|
TIMEM("scalar-scalar broadcast", add, a, b, device);
|
||||||
|
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
|
||||||
|
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
|
||||||
|
TIMEM("scalar-scalar broadcast", divide, a, b, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
void time_strided_ops() {
|
||||||
|
int M = 50, N = 50, O = 50, P = 50;
|
||||||
|
auto a = random::uniform({M, N, O, P});
|
||||||
|
auto b = random::uniform({M, N, O, P});
|
||||||
|
auto device = default_device();
|
||||||
|
eval(a, b);
|
||||||
|
TIMEM("non-strided", add, a, b, device);
|
||||||
|
a = transpose(a, {1, 0, 2, 3});
|
||||||
|
b = transpose(b, {3, 2, 0, 1});
|
||||||
|
eval(a, b);
|
||||||
|
TIMEM("strided", add, a, b, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
void time_comparisons() {
|
||||||
|
int M = 1000, N = 100, K = 10;
|
||||||
|
auto a = random::uniform({M, N, K});
|
||||||
|
auto b = random::uniform({M, N, K});
|
||||||
|
auto device = default_device();
|
||||||
|
eval(a, b);
|
||||||
|
TIME(equal, a, b, device);
|
||||||
|
TIME(greater, a, b, device);
|
||||||
|
TIME(greater_equal, a, b, device);
|
||||||
|
TIME(less, a, b, device);
|
||||||
|
TIME(less_equal, a, b, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
void time_matvec() {
|
||||||
|
int M = 2000, N = 200;
|
||||||
|
auto a = random::uniform({M, N});
|
||||||
|
auto b = random::uniform({N});
|
||||||
|
auto c = random::uniform({M});
|
||||||
|
eval(a, b, c);
|
||||||
|
auto matvec = [&]() { return matmul(a, b); };
|
||||||
|
TIME(matvec);
|
||||||
|
|
||||||
|
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
|
||||||
|
TIME(matvec_transpose);
|
||||||
|
}
|
||||||
|
|
||||||
|
void time_matmul() {
|
||||||
|
int M = 1000, N = 1000, K = 1000;
|
||||||
|
auto a = random::uniform({M, K});
|
||||||
|
auto b = random::uniform({K, N});
|
||||||
|
auto device = default_device();
|
||||||
|
eval(a, b);
|
||||||
|
TIME(matmul, a, b, device);
|
||||||
|
|
||||||
|
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
|
||||||
|
TIME(transpose_matmul);
|
||||||
|
}
|
||||||
|
|
||||||
|
void time_reductions() {
|
||||||
|
auto a = random::normal({10000, 1000});
|
||||||
|
eval(a);
|
||||||
|
auto sum_all = [&a]() { return sum(a, false); };
|
||||||
|
TIME(sum_all);
|
||||||
|
|
||||||
|
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
|
||||||
|
TIME(sum_along_0);
|
||||||
|
|
||||||
|
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
|
||||||
|
TIME(sum_along_1);
|
||||||
|
|
||||||
|
auto prod_all = [&a]() { return prod(a, false); };
|
||||||
|
TIME(prod_all);
|
||||||
|
|
||||||
|
auto all_true = [&a]() { return all(a, false); };
|
||||||
|
TIME(all_true);
|
||||||
|
|
||||||
|
auto all_along_0 = [&a]() { return all(a, 0, false); };
|
||||||
|
TIME(all_along_0);
|
||||||
|
|
||||||
|
auto all_along_1 = [&a]() { return all(a, 1, false); };
|
||||||
|
TIME(all_along_1);
|
||||||
|
|
||||||
|
auto any_true = [&a]() { return any(a, false); };
|
||||||
|
TIME(any_true);
|
||||||
|
|
||||||
|
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
|
||||||
|
TIME(argmin_along_0);
|
||||||
|
|
||||||
|
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
|
||||||
|
TIME(argmin_along_1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void time_gather_scatter() {
|
||||||
|
auto a = random::normal({1000, 768});
|
||||||
|
eval(a);
|
||||||
|
auto indices = random::randint(0, 1000, {256});
|
||||||
|
eval(indices);
|
||||||
|
|
||||||
|
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
|
||||||
|
TIME(embedding_lookup);
|
||||||
|
|
||||||
|
indices = random::randint(0, 768 * 1000, {256 * 768});
|
||||||
|
eval(indices);
|
||||||
|
|
||||||
|
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
|
||||||
|
TIME(single_element_lookup);
|
||||||
|
|
||||||
|
indices = random::randint(0, 1000, {256});
|
||||||
|
auto updates = random::normal({256, 1, 768});
|
||||||
|
eval(indices, updates);
|
||||||
|
|
||||||
|
auto embedding_update = [&a, &indices, &updates]() {
|
||||||
|
return scatter(a, indices, updates, 0);
|
||||||
|
};
|
||||||
|
TIME(embedding_update);
|
||||||
|
|
||||||
|
auto embedding_add = [&a, &indices, &updates]() {
|
||||||
|
return scatter_add(a, indices, updates, 0);
|
||||||
|
};
|
||||||
|
TIME(embedding_add);
|
||||||
|
|
||||||
|
a = reshape(a, {-1});
|
||||||
|
indices = random::randint(0, 768 * 1000, {768 * 256});
|
||||||
|
updates = random::normal({256 * 768, 1});
|
||||||
|
eval(a, indices, updates);
|
||||||
|
|
||||||
|
auto single_element_update = [&a, &indices, &updates]() {
|
||||||
|
return scatter(a, indices, updates, 0);
|
||||||
|
};
|
||||||
|
TIME(single_element_update);
|
||||||
|
|
||||||
|
auto single_element_add = [&a, &indices, &updates]() {
|
||||||
|
return scatter_add(a, indices, updates, 0);
|
||||||
|
};
|
||||||
|
TIME(single_element_add);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||||
|
time_creation_ops();
|
||||||
|
time_type_conversions();
|
||||||
|
time_unary_ops();
|
||||||
|
time_binary_ops();
|
||||||
|
time_strided_ops();
|
||||||
|
time_random_generation();
|
||||||
|
time_comparisons();
|
||||||
|
time_matvec();
|
||||||
|
time_matmul();
|
||||||
|
time_reductions();
|
||||||
|
time_gather_scatter();
|
||||||
|
}
|
||||||
15
benchmarks/python/comparative/README.md
Normal file
15
benchmarks/python/comparative/README.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
Microbenchmarks comparing MLX to PyTorch
|
||||||
|
========================================
|
||||||
|
|
||||||
|
Implement the same microbenchmarks in MLX and PyTorch to compare and make a
|
||||||
|
list of the biggest possible performance improvements and/or regressions.
|
||||||
|
|
||||||
|
Run with `python bench_mlx.py sum_axis --size 8x1024x128 --axis 2 --cpu` for
|
||||||
|
instance to measure the times it takes to sum across the 3rd axis of the above
|
||||||
|
tensor on the cpu.
|
||||||
|
|
||||||
|
`compare.py` runs several benchmarks and compares the speed-up or lack thereof
|
||||||
|
in comparison to PyTorch.
|
||||||
|
|
||||||
|
Each bench script can be run with `--print-pid` to print the PID and wait for a
|
||||||
|
key in order to ease attaching a debugger.
|
||||||
313
benchmarks/python/comparative/bench_mlx.py
Normal file
313
benchmarks/python/comparative/bench_mlx.py
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
def int_or_list(x):
|
||||||
|
try:
|
||||||
|
return int(x)
|
||||||
|
except ValueError:
|
||||||
|
return [int(xi) for xi in x.split(",")]
|
||||||
|
|
||||||
|
|
||||||
|
def none_or_list(x):
|
||||||
|
if x == "":
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return [int(xi) for xi in x.split(",")]
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, *args):
|
||||||
|
for i in range(10):
|
||||||
|
f(*args)
|
||||||
|
|
||||||
|
s = time.time()
|
||||||
|
for i in range(100):
|
||||||
|
f(*args)
|
||||||
|
e = time.time()
|
||||||
|
return e - s
|
||||||
|
|
||||||
|
|
||||||
|
def matmul_square(x):
|
||||||
|
y = x
|
||||||
|
for i in range(10):
|
||||||
|
y = y @ x
|
||||||
|
mx.eval(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def matmul(x, y):
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(x @ y)
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def conv1d(x, y):
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(mx.conv1d(x, y))
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def conv2d(x, y):
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(mx.conv2d(x, y))
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def binary(op, x, y):
|
||||||
|
for i in range(100):
|
||||||
|
y = getattr(mx, op)(x, y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def reduction(op, axis, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(100):
|
||||||
|
ys.append(getattr(mx, op)(x, axis=axis))
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def softmax(axis, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(100):
|
||||||
|
ex = mx.exp(x - mx.max(x, axis=axis, keepdims=True))
|
||||||
|
y = ex / mx.sum(ex, axis=axis, keepdims=True)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_fused(axis, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(100):
|
||||||
|
y = mx.softmax(x, axis=axis)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def relu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = mx.maximum(y, 0)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def scalar_mult(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = y * (1.0 / (1 + i))
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def cross_entropy(targets, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(100):
|
||||||
|
y = mx.logsumexp(x, axis=-1, keepdims=True) - mx.take_along_axis(
|
||||||
|
x, mx.reshape(targets, (-1, 1)), axis=-1
|
||||||
|
)
|
||||||
|
ys.append(mx.mean(y))
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def logsumexp(axis, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(100):
|
||||||
|
ys.append(mx.logsumexp(x, axis=axis))
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def linear(w, b, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(x @ mx.transpose(w, (1, 0)) + b)
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def rope(x):
|
||||||
|
*_, N, D = x.shape
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
shape = x.shape
|
||||||
|
x = mx.reshape(x, (-1, N, D))
|
||||||
|
positions = mx.arange(N)
|
||||||
|
freqs = mx.exp(mx.arange(0.0, D // 2) / math.log(10000 / (D // 2 - 1)))
|
||||||
|
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||||
|
costheta = mx.cos(theta)
|
||||||
|
sintheta = mx.sin(theta)
|
||||||
|
x1 = x[..., ::2]
|
||||||
|
x2 = x[..., 1::2]
|
||||||
|
rx1 = x1 * costheta - x2 * sintheta
|
||||||
|
rx2 = x1 * sintheta + x2 * costheta
|
||||||
|
y = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
||||||
|
y = mx.reshape(y, (-1, N, D))
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def concatenate(axis, x, y):
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(mx.concatenate([x, y], axis=axis))
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def cumsum(axis, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(mx.cumsum(x, axis))
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def sort(axis, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(mx.sort(x, axis))
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def topk(axis, x):
|
||||||
|
k = x.shape[axis] // 3
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(mx.topk(x, k, axis))
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
||||||
|
parser.add_argument(
|
||||||
|
"--size",
|
||||||
|
default=[(1024, 1024)],
|
||||||
|
type=lambda x: list(map(int, x.split("x"))),
|
||||||
|
help="Set the matrix size",
|
||||||
|
action="append",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--axis",
|
||||||
|
default=[1],
|
||||||
|
type=int_or_list,
|
||||||
|
help="Set a reduction axis",
|
||||||
|
action="append",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--transpose",
|
||||||
|
type=none_or_list,
|
||||||
|
default=[],
|
||||||
|
help="Permute the matrix",
|
||||||
|
action="append",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--print-pid", action="store_true", help="Print the PID and pause"
|
||||||
|
)
|
||||||
|
parser.add_argument("--cpu", action="store_true", help="Use the CPU")
|
||||||
|
parser.add_argument(
|
||||||
|
"--fused", action="store_true", help="Use fused functions where possible"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype", choices=["float32", "float16", "bfloat16"], default="float32"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if len(args.size) > 1:
|
||||||
|
args.size.pop(0)
|
||||||
|
if len(args.axis) > 1:
|
||||||
|
args.axis.pop(0)
|
||||||
|
|
||||||
|
if args.print_pid:
|
||||||
|
print(os.getpid())
|
||||||
|
input("Press enter to run")
|
||||||
|
|
||||||
|
if args.cpu:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
else:
|
||||||
|
mx.set_default_device(mx.gpu)
|
||||||
|
dtype = dict(float32=mx.float32, float16=mx.float16, bfloat16=mx.bfloat16)[
|
||||||
|
args.dtype
|
||||||
|
]
|
||||||
|
xs = []
|
||||||
|
for size in args.size:
|
||||||
|
xs.append(mx.random.normal(size).astype(dtype))
|
||||||
|
for i, t in enumerate(args.transpose):
|
||||||
|
if t is None:
|
||||||
|
continue
|
||||||
|
xs[i] = mx.transpose(xs[i], t)
|
||||||
|
mx.eval(xs)
|
||||||
|
x = xs[0]
|
||||||
|
axis = args.axis[0]
|
||||||
|
|
||||||
|
if args.benchmark == "matmul_square":
|
||||||
|
print(bench(matmul_square, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "matmul":
|
||||||
|
print(bench(matmul, *xs))
|
||||||
|
|
||||||
|
elif args.benchmark == "linear":
|
||||||
|
print(bench(linear, *xs))
|
||||||
|
|
||||||
|
elif args.benchmark == "sum_axis":
|
||||||
|
print(bench(reduction, "sum", axis, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "sum_all":
|
||||||
|
print(bench(reduction, "sum", None, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "argmax":
|
||||||
|
print(bench(reduction, "argmax", axis, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "add":
|
||||||
|
print(bench(binary, "add", *xs))
|
||||||
|
|
||||||
|
elif args.benchmark == "mul":
|
||||||
|
print(bench(binary, "multiply", *xs))
|
||||||
|
|
||||||
|
elif args.benchmark == "softmax":
|
||||||
|
if args.fused:
|
||||||
|
print(bench(softmax_fused, axis, x))
|
||||||
|
else:
|
||||||
|
print(bench(softmax, axis, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "relu":
|
||||||
|
print(bench(relu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "scalar_mul":
|
||||||
|
print(bench(scalar_mult, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "cross_entropy":
|
||||||
|
if len(size) != 2:
|
||||||
|
raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size")
|
||||||
|
|
||||||
|
targets = mx.zeros((len(x),), dtype=mx.uint32)
|
||||||
|
print(bench(cross_entropy, targets, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "logsumexp":
|
||||||
|
print(bench(logsumexp, axis, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "rope":
|
||||||
|
print(bench(rope, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "concatenate":
|
||||||
|
print(bench(concatenate, axis, *xs))
|
||||||
|
|
||||||
|
elif args.benchmark == "cumsum":
|
||||||
|
print(bench(cumsum, axis, *xs))
|
||||||
|
|
||||||
|
elif args.benchmark == "conv1d":
|
||||||
|
print(bench(conv1d, *xs))
|
||||||
|
|
||||||
|
elif args.benchmark == "conv2d":
|
||||||
|
print(bench(conv2d, *xs))
|
||||||
|
|
||||||
|
elif args.benchmark == "sort":
|
||||||
|
print(bench(sort, axis, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "topk":
|
||||||
|
print(bench(topk, axis, x))
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown benchmark")
|
||||||
338
benchmarks/python/comparative/bench_torch.py
Normal file
338
benchmarks/python/comparative/bench_torch.py
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.mps
|
||||||
|
|
||||||
|
|
||||||
|
def int_or_list(x):
|
||||||
|
try:
|
||||||
|
return int(x)
|
||||||
|
except ValueError:
|
||||||
|
return [int(xi) for xi in x.split(",")]
|
||||||
|
|
||||||
|
|
||||||
|
def none_or_list(x):
|
||||||
|
if x == "":
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return [int(xi) for xi in x.split(",")]
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, *args):
|
||||||
|
for i in range(10):
|
||||||
|
f(*args)
|
||||||
|
|
||||||
|
s = time.time()
|
||||||
|
for i in range(100):
|
||||||
|
f(*args)
|
||||||
|
e = time.time()
|
||||||
|
return e - s
|
||||||
|
|
||||||
|
|
||||||
|
def sync_if_needed(x):
|
||||||
|
if x.device != torch.device("cpu"):
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def matmul_square(x):
|
||||||
|
y = x
|
||||||
|
for i in range(10):
|
||||||
|
y = y @ x
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def matmul(x, y):
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(x @ y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def conv1d(x, y):
|
||||||
|
x = torch.transpose(x, -1, -2)
|
||||||
|
y = torch.transpose(y, -1, -2)
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(torch.nn.functional.conv1d(x, y))
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def conv2d(x, y):
|
||||||
|
x = torch.permute(x, (0, 3, 1, 2))
|
||||||
|
y = torch.permute(y, (0, 3, 1, 2))
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(torch.nn.functional.conv2d(x, y))
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def binary(op, x, y):
|
||||||
|
for i in range(100):
|
||||||
|
y = getattr(torch, op)(x, y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def reduction(op, axis, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(100):
|
||||||
|
ys.append(getattr(x, op)(axis))
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def softmax(axis, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(100):
|
||||||
|
ex = torch.exp(x - torch.max(x, dim=axis, keepdims=True).values)
|
||||||
|
y = ex / torch.sum(ex, dim=axis, keepdims=True)
|
||||||
|
ys.append(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def softmax_fused(axis, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(100):
|
||||||
|
ys.append(torch.nn.functional.softmax(x, dim=axis))
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def relu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.relu(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def scalar_mult(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = y * (1.0 / (1 + i))
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def cross_entropy(targets, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(100):
|
||||||
|
ys.append(torch.nn.functional.cross_entropy(x, targets))
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def logsumexp(axis, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(100):
|
||||||
|
ys.append(torch.logsumexp(x, dim=axis))
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def linear_fused(w, b, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(torch.nn.functional.linear(x, w, b))
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def linear(w, b, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append((x @ torch.transpose(w, -2, -1)) + b)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def rope(x):
|
||||||
|
*_, N, D = x.shape
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
x = x.view(-1, N, D)
|
||||||
|
positions = torch.arange(N, device=x.device)
|
||||||
|
freqs = 10000 ** torch.linspace(0, 1, D // 2, device=x.device)
|
||||||
|
theta = positions[:, None] * freqs[None]
|
||||||
|
costheta = torch.cos(theta)
|
||||||
|
sintheta = torch.sin(theta)
|
||||||
|
x1 = x[..., ::2]
|
||||||
|
x2 = x[..., 1::2]
|
||||||
|
rx1 = x1 * costheta - x2 * sintheta
|
||||||
|
rx2 = x1 * sintheta + x2 * costheta
|
||||||
|
y = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
|
||||||
|
y = y.reshape(-1, N, D)
|
||||||
|
ys.append(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def concatenate(axis, x, y):
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(torch.cat([x, y], dim=axis))
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def cumsum(axis, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(x.cumsum(axis))
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sort(axis, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(torch.sort(x, dim=axis)[0])
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def topk(axis, x):
|
||||||
|
k = x.shape[axis] // 3
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(torch.topk(x, k, dim=axis)[0])
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
||||||
|
parser.add_argument(
|
||||||
|
"--size",
|
||||||
|
default=[(1024, 1024)],
|
||||||
|
type=lambda x: list(map(int, x.split("x"))),
|
||||||
|
help="Set the matrix size",
|
||||||
|
action="append",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--axis",
|
||||||
|
default=[1],
|
||||||
|
type=int_or_list,
|
||||||
|
help="Set a reduction axis",
|
||||||
|
action="append",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--transpose",
|
||||||
|
type=none_or_list,
|
||||||
|
default=[],
|
||||||
|
help="Permute the matrix",
|
||||||
|
action="append",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--print-pid", action="store_true", help="Print the PID and pause"
|
||||||
|
)
|
||||||
|
parser.add_argument("--cpu", action="store_true", help="Use the CPU")
|
||||||
|
parser.add_argument(
|
||||||
|
"--fused", action="store_true", help="Use fused functions where possible"
|
||||||
|
)
|
||||||
|
parser.add_argument("--dtype", choices=["float32", "float16"], default="float32")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if len(args.size) > 1:
|
||||||
|
args.size.pop(0)
|
||||||
|
if len(args.axis) > 1:
|
||||||
|
args.axis.pop(0)
|
||||||
|
|
||||||
|
if args.print_pid:
|
||||||
|
print(os.getpid())
|
||||||
|
input("Press enter to run")
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
device = "cpu" if args.cpu else "mps"
|
||||||
|
dtype = dict(float32=torch.float32, float16=torch.float16)[args.dtype]
|
||||||
|
xs = []
|
||||||
|
for size in args.size:
|
||||||
|
xs.append(torch.randn(*size).to(device).to(dtype))
|
||||||
|
for i, t in enumerate(args.transpose):
|
||||||
|
if t is None:
|
||||||
|
continue
|
||||||
|
xs[i] = xs[i].permute(*t)
|
||||||
|
x = xs[0]
|
||||||
|
axis = args.axis[0]
|
||||||
|
|
||||||
|
if args.benchmark == "matmul_square":
|
||||||
|
print(bench(matmul_square, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "matmul":
|
||||||
|
print(bench(matmul, *xs))
|
||||||
|
|
||||||
|
elif args.benchmark == "linear":
|
||||||
|
if args.fused:
|
||||||
|
print(bench(linear_fused, *xs))
|
||||||
|
else:
|
||||||
|
print(bench(linear, *xs))
|
||||||
|
|
||||||
|
elif args.benchmark == "sum_axis":
|
||||||
|
print(bench(reduction, "sum", axis, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "sum_all":
|
||||||
|
print(bench(reduction, "sum", None, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "argmax":
|
||||||
|
print(bench(reduction, "argmax", axis, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "add":
|
||||||
|
print(bench(binary, "add", *xs))
|
||||||
|
|
||||||
|
elif args.benchmark == "mul":
|
||||||
|
print(bench(binary, "mul", *xs))
|
||||||
|
|
||||||
|
elif args.benchmark == "softmax":
|
||||||
|
if args.fused:
|
||||||
|
print(bench(softmax_fused, axis, x))
|
||||||
|
else:
|
||||||
|
print(bench(softmax, axis, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "relu":
|
||||||
|
print(bench(relu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "scalar_mul":
|
||||||
|
print(bench(scalar_mult, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "cross_entropy":
|
||||||
|
if len(size) != 2:
|
||||||
|
raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size")
|
||||||
|
|
||||||
|
targets = torch.zeros(len(x), dtype=torch.long).to(x.device)
|
||||||
|
print(bench(cross_entropy, targets, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "logsumexp":
|
||||||
|
print(bench(logsumexp, axis, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "rope":
|
||||||
|
print(bench(rope, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "concatenate":
|
||||||
|
print(bench(concatenate, axis, *xs))
|
||||||
|
|
||||||
|
elif args.benchmark == "cumsum":
|
||||||
|
print(bench(cumsum, axis, *xs))
|
||||||
|
|
||||||
|
elif args.benchmark == "conv1d":
|
||||||
|
print(bench(conv1d, *xs))
|
||||||
|
|
||||||
|
elif args.benchmark == "conv2d":
|
||||||
|
print(bench(conv2d, *xs))
|
||||||
|
|
||||||
|
elif args.benchmark == "sort":
|
||||||
|
print(bench(sort, axis, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "topk":
|
||||||
|
print(bench(topk, axis, x))
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown benchmark")
|
||||||
253
benchmarks/python/comparative/compare.py
Normal file
253
benchmarks/python/comparative/compare.py
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from subprocess import run
|
||||||
|
|
||||||
|
BENCH_MLX = Path(__file__).parent / "bench_mlx.py"
|
||||||
|
BENCH_TORCH = Path(__file__).parent / "bench_torch.py"
|
||||||
|
|
||||||
|
|
||||||
|
def run_or_raise(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
result = run(*args, capture_output=True, **kwargs)
|
||||||
|
return float(result.stdout)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}")
|
||||||
|
|
||||||
|
|
||||||
|
def compare(args):
|
||||||
|
t_mlx = run_or_raise(["python", BENCH_MLX] + args)
|
||||||
|
t_torch = run_or_raise(["python", BENCH_TORCH] + args)
|
||||||
|
|
||||||
|
print((t_torch - t_mlx) / t_torch, " ".join(args), sep="\t")
|
||||||
|
|
||||||
|
|
||||||
|
def compare_mlx_dtypes(args, dt1, dt2):
|
||||||
|
t_mlx_dt1 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt1])
|
||||||
|
t_mlx_dt2 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt2])
|
||||||
|
|
||||||
|
print((t_mlx_dt2 - t_mlx_dt1) / t_mlx_dt2, " ".join(args), sep="\t")
|
||||||
|
|
||||||
|
|
||||||
|
def make_regex_search(regexes):
|
||||||
|
compiled_regexes = list(map(re.compile, regexes))
|
||||||
|
|
||||||
|
def search(x):
|
||||||
|
return (c.search(x) is not None for c in compiled_regexes)
|
||||||
|
|
||||||
|
return search
|
||||||
|
|
||||||
|
|
||||||
|
def make_predicate(positive_filter, negative_filter):
|
||||||
|
if positive_filter is not None:
|
||||||
|
positive_filter_search = make_regex_search(positive_filter)
|
||||||
|
positive_filter = lambda x: all(positive_filter_search(x))
|
||||||
|
else:
|
||||||
|
positive_filter = lambda x: True
|
||||||
|
|
||||||
|
if negative_filter is not None:
|
||||||
|
negative_filter_search = make_regex_search(negative_filter)
|
||||||
|
negative_filter = lambda x: not any(negative_filter_search(x))
|
||||||
|
else:
|
||||||
|
negative_filter = lambda x: True
|
||||||
|
|
||||||
|
def predicate(x):
|
||||||
|
return positive_filter(x) and negative_filter(x)
|
||||||
|
|
||||||
|
return predicate
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run comparisons agains PyTorch")
|
||||||
|
parser.add_argument(
|
||||||
|
"--filter", "-f", help="Regex filter to select benchmarks", nargs="+"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--negative_filter", "-n", help="Regex filter to remove benchmarks", nargs="+"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mlx_dtypes",
|
||||||
|
"-d",
|
||||||
|
help="Compare mlx benchmarks between the 2 provided data types",
|
||||||
|
nargs=2,
|
||||||
|
)
|
||||||
|
args, rest = parser.parse_known_args()
|
||||||
|
|
||||||
|
_filter = make_predicate(args.filter, args.negative_filter)
|
||||||
|
|
||||||
|
if args.mlx_dtypes:
|
||||||
|
compare_filtered = (
|
||||||
|
lambda x: compare_mlx_dtypes(
|
||||||
|
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
|
||||||
|
)
|
||||||
|
if _filter(x)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
compare_filtered = lambda x: compare(x.split() + rest) if _filter(x) else None
|
||||||
|
|
||||||
|
# Binary ops
|
||||||
|
compare_filtered("add --size 10x1024x128 --size 1x1024x128 --cpu")
|
||||||
|
compare_filtered("add --size 10x1024x128 --size 1x1024x128")
|
||||||
|
compare_filtered("add --size 1024x128 --size 1x128 --cpu")
|
||||||
|
compare_filtered("add --size 1024x128 --size 1x128")
|
||||||
|
compare_filtered("add --size 1024x4096 --size 1x4096 --cpu")
|
||||||
|
compare_filtered("add --size 1024x4096 --size 1x4096")
|
||||||
|
compare_filtered("add --size 1024x4096 --size 1x1024 --transpose 1,0 --cpu")
|
||||||
|
compare_filtered("add --size 1024x4096 --size 1x1024 --transpose 1,0")
|
||||||
|
compare_filtered("add --size 1024x1024 --size 1024x1024 --cpu")
|
||||||
|
compare_filtered("add --size 1024x1024 --size 1024x1024")
|
||||||
|
compare_filtered("add --size 1024x1024 --size 1024x1024 --transpose 1,0 --cpu")
|
||||||
|
compare_filtered("add --size 1024x1024 --size 1024x1024 --transpose 1,0")
|
||||||
|
compare_filtered(
|
||||||
|
"add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0 --cpu"
|
||||||
|
)
|
||||||
|
compare_filtered(
|
||||||
|
"add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reduction ops
|
||||||
|
compare_filtered("sum_all --size 10x1024x128 --cpu")
|
||||||
|
compare_filtered("sum_all --size 10x1024x128")
|
||||||
|
compare_filtered("sum_axis --size 16x1024x128 --axis 2 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 16x1024x128 --axis 2")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 2 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 2")
|
||||||
|
compare_filtered("sum_axis --size 1024x1024 --axis 1 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 1024x1024 --axis 1")
|
||||||
|
compare_filtered("sum_axis --size 1024x1024 --axis 0 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 1024x1024 --axis 0")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 1 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
|
||||||
|
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
|
||||||
|
compare_filtered("argmax --size 10x1024x128 --axis 1")
|
||||||
|
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")
|
||||||
|
compare_filtered("argmax --size 10x1024x128 --axis 2")
|
||||||
|
compare_filtered("argmax --size 1024x1024 --axis 1 --cpu")
|
||||||
|
compare_filtered("argmax --size 1024x1024 --axis 1")
|
||||||
|
|
||||||
|
# Matmul ops
|
||||||
|
compare_filtered("matmul_square --size 1024x1024")
|
||||||
|
compare_filtered("matmul_square --size 1024x1024 --cpu")
|
||||||
|
compare_filtered("matmul_square --size 16x1024x1024")
|
||||||
|
compare_filtered("matmul_square --size 16x1024x1024 --cpu")
|
||||||
|
compare_filtered(
|
||||||
|
"matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1"
|
||||||
|
)
|
||||||
|
compare_filtered(
|
||||||
|
"matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1 --cpu"
|
||||||
|
)
|
||||||
|
compare_filtered(
|
||||||
|
"matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1"
|
||||||
|
)
|
||||||
|
compare_filtered(
|
||||||
|
"matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1 --cpu"
|
||||||
|
)
|
||||||
|
compare_filtered("matmul --size 512x8192 --size 8192x512")
|
||||||
|
compare_filtered("matmul --size 512x8192 --size 8192x512 --cpu")
|
||||||
|
# compare_filtered("matmul --size 512x131072 --size 131072x512")
|
||||||
|
# compare_filtered("matmul --size 512x131072 --size 131072x512 --cpu")
|
||||||
|
compare_filtered("matmul --size 8192x512 --size 512x8192")
|
||||||
|
compare_filtered("matmul --size 8192x512 --size 512x8192 --cpu")
|
||||||
|
# compare_filtered("matmul --size 131072x512 --size 512x512")
|
||||||
|
# compare_filtered("matmul --size 131072x512 --size 512x512 --cpu")
|
||||||
|
compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024")
|
||||||
|
compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024 --cpu")
|
||||||
|
compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024 --fused")
|
||||||
|
compare_filtered(
|
||||||
|
"linear --size 1024x1024 --size 1024 --size 128x1024 --fused --cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Matvec ops
|
||||||
|
compare_filtered("matmul --size 1x1x4096 --size 4096x4096 --cpu")
|
||||||
|
compare_filtered("matmul --size 1x1x4096 --size 4096x4096")
|
||||||
|
compare_filtered(
|
||||||
|
"matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0 --cpu"
|
||||||
|
)
|
||||||
|
compare_filtered(
|
||||||
|
"matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0"
|
||||||
|
)
|
||||||
|
compare_filtered("matmul --size 32x1x1000 --size 32x1000x128 --cpu")
|
||||||
|
compare_filtered("matmul --size 32x1x1000 --size 32x1000x128")
|
||||||
|
compare_filtered(
|
||||||
|
"matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1 --cpu"
|
||||||
|
)
|
||||||
|
compare_filtered(
|
||||||
|
"matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Various ops
|
||||||
|
compare_filtered("softmax --size 32x16x1024 --axis 2")
|
||||||
|
compare_filtered("softmax --size 32x16x1024 --axis 2 --cpu")
|
||||||
|
compare_filtered("softmax --size 32x16x1024 --axis 2 --fused")
|
||||||
|
compare_filtered("softmax --size 32x16x1024 --axis 2 --fused --cpu")
|
||||||
|
compare_filtered("softmax --size 2x1024x1024 --axis 1")
|
||||||
|
compare_filtered("softmax --size 2x1024x1024 --axis 1 --cpu")
|
||||||
|
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused")
|
||||||
|
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused --cpu")
|
||||||
|
compare_filtered("relu --size 32x16x1024")
|
||||||
|
compare_filtered("relu --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("scalar_mul --size 32x16x1024")
|
||||||
|
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("cross_entropy --size 256x1024")
|
||||||
|
compare_filtered("cross_entropy --size 256x1024 --cpu")
|
||||||
|
compare_filtered("logsumexp --size 1024x1024 --axis 1")
|
||||||
|
compare_filtered("logsumexp --size 1024x1024 --axis 1 --cpu")
|
||||||
|
compare_filtered("logsumexp --size 1024x1024 --axis 0")
|
||||||
|
compare_filtered("logsumexp --size 1024x1024 --axis 0 --cpu")
|
||||||
|
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 2")
|
||||||
|
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 2 --cpu")
|
||||||
|
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 1")
|
||||||
|
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 1 --cpu")
|
||||||
|
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 0")
|
||||||
|
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 0 --cpu")
|
||||||
|
compare_filtered("concatenate --size 32x1024x128 --size 32x16x128 --axis 1")
|
||||||
|
compare_filtered("concatenate --size 32x1024x128 --size 32x16x128 --axis 1 --cpu")
|
||||||
|
compare_filtered("concatenate --size 32x1024x128 --size 32x1x128 --axis 1")
|
||||||
|
compare_filtered("concatenate --size 32x1024x128 --size 32x1x128 --axis 1 --cpu")
|
||||||
|
compare_filtered("concatenate --size 1x32x1024x128 --size 1x32x1x128 --axis 2")
|
||||||
|
compare_filtered(
|
||||||
|
"concatenate --size 1x32x1024x128 --size 1x32x1x128 --axis 2 --cpu"
|
||||||
|
)
|
||||||
|
compare_filtered("conv1d --size 1x1000x80 --size 128x11x80")
|
||||||
|
compare_filtered("conv1d --size 1x1000x80 --size 128x11x80 --cpu")
|
||||||
|
compare_filtered("conv1d --size 16x1000x80 --size 128x11x80")
|
||||||
|
compare_filtered("conv1d --size 4x1000x80 --size 128x11x80 --cpu")
|
||||||
|
compare_filtered("conv2d --size 1x256x256x3 --size 8x3x3x3")
|
||||||
|
compare_filtered("conv2d --size 1x256x256x3 --size 8x3x3x3 --cpu")
|
||||||
|
compare_filtered("conv2d --size 16x256x256x3 --size 8x3x3x3")
|
||||||
|
compare_filtered("conv2d --size 4x256x256x3 --size 8x3x3x3 --cpu")
|
||||||
|
compare_filtered("cumsum --size 1024x1024 --axis 1 --cpu")
|
||||||
|
compare_filtered("cumsum --size 1024x1024 --axis 0 --cpu")
|
||||||
|
compare_filtered("cumsum --size 1024x1024 --axis 1")
|
||||||
|
compare_filtered("cumsum --size 1024x1024 --axis 0")
|
||||||
|
compare_filtered("cumsum --size 128x1024 --axis 1")
|
||||||
|
compare_filtered("cumsum --size 128x1024 --axis 0")
|
||||||
|
compare_filtered("cumsum --size 1024x4096 --axis 1")
|
||||||
|
compare_filtered("cumsum --size 1024x4096 --axis 0")
|
||||||
|
compare_filtered("cumsum --size 128x4096 --axis 1")
|
||||||
|
compare_filtered("cumsum --size 128x4096 --axis 0")
|
||||||
|
compare_filtered("cumsum --size 1024x7777 --axis 1")
|
||||||
|
compare_filtered("cumsum --size 1024x7777 --axis 0")
|
||||||
|
compare_filtered("cumsum --size 128x7777 --axis 1")
|
||||||
|
compare_filtered("cumsum --size 128x7777 --axis 0")
|
||||||
|
compare_filtered("cumsum --size 32768x128 --axis 1")
|
||||||
|
compare_filtered("cumsum --size 32768x128 --axis 0")
|
||||||
|
|
||||||
|
compare_filtered("sort --size 1024x1024 --axis 0")
|
||||||
|
compare_filtered("sort --size 1024x1024 --axis 1")
|
||||||
|
compare_filtered("sort --size 32768x128 --axis 0")
|
||||||
|
compare_filtered("sort --size 32768x128 --axis 1")
|
||||||
|
compare_filtered("sort --size 128x128 --axis 0 --cpu")
|
||||||
|
compare_filtered("sort --size 128x128 --axis 1 --cpu")
|
||||||
|
|
||||||
|
compare_filtered("topk --size 1024x1024 --axis 0")
|
||||||
|
compare_filtered("topk --size 1024x1024 --axis 1")
|
||||||
|
compare_filtered("topk --size 32768x128 --axis 0")
|
||||||
|
compare_filtered("topk --size 32768x128 --axis 1")
|
||||||
|
compare_filtered("topk --size 128x128 --axis 0 --cpu")
|
||||||
|
compare_filtered("topk --size 128x128 --axis 1 --cpu")
|
||||||
196
benchmarks/python/llama_jax_bench.py
Normal file
196
benchmarks/python/llama_jax_bench.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from flax import linen as nn
|
||||||
|
|
||||||
|
|
||||||
|
class RoPE(nn.Module):
|
||||||
|
dims: int
|
||||||
|
traditional: bool = False
|
||||||
|
|
||||||
|
def _compute_rope(self, costheta, sintheta, x):
|
||||||
|
x1 = x[..., : self.dims // 2]
|
||||||
|
x2 = x[..., self.dims // 2 : self.dims]
|
||||||
|
rx1 = x1 * costheta - x2 * sintheta
|
||||||
|
rx2 = x1 * sintheta + x2 * costheta
|
||||||
|
|
||||||
|
if self.dims < x.shape[-1]:
|
||||||
|
rx = jnp.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
|
||||||
|
else:
|
||||||
|
rx = jnp.concatenate([rx1, rx2], axis=-1)
|
||||||
|
|
||||||
|
return rx
|
||||||
|
|
||||||
|
def _compute_traditional_rope(self, costheta, sintheta, x):
|
||||||
|
x1 = x[..., ::2]
|
||||||
|
x2 = x[..., 1::2]
|
||||||
|
rx1 = x1 * costheta - x2 * sintheta
|
||||||
|
rx2 = x1 * sintheta + x2 * costheta
|
||||||
|
|
||||||
|
if self.dims < x.shape[-1]:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"RoPE doesn't implement partial traditional application"
|
||||||
|
)
|
||||||
|
|
||||||
|
rx = jnp.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
||||||
|
|
||||||
|
return rx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_cos_sin_theta(
|
||||||
|
N: int,
|
||||||
|
D: int,
|
||||||
|
offset: int = 0,
|
||||||
|
base: float = 10000,
|
||||||
|
dtype=jnp.float32,
|
||||||
|
):
|
||||||
|
D = D // 2
|
||||||
|
positions = jnp.arange(offset, N, dtype=dtype)
|
||||||
|
freqs = jnp.exp(-jnp.arange(0, D, dtype=dtype) * (math.log(base) / D))
|
||||||
|
theta = positions.reshape((-1, 1)) * freqs.reshape((1, -1))
|
||||||
|
costheta = jnp.cos(theta)
|
||||||
|
sintheta = jnp.sin(theta)
|
||||||
|
|
||||||
|
return costheta, sintheta
|
||||||
|
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, x, offset: int = 0):
|
||||||
|
shape = x.shape
|
||||||
|
x = x.reshape((-1, shape[-2], shape[-1]))
|
||||||
|
N = x.shape[1] + offset
|
||||||
|
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||||
|
N, self.dims, offset=offset, dtype=x.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
rope = (
|
||||||
|
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||||
|
)
|
||||||
|
rx = rope(costheta, sintheta, x)
|
||||||
|
|
||||||
|
return rx.reshape(shape)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaAttention(nn.Module):
|
||||||
|
dims: int
|
||||||
|
num_heads: int
|
||||||
|
dtype: jnp.dtype
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
num_heads = self.num_heads
|
||||||
|
dims = self.dims
|
||||||
|
|
||||||
|
self.rope = RoPE(dims // num_heads, True)
|
||||||
|
self.query_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
||||||
|
self.key_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
||||||
|
self.value_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
||||||
|
self.out_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
||||||
|
|
||||||
|
def __call__(self, queries, keys, values, mask=None, cache=None):
|
||||||
|
queries = self.query_proj(queries)
|
||||||
|
keys = self.key_proj(keys)
|
||||||
|
values = self.value_proj(values)
|
||||||
|
|
||||||
|
num_heads = self.num_heads
|
||||||
|
B, L, D = queries.shape
|
||||||
|
queries = queries.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
|
||||||
|
keys = keys.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
|
||||||
|
values = values.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
key_cache, value_cache = cache
|
||||||
|
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||||
|
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||||
|
keys = jnp.concatenate([key_cache, keys], axis=2)
|
||||||
|
values = jnp.concatenate([value_cache, values], axis=2)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||||
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
|
scores = (queries * scale) @ keys.transpose((0, 1, 3, 2))
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores + mask
|
||||||
|
scores = jax.nn.softmax(scores, axis=-1)
|
||||||
|
values_hat = (scores @ values).transpose((0, 2, 1, 3)).reshape((B, L, -1))
|
||||||
|
|
||||||
|
return self.out_proj(values_hat), (keys, values)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaEncoderLayer(nn.Module):
|
||||||
|
dims: int
|
||||||
|
mlp_dims: int
|
||||||
|
num_heads: int
|
||||||
|
dtype: jnp.dtype
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
dims = self.dims
|
||||||
|
mlp_dims = self.mlp_dims
|
||||||
|
num_heads = self.num_heads
|
||||||
|
|
||||||
|
self.attention = LlamaAttention(dims, num_heads, dtype)
|
||||||
|
|
||||||
|
self.norm1 = nn.RMSNorm(param_dtype=self.dtype)
|
||||||
|
self.norm2 = nn.RMSNorm(param_dtype=self.dtype)
|
||||||
|
|
||||||
|
self.linear1 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
|
||||||
|
self.linear2 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
|
||||||
|
self.linear3 = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
||||||
|
|
||||||
|
def __call__(self, x, mask=None, cache=None):
|
||||||
|
y = self.norm1(x)
|
||||||
|
y, cache = self.attention(y, y, y, mask, cache)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
y = self.norm2(x)
|
||||||
|
a = self.linear1(y)
|
||||||
|
b = self.linear2(y)
|
||||||
|
y = jax.nn.silu(a) * b
|
||||||
|
y = self.linear3(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
return x, cache
|
||||||
|
|
||||||
|
|
||||||
|
def measure(model, x, cache):
|
||||||
|
for i in range(5):
|
||||||
|
y, c = model(x, mask=None, cache=cache)
|
||||||
|
jax.block_until_ready((y, c))
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for i in range(5):
|
||||||
|
y, c = model(x, mask=None, cache=cache)
|
||||||
|
jax.block_until_ready((y, c))
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
return (end - start) * 1000 / 5
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
H = 32
|
||||||
|
D = 4096
|
||||||
|
F = 43 * 256
|
||||||
|
C = 1000
|
||||||
|
dtype = jnp.float16
|
||||||
|
|
||||||
|
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
|
||||||
|
|
||||||
|
x = jax.random.normal(k1, (1, 1, D), dtype)
|
||||||
|
cache = [
|
||||||
|
jax.random.normal(k2, [1, H, C, D // H], dtype),
|
||||||
|
jax.random.normal(k3, [1, H, C, D // H], dtype),
|
||||||
|
]
|
||||||
|
|
||||||
|
layer = LlamaEncoderLayer(D, F, H, dtype=dtype)
|
||||||
|
params = layer.init(k4, x, mask=None, cache=cache)["params"]
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def model_fn(x, mask, cache):
|
||||||
|
return layer.apply({"params": params}, x, mask=mask, cache=cache)
|
||||||
|
|
||||||
|
T = measure(model_fn, x, cache)
|
||||||
|
|
||||||
|
print("Time per layer per token:", T, "ms")
|
||||||
|
print("Lower bound total time per token:", T * 32, "ms")
|
||||||
197
benchmarks/python/llama_torch_bench.py
Normal file
197
benchmarks/python/llama_torch_bench.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.mps
|
||||||
|
|
||||||
|
|
||||||
|
def sync_if_needed(x):
|
||||||
|
if x.device != torch.device("cpu"):
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
class RoPE(nn.Module):
|
||||||
|
def __init__(self, dims: int, traditional: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.dims = dims
|
||||||
|
self.traditional = traditional
|
||||||
|
|
||||||
|
def _compute_rope(self, costheta, sintheta, x):
|
||||||
|
x1 = x[..., : self.dims // 2]
|
||||||
|
x2 = x[..., self.dims // 2 : self.dims]
|
||||||
|
rx1 = x1 * costheta - x2 * sintheta
|
||||||
|
rx2 = x1 * sintheta + x2 * costheta
|
||||||
|
|
||||||
|
if self.dims < x.shape[-1]:
|
||||||
|
rx = torch.cat([rx1, rx2, x[..., self.dims :]], dim=-1)
|
||||||
|
else:
|
||||||
|
rx = torch.cat([rx1, rx2], dim=-1)
|
||||||
|
|
||||||
|
return rx
|
||||||
|
|
||||||
|
def _compute_traditional_rope(self, costheta, sintheta, x):
|
||||||
|
x1 = x[..., ::2]
|
||||||
|
x2 = x[..., 1::2]
|
||||||
|
rx1 = x1 * costheta - x2 * sintheta
|
||||||
|
rx2 = x1 * sintheta + x2 * costheta
|
||||||
|
|
||||||
|
if self.dims < x.shape[-1]:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"RoPE doesn't implement partial traditional application"
|
||||||
|
)
|
||||||
|
|
||||||
|
rx = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
|
||||||
|
|
||||||
|
return rx
|
||||||
|
|
||||||
|
def forward(self, x, offset: int = 0):
|
||||||
|
shape = x.shape
|
||||||
|
x = x.view(-1, shape[-2], shape[-1])
|
||||||
|
N = x.shape[1] + offset
|
||||||
|
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||||
|
N, self.dims, offset=offset, device=x.device, dtype=x.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
rope = (
|
||||||
|
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||||
|
)
|
||||||
|
rx = rope(costheta, sintheta, x)
|
||||||
|
|
||||||
|
return rx.view(*shape)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_cos_sin_theta(
|
||||||
|
N: int,
|
||||||
|
D: int,
|
||||||
|
offset: int = 0,
|
||||||
|
base: float = 10000,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float32,
|
||||||
|
):
|
||||||
|
D = D // 2
|
||||||
|
positions = torch.arange(offset, N, dtype=dtype, device=device)
|
||||||
|
freqs = torch.exp(
|
||||||
|
-torch.arange(0, D, dtype=dtype, device=device) * (math.log(base) / D)
|
||||||
|
)
|
||||||
|
theta = positions.view(-1, 1) * freqs.view(1, -1)
|
||||||
|
costheta = torch.cos(theta)
|
||||||
|
sintheta = torch.sin(theta)
|
||||||
|
|
||||||
|
return costheta, sintheta
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dims: int, epsilon: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.gamma = nn.Parameter(torch.ones((dims,)))
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
n = torch.rsqrt(x.square().mean(dim=-1, keepdims=True) + self.epsilon)
|
||||||
|
return self.gamma * x * n
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaAttention(nn.Module):
|
||||||
|
def __init__(self, dims: int, num_heads: int):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.rope = RoPE(dims // num_heads, True)
|
||||||
|
self.query_proj = nn.Linear(dims, dims, bias=False)
|
||||||
|
self.key_proj = nn.Linear(dims, dims, bias=False)
|
||||||
|
self.value_proj = nn.Linear(dims, dims, bias=False)
|
||||||
|
self.out_proj = nn.Linear(dims, dims, bias=False)
|
||||||
|
|
||||||
|
def forward(self, queries, keys, values, mask=None, cache=None):
|
||||||
|
queries = self.query_proj(queries)
|
||||||
|
keys = self.key_proj(keys)
|
||||||
|
values = self.value_proj(values)
|
||||||
|
|
||||||
|
num_heads = self.num_heads
|
||||||
|
B, L, D = queries.shape
|
||||||
|
queries = queries.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
||||||
|
keys = keys.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
||||||
|
values = values.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
key_cache, value_cache = cache
|
||||||
|
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||||
|
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||||
|
keys = torch.cat([key_cache, keys], dim=2)
|
||||||
|
values = torch.cat([value_cache, values], dim=2)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||||
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
|
scores = (queries * scale) @ keys.permute(0, 1, 3, 2)
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores + mask
|
||||||
|
scores = torch.softmax(scores, dim=-1)
|
||||||
|
values_hat = (scores @ values).permute(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
|
return self.out_proj(values_hat), (keys, values)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaEncoderLayer(nn.Module):
|
||||||
|
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attention = LlamaAttention(dims, num_heads)
|
||||||
|
|
||||||
|
self.norm1 = RMSNorm(dims)
|
||||||
|
self.norm2 = RMSNorm(dims)
|
||||||
|
|
||||||
|
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
|
||||||
|
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
|
||||||
|
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, cache=None):
|
||||||
|
y = self.norm1(x)
|
||||||
|
y, cache = self.attention(y, y, y, mask, cache)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
y = self.norm2(x)
|
||||||
|
a = self.linear1(y)
|
||||||
|
b = self.linear2(y)
|
||||||
|
y = torch.nn.functional.silu(a) * b
|
||||||
|
y = self.linear3(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
return x, cache
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def measure(model, x, cache):
|
||||||
|
for i in range(5):
|
||||||
|
y, c = model(x, mask=None, cache=cache)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for i in range(5):
|
||||||
|
y, c = model(x, mask=None, cache=cache)
|
||||||
|
sync_if_needed(x)
|
||||||
|
end = time.time()
|
||||||
|
return (end - start) * 1000 / 5
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
H = 32
|
||||||
|
D = 4096
|
||||||
|
F = 43 * 256
|
||||||
|
C = 1000
|
||||||
|
device = torch.device("mps")
|
||||||
|
dtype = torch.float16
|
||||||
|
|
||||||
|
layer = LlamaEncoderLayer(D, F, H).to(device).to(dtype)
|
||||||
|
x = torch.randn(1, 1, D).to(device).to(dtype)
|
||||||
|
cache = [
|
||||||
|
torch.randn(1, H, C, D // H).to(device).to(dtype),
|
||||||
|
torch.randn(1, H, C, D // H).to(device).to(dtype),
|
||||||
|
]
|
||||||
|
|
||||||
|
T = measure(layer, x, cache)
|
||||||
|
|
||||||
|
print("Time per layer per token:", T, "ms")
|
||||||
|
print("Lower bound total time per token:", T * 32, "ms")
|
||||||
106
benchmarks/python/single_ops.py
Normal file
106
benchmarks/python/single_ops.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
import argparse
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
|
def time_add():
|
||||||
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
b = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
mx.eval(a, b)
|
||||||
|
time_fn(mx.add, a, b)
|
||||||
|
|
||||||
|
aT = mx.transpose(a, [0, 2, 1])
|
||||||
|
mx.eval(aT)
|
||||||
|
|
||||||
|
def transpose_add(a, b):
|
||||||
|
return mx.add(a, b)
|
||||||
|
|
||||||
|
time_fn(transpose_add, aT, b)
|
||||||
|
|
||||||
|
b = mx.random.uniform(shape=(1024,))
|
||||||
|
mx.eval(b)
|
||||||
|
|
||||||
|
def slice_add(a, b):
|
||||||
|
return mx.add(a, b)
|
||||||
|
|
||||||
|
time_fn(slice_add, a, b)
|
||||||
|
|
||||||
|
b = mx.reshape(b, (1, 1024, 1))
|
||||||
|
mx.eval(b)
|
||||||
|
|
||||||
|
def mid_slice_add(a, b):
|
||||||
|
return mx.add(a, b)
|
||||||
|
|
||||||
|
time_fn(mid_slice_add, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def time_matmul():
|
||||||
|
a = mx.random.uniform(shape=(1024, 1024))
|
||||||
|
b = mx.random.uniform(shape=(1024, 1024))
|
||||||
|
mx.eval(a, b)
|
||||||
|
time_fn(mx.matmul, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def time_negative():
|
||||||
|
a = mx.random.uniform(shape=(10000, 1000))
|
||||||
|
mx.eval(a)
|
||||||
|
|
||||||
|
def negative(a):
|
||||||
|
return -a
|
||||||
|
|
||||||
|
mx.eval(a)
|
||||||
|
|
||||||
|
time_fn(negative, a)
|
||||||
|
|
||||||
|
|
||||||
|
def time_exp():
|
||||||
|
a = mx.random.uniform(shape=(1000, 100))
|
||||||
|
mx.eval(a)
|
||||||
|
time_fn(mx.exp, a)
|
||||||
|
|
||||||
|
|
||||||
|
def time_logsumexp():
|
||||||
|
a = mx.random.uniform(shape=(64, 10, 10000))
|
||||||
|
mx.eval(a)
|
||||||
|
time_fn(mx.logsumexp, a, axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def time_take():
|
||||||
|
a = mx.random.uniform(shape=(10000, 500))
|
||||||
|
ids = mx.random.randint(low=0, high=10000, shape=(20, 10))
|
||||||
|
ids = [mx.reshape(idx, (-1,)) for idx in ids]
|
||||||
|
mx.eval(ids)
|
||||||
|
|
||||||
|
def random_take():
|
||||||
|
return [mx.take(a, idx, 0) for idx in ids]
|
||||||
|
|
||||||
|
time_fn(random_take)
|
||||||
|
|
||||||
|
|
||||||
|
def time_reshape_transposed():
|
||||||
|
x = mx.random.uniform(shape=(256, 256, 128))
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
def reshape_transposed():
|
||||||
|
return mx.reshape(mx.transpose(x, (1, 0, 2)), (-1,))
|
||||||
|
|
||||||
|
time_fn(reshape_transposed)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser("MLX benchmarks.")
|
||||||
|
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.gpu:
|
||||||
|
mx.set_default_device(mx.gpu)
|
||||||
|
else:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
|
||||||
|
time_add()
|
||||||
|
time_matmul()
|
||||||
|
time_exp()
|
||||||
|
time_negative()
|
||||||
|
time_logsumexp()
|
||||||
|
time_take()
|
||||||
|
time_reshape_transposed()
|
||||||
1
docs/.gitignore
vendored
Normal file
1
docs/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
src/python/_autosummary*/
|
||||||
36
docs/README.md
Normal file
36
docs/README.md
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
## Build the Docs
|
||||||
|
|
||||||
|
### Setup (do once)
|
||||||
|
|
||||||
|
Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
|
||||||
|
for example with `conda`:
|
||||||
|
|
||||||
|
```
|
||||||
|
conda install sphinx
|
||||||
|
pip install sphinx-rtd-theme
|
||||||
|
```
|
||||||
|
|
||||||
|
### Build
|
||||||
|
|
||||||
|
Build the docs from `mlx/docs/`
|
||||||
|
|
||||||
|
```
|
||||||
|
make html
|
||||||
|
```
|
||||||
|
|
||||||
|
View the docs by running a server in `mlx/docs/build/html/`:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m http.server <port>
|
||||||
|
```
|
||||||
|
|
||||||
|
and point your browser to `http://localhost:<port>`.
|
||||||
|
|
||||||
|
### Push to Github Pages
|
||||||
|
|
||||||
|
Check-out the `gh-pages` branch (`git switch gh-pages`) and build
|
||||||
|
the docs. Then force add the `build/html` directory:
|
||||||
|
|
||||||
|
`git add -f build/html`
|
||||||
|
|
||||||
|
Commit and push the changes to the `gh-pages` branch.
|
||||||
20
docs/src/_templates/optimizers-template.rst
Normal file
20
docs/src/_templates/optimizers-template.rst
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
{{ fullname | escape | underline}}
|
||||||
|
|
||||||
|
.. currentmodule:: {{ module }}
|
||||||
|
|
||||||
|
.. autoclass:: {{ objname }}
|
||||||
|
|
||||||
|
{% block methods %}
|
||||||
|
|
||||||
|
{% if methods %}
|
||||||
|
.. rubric:: {{ _('Methods') }}
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
{% for item in methods %}
|
||||||
|
{%- if item not in inherited_members %}
|
||||||
|
~{{ name }}.{{ item }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{% endif %}
|
||||||
|
{% endblock %}
|
||||||
|
|
||||||
44
docs/src/conf.py
Normal file
44
docs/src/conf.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
|
project = "MLX"
|
||||||
|
copyright = "2023, MLX Contributors"
|
||||||
|
author = "MLX Contributors"
|
||||||
|
version = "0.0.0"
|
||||||
|
release = "0.0.0"
|
||||||
|
|
||||||
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
extensions = [
|
||||||
|
"sphinx.ext.autodoc",
|
||||||
|
"sphinx.ext.autosummary",
|
||||||
|
"sphinx.ext.intersphinx",
|
||||||
|
"sphinx.ext.napoleon",
|
||||||
|
]
|
||||||
|
|
||||||
|
python_use_unqualified_type_names = True
|
||||||
|
autosummary_generate = True
|
||||||
|
|
||||||
|
intersphinx_mapping = {
|
||||||
|
"https://docs.python.org/3": None,
|
||||||
|
"https://numpy.org/doc/stable/": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
templates_path = ["_templates"]
|
||||||
|
html_static_path = ["_static"]
|
||||||
|
source_suffix = ".rst"
|
||||||
|
master_doc = "index"
|
||||||
|
highlight_language = "python"
|
||||||
|
pygments_style = "sphinx"
|
||||||
|
|
||||||
|
# -- Options for HTML output -------------------------------------------------
|
||||||
|
|
||||||
|
html_theme = "sphinx_rtd_theme"
|
||||||
|
|
||||||
|
# -- Options for HTMLHelp output ---------------------------------------------
|
||||||
|
|
||||||
|
htmlhelp_basename = "mlx_doc"
|
||||||
382
docs/src/examples/llama-inference.rst
Normal file
382
docs/src/examples/llama-inference.rst
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
LLM inference
|
||||||
|
==============
|
||||||
|
|
||||||
|
MLX enables efficient inference of large-ish transformers on Apple silicon
|
||||||
|
without compromising on ease of use. In this example we will create an
|
||||||
|
inference script for the Llama family of transformer models in which the model
|
||||||
|
is defined in less than 200 lines of python.
|
||||||
|
|
||||||
|
Implementing the model
|
||||||
|
----------------------
|
||||||
|
|
||||||
|
We will use the neural network building blocks defined in the :mod:`mlx.nn`
|
||||||
|
module to concisely define the model architecture.
|
||||||
|
|
||||||
|
Attention layer
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
We will start with the llama attention layer which notably uses the RoPE
|
||||||
|
positional encoding. [1]_ In addition, our attention layer will optionally use a
|
||||||
|
key/value cache that will be concatenated with the provided keys and values to
|
||||||
|
support efficient inference.
|
||||||
|
|
||||||
|
Our implementation uses :class:`mlx.nn.Linear` for all the projections and
|
||||||
|
:class:`mlx.nn.RoPE` for the positional encoding.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class LlamaAttention(nn.Module):
|
||||||
|
def __init__(self, dims: int, num_heads: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
self.rope = nn.RoPE(dims // num_heads, traditional=True)
|
||||||
|
self.query_proj = nn.Linear(dims, dims, bias=False)
|
||||||
|
self.key_proj = nn.Linear(dims, dims, bias=False)
|
||||||
|
self.value_proj = nn.Linear(dims, dims, bias=False)
|
||||||
|
self.out_proj = nn.Linear(dims, dims, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, queries, keys, values, mask=None, cache=None):
|
||||||
|
queries = self.query_proj(queries)
|
||||||
|
keys = self.key_proj(keys)
|
||||||
|
values = self.value_proj(values)
|
||||||
|
|
||||||
|
# Extract some shapes
|
||||||
|
num_heads = self.num_heads
|
||||||
|
B, L, D = queries.shape
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
# Add RoPE to the queries and keys and combine them with the cache
|
||||||
|
if cache is not None:
|
||||||
|
key_cache, value_cache = cache
|
||||||
|
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||||
|
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||||
|
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||||
|
values = mx.concatenate([value_cache, values], axis=2)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
# Finally perform the attention computation
|
||||||
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
|
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores + mask
|
||||||
|
scores = mx.softmax(scores, axis=-1)
|
||||||
|
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
|
# Note that we return the keys and values to possibly be used as a cache
|
||||||
|
return self.out_proj(values_hat), (keys, values)
|
||||||
|
|
||||||
|
Encoder layer
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The other component of the Llama model is the encoder layer which uses RMS
|
||||||
|
normalization [2]_ and SwiGLU. [3]_ For RMS normalization we will use
|
||||||
|
:class:`mlx.nn.RMSNorm` that is already provided in :mod:`mlx.nn`.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class LlamaEncoderLayer(nn.Module):
|
||||||
|
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attention = LlamaAttention(dims, num_heads)
|
||||||
|
|
||||||
|
self.norm1 = nn.RMSNorm(dims)
|
||||||
|
self.norm2 = nn.RMSNorm(dims)
|
||||||
|
|
||||||
|
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
|
||||||
|
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
|
||||||
|
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x, mask=None, cache=None):
|
||||||
|
y = self.norm1(x)
|
||||||
|
y, cache = self.attention(y, y, y, mask, cache)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
y = self.norm2(x)
|
||||||
|
a = self.linear1(y)
|
||||||
|
b = self.linear2(y)
|
||||||
|
y = a * mx.sigmoid(a) * b
|
||||||
|
y = self.linear3(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
return x, cache
|
||||||
|
|
||||||
|
Full model
|
||||||
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
To implement any Llama model we simply have to combine ``LlamaEncoderLayer``
|
||||||
|
instances with an :class:`mlx.nn.Embedding` to embed the input tokens.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class Llama(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embedding = nn.Embedding(vocab_size, dims)
|
||||||
|
self.layers = [
|
||||||
|
LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
self.norm = nn.RMSNorm(dims)
|
||||||
|
self.out_proj = nn.Linear(dims, vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||||
|
mask = mask.astype(self.embedding.weight.dtype)
|
||||||
|
|
||||||
|
x = self.embedding(x)
|
||||||
|
for l in self.layers:
|
||||||
|
x, _ = l(x, mask)
|
||||||
|
x = self.norm(x)
|
||||||
|
return self.out_proj(x)
|
||||||
|
|
||||||
|
Note that in the implementation above we use a simple list to hold the encoder
|
||||||
|
layers but using ``model.parameters()`` will still consider these layers.
|
||||||
|
|
||||||
|
Generation
|
||||||
|
^^^^^^^^^^^
|
||||||
|
|
||||||
|
Our ``Llama`` module can be used for training but not inference as the
|
||||||
|
``__call__`` method above processes one input, completely ignores the cache and
|
||||||
|
performs no sampling whatsoever. In the rest of this subsection, we will
|
||||||
|
implement the inference function as a python generator that processes the
|
||||||
|
prompt and then autoregressively yields tokens one at a time.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class Llama(nn.Module):
|
||||||
|
...
|
||||||
|
|
||||||
|
def generate(self, x, temp=1.0):
|
||||||
|
cache = []
|
||||||
|
|
||||||
|
# Make an additive causal mask. We will need that to process the prompt.
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||||
|
mask = mask.astype(self.embedding.weight.dtype)
|
||||||
|
|
||||||
|
# First we process the prompt x the same way as in __call__ but
|
||||||
|
# save the caches in cache
|
||||||
|
x = self.embedding(x)
|
||||||
|
for l in self.layers:
|
||||||
|
x, c = l(x, mask=mask)
|
||||||
|
cache.append(c) # <--- we store the per layer cache in a
|
||||||
|
# simple python list
|
||||||
|
x = self.norm(x)
|
||||||
|
y = self.out_proj(x[:, -1]) # <--- we only care about the last logits
|
||||||
|
# that generate the next token
|
||||||
|
y = mx.random.categorical(y * (1/temp))
|
||||||
|
|
||||||
|
# y now has size [1]
|
||||||
|
# Since MLX is lazily evaluated nothing is computed yet.
|
||||||
|
# Calling y.item() would force the computation to happen at
|
||||||
|
# this point but we can also choose not to do that and let the
|
||||||
|
# user choose when to start the computation.
|
||||||
|
yield y
|
||||||
|
|
||||||
|
# Now we parsed the prompt and generated the first token we
|
||||||
|
# need to feed it back into the model and loop to generate the
|
||||||
|
# rest.
|
||||||
|
while True:
|
||||||
|
# Unsqueezing the last dimension to add a sequence length
|
||||||
|
# dimension of 1
|
||||||
|
x = y[:, None]
|
||||||
|
|
||||||
|
x = self.embedding(x)
|
||||||
|
for i in range(len(cache)):
|
||||||
|
# We are overwriting the arrays in the cache list. When
|
||||||
|
# the computation will happen, MLX will be discarding the
|
||||||
|
# old cache the moment it is not needed anymore.
|
||||||
|
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
|
||||||
|
x = self.norm(x)
|
||||||
|
y = self.out_proj(x[:, -1])
|
||||||
|
y = mx.random.categorical(y * (1/temp))
|
||||||
|
|
||||||
|
yield y
|
||||||
|
|
||||||
|
Putting it all together
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
We now have everything we need to create a Llama model and sample tokens from
|
||||||
|
it. In the following code, we randomly initialize a small Llama model, process
|
||||||
|
6 tokens of prompt and generate 10 tokens.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
model = Llama(num_layers=12, vocab_size=8192, dims=512, mlp_dims=1024, num_heads=8)
|
||||||
|
|
||||||
|
# Since MLX is lazily evaluated nothing has actually been materialized yet.
|
||||||
|
# We could have set the `dims` to 20_000 on a machine with 8GB of RAM and the
|
||||||
|
# code above would still run. Let's actually materialize the model.
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
prompt = mx.array([[1, 10, 8, 32, 44, 7]]) # <-- Note the double brackets because we
|
||||||
|
# have a batch dimension even
|
||||||
|
# though it is 1 in this case
|
||||||
|
|
||||||
|
generated = [t for i, t in zip(range(10), model.generate(prompt, 0.8))]
|
||||||
|
|
||||||
|
# Since we haven't evaluated anything, nothing is computed yet. The list
|
||||||
|
# `generated` contains the arrays that hold the computation graph for the
|
||||||
|
# full processing of the prompt and the generation of 10 tokens.
|
||||||
|
#
|
||||||
|
# We can evaluate them one at a time, or all together. Concatenate them or
|
||||||
|
# print them. They would all result in very similar runtimes and give exactly
|
||||||
|
# the same results.
|
||||||
|
mx.eval(generated)
|
||||||
|
|
||||||
|
Converting the weights
|
||||||
|
----------------------
|
||||||
|
|
||||||
|
This section assumes that you have access to the original Llama weights and the
|
||||||
|
SentencePiece model that comes with them. We will write a small script to
|
||||||
|
convert the PyTorch weights to MLX compatible ones and write them in a NPZ file
|
||||||
|
that can be loaded directly by MLX.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from itertools import starmap
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def map_torch_to_mlx(key, value):
|
||||||
|
if "tok_embedding" in key:
|
||||||
|
key = "embedding.weight"
|
||||||
|
|
||||||
|
elif "norm" in key:
|
||||||
|
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
|
||||||
|
|
||||||
|
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
|
||||||
|
key = key.replace("wq", "query_proj")
|
||||||
|
key = key.replace("wk", "key_proj")
|
||||||
|
key = key.replace("wv", "value_proj")
|
||||||
|
key = key.replace("wo", "out_proj")
|
||||||
|
|
||||||
|
elif "w1" in key or "w2" in key or "w3" in key:
|
||||||
|
# The FFN is a separate submodule in PyTorch
|
||||||
|
key = key.replace("feed_forward.w1", "linear1")
|
||||||
|
key = key.replace("feed_forward.w3", "linear2")
|
||||||
|
key = key.replace("feed_forward.w2", "linear3")
|
||||||
|
|
||||||
|
elif "output" in key:
|
||||||
|
key = key.replace("output", "out_proj")
|
||||||
|
|
||||||
|
elif "rope" in key:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
return key, value.numpy()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||||
|
parser.add_argument("torch_weights")
|
||||||
|
parser.add_argument("output_file")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
state = torch.load(args.torch_weights)
|
||||||
|
np.savez(
|
||||||
|
args.output_file,
|
||||||
|
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Weight loading and benchmarking
|
||||||
|
-------------------------------
|
||||||
|
|
||||||
|
After converting the weights to be compatible to our implementation, all that is
|
||||||
|
left is to load them from disk and we can finally use the LLM to generate text.
|
||||||
|
We can load numpy format files using the :func:`mlx.core.load` operation.
|
||||||
|
|
||||||
|
To create a parameter dictionary from the key/value representation of NPZ files
|
||||||
|
we will use the :func:`mlx.utils.tree_unflatten` helper method as follows:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from mlx.utils import tree_unflatten
|
||||||
|
|
||||||
|
model.update(tree_unflatten(list(mx.load(weight_file).items())))
|
||||||
|
|
||||||
|
:meth:`mlx.utils.tree_unflatten` will take keys from the NPZ file that look
|
||||||
|
like ``layers.2.attention.query_proj.weight`` and will transform them to
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
{"layers": [..., ..., {"attention": {"query_proj": {"weight": ...}}}]}
|
||||||
|
|
||||||
|
which can then be used to update the model. Note that the method above incurs
|
||||||
|
several unnecessary copies from disk to numpy and then from numpy to MLX. It
|
||||||
|
will be replaced in the future with direct loading to MLX.
|
||||||
|
|
||||||
|
You can download the full example code in `mlx-examples <code>`_. Assuming, the
|
||||||
|
existence of ``weights.pth`` and ``tokenizer.model`` in the current working
|
||||||
|
directory we can play around with our inference script as follows (the timings
|
||||||
|
are representative of an M1 Ultra and the 7B parameter Llama model):
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ python convert.py weights.pth llama-7B.mlx.npz
|
||||||
|
$ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely'
|
||||||
|
[INFO] Loading model from disk: 5.247 s
|
||||||
|
Press enter to start generation
|
||||||
|
------
|
||||||
|
, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down,
|
||||||
|
------
|
||||||
|
[INFO] Prompt processing: 0.437 s
|
||||||
|
[INFO] Full generation: 4.330 s
|
||||||
|
|
||||||
|
We observe that 4.3 seconds are required to generate 100 tokens and 0.4 seconds
|
||||||
|
of those are spent processing the prompt. This amounts to a little over **39 ms
|
||||||
|
per token**.
|
||||||
|
|
||||||
|
By running with a much bigger prompt we can see that the per token generation
|
||||||
|
time as well as the prompt processing time remains almost constant.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'
|
||||||
|
[INFO] Loading model from disk: 5.247 s
|
||||||
|
Press enter to start generation
|
||||||
|
------
|
||||||
|
take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not
|
||||||
|
------
|
||||||
|
[INFO] Prompt processing: 0.579 s
|
||||||
|
[INFO] Full generation: 4.690 s
|
||||||
|
$ python llama.py --num-tokens 500 llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'
|
||||||
|
[INFO] Loading model from disk: 5.628 s
|
||||||
|
Press enter to start generation
|
||||||
|
------
|
||||||
|
take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not reply, but still went on looking at the ground, and took hold of his bundle with a nervous trembling. I waited some time, and then resumed. “It is of no use to say you would not understand, if I were to tell you,” said he. “I have not told you why I am waiting for him,” said I. “And I am sure I should not understand,” replied he. “I will tell you then,” said I, “and, perhaps, you would not be surprised.” “No matter,” said he, “I shall be surprised anyhow; so tell me why you are waiting for him.” “He is my friend,” said I. “Yes,” said he, with a slight smile, “I know.” “He has been kind to me,” said I, “and I am waiting for him. I want to see him, and could have waited as I am now, for a much longer time.” “He will not soon come,” said he. “Unless he sees you here, he will not know of your having waited, and he will be very unlikely to come.” “No matter,” said I, “I shall wait for him.” “This is a strange thing,” said he, still with the same amused smile. “How did you know,” said I, “that he was coming? How should you be waiting?” “That is my secret,” said he. “And you expect him?” “Yes,” said I. “Are you disappointed then, if he does not come?” “No,” said I, “it is his secret, not mine.” “If he comes,” said he, “do you mean to go straight away?” “Yes,” said I, “I cannot be happy if I do not go straight away after him.” “Did you know this place before?” asked he. “Yes,” said I. “Is there any shop to buy food here?” “
|
||||||
|
------
|
||||||
|
[INFO] Prompt processing: 0.633 s
|
||||||
|
[INFO] Full generation: 21.475 s
|
||||||
|
|
||||||
|
Scripts
|
||||||
|
-------
|
||||||
|
|
||||||
|
.. admonition:: Download the code
|
||||||
|
|
||||||
|
The full example code is available in `mlx-examples <code>`_.
|
||||||
|
|
||||||
|
.. code: `https://github.com/ml-explore/mlx-examples/tree/main/llama`_
|
||||||
|
|
||||||
|
.. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.
|
||||||
|
Roformer: Enhanced transformer with rotary position embedding. arXiv
|
||||||
|
preprint arXiv:2104.09864.
|
||||||
|
.. [2] Zhang, B. and Sennrich, R., 2019. Root mean square layer normalization.
|
||||||
|
Advances in Neural Information Processing Systems, 32.
|
||||||
|
.. [3] Shazeer, N., 2020. Glu variants improve transformer. arXiv preprint
|
||||||
|
arXiv:2002.05202.
|
||||||
49
docs/src/index.rst
Normal file
49
docs/src/index.rst
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
MLX
|
||||||
|
===
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:caption: Install
|
||||||
|
:maxdepth: 1
|
||||||
|
|
||||||
|
install
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:caption: Usage
|
||||||
|
:maxdepth: 1
|
||||||
|
|
||||||
|
quick_start
|
||||||
|
using_streams
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:caption: Examples
|
||||||
|
:maxdepth: 1
|
||||||
|
|
||||||
|
examples/linear_regression
|
||||||
|
examples/mlp
|
||||||
|
examples/llama-inference
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:caption: Further Reading
|
||||||
|
:maxdepth: 1
|
||||||
|
|
||||||
|
dev/extensions
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:caption: Python API Reference
|
||||||
|
:maxdepth: 1
|
||||||
|
|
||||||
|
python/array
|
||||||
|
python/devices_and_streams
|
||||||
|
python/ops
|
||||||
|
python/random
|
||||||
|
python/transforms
|
||||||
|
python/fft
|
||||||
|
python/nn
|
||||||
|
python/optimizers
|
||||||
|
python/tree_utils
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:caption: C++ API Reference
|
||||||
|
:maxdepth: 1
|
||||||
|
|
||||||
|
cpp/ops
|
||||||
102
docs/src/install.rst
Normal file
102
docs/src/install.rst
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
Build and Install
|
||||||
|
=================
|
||||||
|
|
||||||
|
Install from PyPI
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
MLX is available at Apple's internal PyPI repository. All you have to do to use
|
||||||
|
MLX with your own Apple silicon computer is
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
pip install apple-mlx -i https://pypi.apple.com/simple
|
||||||
|
|
||||||
|
Build from source
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
Build Requirements
|
||||||
|
^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
||||||
|
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
|
||||||
|
|
||||||
|
|
||||||
|
Python API
|
||||||
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
To build and install the MLX python library from source, first, clone MLX from
|
||||||
|
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||||
|
|
||||||
|
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
|
||||||
|
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
pip install "pybind11[global]"
|
||||||
|
conda install pybind11
|
||||||
|
brew install pybind11
|
||||||
|
|
||||||
|
Then simply build and install it using pip:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
|
||||||
|
|
||||||
|
|
||||||
|
C++ API
|
||||||
|
^^^^^^^
|
||||||
|
|
||||||
|
Currently, MLX must be built and installed from source.
|
||||||
|
|
||||||
|
Similarly to the python library, to build and install the MLX C++ library start
|
||||||
|
by cloning MLX from `its GitHub repo
|
||||||
|
<https://github.com/ml-explore/mlx>`_:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||||
|
|
||||||
|
Create a build directory and run CMake and make:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
mkdir -p build && cd build
|
||||||
|
cmake .. && make -j
|
||||||
|
|
||||||
|
Run tests with:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
make test
|
||||||
|
|
||||||
|
Install with:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
make install
|
||||||
|
|
||||||
|
Note that the built ``mlx.metallib`` file should be either at the same
|
||||||
|
directory as the executable statically linked to ``libmlx.a`` or the
|
||||||
|
preprocessor constant ``METAL_PATH`` should be defined at build time and it
|
||||||
|
should point to the path to the built metal library.
|
||||||
|
|
||||||
|
.. list-table:: Build Options
|
||||||
|
:widths: 25 8
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - Option
|
||||||
|
- Default
|
||||||
|
* - MLX_BUILD_TESTS
|
||||||
|
- ON
|
||||||
|
* - MLX_BUILD_EXAMPLES
|
||||||
|
- OFF
|
||||||
|
* - MLX_BUILD_BENCHMARKS
|
||||||
|
- OFF
|
||||||
|
* - MLX_BUILD_METAL
|
||||||
|
- ON
|
||||||
|
* - MLX_BUILD_PYTHON_BINDINGS
|
||||||
|
- OFF
|
||||||
22
docs/src/python/fft.rst
Normal file
22
docs/src/python/fft.rst
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
.. _fft:
|
||||||
|
|
||||||
|
FFT
|
||||||
|
===
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.fft
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
fft
|
||||||
|
ifft
|
||||||
|
fft2
|
||||||
|
ifft2
|
||||||
|
fftn
|
||||||
|
ifftn
|
||||||
|
rfft
|
||||||
|
irfft
|
||||||
|
rfft2
|
||||||
|
irfft2
|
||||||
|
rfftn
|
||||||
|
irfftn
|
||||||
172
docs/src/python/nn.rst
Normal file
172
docs/src/python/nn.rst
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
.. _nn:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn
|
||||||
|
|
||||||
|
Neural Networks
|
||||||
|
===============
|
||||||
|
|
||||||
|
Writing arbitrarily complex neural networks in MLX can be done using only
|
||||||
|
:class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the
|
||||||
|
user to write again and again the same simple neural network operations as well
|
||||||
|
as handle all the parameter state and initialization manually and explicitly.
|
||||||
|
|
||||||
|
The module :mod:`mlx.nn` solves this problem by providing an intuitive way of
|
||||||
|
composing neural network layers, initializing their parameters, freezing them
|
||||||
|
for finetuning and more.
|
||||||
|
|
||||||
|
Quick Start with Neural Networks
|
||||||
|
---------------------------------
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, in_dims: int, out_dims: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = [
|
||||||
|
nn.Linear(in_dims, 128),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.Linear(128, out_dims),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
for i, l in enumerate(self.layers):
|
||||||
|
x = mx.maximum(x, 0) if i > 0 else x
|
||||||
|
x = l(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# The model is created with all its parameters but nothing is initialized
|
||||||
|
# yet because MLX is lazily evaluated
|
||||||
|
mlp = MLP(2, 10)
|
||||||
|
|
||||||
|
# We can access its parameters by calling mlp.parameters()
|
||||||
|
params = mlp.parameters()
|
||||||
|
print(params["layers"][0]["weight"].shape)
|
||||||
|
|
||||||
|
# Printing a parameter will cause it to be evaluated and thus initialized
|
||||||
|
print(params["layers"][0])
|
||||||
|
|
||||||
|
# We can also force evaluate all parameters to initialize the model
|
||||||
|
mx.eval(mlp.parameters())
|
||||||
|
|
||||||
|
# A simple loss function.
|
||||||
|
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
|
||||||
|
# it from the local scope. It could be a positional argument or a
|
||||||
|
# keyword argument.
|
||||||
|
def l2_loss(x, y):
|
||||||
|
y_hat = mlp(x)
|
||||||
|
return (y_hat - y).square().mean()
|
||||||
|
|
||||||
|
# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
|
||||||
|
# gradient with respect to `mlp.trainable_parameters()`
|
||||||
|
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
|
||||||
|
|
||||||
|
|
||||||
|
.. _module_class:
|
||||||
|
|
||||||
|
The Module Class
|
||||||
|
----------------
|
||||||
|
|
||||||
|
The workhorse of any neural network library is the :class:`Module` class. In
|
||||||
|
MLX the :class:`Module` class is a container of :class:`mlx.core.array` or
|
||||||
|
:class:`Module` instances. Its main function is to provide a way to
|
||||||
|
recursively **access** and **update** its parameters and those of its
|
||||||
|
submodules.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
A parameter of a module is any public member of type :class:`mlx.core.array` (its
|
||||||
|
name should not start with ``_``). It can be arbitrarily nested in other
|
||||||
|
:class:`Module` instances or lists and dictionaries.
|
||||||
|
|
||||||
|
:meth:`Module.parameters` can be used to extract a nested dictionary with all
|
||||||
|
the parameters of a module and its submodules.
|
||||||
|
|
||||||
|
A :class:`Module` can also keep track of "frozen" parameters.
|
||||||
|
:meth:`Module.trainable_parameters` returns only the subset of
|
||||||
|
:meth:`Module.parameters` that is not frozen. When using
|
||||||
|
:meth:`mlx.nn.value_and_grad` the gradients returned will be with respect to these
|
||||||
|
trainable parameters.
|
||||||
|
|
||||||
|
Updating the parameters
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
MLX modules allow accessing and updating individual parameters. However, most
|
||||||
|
times we need to update large subsets of a module's parameters. This action is
|
||||||
|
performed by :meth:`Module.update`.
|
||||||
|
|
||||||
|
Value and grad
|
||||||
|
--------------
|
||||||
|
|
||||||
|
Using a :class:`Module` does not preclude using MLX's high order function
|
||||||
|
transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However,
|
||||||
|
these function transformations assume pure functions, namely the parameters
|
||||||
|
should be passed as an argument to the function being transformed.
|
||||||
|
|
||||||
|
There is an easy pattern to achieve that with MLX modules
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
model = ...
|
||||||
|
|
||||||
|
def f(params, other_inputs):
|
||||||
|
model.update(params) # <---- Necessary to make the model use the passed parameters
|
||||||
|
return model(other_inputs)
|
||||||
|
|
||||||
|
f(model.trainable_parameters(), mx.zeros((10,)))
|
||||||
|
|
||||||
|
However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only
|
||||||
|
computes the gradients with respect to the trainable parameters of the model.
|
||||||
|
|
||||||
|
In detail:
|
||||||
|
|
||||||
|
- it wraps the passed function with a function that calls :meth:`Module.update`
|
||||||
|
to make sure the model is using the provided parameters.
|
||||||
|
- it calls :meth:`mlx.core.value_and_grad` to transform the function into a function
|
||||||
|
that also computes the gradients with respect to the passed parameters.
|
||||||
|
- it wraps the returned function with a function that passes the trainable
|
||||||
|
parameters as the first argument to the function returned by
|
||||||
|
:meth:`mlx.core.value_and_grad`
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
value_and_grad
|
||||||
|
|
||||||
|
Neural Network Layers
|
||||||
|
---------------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
|
Embedding
|
||||||
|
ReLU
|
||||||
|
GELU
|
||||||
|
SiLU
|
||||||
|
Linear
|
||||||
|
Conv1d
|
||||||
|
Conv2d
|
||||||
|
LayerNorm
|
||||||
|
RMSNorm
|
||||||
|
GroupNorm
|
||||||
|
RoPE
|
||||||
|
MultiHeadAttention
|
||||||
|
Sequential
|
||||||
|
|
||||||
|
Layers without parameters (e.g. activation functions) are also provided as
|
||||||
|
simple functions.
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary_functions
|
||||||
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
|
gelu
|
||||||
|
gelu_approx
|
||||||
|
gelu_fast_approx
|
||||||
|
relu
|
||||||
|
silu
|
||||||
7
docs/src/python/nn/module.rst
Normal file
7
docs/src/python/nn/module.rst
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
mlx.nn.Module
|
||||||
|
=============
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn
|
||||||
|
|
||||||
|
.. autoclass:: Module
|
||||||
|
:members:
|
||||||
41
docs/src/python/optimizers.rst
Normal file
41
docs/src/python/optimizers.rst
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
.. _optimizers:
|
||||||
|
|
||||||
|
Optimizers
|
||||||
|
==========
|
||||||
|
|
||||||
|
The optimizers in MLX can be used both with :mod:`mlx.nn` but also with pure
|
||||||
|
:mod:`mlx.core` functions. A typical example involves calling
|
||||||
|
:meth:`Optimizer.update` to update a model's parameters based on the loss
|
||||||
|
gradients and subsequently calling :func:`mlx.core.eval` to evaluate both the
|
||||||
|
model's parameters and the **optimizer state**.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Create a model
|
||||||
|
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
# Create the gradient function and the optimizer
|
||||||
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||||
|
optimizer = optim.SGD(learning_rate=learning_rate)
|
||||||
|
|
||||||
|
for e in range(num_epochs):
|
||||||
|
for X, y in batch_iterate(batch_size, train_images, train_labels):
|
||||||
|
loss, grads = loss_and_grad_fn(model, X, y)
|
||||||
|
|
||||||
|
# Update the model with the gradients. So far no computation has happened.
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
|
# Compute the new parameters but also the optimizer state.
|
||||||
|
mx.eval(model.parameters(), optimizer.state)
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.optimizers
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
:template: optimizers-template.rst
|
||||||
|
|
||||||
|
OptimizerState
|
||||||
|
Optimizer
|
||||||
|
SGD
|
||||||
|
Adam
|
||||||
45
docs/src/python/random.rst
Normal file
45
docs/src/python/random.rst
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
.. _random:
|
||||||
|
|
||||||
|
Random
|
||||||
|
======
|
||||||
|
|
||||||
|
Random sampling functions in MLX use an implicit global PRNG state by default.
|
||||||
|
However, all function take an optional ``key`` keyword argument for when more
|
||||||
|
fine-grained control or explicit state management is needed.
|
||||||
|
|
||||||
|
For example, you can generate random numbers with:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
print(mx.random.uniform())
|
||||||
|
|
||||||
|
which will print a sequence of unique pseudo random numbers. Alternatively you
|
||||||
|
can explicitly set the key:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
key = mx.random.key(0)
|
||||||
|
for _ in range(3):
|
||||||
|
print(mx.random.uniform(key=key))
|
||||||
|
|
||||||
|
which will yield the same pseudo random number at each iteration.
|
||||||
|
|
||||||
|
Following `JAX's PRNG design <https://jax.readthedocs.io/en/latest/jep/263-prng.html>`_
|
||||||
|
we use a splittable version of Threefry, which is a counter-based PRNG.
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.random
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
seed
|
||||||
|
key
|
||||||
|
split
|
||||||
|
bernoulli
|
||||||
|
categorical
|
||||||
|
gumbel
|
||||||
|
normal
|
||||||
|
randint
|
||||||
|
uniform
|
||||||
|
truncated_normal
|
||||||
93
docs/src/quick_start.rst
Normal file
93
docs/src/quick_start.rst
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
Quick Start Guide
|
||||||
|
=================
|
||||||
|
|
||||||
|
MLX is a NumPy-like array framework designed for efficient and flexible
|
||||||
|
machine learning on Apple silicon. The Python API closely follows NumPy with
|
||||||
|
a few exceptions. MLX also has a fully featured C++ API which closely follows
|
||||||
|
the Python API.
|
||||||
|
|
||||||
|
The main differences between MLX and NumPy are:
|
||||||
|
|
||||||
|
- **Composable function transformations**: MLX has composable function
|
||||||
|
transformations for automatic differentiation, automatic vectorization,
|
||||||
|
and computation graph optimization.
|
||||||
|
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
||||||
|
materialized when needed.
|
||||||
|
- **Multi-device**: Operations can run on any of the suppoorted devices (CPU,
|
||||||
|
GPU, ...)
|
||||||
|
|
||||||
|
The design of MLX is strongly inspired by frameworks like `PyTorch
|
||||||
|
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
|
||||||
|
`ArrayFire <https://arrayfire.org/>`_. A noteable difference from these
|
||||||
|
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
|
||||||
|
memory. Operations on MLX arrays can be performed on any of the supported
|
||||||
|
device types without performing data copies. Currently supported device types
|
||||||
|
are the CPU and GPU.
|
||||||
|
|
||||||
|
Basics
|
||||||
|
------
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
Import ``mlx.core`` and make an :class:`array`:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
>> import mlx.core as mx
|
||||||
|
>> a = mx.array([1, 2, 3, 4])
|
||||||
|
>> a.shape
|
||||||
|
[4]
|
||||||
|
>> a.dtype
|
||||||
|
int32
|
||||||
|
>> b = mx.array([1.0, 2.0, 3.0, 4.0])
|
||||||
|
>> b.dtype
|
||||||
|
float32
|
||||||
|
|
||||||
|
Operations in MLX are lazy. The outputs of MLX operations are not computed
|
||||||
|
until they are needed. To force an array to be evaluated use
|
||||||
|
:func:`eval`. Arrays will automatically be evaluated in a few cases. For
|
||||||
|
example, inspecting a scalar with :meth:`array.item`, printing an array,
|
||||||
|
or converting an array from :class:`array` to :class:`numpy.ndarray` all
|
||||||
|
automatically evaluate the array.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
>> c = a + b # c not yet evaluated
|
||||||
|
>> mx.eval(c) # evaluates c
|
||||||
|
>> c = a + b
|
||||||
|
>> print(c) # Also evaluates c
|
||||||
|
array([2, 4, 6, 8], dtype=float32)
|
||||||
|
>> c = a + b
|
||||||
|
>> import numpy as np
|
||||||
|
>> np.array(c) # Also evaluates c
|
||||||
|
array([2., 4., 6., 8.], dtype=float32)
|
||||||
|
|
||||||
|
Function and Graph Transformations
|
||||||
|
----------------------------------
|
||||||
|
|
||||||
|
MLX has standard function transformations like :func:`grad` and :func:`vmap`.
|
||||||
|
Transformations can be composed arbitrarily. For example
|
||||||
|
``grad(vmap(grad(fn)))`` (or any other composition) is allowed.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
>> x = mx.array(0.0)
|
||||||
|
>> mx.sin(x)
|
||||||
|
array(0, dtype=float32)
|
||||||
|
>> mx.grad(mx.sin)(x)
|
||||||
|
array(1, dtype=float32)
|
||||||
|
>> mx.grad(mx.grad(mx.sin))(x)
|
||||||
|
array(-0, dtype=float32)
|
||||||
|
|
||||||
|
Other gradient transformations include :func:`vjp` for vector-Jacobian products
|
||||||
|
and :func:`jvp` for Jacobian-vector products.
|
||||||
|
|
||||||
|
Use :func:`value_and_grad` to efficiently compute both a function's output and
|
||||||
|
gradient with respect to the function's input.
|
||||||
|
|
||||||
|
|
||||||
|
Devices and Streams
|
||||||
|
-------------------
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
16
docs/src/using_streams.rst
Normal file
16
docs/src/using_streams.rst
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
Using Streams
|
||||||
|
=============
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
Specifying the :obj:`Stream`
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
All operations (including random number generation) take an optional
|
||||||
|
keyword argument ``stream``. The ``stream`` kwarg specifies which
|
||||||
|
:obj:`Stream` the operation should run on. If the stream is unspecified then
|
||||||
|
the operation is run on the default stream of the default device:
|
||||||
|
``mx.default_stream(mx.default_device())``. The ``stream`` kwarg can also
|
||||||
|
be a :obj:`Device` (e.g. ``stream=my_device``) in which case the operation is
|
||||||
|
run on the default stream of the provided device
|
||||||
|
``mx.default_stream(my_device)``.
|
||||||
52
examples/cpp/logistic_regression.cpp
Normal file
52
examples/cpp/logistic_regression.cpp
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
#include <chrono>
|
||||||
|
#include <cmath>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
#include "timer.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An example of logistic regression with MLX.
|
||||||
|
*/
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
int num_features = 100;
|
||||||
|
int num_examples = 1'000;
|
||||||
|
int num_iters = 10'000;
|
||||||
|
float learning_rate = 0.1;
|
||||||
|
|
||||||
|
// True parameters
|
||||||
|
auto w_star = random::normal({num_features});
|
||||||
|
|
||||||
|
// The input examples
|
||||||
|
auto X = random::normal({num_examples, num_features});
|
||||||
|
|
||||||
|
// Labels
|
||||||
|
auto y = matmul(X, w_star) > 0;
|
||||||
|
|
||||||
|
// Initialize random parameters
|
||||||
|
array w = 1e-2 * random::normal({num_features});
|
||||||
|
|
||||||
|
auto loss_fn = [&](array w) {
|
||||||
|
auto logits = matmul(X, w);
|
||||||
|
auto scale = (1.0f / num_examples);
|
||||||
|
return scale * sum(logaddexp(array(0.0f), logits) - y * logits);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto grad_fn = grad(loss_fn);
|
||||||
|
|
||||||
|
auto tic = timer::time();
|
||||||
|
for (int it = 0; it < num_iters; ++it) {
|
||||||
|
auto grad = grad_fn(w);
|
||||||
|
w = w - learning_rate * grad;
|
||||||
|
eval(w);
|
||||||
|
}
|
||||||
|
auto toc = timer::time();
|
||||||
|
|
||||||
|
auto loss = loss_fn(w);
|
||||||
|
auto acc = sum((matmul(X, w) > 0) == y) / num_examples;
|
||||||
|
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||||
|
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
|
||||||
|
<< throughput << " (it/s)." << std::endl;
|
||||||
|
}
|
||||||
97
examples/cpp/tutorial.cpp
Normal file
97
examples/cpp/tutorial.cpp
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
#include <cassert>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
void array_basics() {
|
||||||
|
// Make a scalar array:
|
||||||
|
array x(1.0);
|
||||||
|
|
||||||
|
// Get the value out of it:
|
||||||
|
auto s = x.item<float>();
|
||||||
|
assert(s == 1.0);
|
||||||
|
|
||||||
|
// Scalars have a size of 1:
|
||||||
|
size_t size = x.size();
|
||||||
|
assert(size == 1);
|
||||||
|
|
||||||
|
// Scalars have 0 dimensions:
|
||||||
|
int ndim = x.ndim();
|
||||||
|
assert(ndim == 0);
|
||||||
|
|
||||||
|
// The shape should be an empty vector:
|
||||||
|
auto shape = x.shape();
|
||||||
|
assert(shape.empty());
|
||||||
|
|
||||||
|
// The datatype should be float32:
|
||||||
|
auto dtype = x.dtype();
|
||||||
|
assert(dtype == float32);
|
||||||
|
|
||||||
|
// Specify the dtype when constructing the array:
|
||||||
|
x = array(1, int32);
|
||||||
|
assert(x.dtype() == int32);
|
||||||
|
x.item<int>(); // OK
|
||||||
|
// x.item<float>(); // Undefined!
|
||||||
|
|
||||||
|
// Make a multidimensional array:
|
||||||
|
x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||||
|
// mlx is row-major by default so the first row of this array
|
||||||
|
// is [1.0, 2.0] and the second row is [3.0, 4.0]
|
||||||
|
|
||||||
|
// Make an array of shape {2, 2} filled with ones:
|
||||||
|
auto y = ones({2, 2});
|
||||||
|
|
||||||
|
// Pointwise add x and y:
|
||||||
|
auto z = add(x, y);
|
||||||
|
|
||||||
|
// Same thing:
|
||||||
|
z = x + y;
|
||||||
|
|
||||||
|
// mlx is lazy by default. At this point `z` only
|
||||||
|
// has a shape and a type but no actual data:
|
||||||
|
assert(z.dtype() == float32);
|
||||||
|
assert(z.shape(0) == 2);
|
||||||
|
assert(z.shape(1) == 2);
|
||||||
|
|
||||||
|
// To actually run the compuation you must evaluate `z`.
|
||||||
|
// Under the hood, mlx records operations in a graph.
|
||||||
|
// The variable `z` is a node in the graph which points to its operation
|
||||||
|
// and inputs. When `eval` is called on an array (or arrays), the array and
|
||||||
|
// all of its dependencies are recursively evaluated to produce the result.
|
||||||
|
// Once an array is evaluated, it has data and is detached from its inputs.
|
||||||
|
eval(z);
|
||||||
|
|
||||||
|
// Of course the array can still be an input to other operations. You can even
|
||||||
|
// call eval on the array again, this will just be a no-op:
|
||||||
|
eval(z); // no-op
|
||||||
|
|
||||||
|
// Some functions or methods on arrays implicitly evaluate them. For example
|
||||||
|
// accessing a value in an array or printing the array implicitly evaluate it:
|
||||||
|
z = ones({1});
|
||||||
|
z.item<float>(); // implicit evaluation
|
||||||
|
|
||||||
|
z = ones({2, 2});
|
||||||
|
std::cout << z << std::endl; // implicit evaluation
|
||||||
|
}
|
||||||
|
|
||||||
|
void automatic_differentiation() {
|
||||||
|
auto fn = [](array x) { return square(x); };
|
||||||
|
|
||||||
|
// Computing the derivative function of a function
|
||||||
|
auto grad_fn = grad(fn);
|
||||||
|
// Call grad_fn on the input to get the derivative
|
||||||
|
auto x = array(1.5);
|
||||||
|
auto dfdx = grad_fn(x);
|
||||||
|
// dfdx is 2 * x
|
||||||
|
|
||||||
|
// Get the second derivative by composing grad with grad
|
||||||
|
auto df2dx2 = grad(grad(fn))(x);
|
||||||
|
// df2dx2 is 2
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
array_basics();
|
||||||
|
automatic_differentiation();
|
||||||
|
}
|
||||||
66
examples/extensions/CMakeLists.txt
Normal file
66
examples/extensions/CMakeLists.txt
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.24)
|
||||||
|
|
||||||
|
project(mlx_sample_extensions LANGUAGES CXX)
|
||||||
|
|
||||||
|
# ----------------------------- Setup -----------------------------
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
|
|
||||||
|
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
||||||
|
|
||||||
|
# ----------------------------- Dependencies -----------------------------
|
||||||
|
find_package(MLX CONFIG REQUIRED)
|
||||||
|
find_package(Python COMPONENTS Interpreter Development)
|
||||||
|
find_package(pybind11 CONFIG REQUIRED)
|
||||||
|
|
||||||
|
# ----------------------------- Extensions -----------------------------
|
||||||
|
|
||||||
|
# Add library
|
||||||
|
add_library(mlx_ext)
|
||||||
|
|
||||||
|
# Add sources
|
||||||
|
target_sources(
|
||||||
|
mlx_ext
|
||||||
|
PUBLIC
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add include headers
|
||||||
|
target_include_directories(
|
||||||
|
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Link to mlx
|
||||||
|
target_link_libraries(mlx_ext PUBLIC mlx)
|
||||||
|
|
||||||
|
# ----------------------------- Metal -----------------------------
|
||||||
|
|
||||||
|
# Build metallib
|
||||||
|
if(MLX_BUILD_METAL)
|
||||||
|
|
||||||
|
mlx_build_metallib(
|
||||||
|
TARGET mlx_ext_metallib
|
||||||
|
TITLE mlx_ext
|
||||||
|
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
|
||||||
|
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
|
||||||
|
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
|
||||||
|
)
|
||||||
|
|
||||||
|
add_dependencies(
|
||||||
|
mlx_ext
|
||||||
|
mlx_ext_metallib
|
||||||
|
)
|
||||||
|
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# ----------------------------- Pybind -----------------------------
|
||||||
|
pybind11_add_module(
|
||||||
|
mlx_sample_extensions
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||||
|
)
|
||||||
|
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
||||||
|
|
||||||
|
if(BUILD_SHARED_LIBS)
|
||||||
|
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
||||||
|
endif()
|
||||||
84
examples/extensions/axpby/axpby.h
Normal file
84
examples/extensions/axpby/axpby.h
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Operation
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Scale and sum two vectors elementwise
|
||||||
|
* z = alpha * x + beta * y
|
||||||
|
*
|
||||||
|
* Follow numpy style broadcasting between x and y
|
||||||
|
* Inputs are upcasted to floats if needed
|
||||||
|
**/
|
||||||
|
array axpby(
|
||||||
|
const array& x, // Input array x
|
||||||
|
const array& y, // Input array y
|
||||||
|
const float alpha, // Scaling factor for x
|
||||||
|
const float beta, // Scaling factor for y
|
||||||
|
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||||
|
);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Primitive
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
class Axpby : public Primitive {
|
||||||
|
public:
|
||||||
|
explicit Axpby(Stream stream, float alpha, float beta)
|
||||||
|
: Primitive(stream), alpha_(alpha), beta_(beta){};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||||
|
* for the given inputs and populate the output array.
|
||||||
|
*
|
||||||
|
* To avoid unecessary allocations, the evaluation function
|
||||||
|
* is responsible for allocating space for the array.
|
||||||
|
*/
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
/** The Jacobian-vector product. */
|
||||||
|
array jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) override;
|
||||||
|
|
||||||
|
/** The vector-Jacobian product. */
|
||||||
|
std::vector<array> vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const array& cotan,
|
||||||
|
const std::vector<int>& argnums) override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The primitive must know how to vectorize itself accross
|
||||||
|
* the given axes. The output is a pair containing the array
|
||||||
|
* representing the vectorized computation and the axis which
|
||||||
|
* corresponds to the output vectorized dimension.
|
||||||
|
*/
|
||||||
|
std::pair<array, int> vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
|
/** Print the primitive. */
|
||||||
|
void print(std::ostream& os) override {
|
||||||
|
os << "Axpby";
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Equivalence check **/
|
||||||
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
float alpha_;
|
||||||
|
float beta_;
|
||||||
|
|
||||||
|
/** Fall back implementation for evaluation on CPU */
|
||||||
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
61
examples/extensions/axpby/axpby.metal
Normal file
61
examples/extensions/axpby/axpby.metal
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
[[kernel]] void axpby_general(
|
||||||
|
device const T* x [[buffer(0)]],
|
||||||
|
device const T* y [[buffer(1)]],
|
||||||
|
device T* out [[buffer(2)]],
|
||||||
|
constant const float& alpha [[buffer(3)]],
|
||||||
|
constant const float& beta [[buffer(4)]],
|
||||||
|
constant const int* shape [[buffer(5)]],
|
||||||
|
constant const size_t* x_strides [[buffer(6)]],
|
||||||
|
constant const size_t* y_strides [[buffer(7)]],
|
||||||
|
constant const int& ndim [[buffer(8)]],
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||||
|
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
||||||
|
out[index] =
|
||||||
|
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
[[kernel]] void axpby_contiguous(
|
||||||
|
device const T* x [[buffer(0)]],
|
||||||
|
device const T* y [[buffer(1)]],
|
||||||
|
device T* out [[buffer(2)]],
|
||||||
|
constant const float& alpha [[buffer(3)]],
|
||||||
|
constant const float& beta [[buffer(4)]],
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
out[index] =
|
||||||
|
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_axpby(type_name, type) \
|
||||||
|
template [[host_name("axpby_general_" #type_name)]] \
|
||||||
|
[[kernel]] void axpby_general<type>( \
|
||||||
|
device const type* x [[buffer(0)]], \
|
||||||
|
device const type* y [[buffer(1)]], \
|
||||||
|
device type* out [[buffer(2)]], \
|
||||||
|
constant const float& alpha [[buffer(3)]], \
|
||||||
|
constant const float& beta [[buffer(4)]], \
|
||||||
|
constant const int* shape [[buffer(5)]], \
|
||||||
|
constant const size_t* x_strides [[buffer(6)]], \
|
||||||
|
constant const size_t* y_strides [[buffer(7)]], \
|
||||||
|
constant const int& ndim [[buffer(8)]], \
|
||||||
|
uint index [[thread_position_in_grid]]); \
|
||||||
|
template [[host_name("axpby_contiguous_" #type_name)]] \
|
||||||
|
[[kernel]] void axpby_contiguous<type>( \
|
||||||
|
device const type* x [[buffer(0)]], \
|
||||||
|
device const type* y [[buffer(1)]], \
|
||||||
|
device type* out [[buffer(2)]], \
|
||||||
|
constant const float& alpha [[buffer(3)]], \
|
||||||
|
constant const float& beta [[buffer(4)]], \
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
instantiate_axpby(float32, float);
|
||||||
|
instantiate_axpby(float16, half);
|
||||||
|
instantiate_axpby(bflot16, bfloat16_t);
|
||||||
|
instantiate_axpby(complex64, complex64_t);
|
||||||
2
examples/extensions/mlx_sample_extensions/__init__.py
Normal file
2
examples/extensions/mlx_sample_extensions/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
from .mlx_sample_extensions import *
|
||||||
16
examples/extensions/setup.py
Normal file
16
examples/extensions/setup.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from mlx import extension
|
||||||
|
from setuptools import setup
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
setup(
|
||||||
|
name="mlx_sample_extensions",
|
||||||
|
version="0.0.0",
|
||||||
|
description="Sample C++ and Metal extensions for MLX primitives.",
|
||||||
|
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
|
||||||
|
cmdclass={"build_ext": extension.CMakeBuild},
|
||||||
|
packages=["mlx_sample_extensions"],
|
||||||
|
package_dir={"": "."},
|
||||||
|
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||||
|
zip_safe=False,
|
||||||
|
python_requires=">=3.7",
|
||||||
|
)
|
||||||
43
examples/python/linear_regression.py
Normal file
43
examples/python/linear_regression.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
import time
|
||||||
|
|
||||||
|
num_features = 100
|
||||||
|
num_examples = 1_000
|
||||||
|
num_iters = 10_000
|
||||||
|
lr = 0.01
|
||||||
|
|
||||||
|
# True parameters
|
||||||
|
w_star = mx.random.normal((num_features,))
|
||||||
|
|
||||||
|
# Input examples (design matrix)
|
||||||
|
X = mx.random.normal((num_examples, num_features))
|
||||||
|
|
||||||
|
# Noisy labels
|
||||||
|
eps = 1e-2 * mx.random.normal((num_examples,))
|
||||||
|
y = X @ w_star + eps
|
||||||
|
|
||||||
|
# Initialize random parameters
|
||||||
|
w = 1e-2 * mx.random.normal((num_features,))
|
||||||
|
|
||||||
|
|
||||||
|
def loss_fn(w):
|
||||||
|
return 0.5 * mx.mean(mx.square(X @ w - y))
|
||||||
|
|
||||||
|
|
||||||
|
grad_fn = mx.grad(loss_fn)
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
for _ in range(num_iters):
|
||||||
|
grad = grad_fn(w)
|
||||||
|
w = w - lr * grad
|
||||||
|
mx.eval(w)
|
||||||
|
toc = time.time()
|
||||||
|
|
||||||
|
loss = loss_fn(w)
|
||||||
|
error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
|
||||||
|
throughput = num_iters / (toc - tic)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
|
||||||
|
f"Throughput {throughput:.5f} (it/s)"
|
||||||
|
)
|
||||||
46
examples/python/logistic_regression.py
Normal file
46
examples/python/logistic_regression.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
import time
|
||||||
|
|
||||||
|
num_features = 100
|
||||||
|
num_examples = 1_000
|
||||||
|
num_iters = 10_000
|
||||||
|
lr = 0.1
|
||||||
|
|
||||||
|
# True parameters
|
||||||
|
w_star = mx.random.normal((num_features,))
|
||||||
|
|
||||||
|
# Input examples
|
||||||
|
X = mx.random.normal((num_examples, num_features))
|
||||||
|
|
||||||
|
# Labels
|
||||||
|
y = (X @ w_star) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize random parameters
|
||||||
|
w = 1e-2 * mx.random.normal((num_features,))
|
||||||
|
|
||||||
|
|
||||||
|
def loss_fn(w):
|
||||||
|
logits = X @ w
|
||||||
|
return mx.mean(mx.logaddexp(0.0, logits) - y * logits)
|
||||||
|
|
||||||
|
|
||||||
|
grad_fn = mx.grad(loss_fn)
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
for _ in range(num_iters):
|
||||||
|
grad = grad_fn(w)
|
||||||
|
w = w - lr * grad
|
||||||
|
mx.eval(w)
|
||||||
|
|
||||||
|
toc = time.time()
|
||||||
|
|
||||||
|
loss = loss_fn(w)
|
||||||
|
final_preds = (X @ w) > 0
|
||||||
|
acc = mx.mean(final_preds == y)
|
||||||
|
|
||||||
|
throughput = num_iters / (toc - tic)
|
||||||
|
print(
|
||||||
|
f"Loss {loss.item():.5f}, Accuracy {acc.item():.5f} "
|
||||||
|
f"Throughput {throughput:.5f} (it/s)"
|
||||||
|
)
|
||||||
43
mlx.pc.in
Normal file
43
mlx.pc.in
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
# Find MLX
|
||||||
|
#
|
||||||
|
# Defines the following variables:
|
||||||
|
#
|
||||||
|
# MLX_FOUND : True if MLX is found
|
||||||
|
# MLX_INCLUDE_DIRS : Include directory
|
||||||
|
# MLX_LIBRARIES : Libraries to link against
|
||||||
|
# MLX_CXX_FLAGS : Additional compiler flags
|
||||||
|
# MLX_BUILD_ACCELERATE : True if MLX was built with accelerate
|
||||||
|
# MLX_BUILD_METAL : True if MLX was built with metal
|
||||||
|
|
||||||
|
@PACKAGE_INIT@
|
||||||
|
|
||||||
|
include(@PACKAGE_MLX_CMAKE_INSTALL_MODULE_DIR@/MLXTargets.cmake)
|
||||||
|
include(@PACKAGE_MLX_CMAKE_INSTALL_MODULE_DIR@/extension.cmake)
|
||||||
|
|
||||||
|
set_and_check(MLX_LIBRARY_DIRS @PACKAGE_CMAKE_INSTALL_LIBDIR@)
|
||||||
|
set_and_check(MLX_INCLUDE_DIRS @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@)
|
||||||
|
set(MLX_LIBRARIES mlx)
|
||||||
|
|
||||||
|
find_library(MLX_LIBRARY mlx PATHS ${MLX_LIBRARY_DIRS})
|
||||||
|
|
||||||
|
if (@MLX_BUILD_ACCELERATE@)
|
||||||
|
set(MLX_BUILD_ACCELERATE @MLX_BUILD_ACCELERATE@)
|
||||||
|
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -DACCELERATE_NEW_LAPACK)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (@MLX_BUILD_METAL@)
|
||||||
|
set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
|
||||||
|
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
|
||||||
|
set_and_check(MLX_INCLUDE_DIRS
|
||||||
|
${MLX_INCLUDE_DIRS}
|
||||||
|
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set_target_properties(mlx PROPERTIES
|
||||||
|
CXX_STANDARD 17
|
||||||
|
INTERFACE_COMPILE_OPTIONS "${MLX_CXX_FLAGS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
include(FindPackageHandleStandardArgs)
|
||||||
|
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
|
||||||
2
mlx/3rdparty/.clang-format
vendored
Normal file
2
mlx/3rdparty/.clang-format
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
DisableFormat: true
|
||||||
|
SortIncludes: Never
|
||||||
48
mlx/allocator.cpp
Normal file
48
mlx/allocator.cpp
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
#include <cstdlib>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
|
namespace mlx::core::allocator {
|
||||||
|
|
||||||
|
Buffer malloc(size_t size) {
|
||||||
|
auto buffer = allocator().malloc(size);
|
||||||
|
if (size && !buffer.ptr()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
void free(Buffer buffer) {
|
||||||
|
return allocator().free(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
Buffer CommonAllocator::malloc(size_t size) {
|
||||||
|
return Buffer{std::malloc(size)};
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommonAllocator::free(Buffer buffer) {
|
||||||
|
std::free(buffer.raw_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
Buffer malloc_or_wait(size_t size) {
|
||||||
|
auto buffer = allocator().malloc(size);
|
||||||
|
|
||||||
|
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
|
||||||
|
scheduler::wait_for_one();
|
||||||
|
buffer = allocator().malloc(size);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (size && !buffer.ptr()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::allocator
|
||||||
436
mlx/array.h
Normal file
436
mlx/array.h
Normal file
@@ -0,0 +1,436 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/dtype.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Forward declaration
|
||||||
|
class Primitive;
|
||||||
|
using deleter_t = std::function<void(allocator::Buffer)>;
|
||||||
|
|
||||||
|
class array {
|
||||||
|
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||||
|
* object */
|
||||||
|
|
||||||
|
public:
|
||||||
|
/** Construct a scalar array with zero dimensions. */
|
||||||
|
template <typename T>
|
||||||
|
explicit array(T val, Dtype dtype = TypeToDtype<T>());
|
||||||
|
|
||||||
|
/* Special case since std::complex can't be implicitly converted to other
|
||||||
|
* types. */
|
||||||
|
explicit array(const std::complex<float>& val, Dtype dtype = complex64);
|
||||||
|
|
||||||
|
template <typename It>
|
||||||
|
array(
|
||||||
|
It data,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype =
|
||||||
|
TypeToDtype<typename std::iterator_traits<It>::value_type>());
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
|
||||||
|
|
||||||
|
/* Special case so empty lists default to float32. */
|
||||||
|
array(std::initializer_list<float> data);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
array(
|
||||||
|
std::initializer_list<T> data,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype = TypeToDtype<T>());
|
||||||
|
|
||||||
|
/* Build an array from a buffer */
|
||||||
|
array(
|
||||||
|
allocator::Buffer data,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype,
|
||||||
|
deleter_t deleter = allocator::free);
|
||||||
|
|
||||||
|
/** Assignment to rvalue does not compile. */
|
||||||
|
array& operator=(const array& other) && = delete;
|
||||||
|
array& operator=(array&& other) && = delete;
|
||||||
|
|
||||||
|
/** Default copy and move constructors otherwise. */
|
||||||
|
array& operator=(array&& other) & = default;
|
||||||
|
array(const array& other) = default;
|
||||||
|
array(array&& other) = default;
|
||||||
|
|
||||||
|
array& operator=(const array& other) & {
|
||||||
|
if (this->id() != other.id()) {
|
||||||
|
this->array_desc_ = other.array_desc_;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** The size of the array's datatype in bytes. */
|
||||||
|
size_t itemsize() const {
|
||||||
|
return size_of(dtype());
|
||||||
|
};
|
||||||
|
|
||||||
|
/** The number of elements in the array. */
|
||||||
|
size_t size() const {
|
||||||
|
return array_desc_->size;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** The number of bytes in the array. */
|
||||||
|
size_t nbytes() const {
|
||||||
|
return size() * itemsize();
|
||||||
|
};
|
||||||
|
|
||||||
|
/** The number of dimensions of the array. */
|
||||||
|
size_t ndim() const {
|
||||||
|
return array_desc_->shape.size();
|
||||||
|
};
|
||||||
|
|
||||||
|
/** The shape of the array as a vector of integers. */
|
||||||
|
const std::vector<int>& shape() const {
|
||||||
|
return array_desc_->shape;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the size of the corresponding dimension.
|
||||||
|
*
|
||||||
|
* This function supports negative indexing and provides
|
||||||
|
* bounds checking. */
|
||||||
|
int shape(int dim) const {
|
||||||
|
return shape().at(dim < 0 ? dim + ndim() : dim);
|
||||||
|
};
|
||||||
|
|
||||||
|
/** The strides of the array. */
|
||||||
|
const std::vector<size_t>& strides() const {
|
||||||
|
return array_desc_->strides;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** Get the arrays data type. */
|
||||||
|
Dtype dtype() const {
|
||||||
|
return array_desc_->dtype;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** Evaluate the array. */
|
||||||
|
void eval(bool retain_graph = false);
|
||||||
|
|
||||||
|
/** Get the value from a scalar array. */
|
||||||
|
template <typename T>
|
||||||
|
T item(bool retain_graph = false);
|
||||||
|
|
||||||
|
struct ArrayIterator {
|
||||||
|
using iterator_category = std::random_access_iterator_tag;
|
||||||
|
using difference_type = size_t;
|
||||||
|
using value_type = const array;
|
||||||
|
using reference = value_type;
|
||||||
|
|
||||||
|
explicit ArrayIterator(const array& arr, int idx = 0) : arr(arr), idx(idx) {
|
||||||
|
if (arr.ndim() == 0) {
|
||||||
|
throw std::invalid_argument("Cannot iterate over 0-d array.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reference operator*() const;
|
||||||
|
|
||||||
|
ArrayIterator& operator+(difference_type diff) {
|
||||||
|
idx += diff;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayIterator& operator++() {
|
||||||
|
idx++;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
|
||||||
|
return a.arr.id() == b.arr.id() && a.idx == b.idx;
|
||||||
|
};
|
||||||
|
friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
|
||||||
|
return !(a == b);
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
int idx;
|
||||||
|
const array& arr;
|
||||||
|
};
|
||||||
|
|
||||||
|
ArrayIterator begin() const {
|
||||||
|
return ArrayIterator(*this);
|
||||||
|
}
|
||||||
|
ArrayIterator end() const {
|
||||||
|
return ArrayIterator(*this, shape(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The following methods should be used with caution.
|
||||||
|
* They are intended for use by the backend implementation and the
|
||||||
|
* API may change.
|
||||||
|
*/
|
||||||
|
|
||||||
|
array(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype,
|
||||||
|
std::unique_ptr<Primitive> primitive,
|
||||||
|
const std::vector<array>& inputs);
|
||||||
|
|
||||||
|
/** A unique identifier for an array. */
|
||||||
|
std::uintptr_t id() const {
|
||||||
|
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Data {
|
||||||
|
allocator::Buffer buffer;
|
||||||
|
deleter_t d;
|
||||||
|
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
||||||
|
: buffer(buffer), d(d){};
|
||||||
|
// Not copyable
|
||||||
|
Data(const Data& d) = delete;
|
||||||
|
Data& operator=(const Data& d) = delete;
|
||||||
|
~Data() {
|
||||||
|
d(buffer);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Flags {
|
||||||
|
// True if there are no gaps in the underlying data. Each item
|
||||||
|
// in the underlying data buffer belongs to at least one index.
|
||||||
|
bool contiguous : 1;
|
||||||
|
|
||||||
|
bool row_contiguous : 1;
|
||||||
|
bool col_contiguous : 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** The array's primitive. */
|
||||||
|
Primitive& primitive() const {
|
||||||
|
return *(array_desc_->primitive);
|
||||||
|
};
|
||||||
|
|
||||||
|
/** Check if the array has an attached primitive or is a leaf node. */
|
||||||
|
bool has_primitive() const {
|
||||||
|
return array_desc_->primitive != nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** The array's inputs. */
|
||||||
|
const std::vector<array>& inputs() const {
|
||||||
|
return array_desc_->inputs;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** A non-const reference to the array's inputs so that they can be used to
|
||||||
|
* edit the graph. */
|
||||||
|
std::vector<array>& editable_inputs() {
|
||||||
|
return array_desc_->inputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Detach the array from the graph. */
|
||||||
|
void detach();
|
||||||
|
|
||||||
|
/** Get the Flags bit-field. */
|
||||||
|
const Flags& flags() const {
|
||||||
|
return array_desc_->flags;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** The size (in elements) of the underlying buffer the array points to. */
|
||||||
|
size_t data_size() const {
|
||||||
|
return array_desc_->data_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
allocator::Buffer& buffer() {
|
||||||
|
return array_desc_->data->buffer;
|
||||||
|
};
|
||||||
|
const allocator::Buffer& buffer() const {
|
||||||
|
return array_desc_->data->buffer;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T* data() {
|
||||||
|
return static_cast<T*>(array_desc_->data_ptr);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
const T* data() const {
|
||||||
|
return static_cast<T*>(array_desc_->data_ptr);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check if the array has been evaluated
|
||||||
|
bool is_evaled() const {
|
||||||
|
return array_desc_->data != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark the array as a tracer array (true) or not.
|
||||||
|
void set_tracer(bool is_tracer) {
|
||||||
|
array_desc_->is_tracer = is_tracer;
|
||||||
|
}
|
||||||
|
// Check if the array is a tracer array
|
||||||
|
bool is_tracer() const {
|
||||||
|
return array_desc_->is_tracer;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
|
||||||
|
|
||||||
|
void set_data(
|
||||||
|
allocator::Buffer buffer,
|
||||||
|
size_t data_size,
|
||||||
|
std::vector<size_t> strides,
|
||||||
|
Flags flags,
|
||||||
|
deleter_t d = allocator::free);
|
||||||
|
|
||||||
|
void copy_shared_buffer(
|
||||||
|
const array& other,
|
||||||
|
const std::vector<size_t>& strides,
|
||||||
|
Flags flags,
|
||||||
|
size_t data_size,
|
||||||
|
size_t offset = 0);
|
||||||
|
|
||||||
|
void copy_shared_buffer(const array& other);
|
||||||
|
|
||||||
|
void overwrite_descriptor(const array& other) {
|
||||||
|
array_desc_ = other.array_desc_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Initialize the arrays data
|
||||||
|
template <typename It>
|
||||||
|
void init(const It src);
|
||||||
|
|
||||||
|
struct ArrayDesc {
|
||||||
|
std::vector<int> shape;
|
||||||
|
std::vector<size_t> strides;
|
||||||
|
size_t size;
|
||||||
|
Dtype dtype;
|
||||||
|
std::unique_ptr<Primitive> primitive{nullptr};
|
||||||
|
|
||||||
|
// Indicates an array is being used in a graph transform
|
||||||
|
// and should not be detached from the graph
|
||||||
|
bool is_tracer{false};
|
||||||
|
|
||||||
|
// This is a shared pointer so that *different* arrays
|
||||||
|
// can share the underlying data buffer.
|
||||||
|
std::shared_ptr<Data> data{nullptr};
|
||||||
|
|
||||||
|
// Properly offset data pointer
|
||||||
|
void* data_ptr{nullptr};
|
||||||
|
|
||||||
|
// The size in elements of the data buffer the array accesses
|
||||||
|
// This can be different than the actual size of the array if it
|
||||||
|
// has been broadcast or irregularly strided.
|
||||||
|
size_t data_size;
|
||||||
|
|
||||||
|
// Contains useful meta data about the array
|
||||||
|
Flags flags;
|
||||||
|
|
||||||
|
std::vector<array> inputs;
|
||||||
|
|
||||||
|
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
||||||
|
|
||||||
|
explicit ArrayDesc(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype,
|
||||||
|
std::unique_ptr<Primitive> primitive,
|
||||||
|
const std::vector<array>& inputs);
|
||||||
|
|
||||||
|
~ArrayDesc();
|
||||||
|
};
|
||||||
|
|
||||||
|
// The ArrayDesc contains the details of the materialized array including the
|
||||||
|
// shape, strides, the data type. It also includes
|
||||||
|
// the primitive which knows how to compute the array's data from its inputs
|
||||||
|
// and a the list of array's inputs for the primitive.
|
||||||
|
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
||||||
|
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
||||||
|
init(&val);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename It>
|
||||||
|
array::array(
|
||||||
|
It data,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
|
||||||
|
array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
||||||
|
init(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
array::array(
|
||||||
|
std::initializer_list<T> data,
|
||||||
|
Dtype dtype /* = TypeToDtype<T>() */)
|
||||||
|
: array_desc_(std::make_shared<ArrayDesc>(
|
||||||
|
std::vector<int>{static_cast<int>(data.size())},
|
||||||
|
dtype)) {
|
||||||
|
init(data.begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
array::array(
|
||||||
|
std::initializer_list<T> data,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype /* = TypeToDtype<T>() */)
|
||||||
|
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
||||||
|
if (data.size() != size()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Data size and provided shape mismatch in array construction.");
|
||||||
|
}
|
||||||
|
init(data.begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T array::item(bool retain_graph /* = false */) {
|
||||||
|
if (size() != 1) {
|
||||||
|
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
||||||
|
}
|
||||||
|
eval(retain_graph);
|
||||||
|
return *data<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename It>
|
||||||
|
void array::init(It src) {
|
||||||
|
set_data(allocator::malloc(size() * size_of(dtype())));
|
||||||
|
switch (dtype()) {
|
||||||
|
case bool_:
|
||||||
|
std::copy(src, src + size(), data<bool>());
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
std::copy(src, src + size(), data<uint8_t>());
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
std::copy(src, src + size(), data<uint16_t>());
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
std::copy(src, src + size(), data<uint32_t>());
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
std::copy(src, src + size(), data<uint64_t>());
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
std::copy(src, src + size(), data<int8_t>());
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
std::copy(src, src + size(), data<int16_t>());
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
std::copy(src, src + size(), data<int32_t>());
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
std::copy(src, src + size(), data<int64_t>());
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
std::copy(src, src + size(), data<float16_t>());
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
std::copy(src, src + size(), data<float>());
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
std::copy(src, src + size(), data<bfloat16_t>());
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
std::copy(src, src + size(), data<complex64_t>());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
9
mlx/backend/accelerate/CMakeLists.txt
Normal file
9
mlx/backend/accelerate/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
target_sources(
|
||||||
|
mlx
|
||||||
|
PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
|
)
|
||||||
167
mlx/backend/accelerate/matmul.cpp
Normal file
167
mlx/backend/accelerate/matmul.cpp
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <vecLib/BNNS/bnns.h>
|
||||||
|
#include <vecLib/cblas_new.h>
|
||||||
|
|
||||||
|
#include "mlx/backend/accelerate/utils.h"
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::tuple<bool, size_t, array> check_transpose(const array& arr) {
|
||||||
|
auto stx = arr.strides()[arr.ndim() - 2];
|
||||||
|
auto sty = arr.strides()[arr.ndim() - 1];
|
||||||
|
if (stx == arr.shape(-1) && sty == 1) {
|
||||||
|
return std::make_tuple(false, stx, arr);
|
||||||
|
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||||
|
return std::make_tuple(true, sty, arr);
|
||||||
|
} else {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy(arr, arr_copy, CopyType::General);
|
||||||
|
size_t stx = arr.shape(-1);
|
||||||
|
return std::make_tuple(false, stx, arr_copy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
|
||||||
|
if (out.dtype() != float32) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[matmul_cblas] on CPU currently only supports float32");
|
||||||
|
}
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||||
|
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||||
|
size_t M = a.shape(-2);
|
||||||
|
size_t N = b.shape(-1);
|
||||||
|
size_t K = a.shape(-1);
|
||||||
|
|
||||||
|
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||||
|
cblas_sgemm(
|
||||||
|
CblasRowMajor,
|
||||||
|
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||||
|
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
1.0f, // alpha
|
||||||
|
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||||
|
lda,
|
||||||
|
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||||
|
ldb,
|
||||||
|
0.0f, // beta
|
||||||
|
out.data<float>() + M * N * i,
|
||||||
|
out.shape(-1) // ldc
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
||||||
|
// TODO: Update to utilize BNNS broadcasting
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||||
|
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||||
|
size_t M = a.shape(-2);
|
||||||
|
size_t N = b.shape(-1);
|
||||||
|
size_t K = a.shape(-1);
|
||||||
|
|
||||||
|
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
|
||||||
|
|
||||||
|
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
||||||
|
/* float alpha = */ 1.0,
|
||||||
|
/* float beta = */ 0.0,
|
||||||
|
/* bool transA = */ a_transposed,
|
||||||
|
/* bool transB = */ b_transposed,
|
||||||
|
/* bool quadratic = */ false,
|
||||||
|
/* bool a_is_weights = */ false,
|
||||||
|
/* bool b_is_weights = */ false,
|
||||||
|
/* BNNSNDArrayDescriptor iA_desc = */
|
||||||
|
BNNSNDArrayDescriptor{
|
||||||
|
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||||
|
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||||
|
|
||||||
|
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||||
|
{lda, (M * K) / lda, 0, 0, 0, 0, 0, 0},
|
||||||
|
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||||
|
{1, lda, 0, 0, 0, 0, 0, 0},
|
||||||
|
|
||||||
|
/* void * _Nullable data = */ nullptr,
|
||||||
|
/* BNNSDataType data_type = */ bnns_dtype,
|
||||||
|
|
||||||
|
/* void * _Nullable table_data = */ nullptr,
|
||||||
|
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||||
|
|
||||||
|
/* float data_scale = */ 1.0,
|
||||||
|
/* float data_bias = */ 0.0,
|
||||||
|
},
|
||||||
|
/* BNNSNDArrayDescriptor iB_desc = */
|
||||||
|
BNNSNDArrayDescriptor{
|
||||||
|
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||||
|
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||||
|
|
||||||
|
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||||
|
{ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0},
|
||||||
|
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||||
|
{1, ldb, 0, 0, 0, 0, 0, 0},
|
||||||
|
|
||||||
|
/* void * _Nullable data = */ nullptr,
|
||||||
|
/* BNNSDataType data_type = */ bnns_dtype,
|
||||||
|
|
||||||
|
/* void * _Nullable table_data = */ nullptr,
|
||||||
|
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||||
|
|
||||||
|
/* float data_scale = */ 1.0,
|
||||||
|
/* float data_bias = */ 0.0,
|
||||||
|
},
|
||||||
|
/* BNNSNDArrayDescriptor o_desc = */
|
||||||
|
BNNSNDArrayDescriptor{
|
||||||
|
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||||
|
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||||
|
|
||||||
|
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||||
|
{N, M, 0, 0, 0, 0, 0, 0},
|
||||||
|
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||||
|
{1, N, 0, 0, 0, 0, 0, 0},
|
||||||
|
|
||||||
|
/* void * _Nullable data = */ nullptr,
|
||||||
|
/* BNNSDataType data_type = */ bnns_dtype,
|
||||||
|
|
||||||
|
/* void * _Nullable table_data = */ nullptr,
|
||||||
|
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||||
|
|
||||||
|
/* float data_scale = */ 1.0,
|
||||||
|
/* float data_bias = */ 0.0,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
auto bnns_filter =
|
||||||
|
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
|
||||||
|
|
||||||
|
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||||
|
BNNSFilterApplyTwoInput(
|
||||||
|
bnns_filter,
|
||||||
|
a.data<uint8_t>() +
|
||||||
|
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
|
||||||
|
b.data<uint8_t>() +
|
||||||
|
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
|
||||||
|
out.data<uint8_t>() + M * N * i * out.itemsize());
|
||||||
|
}
|
||||||
|
|
||||||
|
BNNSFilterDestroy(bnns_filter);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
if (out.dtype() == float32) {
|
||||||
|
return matmul_cblas(inputs[0], inputs[1], out);
|
||||||
|
}
|
||||||
|
return matmul_bnns(inputs[0], inputs[1], out);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
672
mlx/backend/accelerate/primitives.cpp
Normal file
672
mlx/backend/accelerate/primitives.cpp
Normal file
@@ -0,0 +1,672 @@
|
|||||||
|
#include <cassert>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include <vecLib/vDSP.h>
|
||||||
|
#include <vecLib/vForce.h>
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/common/binary.h"
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/backend/common/unary.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#define DEFAULT(primitive) \
|
||||||
|
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
|
||||||
|
primitive::eval(inputs, out); \
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Use the default implementation for the following primitives
|
||||||
|
DEFAULT(Arange)
|
||||||
|
DEFAULT(ArgPartition)
|
||||||
|
DEFAULT(ArgReduce)
|
||||||
|
DEFAULT(ArgSort)
|
||||||
|
DEFAULT(AsStrided)
|
||||||
|
DEFAULT(Broadcast)
|
||||||
|
DEFAULT(Concatenate)
|
||||||
|
DEFAULT(Copy)
|
||||||
|
DEFAULT(Equal)
|
||||||
|
DEFAULT(Erf)
|
||||||
|
DEFAULT(ErfInv)
|
||||||
|
DEFAULT(FFT)
|
||||||
|
DEFAULT(Gather)
|
||||||
|
DEFAULT(Greater)
|
||||||
|
DEFAULT(GreaterEqual)
|
||||||
|
DEFAULT(Less)
|
||||||
|
DEFAULT(LessEqual)
|
||||||
|
DEFAULT(Load)
|
||||||
|
DEFAULT(LogicalNot)
|
||||||
|
DEFAULT(LogAddExp)
|
||||||
|
DEFAULT(NotEqual)
|
||||||
|
DEFAULT(Pad)
|
||||||
|
DEFAULT(Partition)
|
||||||
|
DEFAULT(RandomBits)
|
||||||
|
DEFAULT(Reshape)
|
||||||
|
DEFAULT(Scatter)
|
||||||
|
DEFAULT(Sigmoid)
|
||||||
|
DEFAULT(Sign)
|
||||||
|
DEFAULT(Slice)
|
||||||
|
DEFAULT(Sort)
|
||||||
|
DEFAULT(StopGradient)
|
||||||
|
DEFAULT(Transpose)
|
||||||
|
|
||||||
|
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
auto size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||||
|
} else if (in.dtype() == int32 && in.flags().contiguous) {
|
||||||
|
auto size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, size);
|
||||||
|
} else if (is_unsigned(in.dtype())) {
|
||||||
|
// No-op for unsigned types
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
unary(in, out, AbsOp());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
|
||||||
|
if (a.dtype() == float32) {
|
||||||
|
binary(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
[](auto x, auto y) { return x + y; },
|
||||||
|
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||||
|
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||||
|
},
|
||||||
|
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||||
|
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||||
|
},
|
||||||
|
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||||
|
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||||
|
});
|
||||||
|
} else if (a.dtype() == int32) {
|
||||||
|
binary(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
[](auto x, auto y) { return x + y; },
|
||||||
|
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||||
|
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||||
|
},
|
||||||
|
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||||
|
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||||
|
},
|
||||||
|
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||||
|
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
binary(a, b, out, [](auto x, auto y) { return x + y; });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
int size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vvacosf(out.data<float>(), in.data<float>(), &size);
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
int size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vvacoshf(out.data<float>(), in.data<float>(), &size);
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
int size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vvasinf(out.data<float>(), in.data<float>(), &size);
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
int size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vvasinhf(out.data<float>(), in.data<float>(), &size);
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
int size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vvatanf(out.data<float>(), in.data<float>(), &size);
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
int size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vvatanhf(out.data<float>(), in.data<float>(), &size);
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
if (in.flags().contiguous) {
|
||||||
|
auto allocfn = [&in, &out]() {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||||
|
in.data_size(),
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
};
|
||||||
|
// Use accelerate functions if possible
|
||||||
|
if (in.dtype() == float32 && out.dtype() == uint32) {
|
||||||
|
allocfn();
|
||||||
|
vDSP_vfixu32(
|
||||||
|
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
|
||||||
|
return;
|
||||||
|
} else if (in.dtype() == float32 && out.dtype() == int32) {
|
||||||
|
allocfn();
|
||||||
|
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
|
||||||
|
return;
|
||||||
|
} else if (in.dtype() == uint32 && out.dtype() == float32) {
|
||||||
|
allocfn();
|
||||||
|
vDSP_vfltu32(
|
||||||
|
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
|
||||||
|
return;
|
||||||
|
} else if (in.dtype() == int32 && out.dtype() == float32) {
|
||||||
|
allocfn();
|
||||||
|
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
int size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vvcosf(out.data<float>(), in.data<float>(), &size);
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
int size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vvcoshf(out.data<float>(), in.data<float>(), &size);
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
|
||||||
|
if (a.dtype() == int32) {
|
||||||
|
binary(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
[](auto x, auto y) { return x / y; },
|
||||||
|
UseDefaultBinaryOp(),
|
||||||
|
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||||
|
vDSP_vsdivi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||||
|
},
|
||||||
|
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||||
|
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
|
||||||
|
});
|
||||||
|
} else if (a.dtype() == float32) {
|
||||||
|
binary(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
[](auto x, auto y) { return x / y; },
|
||||||
|
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||||
|
vDSP_svdiv((const float*)s, (const float*)vec, 1, (float*)o, 1, n);
|
||||||
|
},
|
||||||
|
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||||
|
vDSP_vsdiv((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||||
|
},
|
||||||
|
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||||
|
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
binary(a, b, out, [](auto x, auto y) { return x / y; });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
auto size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||||
|
} else if (is_floating_point(out.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[exp] Cannot exponentiate elements in array"
|
||||||
|
" with non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
assert(in.dtype() == out.dtype());
|
||||||
|
if (in.data_size() == 1 && out.dtype() == float32) {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
vDSP_vfill(in.data<float>(), out.data<float>(), 1, out.size());
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
auto size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
switch (base_) {
|
||||||
|
case Base::e:
|
||||||
|
vvlogf(
|
||||||
|
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||||
|
break;
|
||||||
|
case Base::two:
|
||||||
|
vvlog2f(
|
||||||
|
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||||
|
break;
|
||||||
|
case Base::ten:
|
||||||
|
vvlog10f(
|
||||||
|
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
auto size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vvlog1pf(
|
||||||
|
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||||
|
} else if (is_floating_point(out.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::log1p(x); });
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[log1p] Cannot compute log of elements in array with"
|
||||||
|
" non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
if (out.dtype() == float32) {
|
||||||
|
binary(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
[](auto x, auto y) { return (x > y) ? x : y; },
|
||||||
|
UseDefaultBinaryOp(),
|
||||||
|
UseDefaultBinaryOp(),
|
||||||
|
[](const auto* a, const auto* b, auto* out, int n) {
|
||||||
|
vDSP_vmax((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
|
||||||
|
if (out.dtype() == float32) {
|
||||||
|
binary(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
[](auto x, auto y) { return (x < y) ? x : y; },
|
||||||
|
UseDefaultBinaryOp(),
|
||||||
|
UseDefaultBinaryOp(),
|
||||||
|
[](const auto* a, const auto* b, auto* out, int n) {
|
||||||
|
vDSP_vmin((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
|
||||||
|
if (a.dtype() == float32) {
|
||||||
|
binary(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
[](auto x, auto y) { return x * y; },
|
||||||
|
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||||
|
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||||
|
},
|
||||||
|
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||||
|
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||||
|
},
|
||||||
|
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||||
|
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
binary(a, b, out, [](auto x, auto y) { return x * y; });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
auto size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||||
|
} else {
|
||||||
|
unary(in, out, [](auto x) { return -x; });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||||
|
b.flags().row_contiguous) {
|
||||||
|
int size = a.size();
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
vvpowf(out.data<float>(), a.data<float>(), b.data<float>(), &size);
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (reduce_type_ == Scan::Sum && out.dtype() == float32 &&
|
||||||
|
in.flags().row_contiguous && in.strides()[axis_] == 1 && !inclusive_) {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
int stride = in.shape(axis_);
|
||||||
|
int count = in.size() / stride;
|
||||||
|
const float* input = in.data<float>();
|
||||||
|
float* output = out.data<float>();
|
||||||
|
float s = 1.0;
|
||||||
|
if (!reverse_) {
|
||||||
|
for (int i = 0; i < count; i++) {
|
||||||
|
vDSP_vrsum(input - 1, 1, &s, output, 1, stride);
|
||||||
|
input += stride;
|
||||||
|
output += stride;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < count; i++) {
|
||||||
|
input += stride - 1;
|
||||||
|
output += stride - 1;
|
||||||
|
vDSP_vrsum(input + 1, -1, &s, output, -1, stride);
|
||||||
|
input++;
|
||||||
|
output++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
int size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vvsinf(out.data<float>(), in.data<float>(), &size);
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
int size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vvsinhf(out.data<float>(), in.data<float>(), &size);
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
auto size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||||
|
} else {
|
||||||
|
unary(in, out, [](auto x) { return x * x; });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
int size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
if (recip_) {
|
||||||
|
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
|
||||||
|
} else {
|
||||||
|
vvsqrtf(out.data<float>(), in.data<float>(), &size);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
|
||||||
|
if (a.dtype() == float32) {
|
||||||
|
binary(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
[](auto x, auto y) { return x - y; },
|
||||||
|
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||||
|
float minus_1 = -1;
|
||||||
|
vDSP_vsmsa(
|
||||||
|
(const float*)vec, 1, &minus_1, (const float*)s, (float*)o, 1, n);
|
||||||
|
},
|
||||||
|
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||||
|
float val = -(*s);
|
||||||
|
vDSP_vsadd((const float*)vec, 1, &val, (float*)o, 1, n);
|
||||||
|
},
|
||||||
|
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||||
|
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||||
|
});
|
||||||
|
} else if (a.dtype() == int32) {
|
||||||
|
binary(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
[](auto x, auto y) { return x - y; },
|
||||||
|
UseDefaultBinaryOp(),
|
||||||
|
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||||
|
int val = -(*s);
|
||||||
|
vDSP_vsaddi((const int*)vec, 1, &val, (int*)o, 1, n);
|
||||||
|
},
|
||||||
|
UseDefaultBinaryOp());
|
||||||
|
} else {
|
||||||
|
binary(a, b, out, [](auto x, auto y) { return x - y; });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
int size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vvtanf(out.data<float>(), in.data<float>(), &size);
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
int size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
vvtanhf(out.data<float>(), in.data<float>(), &size);
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
147
mlx/backend/accelerate/reduce.cpp
Normal file
147
mlx/backend/accelerate/reduce.cpp
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <simd/vector.h>
|
||||||
|
#include <vecLib/vDSP.h>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/reduce.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename T, typename VT, int N>
|
||||||
|
void _vectorized_strided_sum(const T* x, T* accum, int size, size_t stride) {
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
size_t s = stride;
|
||||||
|
T* a = accum;
|
||||||
|
while (s >= N) {
|
||||||
|
VT val = (*(VT*)x);
|
||||||
|
*(VT*)a += val;
|
||||||
|
x += N;
|
||||||
|
a += N;
|
||||||
|
s -= N;
|
||||||
|
}
|
||||||
|
while (s-- > 0) {
|
||||||
|
*a++ += *x++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Add proper templates for the strided reduce algorithm so we don't have
|
||||||
|
// to write max/min/sum etc.
|
||||||
|
template <typename T, typename VT, int N>
|
||||||
|
void _vectorized_strided_max(const T* x, T* accum, int size, size_t stride) {
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
size_t s = stride;
|
||||||
|
T* a = accum;
|
||||||
|
while (s >= N) {
|
||||||
|
*(VT*)a = simd_max((*(VT*)x), (*(VT*)a));
|
||||||
|
x += N;
|
||||||
|
a += N;
|
||||||
|
s -= N;
|
||||||
|
}
|
||||||
|
while (s-- > 0) {
|
||||||
|
*a = std::max(*a, *x);
|
||||||
|
a++;
|
||||||
|
x++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename VT, int N>
|
||||||
|
void _vectorized_strided_min(const T* x, T* accum, int size, size_t stride) {
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
size_t s = stride;
|
||||||
|
T* a = accum;
|
||||||
|
while (s >= N) {
|
||||||
|
*(VT*)a = simd_min((*(VT*)x), (*(VT*)a));
|
||||||
|
x += N;
|
||||||
|
a += N;
|
||||||
|
s -= N;
|
||||||
|
}
|
||||||
|
while (s-- > 0) {
|
||||||
|
*a = std::min(*a, *x);
|
||||||
|
a++;
|
||||||
|
x++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename VT, int N>
|
||||||
|
void _vectorized_sum(const T* x, T* accum, int size) {
|
||||||
|
VT _sum = {0};
|
||||||
|
while (size >= N) {
|
||||||
|
_sum += (*(VT*)x);
|
||||||
|
x += N;
|
||||||
|
size -= N;
|
||||||
|
}
|
||||||
|
T sum = _sum[0];
|
||||||
|
for (int i = 1; i < N; i++) {
|
||||||
|
sum += _sum[i];
|
||||||
|
}
|
||||||
|
*accum += sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
if (in.dtype() == float32) {
|
||||||
|
if (reduce_type_ == Reduce::Sum) {
|
||||||
|
reduction_op<float, float>(
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
axes_,
|
||||||
|
0,
|
||||||
|
[](const auto* x, auto* accum, int size, size_t stride) {
|
||||||
|
_vectorized_strided_sum<float, simd_float16, 16>(
|
||||||
|
(const float*)x, (float*)accum, size, stride);
|
||||||
|
},
|
||||||
|
[](const auto* x, auto* accum, int size) {
|
||||||
|
float acc;
|
||||||
|
vDSP_sve((const float*)x, 1, &acc, size);
|
||||||
|
(*accum) += acc;
|
||||||
|
},
|
||||||
|
[](auto* accum, auto x) { *accum += x; });
|
||||||
|
return;
|
||||||
|
} else if (reduce_type_ == Reduce::Max) {
|
||||||
|
reduction_op<float, float>(
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
axes_,
|
||||||
|
-std::numeric_limits<float>::infinity(),
|
||||||
|
[](const auto* x, auto* accum, int size, size_t stride) {
|
||||||
|
_vectorized_strided_max<float, simd_float16, 16>(
|
||||||
|
(const float*)x, (float*)accum, size, stride);
|
||||||
|
},
|
||||||
|
[](const auto* x, auto* accum, int size) {
|
||||||
|
float max;
|
||||||
|
vDSP_maxv((const float*)x, 1, &max, size);
|
||||||
|
(*accum) = (*accum < max) ? max : *accum;
|
||||||
|
},
|
||||||
|
[](auto* accum, auto x) { (*accum) = (*accum < x) ? x : *accum; });
|
||||||
|
return;
|
||||||
|
} else if (reduce_type_ == Reduce::Min) {
|
||||||
|
reduction_op<float, float>(
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
axes_,
|
||||||
|
std::numeric_limits<float>::infinity(),
|
||||||
|
[](const auto* x, auto* accum, int size, size_t stride) {
|
||||||
|
_vectorized_strided_min<float, simd_float16, 16>(
|
||||||
|
(const float*)x, (float*)accum, size, stride);
|
||||||
|
},
|
||||||
|
[](const auto* x, auto* accum, int size) {
|
||||||
|
float min;
|
||||||
|
vDSP_minv((const float*)x, 1, &min, size);
|
||||||
|
(*accum) = (*accum > min) ? min : *accum;
|
||||||
|
},
|
||||||
|
[](auto* accum, auto x) { (*accum) = (*accum > x) ? x : *accum; });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO: Add integer addition and min/max using the templates above and
|
||||||
|
// simd_int16 and friends.
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
18
mlx/backend/common/CMakeLists.txt
Normal file
18
mlx/backend/common/CMakeLists.txt
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
target_sources(
|
||||||
|
mlx
|
||||||
|
PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
|
)
|
||||||
72
mlx/backend/common/arange.h
Normal file
72
mlx/backend/common/arange.h
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/array.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void arange(T start, T next, array& out, size_t size) {
|
||||||
|
auto ptr = out.data<T>();
|
||||||
|
auto step_size = next - start;
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
ptr[i] = start;
|
||||||
|
start += step_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void arange(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
double start,
|
||||||
|
double step) {
|
||||||
|
assert(inputs.size() == 0);
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
throw std::runtime_error("Bool type unsupported for arange.");
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
arange<uint8_t>(start, start + step, out, out.size());
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
arange<uint16_t>(start, start + step, out, out.size());
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
arange<uint32_t>(start, start + step, out, out.size());
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
arange<uint64_t>(start, start + step, out, out.size());
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
arange<int8_t>(start, start + step, out, out.size());
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
arange<int16_t>(start, start + step, out, out.size());
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
arange<int32_t>(start, start + step, out, out.size());
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
arange<int64_t>(start, start + step, out, out.size());
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
arange<float16_t>(start, start + step, out, out.size());
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
arange<float>(start, start + step, out, out.size());
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
arange<bfloat16_t>(start, start + step, out, out.size());
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
arange<complex64_t>(start, start + step, out, out.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
110
mlx/backend/common/arg_reduce.cpp
Normal file
110
mlx/backend/common/arg_reduce.cpp
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename InT, typename OpT>
|
||||||
|
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||||
|
auto axis_size = in.shape()[axis];
|
||||||
|
auto axis_stride = in.strides()[axis];
|
||||||
|
std::vector<size_t> strides = in.strides();
|
||||||
|
std::vector<int> shape = in.shape();
|
||||||
|
strides.erase(strides.begin() + axis);
|
||||||
|
shape.erase(shape.begin() + axis);
|
||||||
|
for (uint32_t i = 0; i < out.size(); ++i) {
|
||||||
|
auto loc = elem_to_loc(i, shape, strides);
|
||||||
|
auto in_ptr = in.data<InT>() + loc;
|
||||||
|
uint32_t ind_v = 0;
|
||||||
|
InT v = (*in_ptr);
|
||||||
|
for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) {
|
||||||
|
op(j, (*in_ptr), &ind_v, &v);
|
||||||
|
}
|
||||||
|
out.data<uint32_t>()[i] = ind_v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename InT>
|
||||||
|
void arg_reduce_dispatch(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
ArgReduce::ReduceType rtype,
|
||||||
|
int axis) {
|
||||||
|
switch (rtype) {
|
||||||
|
case ArgReduce::ArgMin: {
|
||||||
|
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
|
||||||
|
if (x < (*y)) {
|
||||||
|
(*y) = x;
|
||||||
|
(*ind_y) = ind_x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
arg_reduce<InT>(in, out, op, axis);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ArgReduce::ArgMax: {
|
||||||
|
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
|
||||||
|
if (x > (*y)) {
|
||||||
|
(*y) = x;
|
||||||
|
(*ind_y) = ind_x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
arg_reduce<InT>(in, out, op, axis);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void ArgReduce::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
switch (in.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
541
mlx/backend/common/conv.cpp
Normal file
541
mlx/backend/common/conv.cpp
Normal file
@@ -0,0 +1,541 @@
|
|||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#ifdef ACCELERATE_NEW_LAPACK
|
||||||
|
#include <vecLib/cblas_new.h>
|
||||||
|
#else
|
||||||
|
#include <cblas.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Naive reference conv
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void slow_conv_1D(
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation) {
|
||||||
|
const T* start_wt_ptr = wt.data<T>();
|
||||||
|
|
||||||
|
const T* in_ptr = in.data<T>();
|
||||||
|
T* out_ptr = out.data<T>();
|
||||||
|
|
||||||
|
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||||
|
const int iH = in.shape(1); // Input spatial dim
|
||||||
|
const int oH = out.shape(1); // Output spatial dim
|
||||||
|
const int O = wt.shape(0); // Out channels
|
||||||
|
const int C = wt.shape(2); // In channels
|
||||||
|
const int wH = wt.shape(1); // Weight spatial dim
|
||||||
|
|
||||||
|
const size_t in_stride_N = in.strides()[0];
|
||||||
|
const size_t in_stride_H = in.strides()[1];
|
||||||
|
const size_t in_stride_C = in.strides()[2];
|
||||||
|
|
||||||
|
const size_t wt_stride_O = wt.strides()[0];
|
||||||
|
const size_t wt_stride_H = wt.strides()[1];
|
||||||
|
const size_t wt_stride_C = wt.strides()[2];
|
||||||
|
|
||||||
|
const size_t out_stride_N = out.strides()[0];
|
||||||
|
const size_t out_stride_H = out.strides()[1];
|
||||||
|
const size_t out_stride_O = out.strides()[2];
|
||||||
|
|
||||||
|
for (int n = 0; n < N; ++n) {
|
||||||
|
for (int oh = 0; oh < oH; ++oh) {
|
||||||
|
for (int o = 0; o < O; ++o) {
|
||||||
|
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
|
||||||
|
float r = 0.;
|
||||||
|
|
||||||
|
for (int wh = 0; wh < wH; ++wh) {
|
||||||
|
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
||||||
|
|
||||||
|
int ih = oh * wt_strides[0] - padding[0] + wh * wt_dilation[0];
|
||||||
|
|
||||||
|
if (ih >= 0 && ih < iH) {
|
||||||
|
for (int c = 0; c < C; ++c) {
|
||||||
|
r += static_cast<float>(
|
||||||
|
in_ptr[ih * in_stride_H + c * in_stride_C]) *
|
||||||
|
static_cast<float>(wt_ptr[c * wt_stride_C]);
|
||||||
|
} // c
|
||||||
|
|
||||||
|
} // ih check
|
||||||
|
} // wh
|
||||||
|
|
||||||
|
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
|
||||||
|
} // o
|
||||||
|
} // oh
|
||||||
|
|
||||||
|
in_ptr += in_stride_N;
|
||||||
|
out_ptr += out_stride_N;
|
||||||
|
|
||||||
|
} // n
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void slow_conv_2D(
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation) {
|
||||||
|
const T* st_wt_ptr = wt.data<T>();
|
||||||
|
const T* st_in_ptr = in.data<T>();
|
||||||
|
T* st_out_ptr = out.data<T>();
|
||||||
|
|
||||||
|
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||||
|
const int iH = in.shape(1); // Input spatial dim
|
||||||
|
const int iW = in.shape(2); // Input spatial dim
|
||||||
|
const int oH = out.shape(1); // Output spatial dim
|
||||||
|
const int oW = out.shape(2); // Output spatial dim
|
||||||
|
const int O = wt.shape(0); // Out channels
|
||||||
|
const int C = wt.shape(3); // In channels
|
||||||
|
const int wH = wt.shape(1); // Weight spatial dim
|
||||||
|
const int wW = wt.shape(2); // Weight spatial dim
|
||||||
|
|
||||||
|
const size_t in_stride_N = in.strides()[0];
|
||||||
|
const size_t in_stride_H = in.strides()[1];
|
||||||
|
const size_t in_stride_W = in.strides()[2];
|
||||||
|
const size_t in_stride_C = in.strides()[3];
|
||||||
|
|
||||||
|
const size_t wt_stride_O = wt.strides()[0];
|
||||||
|
const size_t wt_stride_H = wt.strides()[1];
|
||||||
|
const size_t wt_stride_W = wt.strides()[2];
|
||||||
|
const size_t wt_stride_C = wt.strides()[3];
|
||||||
|
|
||||||
|
const size_t out_stride_N = out.strides()[0];
|
||||||
|
const size_t out_stride_H = out.strides()[1];
|
||||||
|
const size_t out_stride_W = out.strides()[2];
|
||||||
|
const size_t out_stride_O = out.strides()[3];
|
||||||
|
|
||||||
|
auto pt_conv_no_checks =
|
||||||
|
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||||
|
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||||
|
int ih_base = oh * wt_strides[0] - padding[0];
|
||||||
|
int iw_base = ow * wt_strides[1] - padding[1];
|
||||||
|
|
||||||
|
for (int o = 0; o < O; ++o) {
|
||||||
|
float r = 0.;
|
||||||
|
|
||||||
|
for (int wh = 0; wh < wH; ++wh) {
|
||||||
|
for (int ww = 0; ww < wW; ++ww) {
|
||||||
|
int ih = ih_base + wh * wt_dilation[0];
|
||||||
|
int iw = iw_base + ww * wt_dilation[1];
|
||||||
|
|
||||||
|
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||||
|
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||||
|
|
||||||
|
for (int c = 0; c < C; ++c) {
|
||||||
|
r += static_cast<float>(in_ptr_pt[0]) *
|
||||||
|
static_cast<float>(wt_ptr_pt[0]);
|
||||||
|
in_ptr_pt += in_stride_C;
|
||||||
|
wt_ptr_pt += wt_stride_C;
|
||||||
|
} // c
|
||||||
|
|
||||||
|
} // ww
|
||||||
|
} // wh
|
||||||
|
|
||||||
|
out_ptr[0] = static_cast<T>(r);
|
||||||
|
out_ptr += out_stride_O;
|
||||||
|
wt_ptr += wt_stride_O;
|
||||||
|
} // o
|
||||||
|
};
|
||||||
|
|
||||||
|
auto pt_conv_all_checks =
|
||||||
|
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||||
|
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||||
|
int ih_base = oh * wt_strides[0] - padding[0];
|
||||||
|
int iw_base = ow * wt_strides[1] - padding[1];
|
||||||
|
|
||||||
|
for (int o = 0; o < O; ++o) {
|
||||||
|
float r = 0.;
|
||||||
|
|
||||||
|
for (int wh = 0; wh < wH; ++wh) {
|
||||||
|
for (int ww = 0; ww < wW; ++ww) {
|
||||||
|
int ih = ih_base + wh * wt_dilation[0];
|
||||||
|
int iw = iw_base + ww * wt_dilation[1];
|
||||||
|
|
||||||
|
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||||
|
const T* wt_ptr_pt =
|
||||||
|
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||||
|
const T* in_ptr_pt =
|
||||||
|
in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||||
|
|
||||||
|
for (int c = 0; c < C; ++c) {
|
||||||
|
r += static_cast<float>(in_ptr_pt[0]) *
|
||||||
|
static_cast<float>(wt_ptr_pt[0]);
|
||||||
|
in_ptr_pt += in_stride_C;
|
||||||
|
wt_ptr_pt += wt_stride_C;
|
||||||
|
} // c
|
||||||
|
|
||||||
|
} // ih, iw check
|
||||||
|
} // ww
|
||||||
|
} // wh
|
||||||
|
|
||||||
|
out_ptr[0] = static_cast<T>(r);
|
||||||
|
out_ptr += out_stride_O;
|
||||||
|
wt_ptr += wt_stride_O;
|
||||||
|
} // o
|
||||||
|
};
|
||||||
|
|
||||||
|
int oH_border_0 = 0;
|
||||||
|
int oH_border_1 = (padding[0] + wt_strides[0] + 1) / wt_strides[0];
|
||||||
|
int oH_border_2 = (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0];
|
||||||
|
int oH_border_3 = oH;
|
||||||
|
|
||||||
|
int oW_border_0 = 0;
|
||||||
|
int oW_border_1 = (padding[1] + wt_strides[0] + 1) / wt_strides[1];
|
||||||
|
int oW_border_2 = (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1];
|
||||||
|
int oW_border_3 = oW;
|
||||||
|
|
||||||
|
for (int n = 0; n < N; ++n) {
|
||||||
|
// Case 1: oh might put us out of bounds
|
||||||
|
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
|
||||||
|
for (int ow = 0; ow < oW; ++ow) {
|
||||||
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
|
} // ow
|
||||||
|
} // oh
|
||||||
|
|
||||||
|
// Case 2: oh in bounds
|
||||||
|
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
|
||||||
|
// Case a: ow might put us out of bounds
|
||||||
|
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
|
||||||
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
|
} // ow
|
||||||
|
|
||||||
|
// Case b: ow in bounds
|
||||||
|
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
|
||||||
|
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
|
} // ow
|
||||||
|
|
||||||
|
// Case c: ow might put us out of bounds
|
||||||
|
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
|
||||||
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
|
} // ow
|
||||||
|
|
||||||
|
} // oh
|
||||||
|
|
||||||
|
// Case 3: oh might put us out of bounds
|
||||||
|
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
|
||||||
|
for (int ow = 0; ow < oW; ++ow) {
|
||||||
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
|
} // ow
|
||||||
|
} // oh
|
||||||
|
|
||||||
|
st_in_ptr += in_stride_N;
|
||||||
|
st_out_ptr += out_stride_N;
|
||||||
|
|
||||||
|
} // n
|
||||||
|
}
|
||||||
|
|
||||||
|
void dispatch_slow_conv_1D(
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation) {
|
||||||
|
if (in.dtype() == float32) {
|
||||||
|
return slow_conv_1D<float>(in, wt, out, padding, wt_strides, wt_dilation);
|
||||||
|
} else if (in.dtype() == float16) {
|
||||||
|
return slow_conv_1D<float16_t>(
|
||||||
|
in, wt, out, padding, wt_strides, wt_dilation);
|
||||||
|
} else if (in.dtype() == bfloat16) {
|
||||||
|
return slow_conv_1D<bfloat16_t>(
|
||||||
|
in, wt, out, padding, wt_strides, wt_dilation);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[Convolution::eval] got unsupported data type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void dispatch_slow_conv_2D(
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation) {
|
||||||
|
if (in.dtype() == float32) {
|
||||||
|
return slow_conv_2D<float>(in, wt, out, padding, wt_strides, wt_dilation);
|
||||||
|
} else if (in.dtype() == float16) {
|
||||||
|
return slow_conv_2D<float16_t>(
|
||||||
|
in, wt, out, padding, wt_strides, wt_dilation);
|
||||||
|
} else if (in.dtype() == bfloat16) {
|
||||||
|
return slow_conv_2D<bfloat16_t>(
|
||||||
|
in, wt, out, padding, wt_strides, wt_dilation);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[Convolution::eval] got unsupported data type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Explicit gemm conv
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
void explicit_gemm_conv_1D_cpu(
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation) {
|
||||||
|
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||||
|
const int iH = in.shape(1); // Input spatial dim
|
||||||
|
const int oH = out.shape(1); // Output spatial dim
|
||||||
|
const int O = wt.shape(0); // Out channels
|
||||||
|
const int C = wt.shape(2); // In channels
|
||||||
|
const int wH = wt.shape(1); // Weight spatial dim
|
||||||
|
|
||||||
|
auto conv_dtype = float32;
|
||||||
|
|
||||||
|
// Pad input
|
||||||
|
std::vector<int> padded_shape = {N, iH + 2 * padding[0], C};
|
||||||
|
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||||
|
|
||||||
|
// Fill with zeros
|
||||||
|
copy(array(0, conv_dtype), in_padded, CopyType::Scalar);
|
||||||
|
|
||||||
|
// Pick input slice from padded
|
||||||
|
size_t data_offset = padding[0] * in_padded.strides()[1];
|
||||||
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
|
in_padded_slice.copy_shared_buffer(
|
||||||
|
in_padded,
|
||||||
|
in_padded.strides(),
|
||||||
|
in_padded.flags(),
|
||||||
|
in_padded_slice.size(),
|
||||||
|
data_offset);
|
||||||
|
|
||||||
|
// Copy input values into the slice
|
||||||
|
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||||
|
|
||||||
|
// Make strided view
|
||||||
|
std::vector<int> strided_shape = {N, oH, wH, C};
|
||||||
|
|
||||||
|
std::vector<size_t> strided_strides = {
|
||||||
|
in_padded.strides()[0],
|
||||||
|
in_padded.strides()[1] * wt_strides[0],
|
||||||
|
in_padded.strides()[1],
|
||||||
|
in_padded.strides()[2]};
|
||||||
|
auto flags = in_padded.flags();
|
||||||
|
|
||||||
|
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||||
|
in_strided_view.copy_shared_buffer(
|
||||||
|
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||||
|
|
||||||
|
// Materialize strided view
|
||||||
|
std::vector<int> strided_reshape = {N * oH, wH * C};
|
||||||
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
|
copy(in_strided_view, in_strided, CopyType::General);
|
||||||
|
|
||||||
|
// Check wt dtype and prepare
|
||||||
|
auto gemm_wt = wt;
|
||||||
|
auto gemm_out = out;
|
||||||
|
|
||||||
|
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||||
|
auto ctype =
|
||||||
|
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
|
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||||
|
copy(wt, gemm_wt, ctype);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (out.dtype() != float32) {
|
||||||
|
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||||
|
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peform gemm
|
||||||
|
cblas_sgemm(
|
||||||
|
CblasRowMajor,
|
||||||
|
CblasNoTrans, // no trans A
|
||||||
|
CblasTrans, // transB
|
||||||
|
strided_reshape[0], // M
|
||||||
|
O, // N
|
||||||
|
strided_reshape[1], // K
|
||||||
|
1.0f, // alpha
|
||||||
|
in_strided.data<float>(),
|
||||||
|
strided_reshape[1], // lda
|
||||||
|
gemm_wt.data<float>(),
|
||||||
|
strided_reshape[1], // ldb
|
||||||
|
0.0f, // beta
|
||||||
|
gemm_out.data<float>(),
|
||||||
|
O // ldc
|
||||||
|
);
|
||||||
|
|
||||||
|
// Copy results if needed
|
||||||
|
if (out.dtype() != float32) {
|
||||||
|
copy(gemm_out, out, CopyType::Vector);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void explicit_gemm_conv_2D_cpu(
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation) {
|
||||||
|
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||||
|
const int iH = in.shape(1); // Input spatial dim
|
||||||
|
const int iW = in.shape(2); // Input spatial dim
|
||||||
|
const int oH = out.shape(1); // Output spatial dim
|
||||||
|
const int oW = out.shape(2); // Output spatial dim
|
||||||
|
const int O = wt.shape(0); // Out channels
|
||||||
|
const int C = wt.shape(3); // In channels
|
||||||
|
const int wH = wt.shape(1); // Weight spatial dim
|
||||||
|
const int wW = wt.shape(2); // Weight spatial dim
|
||||||
|
|
||||||
|
auto conv_dtype = out.dtype();
|
||||||
|
|
||||||
|
// Pad input
|
||||||
|
std::vector<int> padded_shape = {
|
||||||
|
N, iH + 2 * padding[0], iW + 2 * padding[1], C};
|
||||||
|
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||||
|
|
||||||
|
// Fill with zeros
|
||||||
|
copy(array(0, conv_dtype), in_padded, CopyType::Scalar);
|
||||||
|
|
||||||
|
// Pick input slice from padded
|
||||||
|
size_t data_offset =
|
||||||
|
padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2];
|
||||||
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
|
in_padded_slice.copy_shared_buffer(
|
||||||
|
in_padded,
|
||||||
|
in_padded.strides(),
|
||||||
|
in_padded.flags(),
|
||||||
|
in_padded_slice.size(),
|
||||||
|
data_offset);
|
||||||
|
|
||||||
|
// Copy input values into the slice
|
||||||
|
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||||
|
|
||||||
|
// Make strided view
|
||||||
|
std::vector<int> strided_shape = {N, oH, oW, wH, wW, C};
|
||||||
|
|
||||||
|
std::vector<size_t> strided_strides = {
|
||||||
|
in_padded.strides()[0],
|
||||||
|
in_padded.strides()[1] * wt_strides[0],
|
||||||
|
in_padded.strides()[2] * wt_strides[1],
|
||||||
|
in_padded.strides()[1],
|
||||||
|
in_padded.strides()[2],
|
||||||
|
in_padded.strides()[3]};
|
||||||
|
auto flags = in_padded.flags();
|
||||||
|
|
||||||
|
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||||
|
in_strided_view.copy_shared_buffer(
|
||||||
|
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||||
|
|
||||||
|
// Materialize strided view
|
||||||
|
std::vector<int> strided_reshape = {N * oH * oW, wH * wW * C};
|
||||||
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
|
copy(in_strided_view, in_strided, CopyType::General);
|
||||||
|
|
||||||
|
// Check wt dtype and prepare
|
||||||
|
auto gemm_wt = wt;
|
||||||
|
auto gemm_out = out;
|
||||||
|
|
||||||
|
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||||
|
auto ctype =
|
||||||
|
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
|
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||||
|
copy(wt, gemm_wt, ctype);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (out.dtype() != float32) {
|
||||||
|
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||||
|
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peform gemm
|
||||||
|
cblas_sgemm(
|
||||||
|
CblasRowMajor,
|
||||||
|
CblasNoTrans, // no trans A
|
||||||
|
CblasTrans, // transB
|
||||||
|
strided_reshape[0], // M
|
||||||
|
O, // N
|
||||||
|
strided_reshape[1], // K
|
||||||
|
1.0f, // alpha
|
||||||
|
in_strided.data<float>(),
|
||||||
|
strided_reshape[1], // lda
|
||||||
|
gemm_wt.data<float>(),
|
||||||
|
strided_reshape[1], // ldb
|
||||||
|
0.0f, // beta
|
||||||
|
gemm_out.data<float>(),
|
||||||
|
O // ldc
|
||||||
|
);
|
||||||
|
|
||||||
|
// Copy results if needed
|
||||||
|
if (out.dtype() != float32) {
|
||||||
|
copy(gemm_out, out, CopyType::Vector);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Conv routing
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
void conv_1D_cpu(
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation) {
|
||||||
|
if (wt_dilation[0] == 1) {
|
||||||
|
return explicit_gemm_conv_1D_cpu(
|
||||||
|
in, wt, out, padding, wt_strides, wt_dilation);
|
||||||
|
}
|
||||||
|
|
||||||
|
return dispatch_slow_conv_1D(in, wt, out, padding, wt_strides, wt_dilation);
|
||||||
|
}
|
||||||
|
|
||||||
|
void conv_2D_cpu(
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation) {
|
||||||
|
return dispatch_slow_conv_2D(in, wt, out, padding, wt_strides, wt_dilation);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto& in = inputs[0];
|
||||||
|
auto& wt = inputs[1];
|
||||||
|
|
||||||
|
// 2D convolution
|
||||||
|
if (in.ndim() == (2 + 2)) {
|
||||||
|
return conv_2D_cpu(
|
||||||
|
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
|
||||||
|
}
|
||||||
|
// 1D convolution
|
||||||
|
else if (in.ndim() == (1 + 2)) {
|
||||||
|
return conv_1D_cpu(
|
||||||
|
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
|
||||||
|
}
|
||||||
|
// Throw error
|
||||||
|
else {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[Convolution::eval] Convolution currently only supports"
|
||||||
|
<< " 1D and 2D convolutions. Got inputs with " << in.ndim() - 2
|
||||||
|
<< " spatial dimensions";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
308
mlx/backend/common/copy.cpp
Normal file
308
mlx/backend/common/copy.cpp
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT>
|
||||||
|
void copy_single(const array& src, array& dst) {
|
||||||
|
auto val = static_cast<DstT>(src.data<SrcT>()[0]);
|
||||||
|
auto dst_ptr = dst.data<DstT>();
|
||||||
|
for (int i = 0; i < dst.size(); ++i) {
|
||||||
|
dst_ptr[i] = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT>
|
||||||
|
void copy_vector(const array& src, array& dst) {
|
||||||
|
auto src_ptr = src.data<SrcT>();
|
||||||
|
auto dst_ptr = dst.data<DstT>();
|
||||||
|
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT>
|
||||||
|
void copy_general_dim1(const array& src, array& dst) {
|
||||||
|
const SrcT* src_ptr = src.data<SrcT>();
|
||||||
|
DstT* dst_ptr = dst.data<DstT>();
|
||||||
|
size_t src_idx = 0;
|
||||||
|
size_t dst_idx = 0;
|
||||||
|
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
||||||
|
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||||
|
src_idx += src.strides()[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT>
|
||||||
|
void copy_general_dim2(const array& src, array& dst) {
|
||||||
|
const SrcT* src_ptr = src.data<SrcT>();
|
||||||
|
DstT* dst_ptr = dst.data<DstT>();
|
||||||
|
size_t src_idx = 0;
|
||||||
|
size_t dst_idx = 0;
|
||||||
|
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < src.shape()[1]; ++j) {
|
||||||
|
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||||
|
src_idx += src.strides()[1];
|
||||||
|
}
|
||||||
|
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT>
|
||||||
|
void copy_general_dim3(const array& src, array& dst) {
|
||||||
|
const SrcT* src_ptr = src.data<SrcT>();
|
||||||
|
DstT* dst_ptr = dst.data<DstT>();
|
||||||
|
size_t src_idx = 0;
|
||||||
|
size_t dst_idx = 0;
|
||||||
|
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < src.shape()[1]; ++j) {
|
||||||
|
for (size_t k = 0; k < src.shape()[2]; ++k) {
|
||||||
|
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||||
|
src_idx += src.strides()[2];
|
||||||
|
}
|
||||||
|
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
|
||||||
|
}
|
||||||
|
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT>
|
||||||
|
void copy_general_dim4(const array& src, array& dst) {
|
||||||
|
const SrcT* src_ptr = src.data<SrcT>();
|
||||||
|
DstT* dst_ptr = dst.data<DstT>();
|
||||||
|
size_t src_idx = 0;
|
||||||
|
size_t dst_idx = 0;
|
||||||
|
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < src.shape()[1]; ++j) {
|
||||||
|
for (size_t k = 0; k < src.shape()[2]; ++k) {
|
||||||
|
for (size_t ii = 0; ii < src.shape()[3]; ++ii) {
|
||||||
|
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||||
|
src_idx += src.strides()[3];
|
||||||
|
}
|
||||||
|
src_idx += src.strides()[2] - src.strides()[3] * src.shape()[3];
|
||||||
|
}
|
||||||
|
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
|
||||||
|
}
|
||||||
|
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT>
|
||||||
|
void copy_general(const array& src, array& dst) {
|
||||||
|
switch (src.ndim()) {
|
||||||
|
case 1:
|
||||||
|
copy_general_dim1<SrcT, DstT>(src, dst);
|
||||||
|
return;
|
||||||
|
case 2:
|
||||||
|
copy_general_dim2<SrcT, DstT>(src, dst);
|
||||||
|
return;
|
||||||
|
case 3:
|
||||||
|
copy_general_dim3<SrcT, DstT>(src, dst);
|
||||||
|
return;
|
||||||
|
case 4:
|
||||||
|
copy_general_dim4<SrcT, DstT>(src, dst);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto src_ptr = src.data<SrcT>();
|
||||||
|
auto dst_ptr = dst.data<DstT>();
|
||||||
|
for (size_t i = 0; i < dst.size(); ++i) {
|
||||||
|
size_t src_elem = elem_to_loc(i, src.shape(), src.strides());
|
||||||
|
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT, int D>
|
||||||
|
inline void copy_general_general_dims(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
size_t offset_src,
|
||||||
|
size_t offset_dst) {
|
||||||
|
if constexpr (D > 1) {
|
||||||
|
int axis = src.ndim() - D;
|
||||||
|
auto stride_src = src.strides()[axis];
|
||||||
|
auto stride_dst = dst.strides()[axis];
|
||||||
|
auto N = src.shape(axis);
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
copy_general_general_dims<SrcT, DstT, D - 1>(
|
||||||
|
src, dst, offset_src, offset_dst);
|
||||||
|
offset_src += stride_src;
|
||||||
|
offset_dst += stride_dst;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int axis = src.ndim() - 1;
|
||||||
|
auto stride_src = src.strides()[axis];
|
||||||
|
auto stride_dst = dst.strides()[axis];
|
||||||
|
auto N = src.shape(axis);
|
||||||
|
const SrcT* src_ptr = src.data<SrcT>() + offset_src;
|
||||||
|
DstT* dst_ptr = dst.data<DstT>() + offset_dst;
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
*dst_ptr = static_cast<DstT>(*src_ptr);
|
||||||
|
src_ptr += stride_src;
|
||||||
|
dst_ptr += stride_dst;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT>
|
||||||
|
void copy_general_general(const array& src, array& dst) {
|
||||||
|
switch (src.ndim()) {
|
||||||
|
case 1:
|
||||||
|
copy_general_general_dims<SrcT, DstT, 1>(src, dst, 0, 0);
|
||||||
|
return;
|
||||||
|
case 2:
|
||||||
|
copy_general_general_dims<SrcT, DstT, 2>(src, dst, 0, 0);
|
||||||
|
return;
|
||||||
|
case 3:
|
||||||
|
copy_general_general_dims<SrcT, DstT, 3>(src, dst, 0, 0);
|
||||||
|
return;
|
||||||
|
case 4:
|
||||||
|
copy_general_general_dims<SrcT, DstT, 4>(src, dst, 0, 0);
|
||||||
|
return;
|
||||||
|
case 5:
|
||||||
|
copy_general_general_dims<SrcT, DstT, 5>(src, dst, 0, 0);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int size = std::accumulate(
|
||||||
|
src.shape().begin() - 5, src.shape().end(), 1, std::multiplies<int>());
|
||||||
|
for (int i = 0; i < src.size(); i += size) {
|
||||||
|
size_t offset_src = elem_to_loc(i, src.shape(), src.strides());
|
||||||
|
size_t offset_dst = elem_to_loc(i, dst.shape(), dst.strides());
|
||||||
|
copy_general_general_dims<SrcT, DstT, 5>(src, dst, offset_src, offset_dst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT>
|
||||||
|
void copy(const array& src, array& dst, CopyType ctype) {
|
||||||
|
switch (ctype) {
|
||||||
|
case CopyType::Scalar:
|
||||||
|
copy_single<SrcT, DstT>(src, dst);
|
||||||
|
return;
|
||||||
|
case CopyType::Vector:
|
||||||
|
copy_vector<SrcT, DstT>(src, dst);
|
||||||
|
return;
|
||||||
|
case CopyType::General:
|
||||||
|
copy_general<SrcT, DstT>(src, dst);
|
||||||
|
return;
|
||||||
|
case CopyType::GeneralGeneral:
|
||||||
|
copy_general_general<SrcT, DstT>(src, dst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT>
|
||||||
|
void copy(const array& src, array& dst, CopyType ctype) {
|
||||||
|
switch (dst.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
copy<SrcT, bool>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
copy<SrcT, uint8_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
copy<SrcT, uint16_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
copy<SrcT, uint32_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
copy<SrcT, uint64_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
copy<SrcT, int8_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
copy<SrcT, int16_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
copy<SrcT, int32_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
copy<SrcT, int64_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
copy<SrcT, float16_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
copy<SrcT, float>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
copy<SrcT, bfloat16_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
copy<SrcT, complex64_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void copy_inplace(const array& src, array& dst, CopyType ctype) {
|
||||||
|
switch (src.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
copy<bool>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
copy<uint8_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
copy<uint16_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
copy<uint32_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
copy<uint64_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
copy<int8_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
copy<int16_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
copy<int32_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
copy<int64_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
copy<float16_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
copy<float>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
copy<bfloat16_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
copy<complex64_t>(src, dst, ctype);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void copy(const array& src, array& dst, CopyType ctype) {
|
||||||
|
// Allocate the output
|
||||||
|
switch (ctype) {
|
||||||
|
case CopyType::Vector:
|
||||||
|
dst.set_data(
|
||||||
|
allocator::malloc_or_wait(src.data_size() * dst.itemsize()),
|
||||||
|
src.data_size(),
|
||||||
|
src.strides(),
|
||||||
|
src.flags());
|
||||||
|
break;
|
||||||
|
case CopyType::Scalar:
|
||||||
|
case CopyType::General:
|
||||||
|
case CopyType::GeneralGeneral:
|
||||||
|
dst.set_data(allocator::malloc_or_wait(dst.nbytes()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (ctype == CopyType::GeneralGeneral) {
|
||||||
|
ctype = CopyType::General;
|
||||||
|
}
|
||||||
|
copy_inplace(src, dst, ctype);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
130
mlx/backend/common/default_primitives.cpp
Normal file
130
mlx/backend/common/default_primitives.cpp
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
#include <cblas.h>
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#define DEFAULT(primitive) \
|
||||||
|
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
|
||||||
|
primitive::eval(inputs, out); \
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
DEFAULT(Abs)
|
||||||
|
DEFAULT(Add)
|
||||||
|
DEFAULT(Arange)
|
||||||
|
DEFAULT(ArcCos)
|
||||||
|
DEFAULT(ArcCosh)
|
||||||
|
DEFAULT(ArcSin)
|
||||||
|
DEFAULT(ArcSinh)
|
||||||
|
DEFAULT(ArcTan)
|
||||||
|
DEFAULT(ArcTanh)
|
||||||
|
DEFAULT(ArgPartition)
|
||||||
|
DEFAULT(ArgReduce)
|
||||||
|
DEFAULT(ArgSort)
|
||||||
|
DEFAULT(AsType)
|
||||||
|
DEFAULT(AsStrided)
|
||||||
|
DEFAULT(Broadcast)
|
||||||
|
DEFAULT(Concatenate)
|
||||||
|
DEFAULT(Convolution)
|
||||||
|
DEFAULT(Copy)
|
||||||
|
DEFAULT(Cos)
|
||||||
|
DEFAULT(Cosh)
|
||||||
|
DEFAULT(Divide)
|
||||||
|
DEFAULT(Equal)
|
||||||
|
DEFAULT(Erf)
|
||||||
|
DEFAULT(ErfInv)
|
||||||
|
DEFAULT(Exp)
|
||||||
|
DEFAULT(FFT)
|
||||||
|
DEFAULT(Full)
|
||||||
|
DEFAULT(Gather)
|
||||||
|
DEFAULT(Greater)
|
||||||
|
DEFAULT(GreaterEqual)
|
||||||
|
DEFAULT(Less)
|
||||||
|
DEFAULT(LessEqual)
|
||||||
|
DEFAULT(Load)
|
||||||
|
DEFAULT(Log)
|
||||||
|
DEFAULT(Log1p)
|
||||||
|
DEFAULT(LogicalNot)
|
||||||
|
DEFAULT(LogAddExp)
|
||||||
|
DEFAULT(Maximum)
|
||||||
|
DEFAULT(Minimum)
|
||||||
|
DEFAULT(Multiply)
|
||||||
|
DEFAULT(Negative)
|
||||||
|
DEFAULT(NotEqual)
|
||||||
|
DEFAULT(Pad)
|
||||||
|
DEFAULT(Partition)
|
||||||
|
DEFAULT(Power)
|
||||||
|
DEFAULT(RandomBits)
|
||||||
|
DEFAULT(Reduce)
|
||||||
|
DEFAULT(Reshape)
|
||||||
|
DEFAULT(Scan)
|
||||||
|
DEFAULT(Scatter)
|
||||||
|
DEFAULT(Sigmoid)
|
||||||
|
DEFAULT(Sign)
|
||||||
|
DEFAULT(Sin)
|
||||||
|
DEFAULT(Sinh)
|
||||||
|
DEFAULT(Slice)
|
||||||
|
DEFAULT(Softmax)
|
||||||
|
DEFAULT(Sort)
|
||||||
|
DEFAULT(Square)
|
||||||
|
DEFAULT(Sqrt)
|
||||||
|
DEFAULT(StopGradient)
|
||||||
|
DEFAULT(Subtract)
|
||||||
|
DEFAULT(Tan)
|
||||||
|
DEFAULT(Tanh)
|
||||||
|
DEFAULT(Transpose)
|
||||||
|
|
||||||
|
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
if (out.dtype() != float32) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[Matmul::eval_cpu] Currently only supports float32.");
|
||||||
|
}
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto& a_pre = inputs[0];
|
||||||
|
auto& b_pre = inputs[1];
|
||||||
|
|
||||||
|
auto check_transpose = [](const array& arr) {
|
||||||
|
auto stx = arr.strides()[arr.ndim() - 2];
|
||||||
|
auto sty = arr.strides()[arr.ndim() - 1];
|
||||||
|
if (stx == arr.shape(-1) && sty == 1) {
|
||||||
|
return std::make_tuple(false, stx, arr);
|
||||||
|
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||||
|
return std::make_tuple(true, sty, arr);
|
||||||
|
} else {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy(arr, arr_copy, CopyType::General);
|
||||||
|
size_t stx = arr.shape(-1);
|
||||||
|
return std::make_tuple(false, stx, arr_copy);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||||
|
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||||
|
int M = a.shape(-2);
|
||||||
|
int N = b.shape(-1);
|
||||||
|
int K = a.shape(-1);
|
||||||
|
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||||
|
cblas_sgemm(
|
||||||
|
CblasRowMajor,
|
||||||
|
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||||
|
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
1.0f, // alpha
|
||||||
|
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||||
|
lda,
|
||||||
|
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||||
|
ldb,
|
||||||
|
0.0f, // beta
|
||||||
|
out.data<float>() + M * N * i,
|
||||||
|
out.shape(-1) // ldc
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
38
mlx/backend/common/erf.cpp
Normal file
38
mlx/backend/common/erf.cpp
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
/* Approximation to the inverse error function.
|
||||||
|
* Based on code from:
|
||||||
|
* https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348
|
||||||
|
*/
|
||||||
|
float erfinv(float a) {
|
||||||
|
auto t = std::fma(a, 0.0f - a, 1.0f);
|
||||||
|
t = std::log(t);
|
||||||
|
float p;
|
||||||
|
if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
|
||||||
|
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
||||||
|
p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
||||||
|
p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
||||||
|
p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
||||||
|
p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
||||||
|
p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
||||||
|
p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
||||||
|
p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
||||||
|
p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
||||||
|
} else { // maximum ulp error = 2.35002
|
||||||
|
p = 5.43877832e-9f; // 0x1.75c000p-28
|
||||||
|
p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
||||||
|
p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
||||||
|
p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
||||||
|
p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
||||||
|
p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
||||||
|
p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
||||||
|
p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
||||||
|
p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
||||||
|
p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
||||||
|
}
|
||||||
|
return a * p;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
10
mlx/backend/common/erf.h
Normal file
10
mlx/backend/common/erf.h
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
/* Approximation to the inverse error function.
|
||||||
|
* Based on code from:
|
||||||
|
* https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348
|
||||||
|
*/
|
||||||
|
float erfinv(float a);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
377
mlx/backend/common/indexing.cpp
Normal file
377
mlx/backend/common/indexing.cpp
Normal file
@@ -0,0 +1,377 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename IdxT>
|
||||||
|
inline size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||||
|
return (idx < 0) ? idx + size : idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline size_t offset_neg_idx(bool idx, size_t) {
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline size_t offset_neg_idx(uint32_t idx, size_t) {
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename IdxT>
|
||||||
|
void gather(
|
||||||
|
const array& src,
|
||||||
|
const std::vector<array>& inds,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const std::vector<int>& slice_sizes) {
|
||||||
|
// If the array is row contiguous then we can do a contiguous copy given
|
||||||
|
// two conditions on the slice size:
|
||||||
|
// - Any number of leading ones in the slice sizes are allowed
|
||||||
|
// - All other slice sizes match the corresponding dimension except the
|
||||||
|
// first non-singleton slice size
|
||||||
|
// If the array is col contiguous then the reverse is the case:
|
||||||
|
// - Any number of trailing ones in the slice sizes are allowed
|
||||||
|
// - All other slice sizes match the corresponding dimension except the
|
||||||
|
// first non-singleton slice size from the end
|
||||||
|
|
||||||
|
bool can_copy = false;
|
||||||
|
if (src.flags().row_contiguous) {
|
||||||
|
can_copy = true;
|
||||||
|
|
||||||
|
// Ignore leading 1s
|
||||||
|
int i = 0;
|
||||||
|
for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i)
|
||||||
|
;
|
||||||
|
|
||||||
|
// Check the remaining
|
||||||
|
i++;
|
||||||
|
for (; i < src.ndim() && can_copy; ++i) {
|
||||||
|
can_copy = (src.shape(i) == slice_sizes[i]);
|
||||||
|
}
|
||||||
|
} else if (src.flags().col_contiguous) {
|
||||||
|
can_copy = true;
|
||||||
|
|
||||||
|
// Ignore trailing 1s
|
||||||
|
int i = slice_sizes.size() - 1;
|
||||||
|
for (; i >= 0 && slice_sizes[i] == 1; --i)
|
||||||
|
;
|
||||||
|
|
||||||
|
// Skip the next slice size and check the remaining
|
||||||
|
i--;
|
||||||
|
for (; i >= 0 && can_copy; --i) {
|
||||||
|
can_copy = (src.shape(i) == slice_sizes[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
size_t slice_size = 1;
|
||||||
|
for (auto s : slice_sizes) {
|
||||||
|
slice_size *= s;
|
||||||
|
}
|
||||||
|
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
|
||||||
|
const T* src_ptr = src.data<T>();
|
||||||
|
T* dst_ptr = out.data<T>();
|
||||||
|
size_t out_idx = 0;
|
||||||
|
|
||||||
|
for (int idx = 0; idx < ind_size; idx++) {
|
||||||
|
size_t src_idx = 0;
|
||||||
|
for (int ii = 0; ii < inds.size(); ++ii) {
|
||||||
|
auto ax = axes[ii];
|
||||||
|
auto idx_loc = elem_to_loc(idx, inds[ii]);
|
||||||
|
auto idx_val =
|
||||||
|
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
|
||||||
|
src_idx += (idx_val * src.strides()[ax]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (slice_size == 1) {
|
||||||
|
dst_ptr[out_idx++] = src_ptr[src_idx];
|
||||||
|
} else if (can_copy) {
|
||||||
|
std::copy(
|
||||||
|
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
|
||||||
|
out_idx += slice_size;
|
||||||
|
} else {
|
||||||
|
for (int jj = 0; jj < slice_size; jj++) {
|
||||||
|
auto src_offset = elem_to_loc(jj, slice_sizes, src.strides());
|
||||||
|
dst_ptr[out_idx++] = src_ptr[src_idx + src_offset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IdxT>
|
||||||
|
void dispatch_gather(
|
||||||
|
const array& src,
|
||||||
|
const std::vector<array>& inds,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const std::vector<int>& size) {
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
gather<bool, IdxT>(src, inds, out, axes, size);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
gather<uint8_t, IdxT>(src, inds, out, axes, size);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
gather<uint16_t, IdxT>(src, inds, out, axes, size);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
gather<uint32_t, IdxT>(src, inds, out, axes, size);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
gather<uint64_t, IdxT>(src, inds, out, axes, size);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
gather<int8_t, IdxT>(src, inds, out, axes, size);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
gather<int16_t, IdxT>(src, inds, out, axes, size);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
gather<int32_t, IdxT>(src, inds, out, axes, size);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
gather<int64_t, IdxT>(src, inds, out, axes, size);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
gather<float16_t, IdxT>(src, inds, out, axes, size);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
gather<float, IdxT>(src, inds, out, axes, size);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
gather<complex64_t, IdxT>(src, inds, out, axes, size);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Gather::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto& src = inputs[0];
|
||||||
|
std::vector<array> inds(inputs.begin() + 1, inputs.end());
|
||||||
|
|
||||||
|
if (inds.empty()) {
|
||||||
|
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (inds[0].dtype()) {
|
||||||
|
case bool_:
|
||||||
|
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
case float32:
|
||||||
|
case bfloat16:
|
||||||
|
case complex64:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[Gather::eval] Cannot gather with floating point indices.");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename InT, typename IdxT, typename OpT>
|
||||||
|
void scatter(
|
||||||
|
const array& updates,
|
||||||
|
array& out,
|
||||||
|
const std::vector<array>& inds,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const OpT& op) {
|
||||||
|
int nind = inds.size();
|
||||||
|
auto inds_ndim = updates.ndim() - out.ndim();
|
||||||
|
size_t n_updates = nind ? inds[0].size() : 1;
|
||||||
|
|
||||||
|
std::vector<int> update_shape(
|
||||||
|
updates.shape().begin() + inds_ndim, updates.shape().end());
|
||||||
|
size_t update_size = 1;
|
||||||
|
for (auto us : update_shape) {
|
||||||
|
update_size *= us;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_updates; ++i) {
|
||||||
|
size_t out_offset = 0;
|
||||||
|
for (int j = 0; j < nind; ++j) {
|
||||||
|
auto ax = axes[j];
|
||||||
|
auto idx_loc = elem_to_loc(i, inds[j]);
|
||||||
|
auto idx_val =
|
||||||
|
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
|
||||||
|
out_offset += (idx_val * out.strides()[ax]);
|
||||||
|
}
|
||||||
|
for (int j = 0; j < update_size; ++j) {
|
||||||
|
auto update_loc = elem_to_loc(i * update_size + j, updates);
|
||||||
|
auto out_loc = elem_to_loc(j, update_shape, out.strides());
|
||||||
|
op(updates.data<InT>()[update_loc],
|
||||||
|
out.data<InT>() + out_offset + out_loc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename InT, typename IdxT>
|
||||||
|
void dispatch_scatter_inds(
|
||||||
|
array& out,
|
||||||
|
const std::vector<array>& indices,
|
||||||
|
const array& updates,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
Scatter::ReduceType rtype) {
|
||||||
|
switch (rtype) {
|
||||||
|
case Scatter::None:
|
||||||
|
scatter<InT, IdxT>(
|
||||||
|
updates, out, indices, axes, [](auto x, auto* y) { (*y) = x; });
|
||||||
|
break;
|
||||||
|
case Scatter::Sum:
|
||||||
|
scatter<InT, IdxT>(
|
||||||
|
updates, out, indices, axes, [](auto x, auto* y) { (*y) += x; });
|
||||||
|
break;
|
||||||
|
case Scatter::Prod:
|
||||||
|
scatter<InT, IdxT>(
|
||||||
|
updates, out, indices, axes, [](auto x, auto* y) { (*y) *= x; });
|
||||||
|
break;
|
||||||
|
case Scatter::Max:
|
||||||
|
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
|
||||||
|
(*y) = (*y > x) ? *y : x;
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
case Scatter::Min:
|
||||||
|
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
|
||||||
|
(*y) = (*y < x) ? *y : x;
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename InT>
|
||||||
|
void dispatch_scatter(
|
||||||
|
array& out,
|
||||||
|
const std::vector<array>& inds,
|
||||||
|
const array& updates,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
Scatter::ReduceType rtype) {
|
||||||
|
if (inds.empty()) {
|
||||||
|
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (inds[0].dtype()) {
|
||||||
|
case bool_:
|
||||||
|
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
dispatch_scatter_inds<InT, uint16_t>(out, inds, updates, axes, rtype);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
dispatch_scatter_inds<InT, uint32_t>(out, inds, updates, axes, rtype);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
dispatch_scatter_inds<InT, uint64_t>(out, inds, updates, axes, rtype);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
dispatch_scatter_inds<InT, int8_t>(out, inds, updates, axes, rtype);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
dispatch_scatter_inds<InT, int16_t>(out, inds, updates, axes, rtype);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
dispatch_scatter_inds<InT, int32_t>(out, inds, updates, axes, rtype);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
case float32:
|
||||||
|
case bfloat16:
|
||||||
|
case complex64:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[Scatter::eval_cpu] Cannot scatter with floating point indices.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scatter::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() >= 2);
|
||||||
|
|
||||||
|
auto& src = inputs[0];
|
||||||
|
std::vector<array> inds(inputs.begin() + 1, inputs.end() - 1);
|
||||||
|
auto& updates = inputs.back();
|
||||||
|
|
||||||
|
// Copy src into out (copy allocates memory for out)
|
||||||
|
copy(src, out, CopyType::General);
|
||||||
|
|
||||||
|
switch (src.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
52
mlx/backend/common/load.cpp
Normal file
52
mlx/backend/common/load.cpp
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/load.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <const uint8_t scalar_size>
|
||||||
|
void swap_endianess(uint8_t* data_bytes, size_t N) {
|
||||||
|
struct Elem {
|
||||||
|
uint8_t bytes[scalar_size];
|
||||||
|
};
|
||||||
|
|
||||||
|
Elem* data = reinterpret_cast<Elem*>(data_bytes);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < N; i++) {
|
||||||
|
for (size_t j = 0; j < (scalar_size / 2); j++) {
|
||||||
|
std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 0);
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
reader_->seek(offset_, std::ios_base::beg);
|
||||||
|
reader_->read(out.data<char>(), out.nbytes());
|
||||||
|
|
||||||
|
if (swap_endianness_) {
|
||||||
|
switch (out.itemsize()) {
|
||||||
|
case 2:
|
||||||
|
swap_endianess<2>(out.data<uint8_t>(), out.data_size());
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
swap_endianess<4>(out.data<uint8_t>(), out.data_size());
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
swap_endianess<8>(out.data<uint8_t>(), out.data_size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
622
mlx/backend/common/primitives.cpp
Normal file
622
mlx/backend/common/primitives.cpp
Normal file
@@ -0,0 +1,622 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <cmath>
|
||||||
|
#include <numeric>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/common/arange.h"
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/backend/common/erf.h"
|
||||||
|
#include "mlx/backend/common/threefry.h"
|
||||||
|
#include "mlx/backend/common/unary.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void Abs::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
if (is_unsigned(in.dtype())) {
|
||||||
|
// No-op for unsigned types
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
unary(in, out, AbsOp());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Arange::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
arange(inputs, out, start_, step_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcCos::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::acos(x); });
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[arccos] Cannot compute inverse cosine of elements in array"
|
||||||
|
" with non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::acosh(x); });
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[arccosh] Cannot compute inverse hyperbolic cosine of elements in"
|
||||||
|
" array with non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcSin::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::asin(x); });
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[arcsin] Cannot compute inverse sine of elements in array"
|
||||||
|
" with non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::asinh(x); });
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[arcsinh] Cannot compute inverse hyperbolic sine of elements in"
|
||||||
|
" array with non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcTan::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::atan(x); });
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[arctan] Cannot compute inverse tangent of elements in array"
|
||||||
|
" with non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::atanh(x); });
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[arctanh] Cannot compute inverse hyperbolic tangent of elements in"
|
||||||
|
" array with non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsType::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||||
|
copy(in, out, ctype);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
if (!in.flags().row_contiguous) {
|
||||||
|
// Just ensuring that inputs[0] came from the ops which would ensure the
|
||||||
|
// input is row contiguous.
|
||||||
|
throw std::runtime_error(
|
||||||
|
"AsStrided must be used with row contiguous arrays only.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the flags given the shape and strides
|
||||||
|
bool row_contiguous = true, col_contiguous = true;
|
||||||
|
size_t r = 1, c = 1;
|
||||||
|
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
|
||||||
|
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
|
||||||
|
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
|
||||||
|
r *= shape_[i];
|
||||||
|
c *= shape_[j];
|
||||||
|
}
|
||||||
|
auto flags = in.flags();
|
||||||
|
// TODO: Compute the contiguous flag in a better way cause now we are
|
||||||
|
// unnecessarily strict.
|
||||||
|
flags.contiguous = row_contiguous || col_contiguous;
|
||||||
|
flags.row_contiguous = row_contiguous;
|
||||||
|
flags.col_contiguous = col_contiguous;
|
||||||
|
|
||||||
|
// There is no easy way to compute the actual data size so we use out.size().
|
||||||
|
// The contiguous flag will almost certainly not be set so no code should
|
||||||
|
// rely on data_size anyway.
|
||||||
|
size_t data_size = out.size();
|
||||||
|
|
||||||
|
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::vector<size_t> strides(out.ndim(), 0);
|
||||||
|
int diff = out.ndim() - in.ndim();
|
||||||
|
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||||
|
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||||
|
}
|
||||||
|
auto flags = in.flags();
|
||||||
|
if (out.size() > in.size()) {
|
||||||
|
flags.row_contiguous = flags.col_contiguous = false;
|
||||||
|
}
|
||||||
|
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
void Concatenate::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
std::vector<int> sizes;
|
||||||
|
sizes.push_back(0);
|
||||||
|
for (auto& p : inputs) {
|
||||||
|
sizes.push_back(p.shape(axis_));
|
||||||
|
}
|
||||||
|
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto strides = out.strides();
|
||||||
|
auto flags = out.flags();
|
||||||
|
flags.row_contiguous = false;
|
||||||
|
flags.col_contiguous = false;
|
||||||
|
flags.contiguous = false;
|
||||||
|
for (int i = 0; i < inputs.size(); i++) {
|
||||||
|
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||||
|
size_t data_offset = strides[axis_] * sizes[i];
|
||||||
|
out_slice.copy_shared_buffer(
|
||||||
|
out, strides, flags, out_slice.size(), data_offset);
|
||||||
|
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
out.copy_shared_buffer(inputs[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::cos(x); });
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[cos] Cannot compute cosine of elements in array"
|
||||||
|
" with non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Cosh::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::cosh(x); });
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[cosh] Cannot compute hyperbolic cosine of elements in array"
|
||||||
|
" with non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Erf::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case float32:
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
unary_op<float>(in, out, [](auto x) { return std::erf(x); });
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
unary_op<float16_t>(in, out, [](auto x) {
|
||||||
|
return static_cast<float16_t>(std::erf(static_cast<float>(x)));
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
unary_op<bfloat16_t>(in, out, [](auto x) {
|
||||||
|
return static_cast<bfloat16_t>(std::erf(static_cast<float>(x)));
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[erf] Error function only defined for arrays"
|
||||||
|
" with real floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ErfInv::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case float32:
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
unary_op<float>(in, out, [](auto x) { return erfinv(x); });
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
unary_op<float16_t>(in, out, [](auto x) {
|
||||||
|
return static_cast<float16_t>(erfinv(static_cast<float>(x)));
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
unary_op<bfloat16_t>(in, out, [](auto x) {
|
||||||
|
return static_cast<bfloat16_t>(erfinv(static_cast<float>(x)));
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[erf_inv] Inverse error function only defined for arrays"
|
||||||
|
" with real floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[exp] Cannot exponentiate elements in array"
|
||||||
|
" with non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Full::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
assert(in.dtype() == out.dtype());
|
||||||
|
CopyType ctype;
|
||||||
|
if (in.data_size() == 1) {
|
||||||
|
ctype = CopyType::Scalar;
|
||||||
|
} else if (in.flags().contiguous) {
|
||||||
|
ctype = CopyType::Vector;
|
||||||
|
} else {
|
||||||
|
ctype = CopyType::General;
|
||||||
|
}
|
||||||
|
copy(in, out, ctype);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Log::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
switch (base_) {
|
||||||
|
case Base::e:
|
||||||
|
unary_fp(in, out, [](auto x) { return std::log(x); });
|
||||||
|
break;
|
||||||
|
case Base::two:
|
||||||
|
unary_fp(in, out, [](auto x) { return std::log2(x); });
|
||||||
|
break;
|
||||||
|
case Base::ten:
|
||||||
|
unary_fp(in, out, [](auto x) { return std::log10(x); });
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[log] Cannot compute log of elements in array with"
|
||||||
|
" non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Log1p::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::log1p(x); });
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[log1p] Cannot compute log of elements in array with"
|
||||||
|
" non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
unary(in, out, [](auto x) { return !x; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void Negative::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
unary(in, out, [](auto x) { return -x; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void Pad::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
// Inputs must be base input array and scalar val array
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
auto& val = inputs[1];
|
||||||
|
|
||||||
|
// Padding value must be a scalar
|
||||||
|
assert(val.size() == 1);
|
||||||
|
|
||||||
|
// Padding value, input and output must be of the same type
|
||||||
|
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||||
|
|
||||||
|
// Fill output with val
|
||||||
|
copy(val, out, CopyType::Scalar);
|
||||||
|
|
||||||
|
// Find offset for start of input values
|
||||||
|
size_t data_offset = 0;
|
||||||
|
for (int i = 0; i < axes_.size(); i++) {
|
||||||
|
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
|
||||||
|
data_offset += out.strides()[ax] * low_pad_size_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract slice from output where input will be pasted
|
||||||
|
array out_slice(in.shape(), out.dtype(), nullptr, {});
|
||||||
|
out_slice.copy_shared_buffer(
|
||||||
|
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||||
|
|
||||||
|
// Copy input values into the slice
|
||||||
|
copy_inplace(in, out_slice, CopyType::GeneralGeneral);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RandomBits::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
// keys has shape (N1, ..., NK, 2)
|
||||||
|
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||||
|
auto& keys = inputs[0];
|
||||||
|
size_t num_keys = keys.size() / 2;
|
||||||
|
|
||||||
|
size_t elems_per_key = out.size() / num_keys;
|
||||||
|
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto kptr = inputs[0].data<uint32_t>();
|
||||||
|
auto cptr = out.data<char>();
|
||||||
|
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||||
|
auto half_size = out_skip / 2;
|
||||||
|
bool even = out_skip % 2 == 0;
|
||||||
|
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
|
||||||
|
auto ptr = reinterpret_cast<uint32_t*>(cptr);
|
||||||
|
// Get ith key
|
||||||
|
auto kidx = 2 * i;
|
||||||
|
auto k1_elem = elem_to_loc(kidx, keys.shape(), keys.strides());
|
||||||
|
auto k2_elem = elem_to_loc(kidx + 1, keys.shape(), keys.strides());
|
||||||
|
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
|
||||||
|
|
||||||
|
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
|
||||||
|
for (; count.first + 1 < half_size; count.first++, count.second++) {
|
||||||
|
std::tie(ptr[count.first], ptr[count.second]) =
|
||||||
|
random::threefry2x32_hash(key, count);
|
||||||
|
}
|
||||||
|
if (count.first < half_size) {
|
||||||
|
auto rb = random::threefry2x32_hash(key, count);
|
||||||
|
ptr[count.first++] = rb.first;
|
||||||
|
if (bytes_per_key % 4 > 0) {
|
||||||
|
std::copy(
|
||||||
|
reinterpret_cast<char*>(&rb.second),
|
||||||
|
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
|
||||||
|
cptr + 4 * count.second);
|
||||||
|
} else {
|
||||||
|
ptr[count.second] = rb.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!even) {
|
||||||
|
count.second = 0;
|
||||||
|
ptr[half_size] = random::threefry2x32_hash(key, count).first;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (in.flags().row_contiguous) {
|
||||||
|
// For row contiguous reshapes:
|
||||||
|
// - Shallow copy the buffer
|
||||||
|
// - If reshaping into a vector (all singleton dimensions except one) it
|
||||||
|
// becomes col contiguous again.
|
||||||
|
auto flags = in.flags();
|
||||||
|
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||||
|
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||||
|
out.copy_shared_buffer(in, out.strides(), flags, in.data_size());
|
||||||
|
} else {
|
||||||
|
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
auto sigmoid_op = [](auto x) {
|
||||||
|
auto one = static_cast<decltype(x)>(1.0);
|
||||||
|
return one / (one + std::exp(-x));
|
||||||
|
};
|
||||||
|
unary_fp(in, out, sigmoid_op);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[sigmoid] Cannot sigmoid of elements in array with"
|
||||||
|
" non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sign::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
if (in.dtype() == bool_) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
unary(in, out, SignOp());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sin::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::sin(x); });
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[sin] Cannot compute sine of elements in array"
|
||||||
|
" with non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sinh::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::sinh(x); });
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[sinh] Cannot compute hyperbolic sine of elements in array"
|
||||||
|
" with non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto& in = inputs[0];
|
||||||
|
auto strides = in.strides();
|
||||||
|
auto flags = in.flags();
|
||||||
|
size_t data_offset = 0;
|
||||||
|
for (int i = 0; i < in.ndim(); ++i) {
|
||||||
|
data_offset += start_indices_[i] * in.strides()[i];
|
||||||
|
strides[i] *= strides_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute row/col contiguity
|
||||||
|
size_t data_size = 1;
|
||||||
|
size_t f_stride = 1;
|
||||||
|
size_t b_stride = 1;
|
||||||
|
flags.row_contiguous = true;
|
||||||
|
flags.col_contiguous = true;
|
||||||
|
for (int i = 0, ri = out.ndim() - 1; ri >= 0; i++, ri--) {
|
||||||
|
flags.col_contiguous &= strides[i] == f_stride || out.shape(i) == 1;
|
||||||
|
flags.row_contiguous &= strides[ri] == b_stride || out.shape(ri) == 1;
|
||||||
|
f_stride *= out.shape(i);
|
||||||
|
b_stride *= out.shape(ri);
|
||||||
|
if (strides[i] > 0) {
|
||||||
|
data_size *= out.shape(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data_size == 1) {
|
||||||
|
// Broadcasted scalar array is contiguous.
|
||||||
|
flags.contiguous = true;
|
||||||
|
} else if (data_size == in.data_size()) {
|
||||||
|
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||||
|
// alone.
|
||||||
|
} else {
|
||||||
|
// We sliced something. So either we are row or col contiguous or we
|
||||||
|
// punched a hole.
|
||||||
|
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||||
|
}
|
||||||
|
|
||||||
|
out.copy_shared_buffer(in, strides, flags, data_size, data_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Square::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
unary(in, out, [](auto x) { return x * x; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sqrt::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
if (recip_) {
|
||||||
|
unary_fp(in, out, [](auto x) {
|
||||||
|
return static_cast<decltype(x)>(1.0) / sqrt(x);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
unary_fp(in, out, [](auto x) { return sqrt(x); });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
out.copy_shared_buffer(inputs[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Tan::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::tan(x); });
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[tan] Cannot compute tangent of elements in array"
|
||||||
|
" with non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::tanh(x); });
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[tanh] Cannot compute hyperbolic tangent of elements in array"
|
||||||
|
" with non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
std::vector<size_t> out_strides(out.ndim());
|
||||||
|
auto& in = inputs[0];
|
||||||
|
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||||
|
out_strides[ax] = in.strides()[axes_[ax]];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conditions for {row/col}_contiguous
|
||||||
|
// - array must be contiguous (no gaps)
|
||||||
|
// - underlying buffer size should have the same size as the array
|
||||||
|
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||||
|
// with size == 1)
|
||||||
|
// - in the forward direction (column contiguous)
|
||||||
|
// - in the reverse direction (row contiguous)
|
||||||
|
// - vectors are both row and col contiguous (hence if both row/col are
|
||||||
|
// true, they stay true)
|
||||||
|
auto flags = in.flags();
|
||||||
|
if (flags.contiguous && in.data_size() == in.size()) {
|
||||||
|
size_t f_stride = 1;
|
||||||
|
size_t b_stride = 1;
|
||||||
|
flags.col_contiguous = true;
|
||||||
|
flags.row_contiguous = true;
|
||||||
|
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||||
|
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
|
||||||
|
f_stride *= out.shape(i);
|
||||||
|
flags.row_contiguous &=
|
||||||
|
(out_strides[ri] == b_stride || out.shape(ri) == 1);
|
||||||
|
b_stride *= out.shape(ri);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
215
mlx/backend/common/reduce.cpp
Normal file
215
mlx/backend/common/reduce.cpp
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
#include <cassert>
|
||||||
|
#include <functional>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/reduce.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
struct Limits {
|
||||||
|
static const U max;
|
||||||
|
static const U min;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define instantiate_default_limit(type) \
|
||||||
|
template <> \
|
||||||
|
struct Limits<type> { \
|
||||||
|
static constexpr type max = std::numeric_limits<type>::max(); \
|
||||||
|
static constexpr type min = std::numeric_limits<type>::min(); \
|
||||||
|
};
|
||||||
|
|
||||||
|
instantiate_default_limit(uint8_t);
|
||||||
|
instantiate_default_limit(uint16_t);
|
||||||
|
instantiate_default_limit(uint32_t);
|
||||||
|
instantiate_default_limit(uint64_t);
|
||||||
|
instantiate_default_limit(int8_t);
|
||||||
|
instantiate_default_limit(int16_t);
|
||||||
|
instantiate_default_limit(int32_t);
|
||||||
|
instantiate_default_limit(int64_t);
|
||||||
|
|
||||||
|
#define instantiate_float_limit(type) \
|
||||||
|
template <> \
|
||||||
|
struct Limits<type> { \
|
||||||
|
static const type max; \
|
||||||
|
static const type min; \
|
||||||
|
};
|
||||||
|
|
||||||
|
instantiate_float_limit(float16_t);
|
||||||
|
instantiate_float_limit(bfloat16_t);
|
||||||
|
instantiate_float_limit(float);
|
||||||
|
instantiate_float_limit(complex64_t);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct Limits<bool> {
|
||||||
|
static constexpr bool max = true;
|
||||||
|
static constexpr bool min = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
const float Limits<float>::max = std::numeric_limits<float>::infinity();
|
||||||
|
const float Limits<float>::min = -std::numeric_limits<float>::infinity();
|
||||||
|
const bfloat16_t Limits<bfloat16_t>::max =
|
||||||
|
std::numeric_limits<float>::infinity();
|
||||||
|
const bfloat16_t Limits<bfloat16_t>::min =
|
||||||
|
-std::numeric_limits<float>::infinity();
|
||||||
|
const float16_t Limits<float16_t>::max = std::numeric_limits<float>::infinity();
|
||||||
|
const float16_t Limits<float16_t>::min =
|
||||||
|
-std::numeric_limits<float>::infinity();
|
||||||
|
const complex64_t Limits<complex64_t>::max =
|
||||||
|
std::numeric_limits<float>::infinity();
|
||||||
|
const complex64_t Limits<complex64_t>::min =
|
||||||
|
-std::numeric_limits<float>::infinity();
|
||||||
|
|
||||||
|
struct AndReduce {
|
||||||
|
template <typename T>
|
||||||
|
void operator()(bool* a, T b) {
|
||||||
|
(*a) &= (b != 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(bool* y, bool x) {
|
||||||
|
(*y) &= x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct OrReduce {
|
||||||
|
template <typename T>
|
||||||
|
void operator()(bool* a, T b) {
|
||||||
|
(*a) |= (b != 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(bool* y, bool x) {
|
||||||
|
(*y) |= x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InT>
|
||||||
|
void reduce_dispatch_out(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType rtype,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
switch (rtype) {
|
||||||
|
case Reduce::And: {
|
||||||
|
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Reduce::Or: {
|
||||||
|
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Reduce::Sum: {
|
||||||
|
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
reduction_op<InT, bool>(in, out, axes, false, op);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
reduction_op<InT, uint8_t>(in, out, axes, 0, op);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
reduction_op<InT, uint16_t>(in, out, axes, 0, op);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
reduction_op<InT, uint32_t>(in, out, axes, 0, op);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
reduction_op<InT, uint64_t>(in, out, axes, 0, op);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
reduction_op<InT, int8_t>(in, out, axes, 0, op);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
reduction_op<InT, int16_t>(in, out, axes, 0, op);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
reduction_op<InT, int64_t>(in, out, axes, 0, op);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
reduction_op<InT, float16_t>(in, out, axes, 0.0f, op);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
reduction_op<InT, float>(in, out, axes, 0.0f, op);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
reduction_op<InT, bfloat16_t>(in, out, axes, 0.0f, op);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
reduction_op<InT, complex64_t>(in, out, axes, complex64_t{0.0f}, op);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case Reduce::Prod: {
|
||||||
|
auto op = [](auto y, auto x) { (*y) *= x; };
|
||||||
|
reduction_op<InT, InT>(in, out, axes, 1, op);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Reduce::Max: {
|
||||||
|
auto op = [](auto y, auto x) { (*y) = (*y > x) ? *y : x; };
|
||||||
|
auto init = Limits<InT>::min;
|
||||||
|
reduction_op<InT, InT>(in, out, axes, init, op);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Reduce::Min: {
|
||||||
|
auto op = [](auto y, auto x) { (*y) = (*y < x) ? *y : x; };
|
||||||
|
auto init = Limits<InT>::max;
|
||||||
|
reduction_op<InT, InT>(in, out, axes, init, op);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Reduce::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
switch (in.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
reduce_dispatch_out<bool>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
reduce_dispatch_out<uint32_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
reduce_dispatch_out<uint64_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
reduce_dispatch_out<int32_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
reduce_dispatch_out<int64_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
reduce_dispatch_out<float16_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
reduce_dispatch_out<float>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
reduce_dispatch_out<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
reduce_dispatch_out<complex64_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
364
mlx/backend/common/reduce.h
Normal file
364
mlx/backend/common/reduce.h
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
enum ReductionOpType {
|
||||||
|
// Self-explanatory. Read everything and produce 1 output.
|
||||||
|
ContiguousAllReduce,
|
||||||
|
|
||||||
|
// The input is contiguous and the last axis is reduced
|
||||||
|
// N1xR1xN2xR2x...xNnxRn
|
||||||
|
ContiguousReduce,
|
||||||
|
|
||||||
|
// The input is contiguous and the last axis is not reduced
|
||||||
|
// R1xN1xR2xN2x...xRnxNn
|
||||||
|
ContiguousStridedReduce,
|
||||||
|
|
||||||
|
// The input is not contiguous but the last axis is and it is reduced so we
|
||||||
|
// need to figure out the offsets but we can call the contiguous reduce after
|
||||||
|
// that.
|
||||||
|
// N3xR1xN1xR4x...xRn
|
||||||
|
GeneralContiguousReduce,
|
||||||
|
|
||||||
|
// The input is not contiguous but the last reduction axis and the last axis
|
||||||
|
// are so we need to figure out the offset but we can call the strided reduce
|
||||||
|
// after that.
|
||||||
|
GeneralStridedReduce,
|
||||||
|
|
||||||
|
// The input is not contiguous after the reduction axis and it may contain
|
||||||
|
// 0-stride axes or transpositions. We could copy the strides and produce a
|
||||||
|
// transposed outcome or we can read the input out of order and write the
|
||||||
|
// output in order.
|
||||||
|
GeneralReduce
|
||||||
|
};
|
||||||
|
|
||||||
|
// Helper for the ndimensional strided loop
|
||||||
|
// Should this be in utils?
|
||||||
|
inline void nd_loop(
|
||||||
|
std::function<void(int)> callback,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
const std::vector<size_t>& strides) {
|
||||||
|
std::function<void(int, int)> loop_inner;
|
||||||
|
loop_inner = [&](int dim, int offset) {
|
||||||
|
if (dim < shape.size() - 1) {
|
||||||
|
int size = shape[dim];
|
||||||
|
size_t stride = strides[dim];
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
loop_inner(dim + 1, offset + i * stride);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int size = shape[dim];
|
||||||
|
size_t stride = strides[dim];
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
callback(offset + i * stride);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
loop_inner(0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||||
|
const array& x,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
std::vector<int> shape = x.shape();
|
||||||
|
std::vector<size_t> strides = x.strides();
|
||||||
|
|
||||||
|
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||||
|
int a = axes[i];
|
||||||
|
shape.erase(shape.begin() + a);
|
||||||
|
strides.erase(strides.begin() + a);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_pair(shape, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
struct DefaultStridedReduce {
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
DefaultStridedReduce(Op op_) : op(op_) {}
|
||||||
|
|
||||||
|
void operator()(const T* x, U* accumulator, int size, size_t stride) {
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
U* moving_accumulator = accumulator;
|
||||||
|
for (int j = 0; j < stride; j++) {
|
||||||
|
op(moving_accumulator, *x);
|
||||||
|
moving_accumulator++;
|
||||||
|
x++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
struct DefaultContiguousReduce {
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
DefaultContiguousReduce(Op op_) : op(op_) {}
|
||||||
|
|
||||||
|
void operator()(const T* x, U* accumulator, int size) {
|
||||||
|
while (size-- > 0) {
|
||||||
|
op(accumulator, *x);
|
||||||
|
x++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ReductionPlan {
|
||||||
|
ReductionOpType type;
|
||||||
|
std::vector<int> shape;
|
||||||
|
std::vector<size_t> strides;
|
||||||
|
|
||||||
|
ReductionPlan(
|
||||||
|
ReductionOpType type_,
|
||||||
|
std::vector<int> shape_,
|
||||||
|
std::vector<size_t> strides_)
|
||||||
|
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
|
||||||
|
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||||
|
// The data is all there and we are reducing over everything
|
||||||
|
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||||
|
(x.flags().row_contiguous || x.flags().col_contiguous)) {
|
||||||
|
return ContiguousAllReduce;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Row contiguous input so the output is row contiguous
|
||||||
|
if (x.flags().row_contiguous) {
|
||||||
|
// Merge consecutive axes
|
||||||
|
std::vector<int> shape = {x.shape(axes[0])};
|
||||||
|
std::vector<size_t> strides = {x.strides()[axes[0]]};
|
||||||
|
for (int i = 1; i < axes.size(); i++) {
|
||||||
|
if (axes[i] - 1 == axes[i - 1]) {
|
||||||
|
shape.back() *= x.shape(axes[i]);
|
||||||
|
strides.back() = x.strides()[axes[i]];
|
||||||
|
} else {
|
||||||
|
shape.push_back(x.shape(axes[i]));
|
||||||
|
strides.push_back(x.strides()[axes[i]]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (strides.back() == 1) {
|
||||||
|
return ReductionPlan(ContiguousReduce, shape, strides);
|
||||||
|
} else if (strides.back() > 1) {
|
||||||
|
return ReductionPlan(ContiguousStridedReduce, shape, strides);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Let's check if we can optimize our access patterns
|
||||||
|
//
|
||||||
|
// 1. We have a reduction axis with stride 1. Simply call
|
||||||
|
// GeneralContiguousReduce and be done with it.
|
||||||
|
// 2. We have transpositions and we are not reducing over the axis with
|
||||||
|
// stride 1. However, we are reducing over an axis where everything is
|
||||||
|
// contiguous in memory to the right of that axis. We can call strided
|
||||||
|
// reduce and be done with it.
|
||||||
|
// 2. We have weird transpositions and expands. Copy the strides to the
|
||||||
|
// output, then call strided reduce.
|
||||||
|
|
||||||
|
// Sort reduction axes by stride in order to merge them and figure out if we
|
||||||
|
// have a contiguous reduction.
|
||||||
|
std::vector<std::pair<int, size_t>> reductions;
|
||||||
|
for (auto a : axes) {
|
||||||
|
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||||
|
}
|
||||||
|
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
|
||||||
|
return a.second > b.second;
|
||||||
|
});
|
||||||
|
// Extract the two smallest and try to merge them in case the contiguous
|
||||||
|
// reduction can be bigger than just the last axis.
|
||||||
|
for (int i = reductions.size() - 1; i >= 1; i--) {
|
||||||
|
auto a = reductions[i];
|
||||||
|
auto b = reductions[i - 1];
|
||||||
|
|
||||||
|
// b.stride = a.shape * a.stride then a and b are contiguous
|
||||||
|
if (b.second == a.first * a.second) {
|
||||||
|
reductions.erase(reductions.begin() + i);
|
||||||
|
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> shape;
|
||||||
|
std::vector<size_t> strides;
|
||||||
|
for (auto r : reductions) {
|
||||||
|
shape.push_back(r.first);
|
||||||
|
strides.push_back(r.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
// We can call the contiguous reduction op for every weird way the input is
|
||||||
|
// structured in the rest of the axes.
|
||||||
|
if (strides.back() == 1) {
|
||||||
|
return ReductionPlan(GeneralContiguousReduce, shape, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delegate to the general strided reduction op if the axes after
|
||||||
|
// strides.back() are contiguous.
|
||||||
|
if (strides.back() > 1) {
|
||||||
|
int size = 1;
|
||||||
|
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||||
|
if (axes.back() == i) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (x.strides()[i] != size) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
size *= x.shape(i);
|
||||||
|
}
|
||||||
|
if (size >= strides.back()) {
|
||||||
|
return ReductionPlan(GeneralStridedReduce, shape, strides);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ReductionPlan(GeneralReduce, shape, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename OpS, typename OpC, typename Op>
|
||||||
|
void reduction_op(
|
||||||
|
const array& x,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
U init,
|
||||||
|
OpS ops,
|
||||||
|
OpC opc,
|
||||||
|
Op op) {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
ReductionPlan plan = get_reduction_plan(x, axes);
|
||||||
|
|
||||||
|
if (plan.type == ContiguousAllReduce) {
|
||||||
|
U* out_ptr = out.data<U>();
|
||||||
|
*out_ptr = init;
|
||||||
|
opc(x.data<T>(), out_ptr, x.size());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> shape;
|
||||||
|
std::vector<size_t> strides;
|
||||||
|
|
||||||
|
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
|
||||||
|
int reduction_size = plan.shape[0];
|
||||||
|
const T* x_ptr = x.data<T>();
|
||||||
|
U* out_ptr = out.data<U>();
|
||||||
|
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
|
||||||
|
*out_ptr = init;
|
||||||
|
opc(x_ptr, out_ptr, reduction_size);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) {
|
||||||
|
int reduction_size = plan.shape.back();
|
||||||
|
plan.shape.pop_back();
|
||||||
|
plan.strides.pop_back();
|
||||||
|
const T* x_ptr = x.data<T>();
|
||||||
|
U* out_ptr = out.data<U>();
|
||||||
|
// Unrolling the following loop (and implementing it in order for
|
||||||
|
// ContiguousReduce) should hold extra performance boost.
|
||||||
|
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||||
|
if (plan.shape.size() == 0) {
|
||||||
|
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||||
|
int offset = elem_to_loc(i, shape, strides);
|
||||||
|
*out_ptr = init;
|
||||||
|
opc(x_ptr + offset, out_ptr, reduction_size);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||||
|
int offset = elem_to_loc(i, shape, strides);
|
||||||
|
*out_ptr = init;
|
||||||
|
nd_loop(
|
||||||
|
[&](int extra_offset) {
|
||||||
|
opc(x_ptr + offset + extra_offset, out_ptr, reduction_size);
|
||||||
|
},
|
||||||
|
plan.shape,
|
||||||
|
plan.strides);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) {
|
||||||
|
int reduction_size = plan.shape.back();
|
||||||
|
size_t reduction_stride = plan.strides.back();
|
||||||
|
plan.shape.pop_back();
|
||||||
|
plan.strides.pop_back();
|
||||||
|
const T* x_ptr = x.data<T>();
|
||||||
|
U* out_ptr = out.data<U>();
|
||||||
|
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||||
|
std::fill_n(out_ptr, reduction_stride, init);
|
||||||
|
ops(x_ptr, out_ptr, reduction_size, reduction_stride);
|
||||||
|
x_ptr += reduction_stride * reduction_size;
|
||||||
|
out_ptr += reduction_stride;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (plan.type == GeneralStridedReduce ||
|
||||||
|
plan.type == ContiguousStridedReduce) {
|
||||||
|
int reduction_size = plan.shape.back();
|
||||||
|
size_t reduction_stride = plan.strides.back();
|
||||||
|
plan.shape.pop_back();
|
||||||
|
plan.strides.pop_back();
|
||||||
|
const T* x_ptr = x.data<T>();
|
||||||
|
U* out_ptr = out.data<U>();
|
||||||
|
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||||
|
if (plan.shape.size() == 0) {
|
||||||
|
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||||
|
int offset = elem_to_loc(i, shape, strides);
|
||||||
|
std::fill_n(out_ptr, reduction_stride, init);
|
||||||
|
ops(x_ptr + offset, out_ptr, reduction_size, reduction_stride);
|
||||||
|
out_ptr += reduction_stride;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||||
|
int offset = elem_to_loc(i, shape, strides);
|
||||||
|
std::fill_n(out_ptr, reduction_stride, init);
|
||||||
|
nd_loop(
|
||||||
|
[&](int extra_offset) {
|
||||||
|
ops(x_ptr + offset + extra_offset,
|
||||||
|
out_ptr,
|
||||||
|
reduction_size,
|
||||||
|
reduction_stride);
|
||||||
|
},
|
||||||
|
plan.shape,
|
||||||
|
plan.strides);
|
||||||
|
out_ptr += reduction_stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (plan.type == GeneralReduce) {
|
||||||
|
const T* x_ptr = x.data<T>();
|
||||||
|
U* out_ptr = out.data<U>();
|
||||||
|
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||||
|
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||||
|
int offset = elem_to_loc(i, shape, strides);
|
||||||
|
U val = init;
|
||||||
|
nd_loop(
|
||||||
|
[&](int extra_offset) { op(&val, *(x_ptr + offset + extra_offset)); },
|
||||||
|
plan.shape,
|
||||||
|
plan.strides);
|
||||||
|
*out_ptr = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void reduction_op(
|
||||||
|
const array& x,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
U init,
|
||||||
|
Op op) {
|
||||||
|
DefaultStridedReduce<T, U, Op> ops(op);
|
||||||
|
DefaultContiguousReduce<T, U, Op> opc(op);
|
||||||
|
reduction_op<T, U>(x, out, axes, init, ops, opc, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
323
mlx/backend/common/scan.cpp
Normal file
323
mlx/backend/common/scan.cpp
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
struct DefaultContiguousScan {
|
||||||
|
Op op;
|
||||||
|
U init;
|
||||||
|
|
||||||
|
DefaultContiguousScan(Op op_, U init_) : op(op_), init(init_) {}
|
||||||
|
|
||||||
|
void operator()(
|
||||||
|
const T* input,
|
||||||
|
U* output,
|
||||||
|
int count,
|
||||||
|
int stride,
|
||||||
|
bool reverse,
|
||||||
|
bool inclusive) {
|
||||||
|
if (!reverse) {
|
||||||
|
if (inclusive) {
|
||||||
|
for (int i = 0; i < count; i++) {
|
||||||
|
*output = *input;
|
||||||
|
for (int j = 1; j < stride; j++) {
|
||||||
|
input++;
|
||||||
|
output++;
|
||||||
|
op(output, output - 1, input);
|
||||||
|
}
|
||||||
|
output++;
|
||||||
|
input++;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < count; i++) {
|
||||||
|
*output = init;
|
||||||
|
for (int j = 1; j < stride; j++) {
|
||||||
|
op(output + 1, output, input);
|
||||||
|
input++;
|
||||||
|
output++;
|
||||||
|
}
|
||||||
|
output++;
|
||||||
|
input++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (inclusive) {
|
||||||
|
for (int i = 0; i < count; i++) {
|
||||||
|
output += stride - 1;
|
||||||
|
input += stride - 1;
|
||||||
|
*output = *input;
|
||||||
|
for (int j = 1; j < stride; j++) {
|
||||||
|
input--;
|
||||||
|
output--;
|
||||||
|
op(output, output + 1, input);
|
||||||
|
}
|
||||||
|
output += stride;
|
||||||
|
input += stride;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < count; i++) {
|
||||||
|
output += stride - 1;
|
||||||
|
input += stride - 1;
|
||||||
|
*output = init;
|
||||||
|
for (int j = 1; j < stride; j++) {
|
||||||
|
op(output - 1, output, input);
|
||||||
|
input--;
|
||||||
|
output--;
|
||||||
|
}
|
||||||
|
output += stride;
|
||||||
|
input += stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
struct DefaultStridedScan {
|
||||||
|
Op op;
|
||||||
|
U init;
|
||||||
|
|
||||||
|
DefaultStridedScan(Op op_, U init_) : op(op_), init(init_) {}
|
||||||
|
|
||||||
|
void operator()(
|
||||||
|
const T* input,
|
||||||
|
U* output,
|
||||||
|
int count,
|
||||||
|
int size,
|
||||||
|
int stride,
|
||||||
|
bool reverse,
|
||||||
|
bool inclusive) {
|
||||||
|
// TODO: Vectorize the following naive implementation
|
||||||
|
if (!reverse) {
|
||||||
|
if (inclusive) {
|
||||||
|
for (int i = 0; i < count; i++) {
|
||||||
|
std::copy(input, input + stride, output);
|
||||||
|
output += stride;
|
||||||
|
input += stride;
|
||||||
|
for (int j = 1; j < size; j++) {
|
||||||
|
for (int k = 0; k < stride; k++) {
|
||||||
|
op(output, output - stride, input);
|
||||||
|
output++;
|
||||||
|
input++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < count; i++) {
|
||||||
|
std::fill(output, output + stride, init);
|
||||||
|
output += stride;
|
||||||
|
input += stride;
|
||||||
|
for (int j = 1; j < size; j++) {
|
||||||
|
for (int k = 0; k < stride; k++) {
|
||||||
|
op(output, output - stride, input - stride);
|
||||||
|
output++;
|
||||||
|
input++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (inclusive) {
|
||||||
|
for (int i = 0; i < count; i++) {
|
||||||
|
output += (size - 1) * stride;
|
||||||
|
input += (size - 1) * stride;
|
||||||
|
std::copy(input, input + stride, output);
|
||||||
|
for (int j = 1; j < size; j++) {
|
||||||
|
for (int k = 0; k < stride; k++) {
|
||||||
|
output--;
|
||||||
|
input--;
|
||||||
|
op(output, output + stride, input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output += size * stride;
|
||||||
|
input += size * stride;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < count; i++) {
|
||||||
|
output += (size - 1) * stride;
|
||||||
|
input += (size - 1) * stride;
|
||||||
|
std::fill(output, output + stride, init);
|
||||||
|
for (int j = 1; j < size; j++) {
|
||||||
|
for (int k = 0; k < stride; k++) {
|
||||||
|
output--;
|
||||||
|
input--;
|
||||||
|
op(output, output + stride, input + stride);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output += size * stride;
|
||||||
|
input += size * stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U, typename OpCS, typename OpSS>
|
||||||
|
void scan_op(
|
||||||
|
OpCS opcs,
|
||||||
|
OpSS opss,
|
||||||
|
const array& input,
|
||||||
|
array& output,
|
||||||
|
int axis,
|
||||||
|
bool reverse,
|
||||||
|
bool inclusive) {
|
||||||
|
output.set_data(allocator::malloc_or_wait(output.nbytes()));
|
||||||
|
|
||||||
|
if (input.flags().row_contiguous) {
|
||||||
|
if (input.strides()[axis] == 1) {
|
||||||
|
opcs(
|
||||||
|
input.data<T>(),
|
||||||
|
output.data<U>(),
|
||||||
|
input.size() / input.shape(axis),
|
||||||
|
input.shape(axis),
|
||||||
|
reverse,
|
||||||
|
inclusive);
|
||||||
|
} else {
|
||||||
|
opss(
|
||||||
|
input.data<T>(),
|
||||||
|
output.data<U>(),
|
||||||
|
input.size() / input.shape(axis) / input.strides()[axis],
|
||||||
|
input.shape(axis),
|
||||||
|
input.strides()[axis],
|
||||||
|
reverse,
|
||||||
|
inclusive);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Scan op supports only contiguous inputs");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
void scan_dispatch(
|
||||||
|
Scan::ReduceType rtype,
|
||||||
|
const array& input,
|
||||||
|
array& output,
|
||||||
|
int axis,
|
||||||
|
bool reverse,
|
||||||
|
bool inclusive) {
|
||||||
|
switch (rtype) {
|
||||||
|
case Scan::Sum: {
|
||||||
|
auto op = [](U* o, const U* y, const T* x) { *o = *y + *x; };
|
||||||
|
auto init = static_cast<U>(0);
|
||||||
|
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||||
|
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
||||||
|
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Scan::Prod: {
|
||||||
|
auto op = [](U* o, const U* y, const T* x) { *o = *y * (*x); };
|
||||||
|
auto init = static_cast<U>(1);
|
||||||
|
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||||
|
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
||||||
|
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Scan::Min: {
|
||||||
|
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
|
||||||
|
auto init = (is_floating_point(input.dtype()))
|
||||||
|
? static_cast<U>(std::numeric_limits<float>::infinity())
|
||||||
|
: std::numeric_limits<U>::max();
|
||||||
|
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||||
|
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
||||||
|
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Scan::Max: {
|
||||||
|
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
|
||||||
|
auto init = (is_floating_point(input.dtype()))
|
||||||
|
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||||
|
: std::numeric_limits<U>::max();
|
||||||
|
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||||
|
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
||||||
|
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Scan::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
// Ensure contiguity
|
||||||
|
auto in = inputs[0];
|
||||||
|
if (!in.flags().row_contiguous) {
|
||||||
|
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||||
|
copy(in, arr_copy, CopyType::General);
|
||||||
|
in = arr_copy;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (in.dtype()) {
|
||||||
|
case bool_: {
|
||||||
|
// We could do a full dtype x dtype switch but this is the only case
|
||||||
|
// where we accumulate in a different type, for now.
|
||||||
|
//
|
||||||
|
// TODO: If we add the option to accumulate floats in higher precision
|
||||||
|
// floats perhaps we should add the full all-to-all dispatch.
|
||||||
|
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
|
||||||
|
scan_dispatch<bool, int32_t>(
|
||||||
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
|
} else {
|
||||||
|
scan_dispatch<bool, bool>(
|
||||||
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case uint8:
|
||||||
|
scan_dispatch<uint8_t, uint8_t>(
|
||||||
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
scan_dispatch<uint16_t, uint16_t>(
|
||||||
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
scan_dispatch<uint32_t, uint32_t>(
|
||||||
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
scan_dispatch<uint64_t, uint64_t>(
|
||||||
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
scan_dispatch<int8_t, int8_t>(
|
||||||
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
scan_dispatch<int16_t, int16_t>(
|
||||||
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
scan_dispatch<int32_t, int32_t>(
|
||||||
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
scan_dispatch<int64_t, int64_t>(
|
||||||
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
scan_dispatch<float16_t, float16_t>(
|
||||||
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
scan_dispatch<float, float>(
|
||||||
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
scan_dispatch<bfloat16_t, bfloat16_t>(
|
||||||
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
throw std::runtime_error("Scan ops do not support complex types yet");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
29
mlx/backend/common/threefry.cpp
Normal file
29
mlx/backend/common/threefry.cpp
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
#include "mlx/backend/common/threefry.h"
|
||||||
|
|
||||||
|
namespace mlx::core::random {
|
||||||
|
|
||||||
|
std::pair<uint32_t, uint32_t> threefry2x32_hash(
|
||||||
|
const std::pair<uint32_t, uint32_t>& key,
|
||||||
|
std::pair<uint32_t, uint32_t> count) {
|
||||||
|
constexpr static uint32_t rotations[2][4] = {
|
||||||
|
{13, 15, 26, 6}, {17, 29, 16, 24}};
|
||||||
|
|
||||||
|
uint32_t ks[3] = {key.first, key.second, key.first ^ key.second ^ 0x1BD11BDA};
|
||||||
|
|
||||||
|
count.first += ks[0];
|
||||||
|
count.second += ks[1];
|
||||||
|
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
for (auto r : rotations[i % 2]) {
|
||||||
|
count.first += count.second;
|
||||||
|
count.second = (count.second << r) | (count.second >> (32 - r));
|
||||||
|
count.second ^= count.first;
|
||||||
|
}
|
||||||
|
count.first += ks[(i + 1) % 3];
|
||||||
|
count.second += ks[(i + 2) % 3] + i + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::random
|
||||||
19
mlx/backend/common/threefry.h
Normal file
19
mlx/backend/common/threefry.h
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace mlx::core::random {
|
||||||
|
|
||||||
|
/** Applies the Threefry 2x32 hash function.
|
||||||
|
* This code is based on the Jax counter-based and splittable PRNG
|
||||||
|
* https://github.com/google/jax/blob/main/docs/jep/263-prng.md
|
||||||
|
*
|
||||||
|
* Original Threefry reference:
|
||||||
|
* http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
|
||||||
|
*/
|
||||||
|
std::pair<uint32_t, uint32_t> threefry2x32_hash(
|
||||||
|
const std::pair<uint32_t, uint32_t>& key,
|
||||||
|
std::pair<uint32_t, uint32_t> count);
|
||||||
|
|
||||||
|
} // namespace mlx::core::random
|
||||||
147
mlx/backend/common/unary.h
Normal file
147
mlx/backend/common/unary.h
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct AbsOp {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::abs(x);
|
||||||
|
}
|
||||||
|
uint8_t operator()(uint8_t x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
uint16_t operator()(uint16_t x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
uint32_t operator()(uint32_t x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
uint64_t operator()(uint64_t x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
bool operator()(bool x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SignOp {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return (x > T(0)) - (x < T(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t operator()(uint8_t x) {
|
||||||
|
return x != 0;
|
||||||
|
}
|
||||||
|
uint16_t operator()(uint16_t x) {
|
||||||
|
return x != 0;
|
||||||
|
}
|
||||||
|
uint32_t operator()(uint32_t x) {
|
||||||
|
return x != 0;
|
||||||
|
}
|
||||||
|
uint64_t operator()(uint64_t x) {
|
||||||
|
return x != 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Op>
|
||||||
|
void unary_op(const array& a, array& out, Op op) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
if (a.flags().contiguous) {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||||
|
a.data_size(),
|
||||||
|
a.strides(),
|
||||||
|
a.flags());
|
||||||
|
T* dst = out.data<T>();
|
||||||
|
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||||
|
dst[i] = op(a_ptr[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
T* dst = out.data<T>();
|
||||||
|
for (size_t i = 0; i < out.size(); ++i) {
|
||||||
|
// TODO this is super inefficient, need to fix.
|
||||||
|
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||||
|
dst[i] = op(a_ptr[a_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void unary(const array& a, array& out, Op op) {
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
unary_op<bool>(a, out, op);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
unary_op<uint8_t>(a, out, op);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
unary_op<uint16_t>(a, out, op);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
unary_op<uint32_t>(a, out, op);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
unary_op<uint64_t>(a, out, op);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
unary_op<int8_t>(a, out, op);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
unary_op<int16_t>(a, out, op);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
unary_op<int32_t>(a, out, op);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
unary_op<int64_t>(a, out, op);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
unary_op<float16_t>(a, out, op);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
unary_op<float>(a, out, op);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
unary_op<bfloat16_t>(a, out, op);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
unary_op<complex64_t>(a, out, op);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void unary_fp(const array& a, array& out, Op op) {
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case bfloat16:
|
||||||
|
unary_op<bfloat16_t>(a, out, op);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
unary_op<float16_t>(a, out, op);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
unary_op<float>(a, out, op);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
unary_op<complex64_t>(a, out, op);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
std::ostringstream err;
|
||||||
|
err << "[unary_fp] Does not support " << out.dtype();
|
||||||
|
throw std::runtime_error(err.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
29
mlx/backend/common/utils.h
Normal file
29
mlx/backend/common/utils.h
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
inline size_t elem_to_loc(
|
||||||
|
int elem,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
const std::vector<size_t>& strides) {
|
||||||
|
size_t loc = 0;
|
||||||
|
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||||
|
auto q_and_r = ldiv(elem, shape[i]);
|
||||||
|
loc += q_and_r.rem * strides[i];
|
||||||
|
elem = q_and_r.quot;
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline size_t elem_to_loc(int elem, const array& a) {
|
||||||
|
if (a.flags().row_contiguous) {
|
||||||
|
return elem;
|
||||||
|
}
|
||||||
|
return elem_to_loc(elem, a.shape(), a.strides());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
26
mlx/backend/metal/CMakeLists.txt
Normal file
26
mlx/backend/metal/CMakeLists.txt
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
target_sources(
|
||||||
|
mlx
|
||||||
|
PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
|
)
|
||||||
|
|
||||||
|
if (NOT MLX_METAL_PATH)
|
||||||
|
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
|
||||||
|
|
||||||
|
target_compile_definitions(
|
||||||
|
mlx PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib")
|
||||||
113
mlx/backend/metal/copy.cpp
Normal file
113
mlx/backend/metal/copy.cpp
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/copy.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||||
|
if (ctype == CopyType::Vector) {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||||
|
in.data_size(),
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
}
|
||||||
|
if (ctype == CopyType::GeneralGeneral) {
|
||||||
|
ctype = CopyType::General;
|
||||||
|
}
|
||||||
|
copy_gpu_inplace(in, out, ctype, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||||
|
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
void copy_gpu_inplace(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
CopyType ctype,
|
||||||
|
const Stream& s) {
|
||||||
|
// Try to collapse contiguous dims
|
||||||
|
auto [shape, strides] = collapse_contiguous_dims(in, out);
|
||||||
|
auto& strides_in = strides[0];
|
||||||
|
auto& strides_out = strides[1];
|
||||||
|
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
std::ostringstream kname;
|
||||||
|
switch (ctype) {
|
||||||
|
case CopyType::Scalar:
|
||||||
|
kname << "scopy";
|
||||||
|
break;
|
||||||
|
case CopyType::Vector:
|
||||||
|
kname << "vcopy";
|
||||||
|
break;
|
||||||
|
case CopyType::General:
|
||||||
|
kname << "gcopy";
|
||||||
|
break;
|
||||||
|
case CopyType::GeneralGeneral:
|
||||||
|
kname << "ggcopy";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
kname << type_to_name(in) << type_to_name(out);
|
||||||
|
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
|
||||||
|
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||||
|
kname << "_" << shape.size();
|
||||||
|
}
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
|
||||||
|
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||||
|
size_t ndim = shape.size();
|
||||||
|
if (ndim > 3) {
|
||||||
|
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
|
||||||
|
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 3);
|
||||||
|
if (ctype == CopyType::GeneralGeneral) {
|
||||||
|
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 4);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// The shape is implicit in the grid for <= 3D
|
||||||
|
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 2);
|
||||||
|
if (ctype == CopyType::GeneralGeneral) {
|
||||||
|
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||||
|
compute_encoder->setBytes(
|
||||||
|
&ndim, sizeof(int), (ctype == CopyType::GeneralGeneral) ? 5 : 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
|
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||||
|
int rest = in.size() / (dim0 * dim1);
|
||||||
|
|
||||||
|
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (thread_group_size != 1024) {
|
||||||
|
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
|
||||||
|
}
|
||||||
|
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||||
|
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
} else {
|
||||||
|
size_t nthreads = out.data_size();
|
||||||
|
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (thread_group_size > nthreads) {
|
||||||
|
thread_group_size = nthreads;
|
||||||
|
}
|
||||||
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
81
mlx/backend/metal/device.h
Normal file
81
mlx/backend/metal/device.h
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <Metal/Metal.hpp>
|
||||||
|
#include <functional>
|
||||||
|
#include <mutex>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include <dlfcn.h>
|
||||||
|
#include <filesystem>
|
||||||
|
|
||||||
|
#include "mlx/device.h"
|
||||||
|
|
||||||
|
namespace fs = std::filesystem;
|
||||||
|
|
||||||
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
|
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||||
|
Dl_info info;
|
||||||
|
std::string mtllib_path;
|
||||||
|
std::string lib_ext = lib_name + ".metallib";
|
||||||
|
|
||||||
|
int success = dladdr((void*)get_colocated_mtllib_path, &info);
|
||||||
|
if (success) {
|
||||||
|
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
|
||||||
|
mtllib_path = mtllib.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
return mtllib_path;
|
||||||
|
}
|
||||||
|
|
||||||
|
class Device {
|
||||||
|
public:
|
||||||
|
Device();
|
||||||
|
Device(const Device&) = delete;
|
||||||
|
Device& operator=(const Device&) = delete;
|
||||||
|
~Device();
|
||||||
|
|
||||||
|
MTL::Device* mtl_device() {
|
||||||
|
return device_;
|
||||||
|
};
|
||||||
|
|
||||||
|
void new_queue(int index);
|
||||||
|
MTL::CommandBuffer* new_command_buffer(int index);
|
||||||
|
MTL::CommandBuffer* get_command_buffer(int index);
|
||||||
|
int get_command_buffer_ops(int index);
|
||||||
|
void increment_command_buffer_ops(int index);
|
||||||
|
void commit_command_buffer(int index);
|
||||||
|
MTL::ComputeCommandEncoder* get_command_encoder(int index);
|
||||||
|
void end_encoding(int index);
|
||||||
|
|
||||||
|
void register_library(
|
||||||
|
const std::string& lib_name,
|
||||||
|
const std::string& lib_path);
|
||||||
|
void register_library(
|
||||||
|
const std::string& lib_name,
|
||||||
|
const std::function<std::string(const std::string&)>& lib_path_func =
|
||||||
|
get_colocated_mtllib_path);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_kernel(
|
||||||
|
const std::string& name,
|
||||||
|
const std::string& lib_name = "mlx");
|
||||||
|
|
||||||
|
MTL::ArgumentEncoder* argument_encoder(
|
||||||
|
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
NS::AutoreleasePool* pool_;
|
||||||
|
MTL::Device* device_;
|
||||||
|
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
||||||
|
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
||||||
|
std::unordered_map<int32_t, MTL::ComputeCommandEncoder*> encoder_map_;
|
||||||
|
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
||||||
|
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||||
|
std::mutex mtx_;
|
||||||
|
};
|
||||||
|
|
||||||
|
Device& device(mlx::core::Device);
|
||||||
|
NS::AutoreleasePool*& thread_autorelease_pool();
|
||||||
|
|
||||||
|
} // namespace mlx::core::metal
|
||||||
296
mlx/backend/metal/indexing.cpp
Normal file
296
mlx/backend/metal/indexing.cpp
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <numeric>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/binary.h"
|
||||||
|
#include "mlx/backend/metal/copy.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
auto& src = inputs[0];
|
||||||
|
int nidx = inputs.size() - 1;
|
||||||
|
|
||||||
|
if (nidx > METAL_MAX_INDEX_ARRAYS) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[Gather::eval_gpu] Gathering with more than "
|
||||||
|
<< METAL_MAX_INDEX_ARRAYS << " index arrays not yet supported.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
std::ostringstream kname;
|
||||||
|
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||||
|
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
|
||||||
|
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
|
||||||
|
size_t slice_size = 1;
|
||||||
|
for (auto s : slice_sizes_) {
|
||||||
|
slice_size *= s;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t ndim = src.ndim();
|
||||||
|
size_t nthreads = out.size();
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (thread_group_size > nthreads) {
|
||||||
|
thread_group_size = nthreads;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
// Make the argument buffer to store the indices for the
|
||||||
|
// `Indices` struct in kernels/indexing.metal
|
||||||
|
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
|
||||||
|
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||||
|
arg_descs[0]->setIndex(0);
|
||||||
|
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
|
||||||
|
arg_descs[0]->setArrayLength(nidx);
|
||||||
|
|
||||||
|
// Shapes
|
||||||
|
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||||
|
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
|
||||||
|
arg_descs[1]->setIndex(nidx + 1);
|
||||||
|
|
||||||
|
// Strides
|
||||||
|
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||||
|
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
|
||||||
|
arg_descs[2]->setIndex(nidx + 2);
|
||||||
|
|
||||||
|
// Indices ndim
|
||||||
|
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||||
|
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
|
||||||
|
arg_descs[3]->setIndex(nidx + 3);
|
||||||
|
|
||||||
|
// Get the argument encoder
|
||||||
|
auto arg_enc = d.argument_encoder(arg_descs);
|
||||||
|
|
||||||
|
// Allocate and fill buffers for shapes and strides
|
||||||
|
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||||
|
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
|
||||||
|
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
|
||||||
|
for (int i = 0; i < nidx; ++i) {
|
||||||
|
std::copy(
|
||||||
|
inputs[i + 1].shape().begin(),
|
||||||
|
inputs[i + 1].shape().end(),
|
||||||
|
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
|
||||||
|
std::copy(
|
||||||
|
inputs[i + 1].strides().begin(),
|
||||||
|
inputs[i + 1].strides().end(),
|
||||||
|
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate the argument bufer
|
||||||
|
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||||
|
|
||||||
|
// Register data with the encoder
|
||||||
|
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
|
||||||
|
for (int i = 0; i < nidx; ++i) {
|
||||||
|
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
|
||||||
|
}
|
||||||
|
arg_enc->setBuffer(
|
||||||
|
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||||
|
compute_encoder->useResource(
|
||||||
|
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), MTL::ResourceUsageRead);
|
||||||
|
arg_enc->setBuffer(
|
||||||
|
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||||
|
compute_encoder->useResource(
|
||||||
|
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), MTL::ResourceUsageRead);
|
||||||
|
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
|
||||||
|
|
||||||
|
// Set all the buffers
|
||||||
|
set_array_buffer(compute_encoder, src, 0);
|
||||||
|
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 1);
|
||||||
|
set_array_buffer(compute_encoder, out, 2);
|
||||||
|
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3);
|
||||||
|
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4);
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||||
|
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6);
|
||||||
|
compute_encoder->setBytes(&slice_size, sizeof(size_t), 7);
|
||||||
|
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 8);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
|
||||||
|
// Cleanup temporaries
|
||||||
|
arg_enc->release();
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
|
||||||
|
allocator::free(arg_buf);
|
||||||
|
allocator::free(idx_shapes_buf);
|
||||||
|
allocator::free(idx_strides_buf);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
if (size_of(out.dtype()) == 8) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[Scatter::eval_gpu] Does not support " << out.dtype();
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
int nidx = axes_.size();
|
||||||
|
if (nidx > METAL_MAX_INDEX_ARRAYS) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[Scatter::eval_gpu] Gathering with more than "
|
||||||
|
<< METAL_MAX_INDEX_ARRAYS << " index arrays not yet supported.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy src into out
|
||||||
|
auto copy_type =
|
||||||
|
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||||
|
copy_gpu(inputs[0], out, copy_type);
|
||||||
|
|
||||||
|
// Get stream
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
// Get kernel name
|
||||||
|
std::ostringstream kname;
|
||||||
|
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||||
|
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Scatter::None:
|
||||||
|
kname << "_none";
|
||||||
|
break;
|
||||||
|
case Scatter::Sum:
|
||||||
|
kname << "_sum";
|
||||||
|
break;
|
||||||
|
case Scatter::Prod:
|
||||||
|
kname << "_prod";
|
||||||
|
break;
|
||||||
|
case Scatter::Max:
|
||||||
|
kname << "_max";
|
||||||
|
break;
|
||||||
|
case Scatter::Min:
|
||||||
|
kname << "_min";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
kname << "_" << nidx;
|
||||||
|
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
|
||||||
|
auto& upd = inputs.back();
|
||||||
|
size_t nthreads = upd.size();
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (thread_group_size > nthreads) {
|
||||||
|
thread_group_size = nthreads;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
// Make the argument buffer to store the indices for the
|
||||||
|
// `Indices` struct in kernels/indexing.metal
|
||||||
|
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
|
||||||
|
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||||
|
arg_descs[0]->setIndex(0);
|
||||||
|
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
|
||||||
|
arg_descs[0]->setArrayLength(nidx);
|
||||||
|
|
||||||
|
// Shapes
|
||||||
|
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||||
|
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
|
||||||
|
arg_descs[1]->setIndex(nidx + 1);
|
||||||
|
|
||||||
|
// Strides
|
||||||
|
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||||
|
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
|
||||||
|
arg_descs[2]->setIndex(nidx + 2);
|
||||||
|
|
||||||
|
// Indices ndim
|
||||||
|
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||||
|
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
|
||||||
|
arg_descs[3]->setIndex(nidx + 3);
|
||||||
|
|
||||||
|
// Get the argument encoder
|
||||||
|
auto arg_enc = d.argument_encoder(arg_descs);
|
||||||
|
|
||||||
|
// Allocate and fill buffers for shapes and strides
|
||||||
|
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||||
|
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
|
||||||
|
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
|
||||||
|
for (int i = 0; i < nidx; ++i) {
|
||||||
|
std::copy(
|
||||||
|
inputs[i + 1].shape().begin(),
|
||||||
|
inputs[i + 1].shape().end(),
|
||||||
|
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
|
||||||
|
std::copy(
|
||||||
|
inputs[i + 1].strides().begin(),
|
||||||
|
inputs[i + 1].strides().end(),
|
||||||
|
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate the argument bufer
|
||||||
|
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||||
|
|
||||||
|
// Register data with the encoder
|
||||||
|
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
|
||||||
|
for (int i = 0; i < nidx; ++i) {
|
||||||
|
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
|
||||||
|
}
|
||||||
|
arg_enc->setBuffer(
|
||||||
|
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||||
|
compute_encoder->useResource(
|
||||||
|
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), MTL::ResourceUsageRead);
|
||||||
|
arg_enc->setBuffer(
|
||||||
|
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||||
|
compute_encoder->useResource(
|
||||||
|
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), MTL::ResourceUsageRead);
|
||||||
|
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
|
||||||
|
|
||||||
|
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 0);
|
||||||
|
size_t upd_ndim = upd.ndim();
|
||||||
|
size_t upd_size = 1;
|
||||||
|
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||||
|
upd_size *= upd.shape(i);
|
||||||
|
}
|
||||||
|
set_array_buffer(compute_encoder, upd, 1);
|
||||||
|
set_array_buffer(compute_encoder, out, 2);
|
||||||
|
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
|
||||||
|
compute_encoder->setBytes(upd.strides().data(), upd_ndim * sizeof(size_t), 4);
|
||||||
|
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||||
|
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||||
|
|
||||||
|
size_t out_ndim = out.ndim();
|
||||||
|
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
|
||||||
|
compute_encoder->setBytes(out.strides().data(), out_ndim * sizeof(size_t), 8);
|
||||||
|
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||||
|
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
|
||||||
|
// Cleanup temporaries
|
||||||
|
arg_enc->release();
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
|
||||||
|
allocator::free(arg_buf);
|
||||||
|
allocator::free(idx_shapes_buf);
|
||||||
|
allocator::free(idx_strides_buf);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
320
mlx/backend/metal/kernels/atomic.h
Normal file
320
mlx/backend/metal/kernels/atomic.h
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <metal_atomic>
|
||||||
|
#include <metal_stdlib>
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Atomic utils
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#pragma METAL internals : enable
|
||||||
|
template <typename T>
|
||||||
|
constexpr constant bool is_metal_atomic = _disjunction<
|
||||||
|
is_same<T, int>,
|
||||||
|
is_same<T, uint>,
|
||||||
|
is_same<T, ulong>,
|
||||||
|
is_same<T, float>>::value;
|
||||||
|
|
||||||
|
#pragma METAL internals : disable
|
||||||
|
|
||||||
|
template <typename T, typename = void>
|
||||||
|
struct mlx_atomic {
|
||||||
|
atomic<uint> val;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
|
||||||
|
atomic<T> val;
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Native metal atomics
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC T
|
||||||
|
mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
|
||||||
|
return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC void
|
||||||
|
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||||
|
atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC void
|
||||||
|
mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||||
|
atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC void
|
||||||
|
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||||
|
atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC void
|
||||||
|
mlx_atomic_fetch_min_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||||
|
atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC void
|
||||||
|
mlx_atomic_fetch_max_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||||
|
atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC void
|
||||||
|
mlx_atomic_fetch_add_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||||
|
atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC void
|
||||||
|
mlx_atomic_fetch_mul_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||||
|
T expected = mlx_atomic_load_explicit(object, offset);
|
||||||
|
while (!mlx_atomic_compare_exchange_weak_explicit(
|
||||||
|
object, &expected, val * expected, offset)) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
|
||||||
|
device mlx_atomic<T>* object,
|
||||||
|
thread T* expected,
|
||||||
|
T val,
|
||||||
|
int offset) {
|
||||||
|
return atomic_compare_exchange_weak_explicit(
|
||||||
|
&(object[offset].val),
|
||||||
|
expected,
|
||||||
|
val,
|
||||||
|
memory_order_relaxed,
|
||||||
|
memory_order_relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Specialization for float since it does not atomic_fetch_min_explicit
|
||||||
|
template <>
|
||||||
|
METAL_FUNC void mlx_atomic_fetch_min_explicit<float>(
|
||||||
|
device mlx_atomic<float>* object,
|
||||||
|
float val,
|
||||||
|
int offset) {
|
||||||
|
float expected = mlx_atomic_load_explicit(object, offset);
|
||||||
|
while (val < expected) {
|
||||||
|
if (mlx_atomic_compare_exchange_weak_explicit(
|
||||||
|
object, &expected, val, offset)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Specialization for float since it does not atomic_fetch_max_explicit
|
||||||
|
template <>
|
||||||
|
METAL_FUNC void mlx_atomic_fetch_max_explicit<float>(
|
||||||
|
device mlx_atomic<float>* object,
|
||||||
|
float val,
|
||||||
|
int offset) {
|
||||||
|
float expected = mlx_atomic_load_explicit(object, offset);
|
||||||
|
while (val > expected) {
|
||||||
|
if (mlx_atomic_compare_exchange_weak_explicit(
|
||||||
|
object, &expected, val, offset)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Custom atomics
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
constexpr constant uint packing_size = sizeof(uint) / sizeof(T);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
union uint_or_packed {
|
||||||
|
T val[packing_size<T>];
|
||||||
|
uint bits;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Op>
|
||||||
|
struct mlx_atomic_update_helper {
|
||||||
|
uint operator()(uint_or_packed<T> init, T update, int elem_offset) {
|
||||||
|
Op op;
|
||||||
|
init.val[elem_offset] = op(update, init.val[elem_offset]);
|
||||||
|
return init.bits;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Op>
|
||||||
|
METAL_FUNC void mlx_atomic_update_and_store(
|
||||||
|
device mlx_atomic<T>* object,
|
||||||
|
T update,
|
||||||
|
int offset) {
|
||||||
|
int pack_offset = offset / packing_size<T>;
|
||||||
|
int elem_offset = offset % packing_size<T>;
|
||||||
|
|
||||||
|
mlx_atomic_update_helper<T, Op> helper;
|
||||||
|
uint_or_packed<T> expected;
|
||||||
|
expected.bits =
|
||||||
|
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
|
||||||
|
|
||||||
|
while (Op::condition(update, expected.val[elem_offset]) &&
|
||||||
|
!mlx_atomic_compare_exchange_weak_explicit(
|
||||||
|
object,
|
||||||
|
&(expected.bits),
|
||||||
|
helper(expected, update, elem_offset),
|
||||||
|
pack_offset)) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct __None {
|
||||||
|
static bool condition(T a, T b) {
|
||||||
|
#pragma unused(a)
|
||||||
|
#pragma unused(b)
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
T operator()(T a, T b) {
|
||||||
|
#pragma unused(b)
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct __Add {
|
||||||
|
static bool condition(T a, T b) {
|
||||||
|
#pragma unused(a)
|
||||||
|
#pragma unused(b)
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
T operator()(T a, T b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct __Mul {
|
||||||
|
static bool condition(T a, T b) {
|
||||||
|
#pragma unused(a)
|
||||||
|
return b != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
T operator()(T a, T b) {
|
||||||
|
return a * b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct __Max {
|
||||||
|
static bool condition(T a, T b) {
|
||||||
|
return a > b;
|
||||||
|
}
|
||||||
|
|
||||||
|
T operator()(T a, T b) {
|
||||||
|
return max(a, b);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct __Min {
|
||||||
|
static bool condition(T a, T b) {
|
||||||
|
return a < b;
|
||||||
|
}
|
||||||
|
|
||||||
|
T operator()(T a, T b) {
|
||||||
|
return min(a, b);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC T
|
||||||
|
mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
|
||||||
|
int pack_offset = offset / sizeof(T);
|
||||||
|
int elem_offset = offset % sizeof(T);
|
||||||
|
uint_or_packed<T> packed_val;
|
||||||
|
packed_val.bits =
|
||||||
|
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
|
||||||
|
return packed_val.val[elem_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC void
|
||||||
|
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||||
|
mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC void
|
||||||
|
mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||||
|
int pack_offset = offset / packing_size<T>;
|
||||||
|
int elem_offset = offset % packing_size<T>;
|
||||||
|
uint_or_packed<T> identity;
|
||||||
|
identity.bits = __UINT32_MAX__;
|
||||||
|
identity.val[elem_offset] = val;
|
||||||
|
|
||||||
|
atomic_fetch_and_explicit(
|
||||||
|
&(object[pack_offset].val), identity.bits, memory_order_relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC void
|
||||||
|
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||||
|
int pack_offset = offset / packing_size<T>;
|
||||||
|
int elem_offset = offset % packing_size<T>;
|
||||||
|
uint_or_packed<T> identity;
|
||||||
|
identity.bits = 0;
|
||||||
|
identity.val[elem_offset] = val;
|
||||||
|
|
||||||
|
atomic_fetch_or_explicit(
|
||||||
|
&(object[pack_offset].val), identity.bits, memory_order_relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC void
|
||||||
|
mlx_atomic_fetch_min_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||||
|
mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC void
|
||||||
|
mlx_atomic_fetch_max_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||||
|
mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC void
|
||||||
|
mlx_atomic_fetch_add_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||||
|
mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC void
|
||||||
|
mlx_atomic_fetch_mul_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
||||||
|
mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||||
|
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
|
||||||
|
device mlx_atomic<T>* object,
|
||||||
|
thread uint* expected,
|
||||||
|
uint val,
|
||||||
|
int offset) {
|
||||||
|
return atomic_compare_exchange_weak_explicit(
|
||||||
|
&(object[offset].val),
|
||||||
|
expected,
|
||||||
|
val,
|
||||||
|
memory_order_relaxed,
|
||||||
|
memory_order_relaxed);
|
||||||
|
}
|
||||||
315
mlx/backend/metal/kernels/bf16.h
Normal file
315
mlx/backend/metal/kernels/bf16.h
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
#if defined(__HAVE_BFLOAT__)
|
||||||
|
|
||||||
|
typedef bfloat bfloat16_t;
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Helpers
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
|
||||||
|
// Check for nan
|
||||||
|
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
|
||||||
|
_fp_encoding_traits<float>::inf_mask) {
|
||||||
|
return uint16_t(as_type<uint32_t>(0x7FC0));
|
||||||
|
}
|
||||||
|
// Take bits
|
||||||
|
uint32_t float_bits = as_type<uint32_t>(x);
|
||||||
|
|
||||||
|
// Round to nearest even
|
||||||
|
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
|
||||||
|
|
||||||
|
// Take upper 16 bits
|
||||||
|
return float_bits >> 16;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
|
||||||
|
// Upper 16 bits are the data and lower 16 bits are 0s
|
||||||
|
return as_type<float>((uint32_t)x << 16);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct _MLX_BFloat16;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static constexpr constant bool can_convert_to_bfloat =
|
||||||
|
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static constexpr constant bool can_convert_from_bfloat =
|
||||||
|
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Bfloat struct
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
struct _MLX_BFloat16 {
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Constructors
|
||||||
|
uint16_t bits_;
|
||||||
|
_MLX_BFloat16() thread = default;
|
||||||
|
_MLX_BFloat16() threadgroup = default;
|
||||||
|
_MLX_BFloat16() device = default;
|
||||||
|
_MLX_BFloat16() constant = default;
|
||||||
|
|
||||||
|
struct bits_to_bfloat_struct {};
|
||||||
|
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
|
||||||
|
return bits_to_bfloat_struct();
|
||||||
|
}
|
||||||
|
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
|
||||||
|
: bits_(bits) {}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Conversions to bfloat
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||||
|
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
|
||||||
|
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||||
|
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
|
||||||
|
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||||
|
constexpr METAL_FUNC _MLX_BFloat16(T x) device
|
||||||
|
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||||
|
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
|
||||||
|
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Conversions from bfloat
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||||
|
constexpr METAL_FUNC operator T() const thread {
|
||||||
|
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||||
|
constexpr METAL_FUNC operator T() const threadgroup {
|
||||||
|
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||||
|
constexpr METAL_FUNC operator T() const device {
|
||||||
|
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||||
|
constexpr METAL_FUNC operator T() const constant {
|
||||||
|
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Bfloat operators
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Unary ops
|
||||||
|
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
|
||||||
|
return -static_cast<float>(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Binary operators
|
||||||
|
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
||||||
|
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
|
||||||
|
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
||||||
|
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
||||||
|
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||||
|
} \
|
||||||
|
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
||||||
|
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Arithmetic Operators
|
||||||
|
#define bfloat_binop(_op_, _operator_) \
|
||||||
|
bfloat_binop_base( \
|
||||||
|
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, float, half, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
||||||
|
|
||||||
|
bfloat_binop(+, operator+);
|
||||||
|
bfloat_binop(-, operator-);
|
||||||
|
bfloat_binop(*, operator*);
|
||||||
|
bfloat_binop(/, operator/);
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Comparison ops
|
||||||
|
#define bfloat_compop(__op__, __operator__) \
|
||||||
|
bfloat_binop_base( \
|
||||||
|
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
||||||
|
|
||||||
|
bfloat_compop(>, operator>);
|
||||||
|
bfloat_compop(<, operator<);
|
||||||
|
bfloat_compop(>=, operator>=);
|
||||||
|
bfloat_compop(<=, operator<=);
|
||||||
|
bfloat_compop(==, operator==);
|
||||||
|
bfloat_compop(!=, operator!=);
|
||||||
|
|
||||||
|
#undef bfloat_compop
|
||||||
|
#undef bfloat_binop_base
|
||||||
|
#undef bfloat_binop_helper
|
||||||
|
#undef bfloat_binop
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Inplace Operators
|
||||||
|
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
|
||||||
|
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
||||||
|
addr_space _MLX_BFloat16& lhs, itype rhs) { \
|
||||||
|
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||||
|
return lhs; \
|
||||||
|
} \
|
||||||
|
constexpr METAL_FUNC addr_space itype& __operator__( \
|
||||||
|
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
|
||||||
|
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||||
|
return lhs; \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
|
||||||
|
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
|
||||||
|
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
|
||||||
|
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
|
||||||
|
|
||||||
|
#define bfloat_inplace_op(itype) \
|
||||||
|
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
|
||||||
|
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
|
||||||
|
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
|
||||||
|
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
|
||||||
|
|
||||||
|
bfloat_inplace_op(float);
|
||||||
|
bfloat_inplace_op(half);
|
||||||
|
bfloat_inplace_op(int16_t);
|
||||||
|
bfloat_inplace_op(int32_t);
|
||||||
|
bfloat_inplace_op(int64_t);
|
||||||
|
bfloat_inplace_op(uint16_t);
|
||||||
|
bfloat_inplace_op(uint32_t);
|
||||||
|
bfloat_inplace_op(uint64_t);
|
||||||
|
|
||||||
|
#undef bfloat_inplace_op_helper
|
||||||
|
#undef bfloat_inplace_op_addr_space_helper
|
||||||
|
#undef bfloat_inplace_op
|
||||||
|
|
||||||
|
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
|
||||||
|
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
||||||
|
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
|
||||||
|
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||||
|
return lhs; \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
|
||||||
|
bfloat_inplace_op_helper(__op__, __operator__, device); \
|
||||||
|
bfloat_inplace_op_helper(__op__, __operator__, thread); \
|
||||||
|
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
|
||||||
|
|
||||||
|
bfloat_inplace_op_addr_space_helper(+, operator+=);
|
||||||
|
bfloat_inplace_op_addr_space_helper(-, operator-=);
|
||||||
|
bfloat_inplace_op_addr_space_helper(*, operator*=);
|
||||||
|
bfloat_inplace_op_addr_space_helper(/, operator/=);
|
||||||
|
|
||||||
|
#undef bfloat_inplace_op_helper
|
||||||
|
#undef bfloat_inplace_op_addr_space_helper
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Bfloat typedef
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
typedef struct _MLX_BFloat16 bfloat16_t;
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Bfloat numeric limits
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#pragma METAL internals : enable
|
||||||
|
|
||||||
|
namespace metal {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
|
||||||
|
static constexpr constant int digits = 8;
|
||||||
|
static constexpr constant int digits10 = 2;
|
||||||
|
static constexpr constant int max_digits10 = 4;
|
||||||
|
static constexpr constant int radix = 2;
|
||||||
|
static constexpr constant int min_exponent = -125;
|
||||||
|
static constexpr constant int min_exponent10 = -37;
|
||||||
|
static constexpr constant int max_exponent = 128;
|
||||||
|
static constexpr constant int max_exponent10 = 38;
|
||||||
|
|
||||||
|
static constexpr bfloat16_t min() {
|
||||||
|
return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t lowest() {
|
||||||
|
return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t max() {
|
||||||
|
return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t epsilon() {
|
||||||
|
return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t round_error() {
|
||||||
|
return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t infinity() {
|
||||||
|
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t quiet_NaN() {
|
||||||
|
return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t signaling_NaN() {
|
||||||
|
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t denorm_min() {
|
||||||
|
return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
||||||
|
return x != x;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace metal
|
||||||
|
|
||||||
|
#pragma METAL internals : disable
|
||||||
|
|
||||||
|
#endif // defined(__HAVE_BFLOAT__)
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/bf16_math.h"
|
||||||
369
mlx/backend/metal/kernels/binary.metal
Normal file
369
mlx/backend/metal/kernels/binary.metal
Normal file
@@ -0,0 +1,369 @@
|
|||||||
|
#include <metal_integer>
|
||||||
|
#include <metal_math>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
|
||||||
|
struct Add {
|
||||||
|
template <typename T> T operator()(T x, T y) { return x + y; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Divide {
|
||||||
|
template <typename T> T operator()(T x, T y) { return x / y; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Equal {
|
||||||
|
template <typename T> bool operator()(T x, T y) { return x == y; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NaNEqual {
|
||||||
|
template <typename T> bool operator()(T x, T y) {
|
||||||
|
return x == y || (metal::isnan(x) && metal::isnan(y));
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
bool operator()(complex64_t x, complex64_t y) {
|
||||||
|
return x == y ||
|
||||||
|
(metal::isnan(x.real) && metal::isnan(y.real)
|
||||||
|
&& metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||||
|
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||||
|
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Greater {
|
||||||
|
template <typename T> bool operator()(T x, T y) { return x > y; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct GreaterEqual {
|
||||||
|
template <typename T> bool operator()(T x, T y) { return x >= y; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Less {
|
||||||
|
template <typename T> bool operator()(T x, T y) { return x < y; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LessEqual {
|
||||||
|
template <typename T> bool operator()(T x, T y) { return x <= y; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LogAddExp {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
constexpr T inf = metal::numeric_limits<T>::infinity();
|
||||||
|
T maxval = metal::max(x, y);
|
||||||
|
T minval = metal::min(x, y);
|
||||||
|
return (minval == -inf || maxval == inf) ? maxval :
|
||||||
|
(maxval + log1p(metal::exp(minval - maxval)));
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Maximum {
|
||||||
|
template <typename T> T operator()(T x, T y) { return metal::max(x, y); }
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||||
|
return x >= y ? x : y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Minimum {
|
||||||
|
template <typename T> T operator()(T x, T y) { return metal::min(x, y); }
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||||
|
return x <= y ? x : y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Multiply {
|
||||||
|
template <typename T> T operator()(T x, T y) { return x * y; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NotEqual {
|
||||||
|
template <typename T> bool operator()(T x, T y) { return x != y; }
|
||||||
|
template <>
|
||||||
|
bool operator()(complex64_t x, complex64_t y) {
|
||||||
|
return x.real != y.real || x.imag != y.imag;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Power {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||||
|
return metal::pow(base, exp);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||||
|
T res = 1;
|
||||||
|
while (exp) {
|
||||||
|
if (exp & 1) {
|
||||||
|
res *= base;
|
||||||
|
}
|
||||||
|
exp >>= 1;
|
||||||
|
base *= base;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||||
|
auto x_theta = metal::atan(x.imag / x.real);
|
||||||
|
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||||
|
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||||
|
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
||||||
|
return {mag * metal::cos(phase), mag * metal::sin(phase)};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct Subtract {
|
||||||
|
template <typename T> T operator()(T x, T y) { return x - y; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_op_s2s(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
c[index] = Op()(a[0], b[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_op_ss(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
c[index] = Op()(a[0], b[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_op_sv(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
c[index] = Op()(a[0], b[index]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_op_vs(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
c[index] = Op()(a[index], b[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_op_vv(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
c[index] = Op()(a[index], b[index]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_op_g_nd1(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
constant const size_t& a_stride,
|
||||||
|
constant const size_t& b_stride,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||||
|
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||||
|
c[index] = Op()(a[a_idx], b[b_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_op_g_nd2(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
constant const size_t a_strides[2],
|
||||||
|
constant const size_t b_strides[2],
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||||
|
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||||
|
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||||
|
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_op_g_nd3(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
constant const size_t a_strides[3],
|
||||||
|
constant const size_t b_strides[3],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||||
|
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||||
|
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
|
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int DIM>
|
||||||
|
[[kernel]] void binary_op_g_nd(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
constant const int shape[DIM],
|
||||||
|
constant const size_t a_strides[DIM],
|
||||||
|
constant const size_t b_strides[DIM],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||||
|
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
|
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_op_g(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
constant const int* shape,
|
||||||
|
constant const size_t* a_strides,
|
||||||
|
constant const size_t* b_strides,
|
||||||
|
constant const int& ndim,
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||||
|
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||||
|
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||||
|
template [[host_name(name)]] \
|
||||||
|
[[kernel]] void binary_op_##bopt<itype, otype, op>( \
|
||||||
|
device const itype* a, \
|
||||||
|
device const itype* b, \
|
||||||
|
device otype* c, \
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
||||||
|
template [[host_name(name "_" #dims)]] \
|
||||||
|
[[kernel]] void binary_op_g_nd<itype, otype, op, dims>( \
|
||||||
|
device const itype* a, \
|
||||||
|
device const itype* b, \
|
||||||
|
device otype* c, \
|
||||||
|
constant const int shape[dims], \
|
||||||
|
constant const size_t a_strides[dims], \
|
||||||
|
constant const size_t b_strides[dims], \
|
||||||
|
uint3 index [[thread_position_in_grid]], \
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
|
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
||||||
|
template [[host_name(name "_1")]] \
|
||||||
|
[[kernel]] void binary_op_g_nd1<itype, otype, op>( \
|
||||||
|
device const itype* a, \
|
||||||
|
device const itype* b, \
|
||||||
|
device otype* c, \
|
||||||
|
constant const size_t& a_stride, \
|
||||||
|
constant const size_t& b_stride, \
|
||||||
|
uint index [[thread_position_in_grid]]); \
|
||||||
|
template [[host_name(name "_2")]] \
|
||||||
|
[[kernel]] void binary_op_g_nd2<itype, otype, op>( \
|
||||||
|
device const itype* a, \
|
||||||
|
device const itype* b, \
|
||||||
|
device otype* c, \
|
||||||
|
constant const size_t a_strides[2], \
|
||||||
|
constant const size_t b_strides[2], \
|
||||||
|
uint2 index [[thread_position_in_grid]], \
|
||||||
|
uint2 grid_dim [[threads_per_grid]]); \
|
||||||
|
template [[host_name(name "_3")]] \
|
||||||
|
[[kernel]] void binary_op_g_nd3<itype, otype, op>( \
|
||||||
|
device const itype* a, \
|
||||||
|
device const itype* b, \
|
||||||
|
device otype* c, \
|
||||||
|
constant const size_t a_strides[3], \
|
||||||
|
constant const size_t b_strides[3], \
|
||||||
|
uint3 index [[thread_position_in_grid]], \
|
||||||
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
|
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||||
|
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||||
|
|
||||||
|
|
||||||
|
#define instantiate_binary_g(name, itype, otype, op) \
|
||||||
|
template [[host_name(name)]] \
|
||||||
|
[[kernel]] void binary_op_g<itype, otype, op>( \
|
||||||
|
device const itype* a, \
|
||||||
|
device const itype* b, \
|
||||||
|
device otype* c, \
|
||||||
|
constant const int* shape, \
|
||||||
|
constant const size_t* a_strides, \
|
||||||
|
constant const size_t* b_strides, \
|
||||||
|
constant const int& ndim, \
|
||||||
|
uint3 index [[thread_position_in_grid]], \
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
|
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
||||||
|
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
||||||
|
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
||||||
|
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
||||||
|
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
||||||
|
instantiate_binary_g("g" #name #tname, itype, otype, op) \
|
||||||
|
instantiate_binary_g_nd("g" #name #tname, itype, otype, op)
|
||||||
|
|
||||||
|
#define instantiate_binary_float(name, op) \
|
||||||
|
instantiate_binary_all(name, float16, half, half, op) \
|
||||||
|
instantiate_binary_all(name, float32, float, float, op) \
|
||||||
|
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||||
|
|
||||||
|
#define instantiate_binary_types(name, op) \
|
||||||
|
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||||
|
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||||
|
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||||
|
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
|
||||||
|
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
|
||||||
|
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||||
|
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||||
|
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||||
|
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
|
||||||
|
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||||
|
instantiate_binary_float(name, op)
|
||||||
|
|
||||||
|
#define instantiate_binary_types_bool(name, op) \
|
||||||
|
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||||
|
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
|
||||||
|
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
|
||||||
|
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
|
||||||
|
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
|
||||||
|
instantiate_binary_all(name, int8, int8_t, bool, op) \
|
||||||
|
instantiate_binary_all(name, int16, int16_t, bool, op) \
|
||||||
|
instantiate_binary_all(name, int32, int32_t, bool, op) \
|
||||||
|
instantiate_binary_all(name, int64, int64_t, bool, op) \
|
||||||
|
instantiate_binary_all(name, float16, half, bool, op) \
|
||||||
|
instantiate_binary_all(name, float32, float, bool, op) \
|
||||||
|
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
|
||||||
|
instantiate_binary_all(name, complex64, complex64_t, bool, op)
|
||||||
|
|
||||||
|
instantiate_binary_types(add, Add)
|
||||||
|
instantiate_binary_float(div, Divide)
|
||||||
|
instantiate_binary_types_bool(eq, Equal)
|
||||||
|
instantiate_binary_types_bool(ge, Greater)
|
||||||
|
instantiate_binary_types_bool(geq, GreaterEqual)
|
||||||
|
instantiate_binary_types_bool(le, Less)
|
||||||
|
instantiate_binary_types_bool(leq, LessEqual)
|
||||||
|
instantiate_binary_types_bool(neq, NotEqual)
|
||||||
|
instantiate_binary_float(lae, LogAddExp)
|
||||||
|
instantiate_binary_types(max, Maximum)
|
||||||
|
instantiate_binary_types(min, Minimum)
|
||||||
|
instantiate_binary_types(mul, Multiply)
|
||||||
|
instantiate_binary_types(sub, Subtract)
|
||||||
|
instantiate_binary_types(pow, Power)
|
||||||
|
|
||||||
|
// NaNEqual only needed for floating point types with boolean output
|
||||||
|
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
|
||||||
|
instantiate_binary_all(naneq, float32, float, bool, NaNEqual)
|
||||||
|
instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
|
||||||
|
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
|
||||||
110
mlx/backend/metal/kernels/complex.h
Normal file
110
mlx/backend/metal/kernels/complex.h
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
struct complex64_t;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static constexpr constant bool can_convert_to_complex64 =
|
||||||
|
!is_same_v<T, complex64_t> && is_convertible_v<T, float>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static constexpr constant bool can_convert_from_complex64 =
|
||||||
|
!is_same_v<T, complex64_t> &&
|
||||||
|
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
|
||||||
|
|
||||||
|
struct complex64_t {
|
||||||
|
float real;
|
||||||
|
float imag;
|
||||||
|
|
||||||
|
// Constructors
|
||||||
|
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
|
||||||
|
|
||||||
|
// Conversions to complex64_t
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||||
|
constexpr complex64_t(T x) thread : real(x), imag(0) {}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||||
|
constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||||
|
constexpr complex64_t(T x) device : real(x), imag(0) {}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||||
|
constexpr complex64_t(T x) constant : real(x), imag(0) {}
|
||||||
|
|
||||||
|
// Converstions from complex64_t
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||||
|
constexpr operator T() const thread {
|
||||||
|
return static_cast<T>(real);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||||
|
constexpr operator T() const threadgroup {
|
||||||
|
return static_cast<T>(real);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||||
|
constexpr operator T() const device {
|
||||||
|
return static_cast<T>(real);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||||
|
constexpr operator T() const constant {
|
||||||
|
return static_cast<T>(real);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
constexpr complex64_t operator-(complex64_t x) {
|
||||||
|
return {-x.real, -x.imag};
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr bool operator>=(complex64_t a, complex64_t b) {
|
||||||
|
return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr bool operator>(complex64_t a, complex64_t b) {
|
||||||
|
return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr bool operator<=(complex64_t a, complex64_t b) {
|
||||||
|
return operator>=(b, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr bool operator<(complex64_t a, complex64_t b) {
|
||||||
|
return operator>(b, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr bool operator==(complex64_t a, complex64_t b) {
|
||||||
|
return a.real == b.real && a.imag == b.imag;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
||||||
|
return {a.real + b.real, a.imag + b.imag};
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
||||||
|
return {a.real - b.real, a.imag - b.imag};
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
||||||
|
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
||||||
|
}
|
||||||
14
mlx/backend/metal/kernels/defines.h
Normal file
14
mlx/backend/metal/kernels/defines.h
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#ifdef __METAL__
|
||||||
|
#define MTL_CONST constant
|
||||||
|
#else
|
||||||
|
#define MTL_CONST
|
||||||
|
#endif
|
||||||
|
|
||||||
|
static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
||||||
|
static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
|
||||||
|
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
||||||
|
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
||||||
|
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
||||||
|
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
|
||||||
479
mlx/backend/metal/kernels/gemm/conv.h
Normal file
479
mlx/backend/metal/kernels/gemm/conv.h
Normal file
@@ -0,0 +1,479 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <metal_simdgroup>
|
||||||
|
#include <metal_simdgroup_matrix>
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/conv_params.h"
|
||||||
|
|
||||||
|
#define MLX_MTL_CONST static constant constexpr const
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Loading helper
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int vec_size,
|
||||||
|
int tgp_size,
|
||||||
|
int tgp_padding = 0>
|
||||||
|
struct Conv2DInputBlockLoader {
|
||||||
|
// Destination dimensions
|
||||||
|
MLX_MTL_CONST int dst_fd = BM;
|
||||||
|
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
|
||||||
|
MLX_MTL_CONST int n_vecs = BK / vec_size;
|
||||||
|
|
||||||
|
// Stride along block row within the block
|
||||||
|
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
||||||
|
MLX_MTL_CONST int n_rows = dst_fd / bstride;
|
||||||
|
|
||||||
|
// Thread location indices
|
||||||
|
const short thread_idx;
|
||||||
|
const short bi;
|
||||||
|
const short bj;
|
||||||
|
|
||||||
|
// threadgroup and device memory
|
||||||
|
threadgroup T* dst;
|
||||||
|
const device T* src;
|
||||||
|
|
||||||
|
const constant MLXConvParams<2>& params;
|
||||||
|
|
||||||
|
int weight_h;
|
||||||
|
int weight_w;
|
||||||
|
|
||||||
|
int offsets_n[n_rows];
|
||||||
|
int offsets_oh[n_rows];
|
||||||
|
int offsets_ow[n_rows];
|
||||||
|
|
||||||
|
/* Constructor */
|
||||||
|
METAL_FUNC Conv2DInputBlockLoader(
|
||||||
|
const device T* src_,
|
||||||
|
threadgroup T* dst_,
|
||||||
|
const constant MLXConvParams<2>& params_,
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||||
|
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||||
|
bi(thread_idx / n_vecs),
|
||||||
|
bj(vec_size * (thread_idx % n_vecs)),
|
||||||
|
dst(dst_ + bi * dst_ld + bj),
|
||||||
|
src(src_ + bj),
|
||||||
|
params(params_),
|
||||||
|
weight_h(0),
|
||||||
|
weight_w(0) {
|
||||||
|
int out_n_pixels = params.oS[0] * params.oS[1];
|
||||||
|
|
||||||
|
for (int i = 0; i < n_rows; ++i) {
|
||||||
|
int offset_nhw = tid.y * BM + bi + i * bstride;
|
||||||
|
offsets_n[i] = offset_nhw / out_n_pixels;
|
||||||
|
int hw = offset_nhw % out_n_pixels;
|
||||||
|
offsets_oh[i] = hw / params.oS[1];
|
||||||
|
offsets_ow[i] = hw % params.oS[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
(void)lid;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Load from device memory into threadgroup memory - without bound checking */
|
||||||
|
METAL_FUNC void load_unsafe() const {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short i = 0, is = 0; i < n_rows; ++i, is += bstride) {
|
||||||
|
int n = offsets_n[i];
|
||||||
|
int oh = offsets_oh[i];
|
||||||
|
int ow = offsets_ow[i];
|
||||||
|
|
||||||
|
int ih = oh * params.str[0] - params.pad[0] + weight_h * params.dil[0];
|
||||||
|
int iw = ow * params.str[1] - params.pad[1] + weight_w * params.dil[1];
|
||||||
|
|
||||||
|
// Read from input if in bounds
|
||||||
|
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
|
||||||
|
const device T* curr_src = src + n * params.in_strides[0] +
|
||||||
|
ih * params.in_strides[1] + iw * params.in_strides[2];
|
||||||
|
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short j = 0; j < vec_size; ++j) {
|
||||||
|
dst[is * dst_ld + j] = curr_src[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zero pad otherwize
|
||||||
|
else {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short j = 0; j < vec_size; ++j) {
|
||||||
|
dst[is * dst_ld + j] = T(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Iteration helper */
|
||||||
|
METAL_FUNC void next() {
|
||||||
|
if (++weight_w < params.wS[1]) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
weight_w = 0;
|
||||||
|
|
||||||
|
if (++weight_h < params.wS[0]) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
weight_h = 0;
|
||||||
|
|
||||||
|
src += BK;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int vec_size,
|
||||||
|
int tgp_size,
|
||||||
|
int tgp_padding = 0>
|
||||||
|
struct Conv2DWeightBlockLoader {
|
||||||
|
// Destination dimensions
|
||||||
|
MLX_MTL_CONST int dst_fd = BN;
|
||||||
|
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
|
||||||
|
MLX_MTL_CONST int n_vecs = BK / vec_size;
|
||||||
|
|
||||||
|
// Stride along block row within the block
|
||||||
|
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
||||||
|
MLX_MTL_CONST int n_rows = dst_fd / bstride;
|
||||||
|
|
||||||
|
// Leading dimension for src
|
||||||
|
const int src_ld;
|
||||||
|
|
||||||
|
// Thread location indices
|
||||||
|
const short thread_idx;
|
||||||
|
const short bi;
|
||||||
|
const short bj;
|
||||||
|
|
||||||
|
// threadgroup and device memory
|
||||||
|
threadgroup T* dst;
|
||||||
|
const device T* src;
|
||||||
|
|
||||||
|
const constant MLXConvParams<2>& params;
|
||||||
|
|
||||||
|
int weight_h;
|
||||||
|
int weight_w;
|
||||||
|
|
||||||
|
/* Constructor */
|
||||||
|
METAL_FUNC Conv2DWeightBlockLoader(
|
||||||
|
const device T* src_,
|
||||||
|
threadgroup T* dst_,
|
||||||
|
const constant MLXConvParams<2>& params_,
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||||
|
: src_ld(params_.wt_strides[0]),
|
||||||
|
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||||
|
bi(thread_idx / n_vecs),
|
||||||
|
bj(vec_size * (thread_idx % n_vecs)),
|
||||||
|
dst(dst_ + bi * dst_ld + bj),
|
||||||
|
src(src_ + bi * src_ld + bj),
|
||||||
|
params(params_),
|
||||||
|
weight_h(0),
|
||||||
|
weight_w(0) {
|
||||||
|
(void)lid;
|
||||||
|
(void)tid;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Load from device memory into threadgroup memory - without bound checking */
|
||||||
|
METAL_FUNC void load_unsafe() const {
|
||||||
|
const device T* curr_src =
|
||||||
|
src + weight_h * params.wt_strides[1] + weight_w * params.wt_strides[2];
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short i = 0; i < dst_fd; i += bstride) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Iteration helper */
|
||||||
|
METAL_FUNC void next() {
|
||||||
|
if (++weight_w < params.wS[1]) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
weight_w = 0;
|
||||||
|
|
||||||
|
if (++weight_h < params.wS[0]) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
weight_h = 0;
|
||||||
|
|
||||||
|
src += BK;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Transforms
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename OutT, typename InT>
|
||||||
|
struct TransformNone {
|
||||||
|
static METAL_FUNC OutT apply(InT x) {
|
||||||
|
return static_cast<OutT>(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct AccumHelper {
|
||||||
|
typedef float accum_type;
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MMA helper
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
int tgp_padding_a = 0,
|
||||||
|
int tgp_padding_b = 0,
|
||||||
|
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||||
|
typename Epilogue = TransformNone<T, AccumType>>
|
||||||
|
struct Conv2DBlockMMA {
|
||||||
|
// Warp tile size along M
|
||||||
|
MLX_MTL_CONST int TM = BM / (WM * 8);
|
||||||
|
// Warp tile size along N
|
||||||
|
MLX_MTL_CONST int TN = BN / (WN * 8);
|
||||||
|
|
||||||
|
// Warp tile simdgroup matrix strides along M
|
||||||
|
MLX_MTL_CONST int TM_stride = 8 * WM;
|
||||||
|
// Warp tile simdgroup matrix strides along M
|
||||||
|
MLX_MTL_CONST int TN_stride = 8 * WN;
|
||||||
|
|
||||||
|
// Leading dimensions of threadgroup A, B blocks
|
||||||
|
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||||
|
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||||
|
|
||||||
|
// Strides of A, B along reduction axis
|
||||||
|
MLX_MTL_CONST short simd_stride_a =
|
||||||
|
transpose_a ? TM_stride : TM_stride * lda_tgp;
|
||||||
|
MLX_MTL_CONST short simd_stride_b =
|
||||||
|
transpose_b ? TN_stride * ldb_tgp : TN_stride;
|
||||||
|
|
||||||
|
// Jump between elements
|
||||||
|
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
|
||||||
|
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
|
||||||
|
|
||||||
|
// Offsets within threadgroup
|
||||||
|
const int tm;
|
||||||
|
const int tn;
|
||||||
|
|
||||||
|
// Simdgroup matrices
|
||||||
|
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||||
|
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||||
|
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||||
|
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||||
|
|
||||||
|
short sm;
|
||||||
|
short sn;
|
||||||
|
|
||||||
|
/* Constructor */
|
||||||
|
METAL_FUNC Conv2DBlockMMA(
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||||
|
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||||
|
short qid = simd_lane_id / 4;
|
||||||
|
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||||
|
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||||
|
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||||
|
// Iterate over BK in blocks of 8
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short kk = 0; kk < BK; kk += 8) {
|
||||||
|
short2 offset_a =
|
||||||
|
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
|
||||||
|
short2 offset_b =
|
||||||
|
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
|
||||||
|
|
||||||
|
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
|
||||||
|
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
// Load elements from threadgroup A as simdgroup matrices
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short i = 0; i < TM; i++) {
|
||||||
|
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
|
||||||
|
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
|
||||||
|
As__ += simd_stride_a;
|
||||||
|
}
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
// Load elements from threadgroup B as simdgroup matrices
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short j = 0; j < TN; j++) {
|
||||||
|
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
|
||||||
|
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
|
||||||
|
Bs__ += simd_stride_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
// Multiply and accumulate into resulr simdgroup matrices
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short i = 0; i < TM; i++) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short j = 0; j < TN; j++) {
|
||||||
|
simdgroup_multiply_accumulate(
|
||||||
|
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Store results from simdgroup_matrix results into device memory */
|
||||||
|
METAL_FUNC void store_result(device T* C, const int ldc) const {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int i = 0; i < TM; i++) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int j = 0; j < TN; j++) {
|
||||||
|
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
|
||||||
|
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||||
|
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
|
||||||
|
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
METAL_FUNC void
|
||||||
|
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int i = 0; i < TM; i++) {
|
||||||
|
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int j = 0; j < TN; j++) {
|
||||||
|
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
|
||||||
|
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
|
||||||
|
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
|
||||||
|
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
|
||||||
|
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernels
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||||
|
typename Epilogue = TransformNone<T, AccumType>>
|
||||||
|
struct Conv2DImplicitGEMMKernel {
|
||||||
|
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
|
||||||
|
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
|
||||||
|
MLX_MTL_CONST short tgp_mem_size_a =
|
||||||
|
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
||||||
|
MLX_MTL_CONST short tgp_mem_size_b =
|
||||||
|
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
||||||
|
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
||||||
|
|
||||||
|
MLX_MTL_CONST short tgp_size = WM * WN * 32;
|
||||||
|
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
|
||||||
|
|
||||||
|
using loader_a_t =
|
||||||
|
Conv2DInputBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_a>;
|
||||||
|
using loader_b_t =
|
||||||
|
Conv2DWeightBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_b>;
|
||||||
|
using mma_t = Conv2DBlockMMA<
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
tgp_padding_a,
|
||||||
|
tgp_padding_b,
|
||||||
|
AccumType,
|
||||||
|
Epilogue>;
|
||||||
|
|
||||||
|
/* Main kernel function */
|
||||||
|
static METAL_FUNC void run(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
const device T* B [[buffer(1)]],
|
||||||
|
device T* C [[buffer(2)]],
|
||||||
|
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||||
|
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
const int c_row = tid.y * BM;
|
||||||
|
const int c_col = tid.x * BN;
|
||||||
|
const int K = params.wt_strides[0];
|
||||||
|
const int N = params.O;
|
||||||
|
|
||||||
|
B += c_col * K;
|
||||||
|
C += c_row * N + c_col;
|
||||||
|
|
||||||
|
// Prepare threadgroup memory for loading
|
||||||
|
threadgroup T* As = tgp_memory;
|
||||||
|
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
loader_a_t loader_a(A, As, params, tid, lid, simd_gid, simd_lid);
|
||||||
|
loader_b_t loader_b(B, Bs, params, tid, lid, simd_gid, simd_lid);
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
mma_t mma_op(simd_gid, simd_lid);
|
||||||
|
|
||||||
|
for (int k = 0; k < K; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
mma_op.store_result(C, N);
|
||||||
|
}
|
||||||
|
};
|
||||||
536
mlx/backend/metal/kernels/gemm/gemm.h
Normal file
536
mlx/backend/metal/kernels/gemm/gemm.h
Normal file
@@ -0,0 +1,536 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <metal_simdgroup>
|
||||||
|
#include <metal_simdgroup_matrix>
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
#define MLX_MTL_CONST static constant constexpr const
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Loading helper
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BROWS,
|
||||||
|
int BCOLS,
|
||||||
|
int BK,
|
||||||
|
int vec_size,
|
||||||
|
int tgp_size,
|
||||||
|
bool transpose,
|
||||||
|
bool ldK,
|
||||||
|
int tgp_padding = 0>
|
||||||
|
struct BlockLoader {
|
||||||
|
// Destination dimensions
|
||||||
|
MLX_MTL_CONST int dst_fd = transpose ? BCOLS : BROWS;
|
||||||
|
MLX_MTL_CONST int dst_ld = (transpose ? BROWS : BCOLS) + tgp_padding;
|
||||||
|
MLX_MTL_CONST int n_vecs = (transpose ? BROWS : BCOLS) / vec_size;
|
||||||
|
|
||||||
|
// Stride along block row within the block
|
||||||
|
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
||||||
|
|
||||||
|
// Leading dimension for src
|
||||||
|
const int src_ld;
|
||||||
|
// Stride along reduction axis between blocks
|
||||||
|
const int tstride;
|
||||||
|
|
||||||
|
// Thread location indices
|
||||||
|
const short thread_idx;
|
||||||
|
const short bi;
|
||||||
|
const short bj;
|
||||||
|
|
||||||
|
// threadgroup and device memory
|
||||||
|
threadgroup T* dst;
|
||||||
|
const device T* src;
|
||||||
|
|
||||||
|
/* Constructor */
|
||||||
|
METAL_FUNC BlockLoader(
|
||||||
|
const device T* src_,
|
||||||
|
const int src_ld_,
|
||||||
|
threadgroup T* dst_,
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||||
|
: src_ld(src_ld_),
|
||||||
|
tstride(
|
||||||
|
BK * ((int)(transpose ^ !ldK) * src_ld + (int)(transpose ^ ldK))),
|
||||||
|
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||||
|
bi(thread_idx / n_vecs),
|
||||||
|
bj(vec_size * (thread_idx % n_vecs)),
|
||||||
|
dst(dst_ + bi * dst_ld + bj),
|
||||||
|
src(src_ + bi * src_ld + bj) {}
|
||||||
|
|
||||||
|
/* Load from device memory into threadgroup memory - without bound checking */
|
||||||
|
METAL_FUNC void load_unsafe() const {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short i = 0; i < dst_fd; i += bstride) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
dst[i * dst_ld + j] = src[i * src_ld + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Load from device memory into threadgroup memory - with bound checking */
|
||||||
|
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||||
|
src_tile_dim = transpose ? src_tile_dim.yx : src_tile_dim.xy;
|
||||||
|
|
||||||
|
// Iterate over rows of block
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short i = 0; i < dst_fd; i += bstride) {
|
||||||
|
// Row is in bounds, we check against column
|
||||||
|
if ((bi + i) < src_tile_dim.y) {
|
||||||
|
// Use fast thread memory for bound checks
|
||||||
|
short tmp_idx[vec_size];
|
||||||
|
T tmp_val[vec_size];
|
||||||
|
|
||||||
|
// Make sure tmp_idx only contains valid indices
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read all valid indcies into tmp_val
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
tmp_val[j] = src[i * src_ld + tmp_idx[j]];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zero out uneeded values
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy values to threadgroup memory
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
dst[i * dst_ld + j] = tmp_val[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Row is out of bounds, we just fill tgp memory with zeros
|
||||||
|
else {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
dst[i * dst_ld + j] = T(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Iteration helper */
|
||||||
|
METAL_FUNC void next() {
|
||||||
|
src += tstride;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Transforms
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename OutT, typename InT>
|
||||||
|
struct TransformNone {
|
||||||
|
static METAL_FUNC OutT apply(InT x) {
|
||||||
|
return static_cast<OutT>(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct AccumHelper {
|
||||||
|
typedef float accum_type;
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MMA helper
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
int tgp_padding_a = 0,
|
||||||
|
int tgp_padding_b = 0,
|
||||||
|
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||||
|
typename Epilogue = TransformNone<T, AccumType>>
|
||||||
|
struct BlockMMA {
|
||||||
|
// Warp tile size along M
|
||||||
|
MLX_MTL_CONST int TM = BM / (WM * 8);
|
||||||
|
// Warp tile size along N
|
||||||
|
MLX_MTL_CONST int TN = BN / (WN * 8);
|
||||||
|
|
||||||
|
// Warp tile simdgroup matrix strides along M
|
||||||
|
MLX_MTL_CONST int TM_stride = 8 * WM;
|
||||||
|
// Warp tile simdgroup matrix strides along M
|
||||||
|
MLX_MTL_CONST int TN_stride = 8 * WN;
|
||||||
|
|
||||||
|
// Leading dimensions of threadgroup A, B blocks
|
||||||
|
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||||
|
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||||
|
|
||||||
|
// Strides of A, B along reduction axis
|
||||||
|
MLX_MTL_CONST short simd_stride_a =
|
||||||
|
transpose_a ? TM_stride : TM_stride * lda_tgp;
|
||||||
|
MLX_MTL_CONST short simd_stride_b =
|
||||||
|
transpose_b ? TN_stride * ldb_tgp : TN_stride;
|
||||||
|
|
||||||
|
// Jump between elements
|
||||||
|
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
|
||||||
|
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
|
||||||
|
|
||||||
|
// Offsets within threadgroup
|
||||||
|
const int tm;
|
||||||
|
const int tn;
|
||||||
|
|
||||||
|
// Simdgroup matrices
|
||||||
|
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||||
|
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||||
|
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||||
|
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||||
|
|
||||||
|
short sm;
|
||||||
|
short sn;
|
||||||
|
|
||||||
|
/* Constructor */
|
||||||
|
METAL_FUNC BlockMMA(
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||||
|
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||||
|
short qid = simd_lane_id / 4;
|
||||||
|
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||||
|
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||||
|
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||||
|
// Iterate over BK in blocks of 8
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short kk = 0; kk < BK; kk += 8) {
|
||||||
|
short2 offset_a =
|
||||||
|
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
|
||||||
|
short2 offset_b =
|
||||||
|
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
|
||||||
|
|
||||||
|
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
|
||||||
|
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
// Load elements from threadgroup A as simdgroup matrices
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short i = 0; i < TM; i++) {
|
||||||
|
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
|
||||||
|
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
|
||||||
|
As__ += simd_stride_a;
|
||||||
|
}
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
// Load elements from threadgroup B as simdgroup matrices
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short j = 0; j < TN; j++) {
|
||||||
|
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
|
||||||
|
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
|
||||||
|
Bs__ += simd_stride_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
// Multiply and accumulate into resulr simdgroup matrices
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short i = 0; i < TM; i++) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (short j = 0; j < TN; j++) {
|
||||||
|
simdgroup_multiply_accumulate(
|
||||||
|
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Store results from simdgroup_matrix results into device memory */
|
||||||
|
METAL_FUNC void store_result(device T* C, const int ldc) const {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int i = 0; i < TM; i++) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int j = 0; j < TN; j++) {
|
||||||
|
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
|
||||||
|
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||||
|
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
|
||||||
|
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
METAL_FUNC void
|
||||||
|
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int i = 0; i < TM; i++) {
|
||||||
|
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int j = 0; j < TN; j++) {
|
||||||
|
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
|
||||||
|
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
|
||||||
|
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
|
||||||
|
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
|
||||||
|
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernels
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
bool MN_aligned,
|
||||||
|
bool K_aligned,
|
||||||
|
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||||
|
typename Epilogue = TransformNone<T, AccumType>>
|
||||||
|
struct GEMMKernel {
|
||||||
|
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
|
||||||
|
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
|
||||||
|
MLX_MTL_CONST short tgp_mem_size_a =
|
||||||
|
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
||||||
|
MLX_MTL_CONST short tgp_mem_size_b =
|
||||||
|
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
||||||
|
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
||||||
|
|
||||||
|
MLX_MTL_CONST short tgp_size = WM * WN * 32;
|
||||||
|
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
|
||||||
|
|
||||||
|
using loader_a_t = BlockLoader<
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BK,
|
||||||
|
BK,
|
||||||
|
vec_size,
|
||||||
|
tgp_size,
|
||||||
|
transpose_a,
|
||||||
|
true,
|
||||||
|
tgp_padding_a>;
|
||||||
|
using loader_b_t = BlockLoader<
|
||||||
|
T,
|
||||||
|
BK,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
vec_size,
|
||||||
|
tgp_size,
|
||||||
|
transpose_b,
|
||||||
|
false,
|
||||||
|
tgp_padding_b>;
|
||||||
|
using mma_t = BlockMMA<
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
tgp_padding_a,
|
||||||
|
tgp_padding_b,
|
||||||
|
AccumType,
|
||||||
|
Epilogue>;
|
||||||
|
|
||||||
|
/* Main kernel function */
|
||||||
|
static METAL_FUNC void run(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
const device T* B [[buffer(1)]],
|
||||||
|
device T* C [[buffer(2)]],
|
||||||
|
const constant int& M [[buffer(3)]],
|
||||||
|
const constant int& N [[buffer(4)]],
|
||||||
|
const constant int& K [[buffer(5)]],
|
||||||
|
const constant int& batch_stride_a [[buffer(6)]],
|
||||||
|
const constant int& batch_stride_b [[buffer(7)]],
|
||||||
|
const constant int& batch_stride_c [[buffer(8)]],
|
||||||
|
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
// Pacifying compiler
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
// Adjust for batch
|
||||||
|
A += batch_stride_a * tid.z;
|
||||||
|
B += batch_stride_b * tid.z;
|
||||||
|
C += batch_stride_c * tid.z;
|
||||||
|
|
||||||
|
// Adjust for transpose
|
||||||
|
const int lda_dev = transpose_a ? M : K;
|
||||||
|
const int ldb_dev = transpose_b ? K : N;
|
||||||
|
|
||||||
|
// Find block in A, B, C
|
||||||
|
const int c_row = tid.y * BM;
|
||||||
|
const int c_col = tid.x * BN;
|
||||||
|
|
||||||
|
A += transpose_a ? c_row : c_row * K;
|
||||||
|
B += transpose_b ? c_col * K : c_col;
|
||||||
|
C += c_row * N + c_col;
|
||||||
|
|
||||||
|
// Prepare threadgroup memory for loading
|
||||||
|
threadgroup T* As = tgp_memory;
|
||||||
|
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
loader_a_t loader_a(A, lda_dev, As, simd_group_id, simd_lane_id);
|
||||||
|
loader_b_t loader_b(B, ldb_dev, Bs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
mma_t mma_op(simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MNK aligned loop
|
||||||
|
if (MN_aligned && K_aligned) {
|
||||||
|
for (int k = 0; k < K; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
mma_op.store_result(C, N);
|
||||||
|
return;
|
||||||
|
|
||||||
|
}
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MN aligned, K unaligned loop
|
||||||
|
else if (MN_aligned && !K_aligned) {
|
||||||
|
// Main loop
|
||||||
|
int k = 0;
|
||||||
|
for (; k + BK <= K; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loop tail
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
loader_a.load_safe(short2(K - k, BM));
|
||||||
|
loader_b.load_safe(short2(BN, K - k));
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
mma_op.store_result(C, N);
|
||||||
|
return;
|
||||||
|
|
||||||
|
}
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MNK unaligned loop
|
||||||
|
else { // Loop over K - unaligned case
|
||||||
|
|
||||||
|
short2 src_tile_dims(min(BN, N - c_col), min(BM, M - c_row));
|
||||||
|
|
||||||
|
if (src_tile_dims.y == BM && src_tile_dims.x == BN) {
|
||||||
|
int k = 0;
|
||||||
|
for (; k + BK <= K; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
if (k < K) {
|
||||||
|
loader_a.load_safe(short2(K - k, BM));
|
||||||
|
loader_b.load_safe(short2(BN, K - k));
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
mma_op.store_result(C, N);
|
||||||
|
return;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
int k = 0;
|
||||||
|
for (; k + BK <= K; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_safe(short2(BK, src_tile_dims.y));
|
||||||
|
loader_b.load_safe(short2(src_tile_dims.x, BK));
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
if (k < K) {
|
||||||
|
loader_a.load_safe(short2(K - k, src_tile_dims.y));
|
||||||
|
loader_b.load_safe(short2(src_tile_dims.x, K - k));
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
mma_op.store_result_safe(C, N, src_tile_dims);
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
302
mlx/backend/metal/kernels/gemv.metal
Normal file
302
mlx/backend/metal/kernels/gemv.metal
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
#include <metal_stdlib>
|
||||||
|
#include <metal_simdgroup>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// Matrix vector multiplication
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
static constant constexpr int SIMD_SIZE = 32;
|
||||||
|
|
||||||
|
template <typename T,
|
||||||
|
const int BM, /* Threadgroup rows (in threads) */
|
||||||
|
const int BN, /* Threadgroup cols (in threads) */
|
||||||
|
const int TM, /* Thread rows (in elements) */
|
||||||
|
const int TN> /* Thread cols (in elements) */
|
||||||
|
[[kernel]] void gemv(
|
||||||
|
const device T* mat [[buffer(0)]],
|
||||||
|
const device T* in_vec [[buffer(1)]],
|
||||||
|
device T* out_vec [[buffer(2)]],
|
||||||
|
const constant int& in_vec_size [[buffer(3)]],
|
||||||
|
const constant int& out_vec_size [[buffer(4)]],
|
||||||
|
const constant int& vector_batch_stride [[buffer(5)]],
|
||||||
|
const constant int& matrix_batch_stride [[buffer(6)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
||||||
|
|
||||||
|
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
||||||
|
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
||||||
|
// - Every thread works on a block of (TM, TN)
|
||||||
|
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||||
|
//
|
||||||
|
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||||
|
// and the corresponding scalar from the vector
|
||||||
|
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
||||||
|
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
||||||
|
// These are then summed up across the threadgroup
|
||||||
|
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||||
|
//
|
||||||
|
// Edge case handling:
|
||||||
|
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||||
|
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||||
|
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
||||||
|
// such that the thread block fits exactly in the matrix
|
||||||
|
|
||||||
|
// Update batch offsets
|
||||||
|
in_vec += tid.z * vector_batch_stride;
|
||||||
|
mat += tid.z * matrix_batch_stride;
|
||||||
|
out_vec += tid.z * out_vec_size;
|
||||||
|
|
||||||
|
// Threadgroup in_vec cache
|
||||||
|
threadgroup T in_vec_block[BN][TN * 2];
|
||||||
|
|
||||||
|
// Thread local accumulation results
|
||||||
|
thread T result[TM] = {0};
|
||||||
|
thread T inter[TN];
|
||||||
|
thread T v_coeff[TN];
|
||||||
|
|
||||||
|
// Block position
|
||||||
|
int out_row = (tid.x * BM + simd_gid) * TM;
|
||||||
|
|
||||||
|
// Exit simdgroup if rows out of bound
|
||||||
|
if(out_row >= out_vec_size)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Adjust tail simdgroup to ensure in bound reads
|
||||||
|
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
|
||||||
|
|
||||||
|
// Advance matrix
|
||||||
|
mat += out_row * in_vec_size;
|
||||||
|
|
||||||
|
// Loop over in_vec in blocks of BN * TN
|
||||||
|
for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Prefetch in_vector for threadgroup use
|
||||||
|
if(simd_gid == 0) {
|
||||||
|
// Main load loop
|
||||||
|
if(bn + TN <= in_vec_size) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for(int tn = 0; tn < TN; tn++) {
|
||||||
|
in_vec_block[simd_lid][tn] = in_vec[bn + tn];
|
||||||
|
}
|
||||||
|
} else { // Edgecase
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for(int tn = 0; tn < TN; tn++) {
|
||||||
|
in_vec_block[simd_lid][tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Load for all rows
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for(int tn = 0; tn < TN; tn++) {
|
||||||
|
v_coeff[tn] = in_vec_block[simd_lid][tn];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per thread work loop
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for(int tm = 0; tm < TM; tm++) {
|
||||||
|
// Load for the row
|
||||||
|
for(int tn = 0; tn < TN; tn++) {
|
||||||
|
inter[tn] = mat[tm * in_vec_size + bn + tn];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accumulate results
|
||||||
|
for(int tn = 0; tn < TN; tn++) {
|
||||||
|
result[tm] += inter[tn] * v_coeff[tn];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simdgroup accumulations
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for(int tm = 0; tm < TM; tm++) {
|
||||||
|
result[tm] = simd_sum(result[tm]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write outputs
|
||||||
|
if(simd_lid == 0) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for(int tm = 0; tm < TM; tm++) {
|
||||||
|
out_vec[out_row + tm] = result[tm];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||||
|
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||||
|
[[kernel]] void gemv<itype, bm, bn, tm, tn>( \
|
||||||
|
const device itype* mat [[buffer(0)]], \
|
||||||
|
const device itype* vec [[buffer(1)]], \
|
||||||
|
device itype* out [[buffer(2)]], \
|
||||||
|
const constant int& in_vec_size [[buffer(3)]], \
|
||||||
|
const constant int& out_vec_size [[buffer(4)]], \
|
||||||
|
const constant int& vector_batch_stride [[buffer(5)]], \
|
||||||
|
const constant int& matrix_batch_stride [[buffer(6)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_gemv_blocks(name, itype) \
|
||||||
|
instantiate_gemv(name, itype, 4, 32, 1, 4) \
|
||||||
|
instantiate_gemv(name, itype, 4, 32, 4, 4) \
|
||||||
|
instantiate_gemv(name, itype, 8, 32, 4, 4)
|
||||||
|
|
||||||
|
instantiate_gemv_blocks(float32, float)
|
||||||
|
instantiate_gemv_blocks(float16, half)
|
||||||
|
instantiate_gemv_blocks(bfloat16, bfloat16_t)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// Vector matrix multiplication
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T,
|
||||||
|
const int BM, /* Threadgroup rows (in threads) */
|
||||||
|
const int BN, /* Threadgroup cols (in threads) */
|
||||||
|
const int TM, /* Thread rows (in elements) */
|
||||||
|
const int TN> /* Thread cols (in elements) */
|
||||||
|
[[kernel]] void gemv_t(
|
||||||
|
const device T* mat [[buffer(0)]],
|
||||||
|
const device T* in_vec [[buffer(1)]],
|
||||||
|
device T* out_vec [[buffer(2)]],
|
||||||
|
const constant int& in_vec_size [[buffer(3)]],
|
||||||
|
const constant int& out_vec_size [[buffer(4)]],
|
||||||
|
const constant int& vector_batch_stride [[buffer(5)]],
|
||||||
|
const constant int& matrix_batch_stride [[buffer(6)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
|
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||||
|
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
||||||
|
// - Every thread works on a block of (TM, TN)
|
||||||
|
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||||
|
//
|
||||||
|
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||||
|
// and the corresponding scalar from the vector
|
||||||
|
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
||||||
|
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
||||||
|
// These are then summed up across the threadgroup
|
||||||
|
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||||
|
//
|
||||||
|
// Edge case handling:
|
||||||
|
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||||
|
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||||
|
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
||||||
|
// such that the thread block fits exactly in the matrix
|
||||||
|
|
||||||
|
// Update batch offsets
|
||||||
|
in_vec += tid.z * vector_batch_stride;
|
||||||
|
mat += tid.z * matrix_batch_stride;
|
||||||
|
out_vec += tid.z * out_vec_size;
|
||||||
|
|
||||||
|
// Thread local accumulation results
|
||||||
|
T result[TN] = {0};
|
||||||
|
T inter[TN];
|
||||||
|
T v_coeff[TM];
|
||||||
|
|
||||||
|
// Threadgroup accumulation results
|
||||||
|
threadgroup T tgp_results[BN][BM][TM];
|
||||||
|
|
||||||
|
int out_col = (tid.x * BN + lid.x) * TN;
|
||||||
|
int in_row = lid.y * TM;
|
||||||
|
|
||||||
|
// Edgecase handling
|
||||||
|
if (out_col < out_vec_size) {
|
||||||
|
// Edgecase handling
|
||||||
|
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
|
||||||
|
|
||||||
|
// Per thread accumulation main loop
|
||||||
|
int bm = in_row;
|
||||||
|
for(; bm < in_vec_size; bm += BM * TM) {
|
||||||
|
// Adding a threadgroup_barrier improves performance slightly
|
||||||
|
// This is possibly it may help exploit cache better
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
if(bm + TM <= in_vec_size) {
|
||||||
|
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for(int tm = 0; tm < TM; tm++) {
|
||||||
|
v_coeff[tm] = in_vec[bm + tm];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for(int tm = 0; tm < TM; tm++) {
|
||||||
|
for(int tn = 0; tn < TN; tn++) {
|
||||||
|
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
|
||||||
|
}
|
||||||
|
for(int tn = 0; tn < TN; tn++) {
|
||||||
|
result[tn] += v_coeff[tm] * inter[tn];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} else { // Edgecase handling
|
||||||
|
for(int tm = 0; bm + tm < in_vec_size; tm++) {
|
||||||
|
v_coeff[tm] = in_vec[bm + tm];
|
||||||
|
|
||||||
|
for(int tn = 0; tn < TN; tn++) {
|
||||||
|
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
|
||||||
|
}
|
||||||
|
for(int tn = 0; tn < TN; tn++) {
|
||||||
|
result[tn] += v_coeff[tm] * inter[tn];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Threadgroup collection
|
||||||
|
for(int i = 0; i < TN; i++) {
|
||||||
|
tgp_results[lid.x][lid.y][i] = result[i];
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if(lid.y == 0 && out_col < out_vec_size) {
|
||||||
|
// Threadgroup accumulation
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for(int i = 1; i < BM; i++) {
|
||||||
|
for(int j = 0; j < TN; j++) {
|
||||||
|
result[j] += tgp_results[lid.x][i][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int j = 0; j < TN; j++) {
|
||||||
|
out_vec[out_col + j] = result[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
|
||||||
|
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||||
|
[[kernel]] void gemv_t<itype, bm, bn, tm, tn>( \
|
||||||
|
const device itype* mat [[buffer(0)]], \
|
||||||
|
const device itype* vec [[buffer(1)]], \
|
||||||
|
device itype* out [[buffer(2)]], \
|
||||||
|
const constant int& in_vec_size [[buffer(3)]], \
|
||||||
|
const constant int& out_vec_size [[buffer(4)]], \
|
||||||
|
const constant int& vector_batch_stride [[buffer(5)]], \
|
||||||
|
const constant int& matrix_batch_stride [[buffer(6)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_gemv_t_blocks(name, itype) \
|
||||||
|
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
|
||||||
|
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
|
||||||
|
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
|
||||||
|
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
|
||||||
|
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
|
||||||
|
instantiate_gemv_t(name, itype, 8, 128, 4, 4)
|
||||||
|
|
||||||
|
instantiate_gemv_t_blocks(float32, float)
|
||||||
|
instantiate_gemv_t_blocks(float16, half)
|
||||||
|
instantiate_gemv_t_blocks(bfloat16, bfloat16_t)
|
||||||
226
mlx/backend/metal/kernels/softmax.metal
Normal file
226
mlx/backend/metal/kernels/softmax.metal
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
#include <metal_atomic>
|
||||||
|
#include <metal_common>
|
||||||
|
#include <metal_simdgroup>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T softmax_exp(T x) {
|
||||||
|
// Softmax doesn't need high precision exponential cause it is gonna be x
|
||||||
|
// will be in (-oo, 0] anyway and subsequently it will be divided by
|
||||||
|
// sum(exp(x_i)).
|
||||||
|
return fast::exp(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||||
|
[[kernel]] void softmax_single_row(
|
||||||
|
const device T* in,
|
||||||
|
device T* out,
|
||||||
|
constant int& axis_size,
|
||||||
|
threadgroup T* local_max [[threadgroup(0)]],
|
||||||
|
threadgroup T* local_normalizer [[threadgroup(1)]],
|
||||||
|
uint gid [[threadgroup_position_in_grid]],
|
||||||
|
uint _lid [[thread_position_in_threadgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
int lid = _lid;
|
||||||
|
|
||||||
|
T ld[N_READS];
|
||||||
|
|
||||||
|
in += gid * axis_size + lid * N_READS;
|
||||||
|
if (lid * N_READS + N_READS <= axis_size) {
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
ld[i] = in[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
ld[i] =
|
||||||
|
((lid * N_READS + i) < axis_size) ? in[i] : T(Limits<T>::finite_min);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (simd_group_id == 0) {
|
||||||
|
local_max[simd_lane_id] = Limits<T>::finite_min;
|
||||||
|
local_normalizer[simd_lane_id] = 0;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Get the max
|
||||||
|
T maxval = Limits<T>::finite_min;
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
maxval = (maxval < ld[i]) ? ld[i] : maxval;
|
||||||
|
}
|
||||||
|
maxval = simd_max(maxval);
|
||||||
|
if (simd_lane_id == 0) {
|
||||||
|
local_max[simd_group_id] = maxval;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (simd_group_id == 0) {
|
||||||
|
maxval = simd_max(local_max[simd_lane_id]);
|
||||||
|
if (simd_lane_id == 0) {
|
||||||
|
local_max[0] = maxval;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
maxval = local_max[0];
|
||||||
|
|
||||||
|
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
||||||
|
T normalizer = 0;
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
T exp_x = softmax_exp(ld[i] - maxval);
|
||||||
|
ld[i] = exp_x;
|
||||||
|
normalizer += exp_x;
|
||||||
|
}
|
||||||
|
normalizer = simd_sum(normalizer);
|
||||||
|
if (simd_lane_id == 0) {
|
||||||
|
local_normalizer[simd_group_id] = normalizer;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (simd_group_id == 0) {
|
||||||
|
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||||
|
if (simd_lane_id == 0) {
|
||||||
|
local_normalizer[0] = normalizer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
normalizer = 1 / local_normalizer[0];
|
||||||
|
|
||||||
|
// Normalize and write to the output
|
||||||
|
out += gid * axis_size + lid * N_READS;
|
||||||
|
if (lid * N_READS + N_READS <= axis_size) {
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
out[i] = ld[i] * normalizer;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
if ((lid * N_READS + i) < axis_size) {
|
||||||
|
out[i] = ld[i] * normalizer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||||
|
[[kernel]] void softmax_looped(
|
||||||
|
const device T* in,
|
||||||
|
device T* out,
|
||||||
|
constant int& axis_size,
|
||||||
|
threadgroup T* local_max [[threadgroup(0)]],
|
||||||
|
threadgroup T* local_normalizer [[threadgroup(1)]],
|
||||||
|
uint gid [[threadgroup_position_in_grid]],
|
||||||
|
uint lid [[thread_position_in_threadgroup]],
|
||||||
|
uint lsize [[threads_per_threadgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
in += gid * axis_size;
|
||||||
|
|
||||||
|
// Get the max and the normalizer in one go
|
||||||
|
T prevmax;
|
||||||
|
T maxval = Limits<T>::finite_min;
|
||||||
|
T normalizer = 0;
|
||||||
|
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||||
|
r++) {
|
||||||
|
int offset = r * lsize * N_READS + lid * N_READS;
|
||||||
|
T vals[N_READS];
|
||||||
|
if (offset + N_READS <= axis_size) {
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
vals[i] = in[offset + i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
vals[i] =
|
||||||
|
(offset + i < axis_size) ? in[offset + i] : T(Limits<T>::finite_min);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
prevmax = maxval;
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
maxval = (maxval < vals[i]) ? vals[i] : maxval;
|
||||||
|
}
|
||||||
|
normalizer *= softmax_exp(prevmax - maxval);
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
normalizer += softmax_exp(vals[i] - maxval);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
|
||||||
|
// lsize) parts. We need to combine them.
|
||||||
|
// 1. We start by finding the max across simd groups
|
||||||
|
// 2. We then change the partial normalizers to account for a possible
|
||||||
|
// change in max
|
||||||
|
// 3. We sum all normalizers
|
||||||
|
prevmax = maxval;
|
||||||
|
maxval = simd_max(maxval);
|
||||||
|
normalizer *= softmax_exp(prevmax - maxval);
|
||||||
|
normalizer = simd_sum(normalizer);
|
||||||
|
|
||||||
|
// Now the normalizer and max value is correct for each simdgroup. We write
|
||||||
|
// them shared memory and combine them.
|
||||||
|
prevmax = maxval;
|
||||||
|
if (simd_lane_id == 0) {
|
||||||
|
local_max[simd_group_id] = maxval;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
maxval = simd_max(local_max[simd_lane_id]);
|
||||||
|
normalizer *= softmax_exp(prevmax - maxval);
|
||||||
|
if (simd_lane_id == 0) {
|
||||||
|
local_normalizer[simd_group_id] = normalizer;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||||
|
normalizer = 1 / normalizer;
|
||||||
|
|
||||||
|
// Finally given the normalizer and max value we can directly write the
|
||||||
|
// softmax output
|
||||||
|
out += gid * axis_size;
|
||||||
|
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||||
|
r++) {
|
||||||
|
int offset = r * lsize * N_READS + lid * N_READS;
|
||||||
|
if (offset + N_READS <= axis_size) {
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
if (offset + i < axis_size) {
|
||||||
|
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_softmax_single_row(name, itype) \
|
||||||
|
template [[host_name("softmax_" #name)]] [[kernel]] void \
|
||||||
|
softmax_single_row<itype>( \
|
||||||
|
const device itype* in, \
|
||||||
|
device itype* out, \
|
||||||
|
constant int& axis_size, \
|
||||||
|
threadgroup itype* local_max [[threadgroup(0)]], \
|
||||||
|
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
||||||
|
uint gid [[thread_position_in_grid]], \
|
||||||
|
uint _lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_softmax_looped(name, itype) \
|
||||||
|
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
|
||||||
|
softmax_looped<itype>( \
|
||||||
|
const device itype* in, \
|
||||||
|
device itype* out, \
|
||||||
|
constant int& axis_size, \
|
||||||
|
threadgroup itype* local_max [[threadgroup(0)]], \
|
||||||
|
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
||||||
|
uint gid [[threadgroup_position_in_grid]], \
|
||||||
|
uint lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint lsize [[threads_per_threadgroup]], \
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_softmax(name, itype) \
|
||||||
|
instantiate_softmax_single_row(name, itype) \
|
||||||
|
instantiate_softmax_looped(name, itype)
|
||||||
|
|
||||||
|
instantiate_softmax(float32, float) instantiate_softmax(float16, half)
|
||||||
|
instantiate_softmax(bfloat16, bfloat16_t)
|
||||||
818
mlx/backend/metal/kernels/sort.metal
Normal file
818
mlx/backend/metal/kernels/sort.metal
Normal file
@@ -0,0 +1,818 @@
|
|||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
|
#define MLX_MTL_CONST static constant constexpr const
|
||||||
|
#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
|
||||||
|
|
||||||
|
using namespace metal;\
|
||||||
|
|
||||||
|
// Based on GPU merge sort algorithm at https://github.com/NVIDIA/cccl/tree/main/cub/cub
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Thread-level sort
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC void thread_swap(thread T& a, thread T& b) {
|
||||||
|
T w = a;
|
||||||
|
a = b;
|
||||||
|
b = w;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct LessThan {
|
||||||
|
static constexpr constant T init = Limits<T>::max;
|
||||||
|
|
||||||
|
METAL_FUNC bool operator()(T a, T b) {
|
||||||
|
return a < b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename val_t,
|
||||||
|
typename idx_t,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short N_PER_THREAD,
|
||||||
|
typename CompareOp>
|
||||||
|
struct ThreadSort {
|
||||||
|
static METAL_FUNC void sort(
|
||||||
|
thread val_t (&vals)[N_PER_THREAD],
|
||||||
|
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||||
|
|
||||||
|
CompareOp op;
|
||||||
|
|
||||||
|
MLX_MTL_LOOP_UNROLL
|
||||||
|
for(short i = 0; i < N_PER_THREAD; ++i) {
|
||||||
|
MLX_MTL_LOOP_UNROLL
|
||||||
|
for(short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
|
||||||
|
if(op(vals[j + 1], vals[j])) {
|
||||||
|
thread_swap(vals[j + 1], vals[j]);
|
||||||
|
thread_swap(idxs[j + 1], idxs[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Threadgroup-level sort
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename val_t,
|
||||||
|
typename idx_t,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short BLOCK_THREADS,
|
||||||
|
short N_PER_THREAD,
|
||||||
|
typename CompareOp>
|
||||||
|
struct BlockMergeSort {
|
||||||
|
using thread_sort_t = ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
|
||||||
|
static METAL_FUNC int merge_partition(
|
||||||
|
const threadgroup val_t* As,
|
||||||
|
const threadgroup val_t* Bs,
|
||||||
|
short A_sz,
|
||||||
|
short B_sz,
|
||||||
|
short sort_md) {
|
||||||
|
|
||||||
|
CompareOp op;
|
||||||
|
|
||||||
|
short A_st = max(0, sort_md - B_sz);
|
||||||
|
short A_ed = min(sort_md, A_sz);
|
||||||
|
|
||||||
|
while(A_st < A_ed) {
|
||||||
|
short md = A_st + (A_ed - A_st) / 2;
|
||||||
|
auto a = As[md];
|
||||||
|
auto b = Bs[sort_md - 1 - md];
|
||||||
|
|
||||||
|
if(op(b, a)) {
|
||||||
|
A_ed = md;
|
||||||
|
} else {
|
||||||
|
A_st = md + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return A_ed;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
static METAL_FUNC void merge_step(
|
||||||
|
const threadgroup val_t* As,
|
||||||
|
const threadgroup val_t* Bs,
|
||||||
|
const threadgroup idx_t* As_idx,
|
||||||
|
const threadgroup idx_t* Bs_idx,
|
||||||
|
short A_sz,
|
||||||
|
short B_sz,
|
||||||
|
thread val_t (&vals)[N_PER_THREAD],
|
||||||
|
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||||
|
|
||||||
|
CompareOp op;
|
||||||
|
short a_idx = 0;
|
||||||
|
short b_idx = 0;
|
||||||
|
|
||||||
|
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||||
|
auto a = As[a_idx];
|
||||||
|
auto b = Bs[b_idx];
|
||||||
|
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
|
||||||
|
|
||||||
|
vals[i] = pred ? b : a;
|
||||||
|
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
|
||||||
|
|
||||||
|
b_idx += short(pred);
|
||||||
|
a_idx += short(!pred);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
static METAL_FUNC void sort(
|
||||||
|
threadgroup val_t* tgp_vals [[threadgroup(0)]],
|
||||||
|
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
|
||||||
|
int size_sorted_axis,
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
|
// Get thread location
|
||||||
|
int idx = lid.x * N_PER_THREAD;
|
||||||
|
|
||||||
|
// Load from shared memory
|
||||||
|
thread val_t thread_vals[N_PER_THREAD];
|
||||||
|
thread idx_t thread_idxs[N_PER_THREAD];
|
||||||
|
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||||
|
thread_vals[i] = tgp_vals[idx + i];
|
||||||
|
if(ARG_SORT) {
|
||||||
|
thread_idxs[i] = tgp_idxs[idx + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per thread sort
|
||||||
|
if(idx < size_sorted_axis) {
|
||||||
|
thread_sort_t::sort(thread_vals, thread_idxs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do merges using threadgroup memory
|
||||||
|
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; merge_threads *= 2) {
|
||||||
|
// Update threadgroup memory
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||||
|
tgp_vals[idx + i] = thread_vals[i];
|
||||||
|
if(ARG_SORT) {
|
||||||
|
tgp_idxs[idx + i] = thread_idxs[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Find location in merge step
|
||||||
|
int merge_group = lid.x / merge_threads;
|
||||||
|
int merge_lane = lid.x % merge_threads;
|
||||||
|
|
||||||
|
int sort_sz = N_PER_THREAD * merge_threads;
|
||||||
|
int sort_st = N_PER_THREAD * merge_threads * merge_group;
|
||||||
|
|
||||||
|
// As = tgp_vals[A_st:A_ed] is sorted
|
||||||
|
// Bs = tgp_vals[B_st:B_ed] is sorted
|
||||||
|
int A_st = sort_st;
|
||||||
|
int A_ed = sort_st + sort_sz / 2;
|
||||||
|
int B_st = sort_st + sort_sz / 2;
|
||||||
|
int B_ed = sort_st + sort_sz;
|
||||||
|
|
||||||
|
const threadgroup val_t* As = tgp_vals + A_st;
|
||||||
|
const threadgroup val_t* Bs = tgp_vals + B_st;
|
||||||
|
int A_sz = A_ed - A_st;
|
||||||
|
int B_sz = B_ed - B_st;
|
||||||
|
|
||||||
|
// Find a partition of merge elements
|
||||||
|
// Ci = merge(As[partition:], Bs[sort_md - partition:])
|
||||||
|
// of size N_PER_THREAD for each merge lane i
|
||||||
|
// C = [Ci] is sorted
|
||||||
|
int sort_md = N_PER_THREAD * merge_lane;
|
||||||
|
int partition = merge_partition(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
A_sz,
|
||||||
|
B_sz,
|
||||||
|
sort_md);
|
||||||
|
|
||||||
|
As += partition;
|
||||||
|
Bs += sort_md - partition;
|
||||||
|
|
||||||
|
A_sz -= partition;
|
||||||
|
B_sz -= sort_md - partition;
|
||||||
|
|
||||||
|
const threadgroup idx_t* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
|
||||||
|
const threadgroup idx_t* Bs_idx = ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
|
||||||
|
|
||||||
|
// Merge starting at the partition and store results in thread registers
|
||||||
|
merge_step(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
As_idx,
|
||||||
|
Bs_idx,
|
||||||
|
A_sz,
|
||||||
|
B_sz,
|
||||||
|
thread_vals,
|
||||||
|
thread_idxs);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write out to shared memory
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||||
|
tgp_vals[idx + i] = thread_vals[i];
|
||||||
|
if(ARG_SORT) {
|
||||||
|
tgp_idxs[idx + i] = thread_idxs[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Kernel sort
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short BLOCK_THREADS,
|
||||||
|
short N_PER_THREAD,
|
||||||
|
typename CompareOp = LessThan<T>>
|
||||||
|
struct KernelMergeSort {
|
||||||
|
using val_t = T;
|
||||||
|
using idx_t = uint;
|
||||||
|
using block_merge_sort_t = BlockMergeSort<
|
||||||
|
val_t,
|
||||||
|
idx_t,
|
||||||
|
ARG_SORT,
|
||||||
|
BLOCK_THREADS,
|
||||||
|
N_PER_THREAD,
|
||||||
|
CompareOp>;
|
||||||
|
|
||||||
|
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
|
||||||
|
|
||||||
|
static METAL_FUNC void block_sort(
|
||||||
|
const device T* inp,
|
||||||
|
device U* out,
|
||||||
|
const constant int& size_sorted_axis,
|
||||||
|
const constant int& stride_sorted_axis,
|
||||||
|
const constant int& stride_segment_axis,
|
||||||
|
threadgroup val_t* tgp_vals,
|
||||||
|
threadgroup idx_t* tgp_idxs,
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
|
// tid.y tells us the segment index
|
||||||
|
inp += tid.y * stride_segment_axis;
|
||||||
|
out += tid.y * stride_segment_axis;
|
||||||
|
|
||||||
|
// Copy into threadgroup memory
|
||||||
|
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
|
||||||
|
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] : val_t(CompareOp::init);
|
||||||
|
if(ARG_SORT) {
|
||||||
|
tgp_idxs[i] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort elements within the block
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Write output
|
||||||
|
for(int i = lid.x; i < size_sorted_axis; i+= BLOCK_THREADS) {
|
||||||
|
if(ARG_SORT) {
|
||||||
|
out[i * stride_sorted_axis] = tgp_idxs[i];
|
||||||
|
} else {
|
||||||
|
out[i * stride_sorted_axis] = tgp_vals[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short BLOCK_THREADS,
|
||||||
|
short N_PER_THREAD>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort(
|
||||||
|
const device T* inp [[buffer(0)]],
|
||||||
|
device U* out [[buffer(1)]],
|
||||||
|
const constant int& size_sorted_axis [[buffer(2)]],
|
||||||
|
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||||
|
const constant int& stride_segment_axis [[buffer(4)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
|
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||||
|
using val_t = typename sort_kernel::val_t;
|
||||||
|
using idx_t = typename sort_kernel::idx_t;
|
||||||
|
|
||||||
|
if(ARG_SORT) {
|
||||||
|
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||||
|
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||||
|
sort_kernel::block_sort(
|
||||||
|
inp,
|
||||||
|
out,
|
||||||
|
size_sorted_axis,
|
||||||
|
stride_sorted_axis,
|
||||||
|
stride_segment_axis,
|
||||||
|
tgp_vals,
|
||||||
|
tgp_idxs,
|
||||||
|
tid,
|
||||||
|
lid);
|
||||||
|
} else {
|
||||||
|
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||||
|
sort_kernel::block_sort(
|
||||||
|
inp,
|
||||||
|
out,
|
||||||
|
size_sorted_axis,
|
||||||
|
stride_sorted_axis,
|
||||||
|
stride_segment_axis,
|
||||||
|
tgp_vals,
|
||||||
|
nullptr,
|
||||||
|
tid,
|
||||||
|
lid);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
constant constexpr const int zero_helper = 0;
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short BLOCK_THREADS,
|
||||||
|
short N_PER_THREAD>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc(
|
||||||
|
const device T* inp [[buffer(0)]],
|
||||||
|
device U* out [[buffer(1)]],
|
||||||
|
const constant int& size_sorted_axis [[buffer(2)]],
|
||||||
|
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||||
|
const constant int& nc_dim [[buffer(4)]],
|
||||||
|
const device int* nc_shape [[buffer(5)]],
|
||||||
|
const device size_t* nc_strides [[buffer(6)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
|
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||||
|
using val_t = typename sort_kernel::val_t;
|
||||||
|
using idx_t = typename sort_kernel::idx_t;
|
||||||
|
|
||||||
|
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
||||||
|
inp += block_idx;
|
||||||
|
out += block_idx;
|
||||||
|
|
||||||
|
if(ARG_SORT) {
|
||||||
|
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||||
|
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||||
|
sort_kernel::block_sort(
|
||||||
|
inp,
|
||||||
|
out,
|
||||||
|
size_sorted_axis,
|
||||||
|
stride_sorted_axis,
|
||||||
|
zero_helper,
|
||||||
|
tgp_vals,
|
||||||
|
tgp_idxs,
|
||||||
|
tid,
|
||||||
|
lid);
|
||||||
|
} else {
|
||||||
|
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||||
|
sort_kernel::block_sort(
|
||||||
|
inp,
|
||||||
|
out,
|
||||||
|
size_sorted_axis,
|
||||||
|
stride_sorted_axis,
|
||||||
|
zero_helper,
|
||||||
|
tgp_vals,
|
||||||
|
nullptr,
|
||||||
|
tid,
|
||||||
|
lid);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Instantiations
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
|
#define instantiate_block_sort(name, itname, itype, otname, otype, arg_sort, bn, tn) \
|
||||||
|
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn)]] \
|
||||||
|
[[kernel]] void block_sort<itype, otype, arg_sort, bn, tn>( \
|
||||||
|
const device itype* inp [[buffer(0)]], \
|
||||||
|
device otype* out [[buffer(1)]], \
|
||||||
|
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||||
|
const constant int& stride_sorted_axis [[buffer(3)]], \
|
||||||
|
const constant int& stride_segment_axis [[buffer(4)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||||
|
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn "_nc")]] \
|
||||||
|
[[kernel]] void block_sort_nc<itype, otype, arg_sort, bn, tn>( \
|
||||||
|
const device itype* inp [[buffer(0)]], \
|
||||||
|
device otype* out [[buffer(1)]], \
|
||||||
|
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||||
|
const constant int& stride_sorted_axis [[buffer(3)]], \
|
||||||
|
const constant int& nc_dim [[buffer(4)]], \
|
||||||
|
const device int* nc_shape [[buffer(5)]], \
|
||||||
|
const device size_t* nc_strides [[buffer(6)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
|
||||||
|
instantiate_block_sort(arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
|
||||||
|
|
||||||
|
#define instantiate_block_sort_base(itname, itype, bn, tn) \
|
||||||
|
instantiate_block_sort(block_merge_sort, itname, itype, itname, itype, false, bn, tn)
|
||||||
|
|
||||||
|
#define instantiate_block_sort_tn(itname, itype, bn) \
|
||||||
|
instantiate_block_sort_base(itname, itype, bn, 8) \
|
||||||
|
instantiate_arg_block_sort_base(itname, itype, bn, 8)
|
||||||
|
|
||||||
|
#define instantiate_block_sort_bn(itname, itype) \
|
||||||
|
instantiate_block_sort_tn(itname, itype, 128) \
|
||||||
|
instantiate_block_sort_tn(itname, itype, 256) \
|
||||||
|
instantiate_block_sort_tn(itname, itype, 512)
|
||||||
|
|
||||||
|
instantiate_block_sort_bn(uint8, uint8_t)
|
||||||
|
instantiate_block_sort_bn(uint16, uint16_t)
|
||||||
|
instantiate_block_sort_bn(uint32, uint32_t)
|
||||||
|
instantiate_block_sort_bn(int8, int8_t)
|
||||||
|
instantiate_block_sort_bn(int16, int16_t)
|
||||||
|
instantiate_block_sort_bn(int32, int32_t)
|
||||||
|
instantiate_block_sort_bn(float16, half)
|
||||||
|
instantiate_block_sort_bn(float32, float)
|
||||||
|
instantiate_block_sort_bn(bfloat16, bfloat16_t)
|
||||||
|
|
||||||
|
#define instantiate_block_sort_long(itname, itype) \
|
||||||
|
instantiate_block_sort_tn(itname, itype, 128) \
|
||||||
|
instantiate_block_sort_tn(itname, itype, 256)
|
||||||
|
|
||||||
|
instantiate_block_sort_long(uint64, uint64_t)
|
||||||
|
instantiate_block_sort_long(int64, int64_t)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Multi block merge sort
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename val_t,
|
||||||
|
typename idx_t,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short BLOCK_THREADS,
|
||||||
|
short N_PER_THREAD,
|
||||||
|
typename CompareOp = LessThan<val_t>>
|
||||||
|
struct KernelMultiBlockMergeSort {
|
||||||
|
using block_merge_sort_t = BlockMergeSort<
|
||||||
|
val_t,
|
||||||
|
idx_t,
|
||||||
|
ARG_SORT,
|
||||||
|
BLOCK_THREADS,
|
||||||
|
N_PER_THREAD,
|
||||||
|
CompareOp>;
|
||||||
|
|
||||||
|
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
|
||||||
|
|
||||||
|
static METAL_FUNC void block_sort(
|
||||||
|
const device val_t* inp,
|
||||||
|
device val_t* out_vals,
|
||||||
|
device idx_t* out_idxs,
|
||||||
|
const constant int& size_sorted_axis,
|
||||||
|
const constant int& stride_sorted_axis,
|
||||||
|
threadgroup val_t* tgp_vals,
|
||||||
|
threadgroup idx_t* tgp_idxs,
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
|
// tid.y tells us the segment index
|
||||||
|
int base_idx = tid.x * N_PER_BLOCK;
|
||||||
|
|
||||||
|
// Copy into threadgroup memory
|
||||||
|
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
|
||||||
|
int idx = base_idx + i;
|
||||||
|
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] : val_t(CompareOp::init);
|
||||||
|
tgp_idxs[i] = idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort elements within the block
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Write output
|
||||||
|
for(int i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
|
||||||
|
int idx = base_idx + i;
|
||||||
|
if(idx < size_sorted_axis) {
|
||||||
|
out_vals[idx] = tgp_vals[i];
|
||||||
|
out_idxs[idx] = tgp_idxs[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static METAL_FUNC int merge_partition(
|
||||||
|
const device val_t* As,
|
||||||
|
const device val_t* Bs,
|
||||||
|
int A_sz,
|
||||||
|
int B_sz,
|
||||||
|
int sort_md) {
|
||||||
|
|
||||||
|
CompareOp op;
|
||||||
|
|
||||||
|
int A_st = max(0, sort_md - B_sz);
|
||||||
|
int A_ed = min(sort_md, A_sz);
|
||||||
|
|
||||||
|
while(A_st < A_ed) {
|
||||||
|
int md = A_st + (A_ed - A_st) / 2;
|
||||||
|
auto a = As[md];
|
||||||
|
auto b = Bs[sort_md - 1 - md];
|
||||||
|
|
||||||
|
if(op(b, a)) {
|
||||||
|
A_ed = md;
|
||||||
|
} else {
|
||||||
|
A_st = md + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return A_ed;
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename val_t,
|
||||||
|
typename idx_t,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short BLOCK_THREADS,
|
||||||
|
short N_PER_THREAD>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort(
|
||||||
|
const device val_t* inp [[buffer(0)]],
|
||||||
|
device val_t* out_vals [[buffer(1)]],
|
||||||
|
device idx_t* out_idxs [[buffer(2)]],
|
||||||
|
const constant int& size_sorted_axis [[buffer(3)]],
|
||||||
|
const constant int& stride_sorted_axis [[buffer(4)]],
|
||||||
|
const constant int& nc_dim [[buffer(5)]],
|
||||||
|
const device int* nc_shape [[buffer(6)]],
|
||||||
|
const device size_t* nc_strides [[buffer(7)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
|
using sort_kernel = KernelMultiBlockMergeSort<val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||||
|
|
||||||
|
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
||||||
|
inp += block_idx;
|
||||||
|
out_vals += tid.y * size_sorted_axis;
|
||||||
|
out_idxs += tid.y * size_sorted_axis;
|
||||||
|
|
||||||
|
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||||
|
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||||
|
|
||||||
|
sort_kernel::block_sort(
|
||||||
|
inp,
|
||||||
|
out_vals,
|
||||||
|
out_idxs,
|
||||||
|
size_sorted_axis,
|
||||||
|
stride_sorted_axis,
|
||||||
|
tgp_vals,
|
||||||
|
tgp_idxs,
|
||||||
|
tid,
|
||||||
|
lid);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename val_t,
|
||||||
|
typename idx_t,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short BLOCK_THREADS,
|
||||||
|
short N_PER_THREAD>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partiton(
|
||||||
|
device idx_t* block_partitions [[buffer(0)]],
|
||||||
|
const device val_t* dev_vals [[buffer(1)]],
|
||||||
|
const device idx_t* dev_idxs [[buffer(2)]],
|
||||||
|
const constant int& size_sorted_axis [[buffer(3)]],
|
||||||
|
const constant int& merge_tiles [[buffer(4)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint3 tgp_dims [[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
|
using sort_kernel = KernelMultiBlockMergeSort<
|
||||||
|
val_t,
|
||||||
|
idx_t,
|
||||||
|
ARG_SORT,
|
||||||
|
BLOCK_THREADS,
|
||||||
|
N_PER_THREAD>;
|
||||||
|
|
||||||
|
block_partitions += tid.y * tgp_dims.x;
|
||||||
|
dev_vals += tid.y * size_sorted_axis;
|
||||||
|
dev_idxs += tid.y * size_sorted_axis;
|
||||||
|
|
||||||
|
// Find location in merge step
|
||||||
|
int merge_group = lid.x / merge_tiles;
|
||||||
|
int merge_lane = lid.x % merge_tiles;
|
||||||
|
|
||||||
|
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||||
|
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||||
|
|
||||||
|
int A_st = min(size_sorted_axis, sort_st);
|
||||||
|
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
||||||
|
int B_st = A_ed;
|
||||||
|
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
|
||||||
|
|
||||||
|
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
|
||||||
|
int partition = sort_kernel::merge_partition(
|
||||||
|
dev_vals + A_st,
|
||||||
|
dev_vals + B_st,
|
||||||
|
A_ed - A_st,
|
||||||
|
B_ed - B_st,
|
||||||
|
partition_at);
|
||||||
|
|
||||||
|
block_partitions[lid.x] = A_st + partition;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename val_t,
|
||||||
|
typename idx_t,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short BLOCK_THREADS,
|
||||||
|
short N_PER_THREAD,
|
||||||
|
typename CompareOp = LessThan<val_t>>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_merge(
|
||||||
|
const device idx_t* block_partitions [[buffer(0)]],
|
||||||
|
const device val_t* dev_vals_in [[buffer(1)]],
|
||||||
|
const device idx_t* dev_idxs_in [[buffer(2)]],
|
||||||
|
device val_t* dev_vals_out [[buffer(3)]],
|
||||||
|
device idx_t* dev_idxs_out [[buffer(4)]],
|
||||||
|
const constant int& size_sorted_axis [[buffer(5)]],
|
||||||
|
const constant int& merge_tiles [[buffer(6)]],
|
||||||
|
const constant int& num_tiles [[buffer(7)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
|
using sort_kernel = KernelMultiBlockMergeSort<
|
||||||
|
val_t,
|
||||||
|
idx_t,
|
||||||
|
ARG_SORT,
|
||||||
|
BLOCK_THREADS,
|
||||||
|
N_PER_THREAD,
|
||||||
|
CompareOp>;
|
||||||
|
|
||||||
|
using block_sort_t = typename sort_kernel::block_merge_sort_t;
|
||||||
|
|
||||||
|
block_partitions += tid.y * (num_tiles + 1);
|
||||||
|
dev_vals_in += tid.y * size_sorted_axis;
|
||||||
|
dev_idxs_in += tid.y * size_sorted_axis;
|
||||||
|
dev_vals_out += tid.y * size_sorted_axis;
|
||||||
|
dev_idxs_out += tid.y * size_sorted_axis;
|
||||||
|
|
||||||
|
int block_idx = tid.x;
|
||||||
|
int merge_group = block_idx / merge_tiles;
|
||||||
|
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||||
|
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||||
|
int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
|
||||||
|
|
||||||
|
int A_st = block_partitions[block_idx + 0];
|
||||||
|
int A_ed = block_partitions[block_idx + 1];
|
||||||
|
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md - A_st);
|
||||||
|
int B_ed = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
|
||||||
|
|
||||||
|
if((block_idx % merge_tiles) == merge_tiles - 1) {
|
||||||
|
A_ed = min(size_sorted_axis, sort_st + sort_sz/2);
|
||||||
|
B_ed = min(size_sorted_axis, sort_st + sort_sz);
|
||||||
|
}
|
||||||
|
|
||||||
|
int A_sz = A_ed - A_st;
|
||||||
|
int B_sz = B_ed - B_st;
|
||||||
|
|
||||||
|
// Load from global memory
|
||||||
|
thread val_t thread_vals[N_PER_THREAD];
|
||||||
|
thread idx_t thread_idxs[N_PER_THREAD];
|
||||||
|
for(int i = 0; i < N_PER_THREAD; i++) {
|
||||||
|
int idx = BLOCK_THREADS * i + lid.x;
|
||||||
|
if(idx < (A_sz + B_sz)) {
|
||||||
|
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] : dev_vals_in[B_st + idx - A_sz];
|
||||||
|
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] : dev_idxs_in[B_st + idx - A_sz];
|
||||||
|
} else {
|
||||||
|
thread_vals[i] = CompareOp::init;
|
||||||
|
thread_idxs[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write to shared memory
|
||||||
|
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||||
|
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
for(int i = 0; i < N_PER_THREAD; i++) {
|
||||||
|
int idx = BLOCK_THREADS * i + lid.x;
|
||||||
|
tgp_vals[idx] = thread_vals[i];
|
||||||
|
tgp_idxs[idx] = thread_idxs[i];
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Merge
|
||||||
|
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
|
||||||
|
|
||||||
|
int A_st_local = block_sort_t::merge_partition(
|
||||||
|
tgp_vals,
|
||||||
|
tgp_vals + A_sz,
|
||||||
|
A_sz,
|
||||||
|
B_sz,
|
||||||
|
sort_md_local);
|
||||||
|
int A_ed_local = A_sz;
|
||||||
|
|
||||||
|
int B_st_local = sort_md_local - A_st_local;
|
||||||
|
int B_ed_local = B_sz;
|
||||||
|
|
||||||
|
int A_sz_local = A_ed_local - A_st_local;
|
||||||
|
int B_sz_local = B_ed_local - B_st_local;
|
||||||
|
|
||||||
|
// Do merge
|
||||||
|
block_sort_t::merge_step(
|
||||||
|
tgp_vals + A_st_local,
|
||||||
|
tgp_vals + A_ed_local + B_st_local,
|
||||||
|
tgp_idxs + A_st_local,
|
||||||
|
tgp_idxs + A_ed_local + B_st_local,
|
||||||
|
A_sz_local,
|
||||||
|
B_sz_local,
|
||||||
|
thread_vals,
|
||||||
|
thread_idxs);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
for(int i = 0; i < N_PER_THREAD; ++i) {
|
||||||
|
int idx = lid.x * N_PER_THREAD;
|
||||||
|
tgp_vals[idx + i] = thread_vals[i];
|
||||||
|
tgp_idxs[idx + i] = thread_idxs[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Write output
|
||||||
|
int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
|
||||||
|
for(int i = lid.x; i < sort_kernel::N_PER_BLOCK; i+= BLOCK_THREADS) {
|
||||||
|
int idx = base_idx + i;
|
||||||
|
if(idx < size_sorted_axis) {
|
||||||
|
dev_vals_out[idx] = tgp_vals[i];
|
||||||
|
dev_idxs_out[idx] = tgp_idxs[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_multi_block_sort(vtname, vtype, itname, itype, arg_sort, bn, tn) \
|
||||||
|
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||||
|
[[kernel]] void mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
|
||||||
|
const device vtype* inp [[buffer(0)]], \
|
||||||
|
device vtype* out_vals [[buffer(1)]], \
|
||||||
|
device itype* out_idxs [[buffer(2)]], \
|
||||||
|
const constant int& size_sorted_axis [[buffer(3)]], \
|
||||||
|
const constant int& stride_sorted_axis [[buffer(4)]], \
|
||||||
|
const constant int& nc_dim [[buffer(5)]], \
|
||||||
|
const device int* nc_shape [[buffer(6)]], \
|
||||||
|
const device size_t* nc_strides [[buffer(7)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||||
|
template [[host_name("mb_block_partiton_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||||
|
[[kernel]] void mb_block_partiton<vtype, itype, arg_sort, bn, tn>( \
|
||||||
|
device itype* block_partitions [[buffer(0)]], \
|
||||||
|
const device vtype* dev_vals [[buffer(1)]], \
|
||||||
|
const device itype* dev_idxs [[buffer(2)]], \
|
||||||
|
const constant int& size_sorted_axis [[buffer(3)]], \
|
||||||
|
const constant int& merge_tiles [[buffer(4)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint3 tgp_dims [[threads_per_threadgroup]]); \
|
||||||
|
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||||
|
[[kernel]] void mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
|
||||||
|
const device itype* block_partitions [[buffer(0)]], \
|
||||||
|
const device vtype* dev_vals_in [[buffer(1)]], \
|
||||||
|
const device itype* dev_idxs_in [[buffer(2)]], \
|
||||||
|
device vtype* dev_vals_out [[buffer(3)]], \
|
||||||
|
device itype* dev_idxs_out [[buffer(4)]], \
|
||||||
|
const constant int& size_sorted_axis [[buffer(5)]], \
|
||||||
|
const constant int& merge_tiles [[buffer(6)]], \
|
||||||
|
const constant int& num_tiles [[buffer(7)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_multi_block_sort_base(vtname, vtype) \
|
||||||
|
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
|
||||||
|
|
||||||
|
instantiate_multi_block_sort_base(uint8, uint8_t)
|
||||||
|
instantiate_multi_block_sort_base(uint16, uint16_t)
|
||||||
|
instantiate_multi_block_sort_base(uint32, uint32_t)
|
||||||
|
instantiate_multi_block_sort_base(int8, int8_t)
|
||||||
|
instantiate_multi_block_sort_base(int16, int16_t)
|
||||||
|
instantiate_multi_block_sort_base(int32, int32_t)
|
||||||
|
instantiate_multi_block_sort_base(float16, half)
|
||||||
|
instantiate_multi_block_sort_base(float32, float)
|
||||||
|
instantiate_multi_block_sort_base(bfloat16, bfloat16_t)
|
||||||
|
|
||||||
|
#define instantiate_multi_block_sort_long(vtname, vtype) \
|
||||||
|
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8)
|
||||||
|
|
||||||
|
instantiate_multi_block_sort_long(uint64, uint64_t)
|
||||||
|
instantiate_multi_block_sort_long(int64, int64_t)
|
||||||
244
mlx/backend/metal/kernels/utils.h
Normal file
244
mlx/backend/metal/kernels/utils.h
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <metal_math>
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/complex.h"
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Type limits utils
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
struct Limits {
|
||||||
|
static const constant U max;
|
||||||
|
static const constant U min;
|
||||||
|
static const constant U finite_max;
|
||||||
|
static const constant U finite_min;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define instantiate_default_limit(type) \
|
||||||
|
template <> \
|
||||||
|
struct Limits<type> { \
|
||||||
|
static constexpr constant type max = metal::numeric_limits<type>::max(); \
|
||||||
|
static constexpr constant type min = metal::numeric_limits<type>::min(); \
|
||||||
|
static constexpr constant type finite_max = \
|
||||||
|
metal::numeric_limits<type>::max(); \
|
||||||
|
static constexpr constant type finite_min = \
|
||||||
|
metal::numeric_limits<type>::min(); \
|
||||||
|
};
|
||||||
|
|
||||||
|
instantiate_default_limit(uint8_t);
|
||||||
|
instantiate_default_limit(uint16_t);
|
||||||
|
instantiate_default_limit(uint32_t);
|
||||||
|
instantiate_default_limit(uint64_t);
|
||||||
|
instantiate_default_limit(int8_t);
|
||||||
|
instantiate_default_limit(int16_t);
|
||||||
|
instantiate_default_limit(int32_t);
|
||||||
|
instantiate_default_limit(int64_t);
|
||||||
|
|
||||||
|
#define instantiate_float_limit(type) \
|
||||||
|
template <> \
|
||||||
|
struct Limits<type> { \
|
||||||
|
static constexpr constant type max = \
|
||||||
|
metal::numeric_limits<type>::infinity(); \
|
||||||
|
static constexpr constant type min = \
|
||||||
|
-metal::numeric_limits<type>::infinity(); \
|
||||||
|
static constexpr constant type finite_max = \
|
||||||
|
metal::numeric_limits<type>::max(); \
|
||||||
|
static constexpr constant type finite_min = \
|
||||||
|
-metal::numeric_limits<type>::max(); \
|
||||||
|
};
|
||||||
|
|
||||||
|
instantiate_float_limit(half);
|
||||||
|
instantiate_float_limit(float);
|
||||||
|
instantiate_float_limit(bfloat16_t);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct Limits<bool> {
|
||||||
|
static constexpr constant bool max = true;
|
||||||
|
static constexpr constant bool min = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Indexing utils
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
inline size_t elem_to_loc(
|
||||||
|
uint elem,
|
||||||
|
device const int* shape,
|
||||||
|
device const size_t* strides,
|
||||||
|
int ndim) {
|
||||||
|
size_t loc = 0;
|
||||||
|
for (int i = ndim - 1; i >= 0; --i) {
|
||||||
|
loc += (elem % shape[i]) * strides[i];
|
||||||
|
elem /= shape[i];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline size_t elem_to_loc(
|
||||||
|
uint elem,
|
||||||
|
constant const int* shape,
|
||||||
|
constant const size_t* strides,
|
||||||
|
int ndim) {
|
||||||
|
size_t loc = 0;
|
||||||
|
for (int i = ndim - 1; i >= 0; --i) {
|
||||||
|
loc += (elem % shape[i]) * strides[i];
|
||||||
|
elem /= shape[i];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
inline uint2 elem_to_loc_2_nd(
|
||||||
|
uint3 elem,
|
||||||
|
constant const int shape[NDIM],
|
||||||
|
constant const size_t a_strides[NDIM],
|
||||||
|
constant const size_t b_strides[NDIM]) {
|
||||||
|
uint2 loc = {
|
||||||
|
static_cast<uint>(
|
||||||
|
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||||
|
static_cast<uint>(
|
||||||
|
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
||||||
|
for (int d = NDIM - 3; d >= 0; --d) {
|
||||||
|
uint l = elem.z % shape[d];
|
||||||
|
loc.x += l * a_strides[d];
|
||||||
|
loc.y += l * b_strides[d];
|
||||||
|
elem.z /= shape[d];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
inline size_t elem_to_loc_nd(
|
||||||
|
uint3 elem,
|
||||||
|
constant const int shape[NDIM],
|
||||||
|
constant const size_t strides[NDIM]) {
|
||||||
|
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
||||||
|
for (int d = NDIM - 3; d >= 0; --d) {
|
||||||
|
loc += (elem.z % shape[d]) * strides[d];
|
||||||
|
elem.z /= shape[d];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
|
||||||
|
return elem * stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
|
||||||
|
return elem.x * strides[1] + elem.y * strides[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
|
||||||
|
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non templated version to handle arbitrary dims
|
||||||
|
inline size_t elem_to_loc(
|
||||||
|
uint3 elem,
|
||||||
|
constant const int* shape,
|
||||||
|
constant const size_t* strides,
|
||||||
|
int ndim) {
|
||||||
|
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
||||||
|
for (int d = ndim - 3; d >= 0; --d) {
|
||||||
|
loc += (elem.z % shape[d]) * strides[d];
|
||||||
|
elem.z /= shape[d];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint2 elem_to_loc_2_nd(
|
||||||
|
uint3 elem,
|
||||||
|
constant const int* shape,
|
||||||
|
constant const size_t* a_strides,
|
||||||
|
constant const size_t* b_strides,
|
||||||
|
int ndim) {
|
||||||
|
uint2 loc = {
|
||||||
|
static_cast<uint>(
|
||||||
|
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||||
|
static_cast<uint>(
|
||||||
|
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
||||||
|
for (int d = ndim - 3; d >= 0; --d) {
|
||||||
|
uint l = elem.z % shape[d];
|
||||||
|
loc.x += l * a_strides[d];
|
||||||
|
loc.y += l * b_strides[d];
|
||||||
|
elem.z /= shape[d];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
inline uint elem_to_loc_nd(
|
||||||
|
uint elem,
|
||||||
|
device const int* shape,
|
||||||
|
device const size_t* strides);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline uint elem_to_loc_nd<1>(
|
||||||
|
uint elem,
|
||||||
|
device const int* shape,
|
||||||
|
device const size_t* strides) {
|
||||||
|
return (elem % shape[0]) * strides[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline uint elem_to_loc_nd<2>(
|
||||||
|
uint elem,
|
||||||
|
device const int* shape,
|
||||||
|
device const size_t* strides) {
|
||||||
|
uint loc = (elem % shape[1]) * strides[1];
|
||||||
|
elem /= shape[1];
|
||||||
|
loc += (elem % shape[0]) * strides[0];
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline uint elem_to_loc_nd<3>(
|
||||||
|
uint elem,
|
||||||
|
device const int* shape,
|
||||||
|
device const size_t* strides) {
|
||||||
|
uint loc = (elem % shape[2]) * strides[2];
|
||||||
|
elem /= shape[2];
|
||||||
|
loc += (elem % shape[1]) * strides[1];
|
||||||
|
elem /= shape[1];
|
||||||
|
loc += (elem % shape[0]) * strides[0];
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline uint elem_to_loc_nd<4>(
|
||||||
|
uint elem,
|
||||||
|
device const int* shape,
|
||||||
|
device const size_t* strides) {
|
||||||
|
uint loc = (elem % shape[3]) * strides[3];
|
||||||
|
elem /= shape[3];
|
||||||
|
loc += (elem % shape[2]) * strides[2];
|
||||||
|
elem /= shape[2];
|
||||||
|
loc += (elem % shape[1]) * strides[1];
|
||||||
|
elem /= shape[1];
|
||||||
|
loc += (elem % shape[0]) * strides[0];
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Calculation utils
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/** Compute ceil((float)N/(float)M) */
|
||||||
|
inline size_t ceildiv(size_t N, size_t M) {
|
||||||
|
return (N + M - 1) / M;
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
||||||
|
inline float log1p(float x) {
|
||||||
|
float xp1 = 1.0f + x;
|
||||||
|
return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bfloat16_t log1p(bfloat16_t x) {
|
||||||
|
float xp1 = 1.0f + static_cast<float>(x);
|
||||||
|
bfloat16_t ret =
|
||||||
|
(xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
446
mlx/backend/metal/matmul.cpp
Normal file
446
mlx/backend/metal/matmul.cpp
Normal file
@@ -0,0 +1,446 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <numeric>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/copy.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
#include "mlx/backend/metal/matmul.h"
|
||||||
|
#include "mlx/backend/metal/mps/gemm.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
bool use_mps() {
|
||||||
|
auto get_val = []() {
|
||||||
|
if (const char* buff_str = std::getenv("MLX_USE_MPS")) {
|
||||||
|
return std::string(buff_str) != "OFF";
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
static bool use_mps_ = get_val();
|
||||||
|
return use_mps_;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
||||||
|
|
||||||
|
inline void mps_matmul(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
int batch_size_out,
|
||||||
|
int lda,
|
||||||
|
int ldb,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
std::vector<array>& copies) {
|
||||||
|
MPS::DataType mps_dtype = MPS::DataTypeFloat32;
|
||||||
|
|
||||||
|
if (out.dtype() == float16) {
|
||||||
|
mps_dtype = MPS::DataTypeFloat16;
|
||||||
|
} else if (out.dtype() == bfloat16) {
|
||||||
|
mps_dtype = MPS::DataTypeBFloat16;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used batched MPSMatrixMultiplication if batch_size_out > 1
|
||||||
|
// We only accept the following cases:
|
||||||
|
// 1. Both a, b have batch_size_out matrices worth of data
|
||||||
|
// 2. Only one of a or b has batch_size_out matrices worth of data and
|
||||||
|
// the other has matrix worth of data
|
||||||
|
|
||||||
|
// The matrix dimsenisons of a and b are sure to be regularly strided
|
||||||
|
if (batch_size_out > 1) {
|
||||||
|
// No broadcasting defaults
|
||||||
|
auto batch_size_a = a.data_size() / (M * K);
|
||||||
|
auto batch_size_b = b.data_size() / (K * N);
|
||||||
|
|
||||||
|
auto matrix_stride_a = M * K;
|
||||||
|
auto matrix_stride_b = K * N;
|
||||||
|
auto matrix_stride_out = M * N;
|
||||||
|
|
||||||
|
// At this point, batch_size_a, batch_size_b show the number of matrices
|
||||||
|
// in data, no broadcasted strides considered
|
||||||
|
if (batch_size_out == std::max(batch_size_a, batch_size_b)) {
|
||||||
|
// Handle simple broadcasting
|
||||||
|
if (std::min(batch_size_a, batch_size_b) == 1) {
|
||||||
|
matrix_stride_a = (batch_size_a == 1) ? 0 : matrix_stride_a;
|
||||||
|
matrix_stride_b = (batch_size_b == 1) ? 0 : matrix_stride_b;
|
||||||
|
|
||||||
|
batch_size_a = batch_size_out;
|
||||||
|
batch_size_b = batch_size_out;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only proceed if broadcasting between a and b is simple
|
||||||
|
// At this point, batch_size_a, batch_size_b show the number of matrices
|
||||||
|
// after broadcasting
|
||||||
|
if (batch_size_a == batch_size_b) {
|
||||||
|
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||||
|
(M * K) / lda,
|
||||||
|
lda,
|
||||||
|
batch_size_a,
|
||||||
|
lda * a.itemsize(),
|
||||||
|
(matrix_stride_a * a.itemsize()),
|
||||||
|
mps_dtype);
|
||||||
|
|
||||||
|
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||||
|
(K * N) / ldb,
|
||||||
|
ldb,
|
||||||
|
batch_size_b,
|
||||||
|
ldb * b.itemsize(),
|
||||||
|
(matrix_stride_b * b.itemsize()),
|
||||||
|
mps_dtype);
|
||||||
|
|
||||||
|
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
batch_size_out,
|
||||||
|
N * out.itemsize(),
|
||||||
|
matrix_stride_out * out.itemsize(),
|
||||||
|
mps_dtype);
|
||||||
|
|
||||||
|
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||||
|
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
|
||||||
|
|
||||||
|
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||||
|
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
|
||||||
|
|
||||||
|
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
|
||||||
|
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
|
||||||
|
|
||||||
|
auto kernel = MPS::MatrixMultiplication::alloc()->init(
|
||||||
|
d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0);
|
||||||
|
|
||||||
|
auto command_buffer = d.get_command_buffer(s.index);
|
||||||
|
kernel->setBatchSize(batch_size_out);
|
||||||
|
kernel->setBatchStart(0);
|
||||||
|
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
|
||||||
|
command_buffer->addCompletedHandler(
|
||||||
|
[a_mat, b_mat, out_mat, kernel, copies](
|
||||||
|
MTL::CommandBuffer*) mutable {
|
||||||
|
a_mat->release();
|
||||||
|
b_mat->release();
|
||||||
|
out_mat->release();
|
||||||
|
kernel->release();
|
||||||
|
copies.clear();
|
||||||
|
});
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schedule as many calls to MPSMatrixMultiplication as needed otherwise
|
||||||
|
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||||
|
a.data_size() / lda, lda, lda * a.itemsize(), mps_dtype);
|
||||||
|
|
||||||
|
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||||
|
b.data_size() / ldb, ldb, ldb * b.itemsize(), mps_dtype);
|
||||||
|
|
||||||
|
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||||
|
batch_size_out * M, N, N * out.itemsize(), mps_dtype);
|
||||||
|
|
||||||
|
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||||
|
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
|
||||||
|
|
||||||
|
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||||
|
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
|
||||||
|
|
||||||
|
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
|
||||||
|
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
|
||||||
|
|
||||||
|
auto kernel = MPS::MatrixMultiplication::alloc()->init(
|
||||||
|
d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0);
|
||||||
|
|
||||||
|
auto command_buffer = d.get_command_buffer(s.index);
|
||||||
|
for (int i = 0; i < batch_size_out; ++i) {
|
||||||
|
auto a_row = elem_to_loc(M * K * i, a.shape(), a.strides()) / lda;
|
||||||
|
auto b_row = elem_to_loc(K * N * i, b.shape(), b.strides()) / ldb;
|
||||||
|
kernel->setLeftMatrixOrigin({a_row, 0, 0});
|
||||||
|
kernel->setRightMatrixOrigin({b_row, 0, 0});
|
||||||
|
kernel->setResultMatrixOrigin({i * static_cast<size_t>(M), 0, 0});
|
||||||
|
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
|
||||||
|
}
|
||||||
|
|
||||||
|
command_buffer->addCompletedHandler(
|
||||||
|
[a_mat, b_mat, out_mat, kernel, copies](MTL::CommandBuffer*) mutable {
|
||||||
|
a_mat->release();
|
||||||
|
b_mat->release();
|
||||||
|
out_mat->release();
|
||||||
|
kernel->release();
|
||||||
|
copies.clear();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void mlx_matmul(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
int batch_size_out,
|
||||||
|
int lda,
|
||||||
|
int ldb,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
std::vector<array>& copies) {
|
||||||
|
// Account for batch sizes and basic broadcasting
|
||||||
|
int batch_size_a = a.data_size() / (M * K);
|
||||||
|
int batch_size_b = b.data_size() / (K * N);
|
||||||
|
|
||||||
|
int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K;
|
||||||
|
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
|
||||||
|
int matrix_stride_out = M * N;
|
||||||
|
|
||||||
|
// Determine dispatch kernel
|
||||||
|
int bm = 32, bn = 32, bk = 16;
|
||||||
|
int wm = 2, wn = 2;
|
||||||
|
|
||||||
|
if ((size_t)batch_size_out * M * N >= 2ul << 20) {
|
||||||
|
if (!transpose_a && transpose_b) {
|
||||||
|
bm = 64;
|
||||||
|
bn = (out.dtype() == float32) ? 64 : 32;
|
||||||
|
bk = (out.dtype() == float32) ? 16 : 32;
|
||||||
|
} else {
|
||||||
|
bm = 64;
|
||||||
|
bn = 64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "gemm_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n')
|
||||||
|
<< "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm
|
||||||
|
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_"
|
||||||
|
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||||
|
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||||
|
|
||||||
|
// Encode and dispatch kernel
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
// Launch only 1 kernel in the case of simple batching / broadcasting
|
||||||
|
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
|
||||||
|
(batch_size_a == batch_size_b ||
|
||||||
|
std::min(batch_size_a, batch_size_b) == 1)) {
|
||||||
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
|
MTL::Size grid_dims =
|
||||||
|
MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, batch_size_out);
|
||||||
|
|
||||||
|
set_array_buffer(compute_encoder, a, 0);
|
||||||
|
set_array_buffer(compute_encoder, b, 1);
|
||||||
|
set_array_buffer(compute_encoder, out, 2);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(&M, sizeof(int), 3);
|
||||||
|
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||||
|
compute_encoder->setBytes(&K, sizeof(int), 5);
|
||||||
|
compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6);
|
||||||
|
compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7);
|
||||||
|
compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8);
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
} else { // Other launch kernels with set offsets
|
||||||
|
|
||||||
|
for (int i = 0; i < batch_size_out; ++i) {
|
||||||
|
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||||
|
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
|
MTL::Size grid_dims = MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, 1);
|
||||||
|
|
||||||
|
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||||
|
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||||
|
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
|
||||||
|
|
||||||
|
compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0);
|
||||||
|
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
|
||||||
|
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 2);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(&M, sizeof(int), 3);
|
||||||
|
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||||
|
compute_encoder->setBytes(&K, sizeof(int), 5);
|
||||||
|
compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6);
|
||||||
|
compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7);
|
||||||
|
compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8);
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
if (!is_floating_point(out.dtype())) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[matmul] Does not yet support non-floating point types.");
|
||||||
|
}
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
auto& a_pre = inputs[0];
|
||||||
|
auto& b_pre = inputs[1];
|
||||||
|
|
||||||
|
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||||
|
// the arrays
|
||||||
|
std::vector<array> copies;
|
||||||
|
auto check_transpose = [&copies, &s](const array& arr) {
|
||||||
|
auto stx = arr.strides()[arr.ndim() - 2];
|
||||||
|
auto sty = arr.strides()[arr.ndim() - 1];
|
||||||
|
if (stx == arr.shape(-1) && sty == 1) {
|
||||||
|
return std::make_tuple(false, stx, arr);
|
||||||
|
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||||
|
return std::make_tuple(true, sty, arr);
|
||||||
|
} else {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||||
|
copies.push_back(arr_copy);
|
||||||
|
size_t stx = arr.shape(-1);
|
||||||
|
return std::make_tuple(false, stx, arr_copy);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto [a_transposed, a_cols, a] = check_transpose(a_pre);
|
||||||
|
auto [b_transposed, b_cols, b] = check_transpose(b_pre);
|
||||||
|
|
||||||
|
int M = a.shape(-2);
|
||||||
|
int N = b.shape(-1);
|
||||||
|
int K = a.shape(-1);
|
||||||
|
|
||||||
|
auto batch_size_out = out.size() / (M * N);
|
||||||
|
|
||||||
|
// Route to gemv if needed
|
||||||
|
if (std::min(M, N) == 1) {
|
||||||
|
// Collect problem info
|
||||||
|
bool is_b_matrix = N != 1;
|
||||||
|
|
||||||
|
auto& mat = is_b_matrix ? b : a;
|
||||||
|
auto& vec = is_b_matrix ? a : b;
|
||||||
|
bool transpose_mat = is_b_matrix ? !b_transposed : a_transposed;
|
||||||
|
int in_vector_len = K;
|
||||||
|
int out_vector_len = is_b_matrix ? N : M;
|
||||||
|
|
||||||
|
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
||||||
|
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||||
|
|
||||||
|
int batch_size_mat = mat.data_size() / (mat_cols * mat_rows);
|
||||||
|
int stride_mat = batch_size_mat == batch_size_out ? mat_cols * mat_rows : 0;
|
||||||
|
|
||||||
|
int batch_size_vec = vec.data_size() / in_vector_len;
|
||||||
|
int stride_vec = batch_size_vec == batch_size_out ? in_vector_len : 0;
|
||||||
|
|
||||||
|
// Determine dispatch kernel
|
||||||
|
int tm = 4, tn = 4;
|
||||||
|
int bm, bn, n_out_per_tgp;
|
||||||
|
std::ostringstream kname;
|
||||||
|
|
||||||
|
if (transpose_mat) {
|
||||||
|
bm = 8;
|
||||||
|
bn = 8;
|
||||||
|
if (out_vector_len >= 24576) {
|
||||||
|
bn = 128;
|
||||||
|
} else if (out_vector_len >= 16384) {
|
||||||
|
bn = 64;
|
||||||
|
} else if (out_vector_len >= 8192) {
|
||||||
|
bn = 16;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Specialized kernel for very small outputs
|
||||||
|
tn = out_vector_len < tn ? 1 : tn;
|
||||||
|
|
||||||
|
n_out_per_tgp = bn * tn;
|
||||||
|
kname << "gemv_t_" << type_to_name(out);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||||
|
bn = 32;
|
||||||
|
|
||||||
|
// Specialized kernel for very small outputs
|
||||||
|
tm = out_vector_len < tm ? 1 : tm;
|
||||||
|
|
||||||
|
n_out_per_tgp = bm * tm;
|
||||||
|
kname << "gemv_" << type_to_name(out);
|
||||||
|
}
|
||||||
|
|
||||||
|
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
|
||||||
|
|
||||||
|
// Encode and dispatch kernel
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||||
|
MTL::Size group_dims = MTL::Size(bn, bm, 1);
|
||||||
|
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||||
|
|
||||||
|
set_array_buffer(compute_encoder, mat, 0);
|
||||||
|
set_array_buffer(compute_encoder, vec, 1);
|
||||||
|
set_array_buffer(compute_encoder, out, 2);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(&in_vector_len, sizeof(int), 3);
|
||||||
|
compute_encoder->setBytes(&out_vector_len, sizeof(int), 4);
|
||||||
|
compute_encoder->setBytes(&stride_vec, sizeof(int), 5);
|
||||||
|
compute_encoder->setBytes(&stride_mat, sizeof(int), 6);
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
d.end_encoding(s.index);
|
||||||
|
|
||||||
|
if (use_mps()) {
|
||||||
|
mps_matmul(
|
||||||
|
s,
|
||||||
|
d,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
batch_size_out,
|
||||||
|
a_cols,
|
||||||
|
b_cols,
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
copies);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlx_matmul(
|
||||||
|
s,
|
||||||
|
d,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
batch_size_out,
|
||||||
|
a_cols,
|
||||||
|
b_cols,
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
copies);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
29
mlx/backend/metal/matmul.h
Normal file
29
mlx/backend/metal/matmul.h
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/copy.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/mps/gemm.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void mlx_matmul(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
int batch_size_out,
|
||||||
|
int lda,
|
||||||
|
int ldb,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
std::vector<array>& copies);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
368
mlx/backend/metal/mps/gemm.h
Normal file
368
mlx/backend/metal/mps/gemm.h
Normal file
@@ -0,0 +1,368 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <Metal/Metal.hpp>
|
||||||
|
|
||||||
|
#define _MPS_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol)
|
||||||
|
#define _MPS_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor)
|
||||||
|
|
||||||
|
namespace MTL::Private::Class {
|
||||||
|
_MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor);
|
||||||
|
_MTL_PRIVATE_DEF_CLS(MPSMatrix);
|
||||||
|
_MTL_PRIVATE_DEF_CLS(MPSVectorDescriptor);
|
||||||
|
_MTL_PRIVATE_DEF_CLS(MPSVector);
|
||||||
|
_MTL_PRIVATE_DEF_CLS(MPSKernel);
|
||||||
|
_MTL_PRIVATE_DEF_CLS(MPSMatrixMultiplication);
|
||||||
|
_MTL_PRIVATE_DEF_CLS(MPSMatrixVectorMultiplication);
|
||||||
|
} // namespace MTL::Private::Class
|
||||||
|
|
||||||
|
namespace MTL::Private::Selector {
|
||||||
|
_MTL_PRIVATE_DEF_SEL(
|
||||||
|
matrixDescriptorWithRows_columns_rowBytes_dataType,
|
||||||
|
"matrixDescriptorWithRows:columns:rowBytes:dataType:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(
|
||||||
|
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType,
|
||||||
|
"matrixDescriptorWithRows:columns:matrices:rowBytes:matrixBytes:dataType:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(rows, "rows");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(initWithBuffer_descriptor, "initWithBuffer:descriptor:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(
|
||||||
|
initWithDevice_,
|
||||||
|
"initWithDevice:transposeLeft:transposeRight:"
|
||||||
|
"resultRows:resultColumns:interiorColumns:alpha:beta:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(
|
||||||
|
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix,
|
||||||
|
"encodeToCommandBuffer:leftMatrix:rightMatrix:resultMatrix:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(setLeftMatrixOrigin_, "setLeftMatrixOrigin:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(setRightMatrixOrigin_, "setRightMatrixOrigin:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(setResultMatrixOrigin_, "setResultMatrixOrigin:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(setBatchStart_, "setBatchStart:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(setBatchSize_, "setBatchSize:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(
|
||||||
|
vectorDescriptorWithLength_dataType,
|
||||||
|
"vectorDescriptorWithLength:dataType:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(
|
||||||
|
vectorDescriptorWithLength_vectors_vectorBytes_dataType,
|
||||||
|
"vectorDescriptorWithLength:vectors:vectorBytes:dataType:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(
|
||||||
|
initWithDevice_transpose_rows_columns_alpha_beta,
|
||||||
|
"initWithDevice:transpose:rows:columns:alpha:beta:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(
|
||||||
|
encodeToCommandBuffer_inputMatrix_inputVector_resultVector,
|
||||||
|
"encodeToCommandBuffer:inputMatrix:inputVector:resultVector:");
|
||||||
|
} // namespace MTL::Private::Selector
|
||||||
|
|
||||||
|
namespace MPS {
|
||||||
|
|
||||||
|
typedef enum DataType : uint32_t {
|
||||||
|
DataTypeFloatBit = 0x10000000,
|
||||||
|
DataTypeAlternateEncodingBit = 0x80000000,
|
||||||
|
DataTypeFloat16 = DataTypeFloatBit | 16,
|
||||||
|
DataTypeFloat32 = DataTypeFloatBit | 32,
|
||||||
|
DataTypeBFloat16 = DataTypeAlternateEncodingBit | DataTypeFloat16
|
||||||
|
} DataType;
|
||||||
|
|
||||||
|
class MatrixDescriptor : public NS::Copying<MatrixDescriptor> {
|
||||||
|
public:
|
||||||
|
static class MatrixDescriptor* matrixDescriptor(
|
||||||
|
NS::UInteger rows,
|
||||||
|
NS::UInteger columns,
|
||||||
|
NS::UInteger rowBytes,
|
||||||
|
NS::UInteger dataType);
|
||||||
|
static class MatrixDescriptor* matrixDescriptor(
|
||||||
|
NS::UInteger rows,
|
||||||
|
NS::UInteger columns,
|
||||||
|
NS::UInteger matrices,
|
||||||
|
NS::UInteger rowBytes,
|
||||||
|
NS::UInteger matrixBytes,
|
||||||
|
NS::UInteger dataType);
|
||||||
|
NS::UInteger rows() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Matrix : public NS::Referencing<Matrix> {
|
||||||
|
public:
|
||||||
|
static class Matrix* alloc();
|
||||||
|
Matrix* init(MTL::Buffer* buffer, MatrixDescriptor* descriptor);
|
||||||
|
Matrix* init(const MTL::Buffer* buffer, MatrixDescriptor* descriptor);
|
||||||
|
};
|
||||||
|
|
||||||
|
class Kernel : public NS::Referencing<Kernel> {
|
||||||
|
public:
|
||||||
|
NS::String* label() const;
|
||||||
|
MTL::Device* device() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixMultiplication
|
||||||
|
: public NS::Referencing<MatrixMultiplication, Kernel> {
|
||||||
|
public:
|
||||||
|
static class MatrixMultiplication* alloc();
|
||||||
|
|
||||||
|
MatrixMultiplication* init(
|
||||||
|
MTL::Device* device,
|
||||||
|
bool transposeLeft,
|
||||||
|
bool transposeRight,
|
||||||
|
NS::UInteger resultRows,
|
||||||
|
NS::UInteger resultColumns,
|
||||||
|
NS::UInteger interiorColumns,
|
||||||
|
double alpha,
|
||||||
|
double beta);
|
||||||
|
|
||||||
|
void encodeToCommandBuffer(
|
||||||
|
MTL::CommandBuffer* commandBuffer,
|
||||||
|
Matrix* leftMatrix,
|
||||||
|
Matrix* rightMatrix,
|
||||||
|
Matrix* resultMatrix);
|
||||||
|
|
||||||
|
void setLeftMatrixOrigin(MTL::Origin origin);
|
||||||
|
void setRightMatrixOrigin(MTL::Origin origin);
|
||||||
|
void setResultMatrixOrigin(MTL::Origin origin);
|
||||||
|
void setBatchStart(NS::UInteger batchStart);
|
||||||
|
void setBatchSize(NS::UInteger batchSize);
|
||||||
|
};
|
||||||
|
|
||||||
|
class VectorDescriptor : public NS::Copying<VectorDescriptor> {
|
||||||
|
public:
|
||||||
|
static class VectorDescriptor* vectorDescriptor(
|
||||||
|
NS::UInteger length,
|
||||||
|
NS::UInteger dataType);
|
||||||
|
static class VectorDescriptor* vectorDescriptor(
|
||||||
|
NS::UInteger length,
|
||||||
|
NS::UInteger vectors,
|
||||||
|
NS::UInteger vectorBytes,
|
||||||
|
NS::UInteger dataType);
|
||||||
|
};
|
||||||
|
|
||||||
|
class Vector : public NS::Referencing<Vector> {
|
||||||
|
public:
|
||||||
|
static class Vector* alloc();
|
||||||
|
Vector* init(MTL::Buffer* buffer, VectorDescriptor* descriptor);
|
||||||
|
Vector* init(const MTL::Buffer* buffer, VectorDescriptor* descriptor);
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixVectorMultiplication
|
||||||
|
: public NS::Referencing<MatrixVectorMultiplication, Kernel> {
|
||||||
|
public:
|
||||||
|
static class MatrixVectorMultiplication* alloc();
|
||||||
|
|
||||||
|
MatrixVectorMultiplication* init(
|
||||||
|
MTL::Device* device,
|
||||||
|
bool transpose,
|
||||||
|
NS::UInteger rows,
|
||||||
|
NS::UInteger columns,
|
||||||
|
double alpha,
|
||||||
|
double beta);
|
||||||
|
|
||||||
|
void encodeToCommandBuffer(
|
||||||
|
MTL::CommandBuffer* commandBuffer,
|
||||||
|
Matrix* inputMatrix,
|
||||||
|
Vector* inputVector,
|
||||||
|
Vector* resultVector);
|
||||||
|
};
|
||||||
|
|
||||||
|
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
|
||||||
|
NS::UInteger rows,
|
||||||
|
NS::UInteger columns,
|
||||||
|
NS::UInteger rowBytes,
|
||||||
|
NS::UInteger dataType) {
|
||||||
|
return Object::sendMessage<MatrixDescriptor*>(
|
||||||
|
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
|
||||||
|
_MPS_PRIVATE_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType),
|
||||||
|
rows,
|
||||||
|
columns,
|
||||||
|
rowBytes,
|
||||||
|
dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
|
||||||
|
NS::UInteger rows,
|
||||||
|
NS::UInteger columns,
|
||||||
|
NS::UInteger matrices,
|
||||||
|
NS::UInteger rowBytes,
|
||||||
|
NS::UInteger matrixBytes,
|
||||||
|
NS::UInteger dataType) {
|
||||||
|
return Object::sendMessage<MatrixDescriptor*>(
|
||||||
|
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
|
||||||
|
_MPS_PRIVATE_SEL(
|
||||||
|
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType),
|
||||||
|
rows,
|
||||||
|
columns,
|
||||||
|
matrices,
|
||||||
|
rowBytes,
|
||||||
|
matrixBytes,
|
||||||
|
dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE NS::UInteger MatrixDescriptor::rows() const {
|
||||||
|
return Object::sendMessage<NS::UInteger>(this, _MPS_PRIVATE_SEL(rows));
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE Matrix* Matrix::alloc() {
|
||||||
|
return NS::Object::alloc<Matrix>(_MPS_PRIVATE_CLS(MPSMatrix));
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE Matrix* Matrix::init(
|
||||||
|
MTL::Buffer* buffer,
|
||||||
|
MatrixDescriptor* descriptor) {
|
||||||
|
return Object::sendMessage<Matrix*>(
|
||||||
|
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE Matrix* Matrix::init(
|
||||||
|
const MTL::Buffer* buffer,
|
||||||
|
MatrixDescriptor* descriptor) {
|
||||||
|
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE NS::String* Kernel::label() const {
|
||||||
|
return Object::sendMessage<NS::String*>(this, _MPS_PRIVATE_SEL(label));
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE MTL::Device* Kernel::device() const {
|
||||||
|
return Object::sendMessage<MTL::Device*>(this, _MPS_PRIVATE_SEL(device));
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::alloc() {
|
||||||
|
return NS::Object::alloc<MatrixMultiplication>(
|
||||||
|
_MPS_PRIVATE_CLS(MPSMatrixMultiplication));
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::init(
|
||||||
|
MTL::Device* device,
|
||||||
|
bool transposeLeft,
|
||||||
|
bool transposeRight,
|
||||||
|
NS::UInteger resultRows,
|
||||||
|
NS::UInteger resultColumns,
|
||||||
|
NS::UInteger interiorColumns,
|
||||||
|
double alpha,
|
||||||
|
double beta) {
|
||||||
|
return Object::sendMessage<MatrixMultiplication*>(
|
||||||
|
this,
|
||||||
|
_MPS_PRIVATE_SEL(initWithDevice_),
|
||||||
|
device,
|
||||||
|
transposeLeft,
|
||||||
|
transposeRight,
|
||||||
|
resultRows,
|
||||||
|
resultColumns,
|
||||||
|
interiorColumns,
|
||||||
|
alpha,
|
||||||
|
beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE void MatrixMultiplication::encodeToCommandBuffer(
|
||||||
|
MTL::CommandBuffer* commandBuffer,
|
||||||
|
Matrix* leftMatrix,
|
||||||
|
Matrix* rightMatrix,
|
||||||
|
Matrix* resultMatrix) {
|
||||||
|
return Object::sendMessage<void>(
|
||||||
|
this,
|
||||||
|
_MPS_PRIVATE_SEL(
|
||||||
|
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix),
|
||||||
|
commandBuffer,
|
||||||
|
leftMatrix,
|
||||||
|
rightMatrix,
|
||||||
|
resultMatrix);
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE void MatrixMultiplication::setLeftMatrixOrigin(MTL::Origin origin) {
|
||||||
|
Object::sendMessage<void>(
|
||||||
|
this, _MPS_PRIVATE_SEL(setLeftMatrixOrigin_), origin);
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE void MatrixMultiplication::setRightMatrixOrigin(
|
||||||
|
MTL::Origin origin) {
|
||||||
|
Object::sendMessage<void>(
|
||||||
|
this, _MPS_PRIVATE_SEL(setRightMatrixOrigin_), origin);
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE void MatrixMultiplication::setResultMatrixOrigin(
|
||||||
|
MTL::Origin origin) {
|
||||||
|
Object::sendMessage<void>(
|
||||||
|
this, _MPS_PRIVATE_SEL(setResultMatrixOrigin_), origin);
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE void MatrixMultiplication::setBatchStart(NS::UInteger batchStart) {
|
||||||
|
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchStart_), batchStart);
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE void MatrixMultiplication::setBatchSize(NS::UInteger batchSize) {
|
||||||
|
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchSize_), batchSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
|
||||||
|
NS::UInteger length,
|
||||||
|
NS::UInteger dataType) {
|
||||||
|
return Object::sendMessage<VectorDescriptor*>(
|
||||||
|
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
|
||||||
|
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_dataType),
|
||||||
|
length,
|
||||||
|
dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
|
||||||
|
NS::UInteger length,
|
||||||
|
NS::UInteger vectors,
|
||||||
|
NS::UInteger vectorBytes,
|
||||||
|
NS::UInteger dataType) {
|
||||||
|
return Object::sendMessage<VectorDescriptor*>(
|
||||||
|
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
|
||||||
|
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_vectors_vectorBytes_dataType),
|
||||||
|
length,
|
||||||
|
vectors,
|
||||||
|
vectorBytes,
|
||||||
|
dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE Vector* Vector::alloc() {
|
||||||
|
return NS::Object::alloc<Vector>(_MPS_PRIVATE_CLS(MPSVector));
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE Vector* Vector::init(
|
||||||
|
MTL::Buffer* buffer,
|
||||||
|
VectorDescriptor* descriptor) {
|
||||||
|
return Object::sendMessage<Vector*>(
|
||||||
|
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE Vector* Vector::init(
|
||||||
|
const MTL::Buffer* buffer,
|
||||||
|
VectorDescriptor* descriptor) {
|
||||||
|
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::alloc() {
|
||||||
|
return NS::Object::alloc<MatrixVectorMultiplication>(
|
||||||
|
_MPS_PRIVATE_CLS(MPSMatrixVectorMultiplication));
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::init(
|
||||||
|
MTL::Device* device,
|
||||||
|
bool transpose,
|
||||||
|
NS::UInteger rows,
|
||||||
|
NS::UInteger columns,
|
||||||
|
double alpha,
|
||||||
|
double beta) {
|
||||||
|
return Object::sendMessage<MatrixVectorMultiplication*>(
|
||||||
|
this,
|
||||||
|
_MPS_PRIVATE_SEL(initWithDevice_transpose_rows_columns_alpha_beta),
|
||||||
|
device,
|
||||||
|
transpose,
|
||||||
|
rows,
|
||||||
|
columns,
|
||||||
|
alpha,
|
||||||
|
beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
_MTL_INLINE void MatrixVectorMultiplication::encodeToCommandBuffer(
|
||||||
|
MTL::CommandBuffer* commandBuffer,
|
||||||
|
Matrix* inputMatrix,
|
||||||
|
Vector* inputVector,
|
||||||
|
Vector* resultVector) {
|
||||||
|
return Object::sendMessage<void>(
|
||||||
|
this,
|
||||||
|
_MPS_PRIVATE_SEL(
|
||||||
|
encodeToCommandBuffer_inputMatrix_inputVector_resultVector),
|
||||||
|
commandBuffer,
|
||||||
|
inputMatrix,
|
||||||
|
inputVector,
|
||||||
|
resultVector);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace MPS
|
||||||
604
mlx/backend/metal/primitives.cpp
Normal file
604
mlx/backend/metal/primitives.cpp
Normal file
@@ -0,0 +1,604 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <numeric>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/binary.h"
|
||||||
|
#include "mlx/backend/metal/copy.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
||||||
|
|
||||||
|
void binary_op(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
const std::string op) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
// Try to collapse contiguous dims
|
||||||
|
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||||
|
auto& strides_a = strides[0];
|
||||||
|
auto& strides_b = strides[1];
|
||||||
|
auto& strides_out = strides[2];
|
||||||
|
|
||||||
|
std::ostringstream kname;
|
||||||
|
switch (bopt) {
|
||||||
|
case ScalarScalar:
|
||||||
|
kname << "ss";
|
||||||
|
break;
|
||||||
|
case ScalarVector:
|
||||||
|
kname << "sv";
|
||||||
|
break;
|
||||||
|
case VectorScalar:
|
||||||
|
kname << "vs";
|
||||||
|
break;
|
||||||
|
case VectorVector:
|
||||||
|
kname << "vv";
|
||||||
|
break;
|
||||||
|
case General:
|
||||||
|
kname << "g";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
kname << op << type_to_name(a);
|
||||||
|
if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||||
|
kname << "_" << shape.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
set_array_buffer(compute_encoder, a, 0);
|
||||||
|
set_array_buffer(compute_encoder, b, 1);
|
||||||
|
set_array_buffer(compute_encoder, out, 2);
|
||||||
|
|
||||||
|
if (bopt == General) {
|
||||||
|
auto ndim = shape.size();
|
||||||
|
if (ndim > 3) {
|
||||||
|
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
|
||||||
|
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||||
|
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||||
|
} else {
|
||||||
|
// The shape is implicit in the grid for <= 3D
|
||||||
|
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
|
||||||
|
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch up to 3D grid of threads
|
||||||
|
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
|
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||||
|
int rest = out.size() / (dim0 * dim1);
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (thread_group_size != 1024) {
|
||||||
|
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||||
|
}
|
||||||
|
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||||
|
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
} else {
|
||||||
|
// Launch a 1D grid of threads
|
||||||
|
size_t nthreads = bopt == General ? out.size() : out.data_size();
|
||||||
|
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (thread_group_size > nthreads) {
|
||||||
|
thread_group_size = nthreads;
|
||||||
|
}
|
||||||
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void unary_op(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
const std::string op) {
|
||||||
|
auto& in = inputs[0];
|
||||||
|
bool contig = in.flags().contiguous;
|
||||||
|
if (contig) {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||||
|
in.data_size(),
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
std::string tname = type_to_name(in);
|
||||||
|
std::string opt_name = contig ? "v" : "g";
|
||||||
|
auto kernel = d.get_kernel(opt_name + op + tname);
|
||||||
|
|
||||||
|
size_t nthreads = contig ? in.data_size() : in.size();
|
||||||
|
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (thread_group_size > nthreads) {
|
||||||
|
thread_group_size = nthreads;
|
||||||
|
}
|
||||||
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
if (!contig) {
|
||||||
|
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);
|
||||||
|
compute_encoder->setBytes(
|
||||||
|
in.strides().data(), in.ndim() * sizeof(size_t), 3);
|
||||||
|
int ndim = in.ndim();
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 4);
|
||||||
|
}
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Abs::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "abs");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "add");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void arange_set_scalars(T start, T next, MTL::ComputeCommandEncoder* enc) {
|
||||||
|
enc->setBytes(&start, sizeof(T), 0);
|
||||||
|
T step = next - start;
|
||||||
|
enc->setBytes(&step, sizeof(T), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 0);
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto kernel = d.get_kernel("arange" + type_to_name(out));
|
||||||
|
size_t nthreads = out.size();
|
||||||
|
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
MTL::Size group_dims = MTL::Size(
|
||||||
|
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case bool_: // unsupported
|
||||||
|
throw std::runtime_error("[Arange::eval_gpu] Does not support bool");
|
||||||
|
case uint8:
|
||||||
|
arange_set_scalars<uint8_t>(start_, start_ + step_, compute_encoder);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
arange_set_scalars<uint16_t>(start_, start_ + step_, compute_encoder);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
arange_set_scalars<uint32_t>(start_, start_ + step_, compute_encoder);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
arange_set_scalars<uint64_t>(start_, start_ + step_, compute_encoder);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
arange_set_scalars<int8_t>(start_, start_ + step_, compute_encoder);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
arange_set_scalars<int16_t>(start_, start_ + step_, compute_encoder);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
arange_set_scalars<int32_t>(start_, start_ + step_, compute_encoder);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
arange_set_scalars<int64_t>(start_, start_ + step_, compute_encoder);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
arange_set_scalars<float16_t>(start_, start_ + step_, compute_encoder);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
arange_set_scalars<float>(start_, start_ + step_, compute_encoder);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
throw std::runtime_error("[Arange::eval_gpu] Does not support bfloat16");
|
||||||
|
case complex64:
|
||||||
|
throw std::runtime_error("[Arange::eval_gpu] Does not support complex64");
|
||||||
|
}
|
||||||
|
|
||||||
|
set_array_buffer(compute_encoder, out, 2);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcCos::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "arccos");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcCosh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "arccosh");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcSin::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "arcsin");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcSinh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "arcsinh");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "arctan");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "arctanh");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
std::string op_name;
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case ArgReduce::ArgMin:
|
||||||
|
op_name = "argmin_";
|
||||||
|
break;
|
||||||
|
case ArgReduce::ArgMax:
|
||||||
|
op_name = "argmax_";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare the shapes, strides and axis arguments.
|
||||||
|
std::vector<size_t> in_strides = in.strides();
|
||||||
|
std::vector<int> shape = in.shape();
|
||||||
|
std::vector<size_t> out_strides = out.strides();
|
||||||
|
size_t axis_stride = in_strides[axis_];
|
||||||
|
size_t axis_size = shape[axis_];
|
||||||
|
if (out_strides.size() == in_strides.size()) {
|
||||||
|
out_strides.erase(out_strides.begin() + axis_);
|
||||||
|
}
|
||||||
|
in_strides.erase(in_strides.begin() + axis_);
|
||||||
|
shape.erase(shape.begin() + axis_);
|
||||||
|
size_t ndim = shape.size();
|
||||||
|
|
||||||
|
// ArgReduce
|
||||||
|
int simd_size = 32;
|
||||||
|
int n_reads = 4;
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
{
|
||||||
|
auto kernel = d.get_kernel(op_name + type_to_name(in));
|
||||||
|
NS::UInteger thread_group_size = std::min(
|
||||||
|
(axis_size + n_reads - 1) / n_reads,
|
||||||
|
kernel->maxTotalThreadsPerThreadgroup());
|
||||||
|
// round up to the closest number divisible by simd_size
|
||||||
|
thread_group_size =
|
||||||
|
(thread_group_size + simd_size - 1) / simd_size * simd_size;
|
||||||
|
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||||
|
|
||||||
|
size_t n_threads = out.size() * thread_group_size;
|
||||||
|
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
|
||||||
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
|
||||||
|
compute_encoder->setBytes(in_strides.data(), ndim * sizeof(size_t), 3);
|
||||||
|
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||||
|
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
|
||||||
|
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
|
||||||
|
compute_encoder->setThreadgroupMemoryLength(
|
||||||
|
simd_size * (sizeof(uint32_t) + in.itemsize()), 0);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
CopyType ctype =
|
||||||
|
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||||
|
copy_gpu(inputs[0], out, ctype);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
std::vector<int> sizes;
|
||||||
|
sizes.push_back(0);
|
||||||
|
for (auto& p : inputs) {
|
||||||
|
sizes.push_back(p.shape(axis_));
|
||||||
|
}
|
||||||
|
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto strides = out.strides();
|
||||||
|
auto flags = out.flags();
|
||||||
|
flags.row_contiguous = false;
|
||||||
|
flags.col_contiguous = false;
|
||||||
|
flags.contiguous = false;
|
||||||
|
for (int i = 0; i < inputs.size(); i++) {
|
||||||
|
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||||
|
size_t data_offset = strides[axis_] * sizes[i];
|
||||||
|
out_slice.copy_shared_buffer(
|
||||||
|
out, strides, flags, out_slice.size(), data_offset);
|
||||||
|
copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Cos::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "cos");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "cosh");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "div");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Erf::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "erf");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ErfInv::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "erfinv");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "exp");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
auto in = inputs[0];
|
||||||
|
CopyType ctype;
|
||||||
|
if (in.data_size() == 1) {
|
||||||
|
ctype = CopyType::Scalar;
|
||||||
|
} else if (in.flags().contiguous) {
|
||||||
|
ctype = CopyType::Vector;
|
||||||
|
} else {
|
||||||
|
ctype = CopyType::General;
|
||||||
|
}
|
||||||
|
copy_gpu(in, out, ctype);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "ge");
|
||||||
|
}
|
||||||
|
|
||||||
|
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "geq");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "le");
|
||||||
|
}
|
||||||
|
|
||||||
|
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "leq");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
switch (base_) {
|
||||||
|
case Base::e:
|
||||||
|
unary_op(inputs, out, "log");
|
||||||
|
break;
|
||||||
|
case Base::two:
|
||||||
|
unary_op(inputs, out, "log2");
|
||||||
|
break;
|
||||||
|
case Base::ten:
|
||||||
|
unary_op(inputs, out, "log10");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Log1p::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "log1p");
|
||||||
|
}
|
||||||
|
|
||||||
|
void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "lnot");
|
||||||
|
}
|
||||||
|
|
||||||
|
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "lae");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "max");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "min");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "mul");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "neg");
|
||||||
|
}
|
||||||
|
|
||||||
|
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "neq");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
// Inputs must be base input array and scalar val array
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
auto& val = inputs[1];
|
||||||
|
|
||||||
|
// Padding value must be a scalar
|
||||||
|
assert(val.size() == 1);
|
||||||
|
|
||||||
|
// Padding value, input and output must be of the same type
|
||||||
|
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||||
|
|
||||||
|
// Fill output with val
|
||||||
|
copy_gpu(val, out, CopyType::Scalar, stream());
|
||||||
|
|
||||||
|
// Find offset for start of input values
|
||||||
|
size_t data_offset = 0;
|
||||||
|
for (int i = 0; i < axes_.size(); i++) {
|
||||||
|
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
|
||||||
|
data_offset += out.strides()[ax] * low_pad_size_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract slice from output where input will be pasted
|
||||||
|
array out_slice(in.shape(), out.dtype(), nullptr, {});
|
||||||
|
out_slice.copy_shared_buffer(
|
||||||
|
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||||
|
|
||||||
|
// Copy input values into the slice
|
||||||
|
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "pow");
|
||||||
|
}
|
||||||
|
|
||||||
|
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
// keys has shape (N1, ..., NK, 2)
|
||||||
|
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||||
|
auto& keys = inputs[0];
|
||||||
|
size_t num_keys = keys.size() / 2;
|
||||||
|
|
||||||
|
size_t elems_per_key = out.size() / num_keys;
|
||||||
|
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||||
|
size_t half_size = out_per_key / 2;
|
||||||
|
bool odd = out_per_key % 2;
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
std::string kname = keys.flags().row_contiguous ? "rbitsc" : "rbits";
|
||||||
|
auto kernel = d.get_kernel(kname);
|
||||||
|
|
||||||
|
// organize into grid nkeys x elem_per_key
|
||||||
|
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
auto nthreads = std::min(num_keys * (half_size + odd), thread_group_size);
|
||||||
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
set_array_buffer(compute_encoder, keys, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
compute_encoder->setBytes(&odd, sizeof(bool), 2);
|
||||||
|
compute_encoder->setBytes(&bytes_per_key, sizeof(size_t), 3);
|
||||||
|
|
||||||
|
if (!keys.flags().row_contiguous) {
|
||||||
|
int ndim = keys.ndim();
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 4);
|
||||||
|
compute_encoder->setBytes(
|
||||||
|
keys.shape().data(), keys.ndim() * sizeof(int), 5);
|
||||||
|
compute_encoder->setBytes(
|
||||||
|
keys.strides().data(), keys.ndim() * sizeof(size_t), 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (in.flags().row_contiguous) {
|
||||||
|
auto flags = in.flags();
|
||||||
|
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||||
|
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||||
|
out.copy_shared_buffer(in, out.strides(), flags, in.data_size());
|
||||||
|
} else {
|
||||||
|
copy_gpu(in, out, CopyType::General);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "sigmoid");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sign::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "sign");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sin::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "sin");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sinh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "sinh");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Square::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "square");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
if (recip_) {
|
||||||
|
unary_op(inputs, out, "rsqrt");
|
||||||
|
} else {
|
||||||
|
unary_op(inputs, out, "sqrt");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "sub");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Tan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "tan");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Tanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "tanh");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
130
mlx/backend/metal/scan.cpp
Normal file
130
mlx/backend/metal/scan.cpp
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
#include <cassert>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/copy.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
// Ensure contiguity
|
||||||
|
std::vector<array> copies;
|
||||||
|
auto in = inputs[0];
|
||||||
|
if (!in.flags().row_contiguous) {
|
||||||
|
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||||
|
copy_gpu(in, arr_copy, CopyType::General, s);
|
||||||
|
copies.push_back(arr_copy);
|
||||||
|
in = arr_copy;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostringstream kname;
|
||||||
|
if (in.strides()[axis_] == 1) {
|
||||||
|
kname << "contiguous_scan_";
|
||||||
|
if (reverse_) {
|
||||||
|
kname << "reverse_";
|
||||||
|
}
|
||||||
|
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Scan::Sum:
|
||||||
|
kname << "sum_";
|
||||||
|
break;
|
||||||
|
case Scan::Prod:
|
||||||
|
kname << "prod_";
|
||||||
|
break;
|
||||||
|
case Scan::Max:
|
||||||
|
kname << "max_";
|
||||||
|
break;
|
||||||
|
case Scan::Min:
|
||||||
|
kname << "min_";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
kname << type_to_name(in) << "_" << type_to_name(out);
|
||||||
|
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
size_t size = in.shape(axis_);
|
||||||
|
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||||
|
|
||||||
|
// Compute the thread grid
|
||||||
|
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
|
||||||
|
int elements_per_simd = n_reads * 32;
|
||||||
|
int thread_groups = in.size() / size;
|
||||||
|
int thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (size < n_reads * 1024) {
|
||||||
|
thread_group_size = ((size + elements_per_simd - 1) / elements_per_simd) *
|
||||||
|
elements_per_simd;
|
||||||
|
} else if (size < n_reads * 2048) {
|
||||||
|
thread_group_size =
|
||||||
|
((size / 2 + elements_per_simd - 1) / elements_per_simd) *
|
||||||
|
elements_per_simd;
|
||||||
|
}
|
||||||
|
thread_group_size = std::min(
|
||||||
|
thread_group_size,
|
||||||
|
static_cast<int>(kernel->maxTotalThreadsPerThreadgroup()));
|
||||||
|
MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1);
|
||||||
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
} else {
|
||||||
|
kname << "strided_scan_";
|
||||||
|
if (reverse_) {
|
||||||
|
kname << "reverse_";
|
||||||
|
}
|
||||||
|
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Scan::Sum:
|
||||||
|
kname << "sum_";
|
||||||
|
break;
|
||||||
|
case Scan::Prod:
|
||||||
|
kname << "prod_";
|
||||||
|
break;
|
||||||
|
case Scan::Max:
|
||||||
|
kname << "max_";
|
||||||
|
break;
|
||||||
|
case Scan::Min:
|
||||||
|
kname << "min_";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
kname << type_to_name(in) << "_" << type_to_name(out);
|
||||||
|
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
size_t size = in.shape(axis_);
|
||||||
|
size_t stride = in.strides()[axis_];
|
||||||
|
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||||
|
compute_encoder->setBytes(&stride, sizeof(size_t), 3);
|
||||||
|
|
||||||
|
// Compute the thread grid
|
||||||
|
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
|
||||||
|
int tile_x = 32;
|
||||||
|
int tile_y = 32;
|
||||||
|
int elements_per_tile_x = tile_x * n_reads;
|
||||||
|
int grid_y = in.size() / size / stride;
|
||||||
|
int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x;
|
||||||
|
MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1);
|
||||||
|
MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (copies.size() > 0) {
|
||||||
|
auto command_buffer = d.get_command_buffer(s.index);
|
||||||
|
command_buffer->addCompletedHandler(
|
||||||
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
167
mlx/backend/metal/utils.h
Normal file
167
mlx/backend/metal/utils.h
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void set_array_buffer(
|
||||||
|
MTL::ComputeCommandEncoder* compute_encoder,
|
||||||
|
MTL::ArgumentEncoder* enc,
|
||||||
|
const array& a,
|
||||||
|
int idx) {
|
||||||
|
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||||
|
auto offset = a.data<char>() -
|
||||||
|
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||||
|
enc->setBuffer(a_buf, offset, idx);
|
||||||
|
// MTL::Resource usage through argument buffer needs to be explicity
|
||||||
|
// flagged to enable hazard tracking
|
||||||
|
compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_array_buffer(
|
||||||
|
MTL::ComputeCommandEncoder* enc,
|
||||||
|
const array& a,
|
||||||
|
int idx) {
|
||||||
|
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||||
|
auto offset = a.data<char>() -
|
||||||
|
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||||
|
enc->setBuffer(a_buf, offset, idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string type_to_name(const array& a) {
|
||||||
|
std::string tname;
|
||||||
|
switch (a.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
tname = "bool_";
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
tname = "uint8";
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
tname = "uint16";
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
tname = "uint32";
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
tname = "uint64";
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
tname = "int8";
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
tname = "int16";
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
tname = "int32";
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
tname = "int64";
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
tname = "float16";
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
tname = "float32";
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
tname = "bfloat16";
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
tname = "complex64";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return tname;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
||||||
|
int pows[3] = {0, 0, 0};
|
||||||
|
int sum = 0;
|
||||||
|
while (true) {
|
||||||
|
int presum = sum;
|
||||||
|
// Check all the pows
|
||||||
|
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||||
|
pows[0]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||||
|
pows[1]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||||
|
pows[2]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == presum || sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collapse dims that are contiguous to possibly route to a better kernel
|
||||||
|
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
|
||||||
|
// should return {{2, 4}, {{1, 2}}}.
|
||||||
|
//
|
||||||
|
// When multiple arrays are passed they should all have the same shape. The
|
||||||
|
// collapsed axes are also the same so one shape is returned.
|
||||||
|
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||||
|
collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||||
|
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||||
|
// -1.
|
||||||
|
std::vector<int> to_collapse;
|
||||||
|
if (xs[0].ndim() > 0) {
|
||||||
|
to_collapse.push_back(0);
|
||||||
|
for (int i = 1; i < xs[0].ndim(); i++) {
|
||||||
|
bool contiguous = true;
|
||||||
|
for (auto& x : xs) {
|
||||||
|
if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) {
|
||||||
|
contiguous = false;
|
||||||
|
}
|
||||||
|
if (!contiguous) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!contiguous) {
|
||||||
|
to_collapse.push_back(-1);
|
||||||
|
}
|
||||||
|
to_collapse.push_back(i);
|
||||||
|
}
|
||||||
|
to_collapse.push_back(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> out_shape;
|
||||||
|
std::vector<std::vector<size_t>> out_strides(xs.size());
|
||||||
|
for (int i = 0; i < to_collapse.size(); i++) {
|
||||||
|
int current_shape = xs[0].shape()[to_collapse[i]];
|
||||||
|
while (to_collapse[++i] != -1) {
|
||||||
|
current_shape *= xs[0].shape()[to_collapse[i]];
|
||||||
|
}
|
||||||
|
out_shape.push_back(current_shape);
|
||||||
|
for (int j = 0; j < xs.size(); j++) {
|
||||||
|
out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(out_shape, out_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Arrays>
|
||||||
|
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||||
|
collapse_contiguous_dims(Arrays... xs) {
|
||||||
|
return collapse_contiguous_dims(
|
||||||
|
std::vector<array>{std::forward<Arrays>(xs)...});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
18
mlx/backend/no_metal/metal.cpp
Normal file
18
mlx/backend/no_metal/metal.cpp
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
|
||||||
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
|
void new_stream(Stream) {}
|
||||||
|
|
||||||
|
std::function<void()> make_task(
|
||||||
|
array& arr,
|
||||||
|
std::vector<std::shared_future<void>> deps,
|
||||||
|
std::shared_ptr<std::promise<void>> p,
|
||||||
|
bool retain_graph) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[metal::make_task] Cannot make GPU task without metal backend");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::metal
|
||||||
27
mlx/device.h
Normal file
27
mlx/device.h
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
struct Device {
|
||||||
|
enum class DeviceType {
|
||||||
|
cpu,
|
||||||
|
gpu,
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr DeviceType cpu = DeviceType::cpu;
|
||||||
|
static constexpr DeviceType gpu = DeviceType::gpu;
|
||||||
|
|
||||||
|
Device(DeviceType type, int index = 0) : type(type), index(index){};
|
||||||
|
|
||||||
|
DeviceType type;
|
||||||
|
int index;
|
||||||
|
};
|
||||||
|
|
||||||
|
const Device& default_device();
|
||||||
|
|
||||||
|
void set_default_device(const Device& d);
|
||||||
|
|
||||||
|
bool operator==(const Device& lhs, const Device& rhs);
|
||||||
|
bool operator!=(const Device& lhs, const Device& rhs);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
205
mlx/dtype.cpp
Normal file
205
mlx/dtype.cpp
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
#include <cstdint>
|
||||||
|
#include <sstream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mlx/dtype.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static constexpr int num_types = 13;
|
||||||
|
|
||||||
|
static constexpr Dtype::Kind type_kinds[num_types] = {
|
||||||
|
Dtype::Kind::b, // bool_,
|
||||||
|
Dtype::Kind::u, // uint8,
|
||||||
|
Dtype::Kind::u, // uint16,
|
||||||
|
Dtype::Kind::u, // uint32,
|
||||||
|
Dtype::Kind::u, // uint64,
|
||||||
|
Dtype::Kind::i, // int8,
|
||||||
|
Dtype::Kind::i, // int16,
|
||||||
|
Dtype::Kind::i, // int32,
|
||||||
|
Dtype::Kind::i, // int64,
|
||||||
|
Dtype::Kind::f, // float16,
|
||||||
|
Dtype::Kind::f, // float32,
|
||||||
|
Dtype::Kind::V, // bfloat16,
|
||||||
|
Dtype::Kind::c // complex64,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Following Jax type promotion rules:
|
||||||
|
// https://jax.readthedocs.io/en/latest/type_promotion.html
|
||||||
|
// clang-format off
|
||||||
|
static constexpr Dtype type_rules[num_types][num_types] = {
|
||||||
|
// bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 bfloat16 complex64
|
||||||
|
{bool_, uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // bool
|
||||||
|
{uint8, uint8, uint16, uint32, uint64, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // uint8
|
||||||
|
{uint16, uint16, uint16, uint32, uint64, int32, int32, int32, int64, float16, float32, bfloat16, complex64}, // uint16
|
||||||
|
{uint32, uint32, uint32, uint32, uint64, int64, int64, int64, int64, float16, float32, bfloat16, complex64}, // uint32
|
||||||
|
{uint64, uint64, uint64, uint64, uint64, float32, float32, float32, float32, float16, float32, bfloat16, complex64}, // uint64
|
||||||
|
{int8, int16, int32, int64, float32, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // int8
|
||||||
|
{int16, int16, int32, int64, float32, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // int16
|
||||||
|
{int32, int32, int32, int64, float32, int32, int32, int32, int64, float16, float32, bfloat16, complex64}, // int32
|
||||||
|
{int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, bfloat16, complex64}, // int64
|
||||||
|
{float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float32, complex64}, // float16
|
||||||
|
{float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, complex64}, // float32
|
||||||
|
{bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, bfloat16, complex64}, // bfloat16
|
||||||
|
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64
|
||||||
|
};
|
||||||
|
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
inline bool is_big_endian() {
|
||||||
|
union ByteOrder {
|
||||||
|
int32_t i;
|
||||||
|
uint8_t c[4];
|
||||||
|
};
|
||||||
|
ByteOrder b = {0x01234567};
|
||||||
|
|
||||||
|
return b.c[0] == 0x01;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Dtype promote_types(const Dtype& t1, const Dtype& t2) {
|
||||||
|
return Dtype(type_rules[static_cast<int>(t1.val)][static_cast<int>(t2.val)]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Dtype::Kind kindof(const Dtype& t) {
|
||||||
|
return type_kinds[static_cast<int>(t.val)];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
TypeToDtype<bool>::operator Dtype() {
|
||||||
|
return bool_;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
TypeToDtype<uint8_t>::operator Dtype() {
|
||||||
|
return uint8;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
TypeToDtype<uint16_t>::operator Dtype() {
|
||||||
|
return uint16;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
TypeToDtype<uint32_t>::operator Dtype() {
|
||||||
|
return uint32;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
TypeToDtype<uint64_t>::operator Dtype() {
|
||||||
|
return uint64;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
TypeToDtype<int8_t>::operator Dtype() {
|
||||||
|
return int8;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
TypeToDtype<int16_t>::operator Dtype() {
|
||||||
|
return int16;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
TypeToDtype<int32_t>::operator Dtype() {
|
||||||
|
return int32;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
TypeToDtype<int64_t>::operator Dtype() {
|
||||||
|
return int64;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
TypeToDtype<float16_t>::operator Dtype() {
|
||||||
|
return float16;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
TypeToDtype<float>::operator Dtype() {
|
||||||
|
return float32;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
TypeToDtype<double>::operator Dtype() {
|
||||||
|
return float32;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
TypeToDtype<bfloat16_t>::operator Dtype() {
|
||||||
|
return bfloat16;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
TypeToDtype<complex64_t>::operator Dtype() {
|
||||||
|
return complex64;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Array protocol typestring for Dtype
|
||||||
|
std::string dtype_to_array_protocol(const Dtype& t) {
|
||||||
|
std::ostringstream r;
|
||||||
|
if (size_of(t) > 1)
|
||||||
|
r << (is_big_endian() ? ">" : "<");
|
||||||
|
else
|
||||||
|
r << "|";
|
||||||
|
r << kindof(t) << (int)size_of(t);
|
||||||
|
return r.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dtype from array protocol type string
|
||||||
|
Dtype dtype_from_array_protocol(const std::string& t) {
|
||||||
|
if (t.length() == 2 || t.length() == 3) {
|
||||||
|
std::string r = t.length() == 3 ? t.substr(1, 2) : t;
|
||||||
|
|
||||||
|
if (r == "V2") {
|
||||||
|
return bfloat16;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t size = r[1] - '0';
|
||||||
|
|
||||||
|
switch (r[0]) {
|
||||||
|
case 'b': {
|
||||||
|
if (size == 1)
|
||||||
|
return bool_;
|
||||||
|
}
|
||||||
|
case 'i': {
|
||||||
|
if (size == 1)
|
||||||
|
return int8;
|
||||||
|
else if (size == 2)
|
||||||
|
return int16;
|
||||||
|
else if (size == 4)
|
||||||
|
return int32;
|
||||||
|
else if (size == 8)
|
||||||
|
return int64;
|
||||||
|
}
|
||||||
|
case 'u': {
|
||||||
|
if (size == 1)
|
||||||
|
return uint8;
|
||||||
|
else if (size == 2)
|
||||||
|
return uint16;
|
||||||
|
else if (size == 4)
|
||||||
|
return uint32;
|
||||||
|
else if (size == 8)
|
||||||
|
return uint64;
|
||||||
|
}
|
||||||
|
case 'f': {
|
||||||
|
if (size == 2)
|
||||||
|
return float16;
|
||||||
|
else if (size == 4)
|
||||||
|
return float32;
|
||||||
|
}
|
||||||
|
case 'c': {
|
||||||
|
return complex64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[from_str] Invalid array protocol type-string: " + t);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
144
mlx/graph_utils.cpp
Normal file
144
mlx/graph_utils.cpp
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
#include <functional>
|
||||||
|
#include <optional>
|
||||||
|
#include <sstream>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#include "mlx/graph_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
using OptionalArrayRef = std::optional<std::reference_wrapper<const array>>;
|
||||||
|
|
||||||
|
struct ArrayNames {
|
||||||
|
std::unordered_map<std::uintptr_t, std::string> names;
|
||||||
|
|
||||||
|
std::string get_name(const array& x) {
|
||||||
|
auto it = names.find(x.id());
|
||||||
|
if (it == names.end()) {
|
||||||
|
// Get the next name in the sequence
|
||||||
|
// [A, B, ..., Z, AA, AB, ...]
|
||||||
|
std::vector<char> letters;
|
||||||
|
auto var_num = names.size() + 1;
|
||||||
|
while (var_num > 0) {
|
||||||
|
letters.push_back('A' + (var_num - 1) % 26);
|
||||||
|
var_num = (var_num - 1) / 26;
|
||||||
|
}
|
||||||
|
std::string name(letters.rbegin(), letters.rend());
|
||||||
|
names.insert({x.id(), name});
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void depth_first_traversal(
|
||||||
|
std::function<void(OptionalArrayRef, const array&, int)> callback,
|
||||||
|
const std::vector<array>& outputs) {
|
||||||
|
std::function<void(OptionalArrayRef, const array&, int)> recurse;
|
||||||
|
std::unordered_set<std::uintptr_t> cache;
|
||||||
|
recurse = [&](OptionalArrayRef parent, const array& x, int input_index) {
|
||||||
|
auto id = x.id();
|
||||||
|
if (cache.find(id) != cache.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
cache.insert(id);
|
||||||
|
for (int i = 0; i < x.inputs().size(); i++) {
|
||||||
|
recurse(x, x.inputs()[i], i);
|
||||||
|
}
|
||||||
|
callback(parent, x, input_index);
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto x : outputs) {
|
||||||
|
recurse(std::nullopt, x, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void depth_first_traversal(
|
||||||
|
std::function<void(const array&)> callback,
|
||||||
|
const std::vector<array>& outputs) {
|
||||||
|
depth_first_traversal(
|
||||||
|
[&callback](OptionalArrayRef p, const array& x, int input_index) {
|
||||||
|
callback(x);
|
||||||
|
},
|
||||||
|
outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_graph(std::ostream& os, const std::vector<array>& outputs) {
|
||||||
|
std::vector<array> tape;
|
||||||
|
std::vector<array> inputs;
|
||||||
|
|
||||||
|
depth_first_traversal(
|
||||||
|
[&](const array& x) {
|
||||||
|
if (x.has_primitive()) {
|
||||||
|
tape.push_back(x);
|
||||||
|
} else {
|
||||||
|
inputs.push_back(x);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
outputs);
|
||||||
|
|
||||||
|
ArrayNames namer;
|
||||||
|
auto print_arr = [&namer, &os](const array& a) {
|
||||||
|
os << namer.get_name(a);
|
||||||
|
os << " [" << a.shape() << ", " << a.dtype() << "]";
|
||||||
|
};
|
||||||
|
|
||||||
|
auto print_arrs = [&](const std::vector<array>& arrs) {
|
||||||
|
for (auto& arr : arrs) {
|
||||||
|
print_arr(arr);
|
||||||
|
if (&arr != &arrs.back()) {
|
||||||
|
os << ", ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
os << "Inputs: ";
|
||||||
|
print_arrs(inputs);
|
||||||
|
os << "\nOutputs: ";
|
||||||
|
print_arrs(outputs);
|
||||||
|
os << "\n";
|
||||||
|
|
||||||
|
for (auto& arr : tape) {
|
||||||
|
arr.primitive().print(os);
|
||||||
|
os << " ";
|
||||||
|
print_arrs(arr.inputs());
|
||||||
|
os << " -> ";
|
||||||
|
print_arr(arr);
|
||||||
|
os << "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
|
||||||
|
os << "digraph {" << std::endl;
|
||||||
|
|
||||||
|
ArrayNames namer;
|
||||||
|
depth_first_traversal(
|
||||||
|
[&namer, &os](auto parent, const array& x, int input_index) {
|
||||||
|
os << "{ ";
|
||||||
|
if (!x.has_primitive()) {
|
||||||
|
os << "rank=source; ";
|
||||||
|
}
|
||||||
|
if (!parent) {
|
||||||
|
os << "rank=sink; ";
|
||||||
|
}
|
||||||
|
os << namer.get_name(x);
|
||||||
|
if (x.has_primitive()) {
|
||||||
|
os << " [label =\"";
|
||||||
|
x.primitive().print(os);
|
||||||
|
os << "\"]";
|
||||||
|
}
|
||||||
|
os << "; }" << std::endl;
|
||||||
|
|
||||||
|
for (auto c : x.inputs()) {
|
||||||
|
os << namer.get_name(c) << " -> " << namer.get_name(x) << std::endl;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
outputs);
|
||||||
|
|
||||||
|
os << "}";
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
112
mlx/load.h
Normal file
112
mlx/load.h
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <istream>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace io {
|
||||||
|
|
||||||
|
class Reader {
|
||||||
|
public:
|
||||||
|
virtual bool is_open() const = 0;
|
||||||
|
virtual bool good() const = 0;
|
||||||
|
virtual size_t tell() const = 0;
|
||||||
|
virtual void seek(
|
||||||
|
int64_t off,
|
||||||
|
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
||||||
|
virtual void read(char* data, size_t n) = 0;
|
||||||
|
virtual std::string label() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Writer {
|
||||||
|
public:
|
||||||
|
virtual bool is_open() const = 0;
|
||||||
|
virtual bool good() const = 0;
|
||||||
|
virtual size_t tell() const = 0;
|
||||||
|
virtual void seek(
|
||||||
|
int64_t off,
|
||||||
|
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
||||||
|
virtual void write(const char* data, size_t n) = 0;
|
||||||
|
virtual std::string label() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class FileReader : public Reader {
|
||||||
|
public:
|
||||||
|
explicit FileReader(const std::shared_ptr<std::ifstream>& is)
|
||||||
|
: is_(is), label_("stream") {}
|
||||||
|
explicit FileReader(const std::string& file_path)
|
||||||
|
: is_(std::make_shared<std::ifstream>(file_path, std::ios::binary)),
|
||||||
|
label_(file_path) {}
|
||||||
|
|
||||||
|
bool is_open() const override {
|
||||||
|
return is_->is_open();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool good() const override {
|
||||||
|
return is_->good();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t tell() const override {
|
||||||
|
return is_->tellg();
|
||||||
|
}
|
||||||
|
|
||||||
|
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||||
|
override {
|
||||||
|
is_->seekg(off, way);
|
||||||
|
}
|
||||||
|
|
||||||
|
void read(char* data, size_t n) override {
|
||||||
|
is_->read(data, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string label() const override {
|
||||||
|
return "file " + label_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<std::ifstream> is_;
|
||||||
|
std::string label_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class FileWriter : public Writer {
|
||||||
|
public:
|
||||||
|
explicit FileWriter(const std::shared_ptr<std::ofstream>& is)
|
||||||
|
: os_(is), label_("stream") {}
|
||||||
|
explicit FileWriter(const std::string& file_path)
|
||||||
|
: os_(std::make_shared<std::ofstream>(file_path, std::ios::binary)),
|
||||||
|
label_(file_path) {}
|
||||||
|
|
||||||
|
bool is_open() const override {
|
||||||
|
return os_->is_open();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool good() const override {
|
||||||
|
return os_->good();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t tell() const override {
|
||||||
|
return os_->tellp();
|
||||||
|
}
|
||||||
|
|
||||||
|
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||||
|
override {
|
||||||
|
os_->seekp(off, way);
|
||||||
|
}
|
||||||
|
|
||||||
|
void write(const char* data, size_t n) override {
|
||||||
|
os_->write(data, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string label() const override {
|
||||||
|
return "file " + label_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<std::ofstream> os_;
|
||||||
|
std::string label_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace io
|
||||||
|
} // namespace mlx::core
|
||||||
11
mlx/mlx.h
Normal file
11
mlx/mlx.h
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
#include "mlx/device.h"
|
||||||
|
#include "mlx/fft.h"
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
#include "mlx/random.h"
|
||||||
|
#include "mlx/stream.h"
|
||||||
|
#include "mlx/transforms.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
932
mlx/ops.h
Normal file
932
mlx/ops.h
Normal file
@@ -0,0 +1,932 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
|
#include "array.h"
|
||||||
|
#include "device.h"
|
||||||
|
#include "load.h"
|
||||||
|
#include "stream.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
||||||
|
|
||||||
|
Stream to_stream(StreamOrDevice s);
|
||||||
|
|
||||||
|
/** Creation operations */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A 1D array of numbers starting at `start` (optional),
|
||||||
|
* stopping at stop, stepping by `step` (optional). **/
|
||||||
|
array arange(
|
||||||
|
double start,
|
||||||
|
double stop,
|
||||||
|
double step,
|
||||||
|
Dtype dtype,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
array arange(double start, double stop, double step, StreamOrDevice s = {});
|
||||||
|
array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {});
|
||||||
|
array arange(double start, double stop, StreamOrDevice s = {});
|
||||||
|
array arange(double stop, Dtype dtype, StreamOrDevice s = {});
|
||||||
|
array arange(double stop, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
array arange(int start, int stop, int step, StreamOrDevice s = {});
|
||||||
|
array arange(int start, int stop, StreamOrDevice s = {});
|
||||||
|
array arange(int stop, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Convert an array to the given data type. */
|
||||||
|
array astype(const array& a, Dtype dtype, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Create a view of an array with the given shape and strides. */
|
||||||
|
array as_strided(
|
||||||
|
const array& a,
|
||||||
|
std::vector<int> shape,
|
||||||
|
std::vector<size_t> strides,
|
||||||
|
size_t offset,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Copy another array. */
|
||||||
|
array copy(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Fill an array of the given shape with the given value(s). */
|
||||||
|
array full(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
const array& vals,
|
||||||
|
Dtype dtype,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
array full(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
const array& vals,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
template <typename T>
|
||||||
|
array full(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
T val,
|
||||||
|
Dtype dtype,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return full(shape, array(val, dtype), to_stream(s));
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array full(const std::vector<int>& shape, T val, StreamOrDevice s = {}) {
|
||||||
|
return full(shape, array(val), to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Fill an array of the given shape with zeros. */
|
||||||
|
array zeros(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
|
||||||
|
inline array zeros(const std::vector<int>& shape, StreamOrDevice s = {}) {
|
||||||
|
return zeros(shape, float32, s);
|
||||||
|
}
|
||||||
|
array zeros_like(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Fill an array of the given shape with ones. */
|
||||||
|
array ones(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
|
||||||
|
inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
|
||||||
|
return ones(shape, float32, s);
|
||||||
|
}
|
||||||
|
array ones_like(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** array manipulation */
|
||||||
|
|
||||||
|
/** Reshape an array to the given shape. */
|
||||||
|
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Remove singleton dimensions at the given axes. */
|
||||||
|
array squeeze(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Remove singleton dimensions at the given axis. */
|
||||||
|
inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) {
|
||||||
|
return squeeze(a, std::vector<int>{axis}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Remove all singleton dimensions. */
|
||||||
|
array squeeze(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Add a singleton dimension at the given axes. */
|
||||||
|
array expand_dims(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Add a singleton dimension at the given axis. */
|
||||||
|
inline array expand_dims(const array& a, int axis, StreamOrDevice s = {}) {
|
||||||
|
return expand_dims(a, std::vector<int>{axis}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Slice an array. */
|
||||||
|
array slice(
|
||||||
|
const array& a,
|
||||||
|
std::vector<int> start,
|
||||||
|
std::vector<int> stop,
|
||||||
|
std::vector<int> strides,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Slice an array with a stride of 1 in each dimension. */
|
||||||
|
array slice(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& start,
|
||||||
|
const std::vector<int>& stop,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Split an array into sub-arrays along a given axis. */
|
||||||
|
std::vector<array>
|
||||||
|
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
|
||||||
|
std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
|
||||||
|
std::vector<array> split(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& indices,
|
||||||
|
int axis,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
std::vector<array>
|
||||||
|
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Concatenate arrays along a given axis. */
|
||||||
|
array concatenate(
|
||||||
|
const std::vector<array>& arrays,
|
||||||
|
int axis,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Permutes the dimensions according to the given axes. */
|
||||||
|
array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
|
||||||
|
inline array transpose(
|
||||||
|
const array& a,
|
||||||
|
std::initializer_list<int> axes,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return transpose(a, std::vector<int>(axes), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Pad an array with a constant value */
|
||||||
|
array pad(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const std::vector<int>& low_pad_size,
|
||||||
|
const std::vector<int>& high_pad_size,
|
||||||
|
const array& pad_value = array(0),
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Pad an array with a constant value along all axes */
|
||||||
|
array pad(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<std::pair<int, int>>& pad_width,
|
||||||
|
const array& pad_value = array(0),
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
array pad(
|
||||||
|
const array& a,
|
||||||
|
const std::pair<int, int>& pad_width,
|
||||||
|
const array& pad_value = array(0),
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
array pad(
|
||||||
|
const array& a,
|
||||||
|
int pad_width,
|
||||||
|
const array& pad_value = array(0),
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Permutes the dimensions in reverse order. */
|
||||||
|
array transpose(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Broadcast an array to a given shape. */
|
||||||
|
array broadcast_to(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Broadcast a vector of arrays against one another. */
|
||||||
|
std::vector<array> broadcast_arrays(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Comparison operations */
|
||||||
|
|
||||||
|
/** Returns the bool array with (a == b) element-wise. */
|
||||||
|
array equal(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
inline array operator==(const array& a, const array& b) {
|
||||||
|
return equal(a, b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator==(T a, const array& b) {
|
||||||
|
return equal(array(a), b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator==(const array& a, T b) {
|
||||||
|
return equal(a, array(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns the bool array with (a != b) element-wise. */
|
||||||
|
array not_equal(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
inline array operator!=(const array& a, const array& b) {
|
||||||
|
return not_equal(a, b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator!=(T a, const array& b) {
|
||||||
|
return not_equal(array(a), b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator!=(const array& a, T b) {
|
||||||
|
return not_equal(a, array(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns bool array with (a > b) element-wise. */
|
||||||
|
array greater(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
inline array operator>(const array& a, const array& b) {
|
||||||
|
return greater(a, b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator>(T a, const array& b) {
|
||||||
|
return greater(array(a), b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator>(const array& a, T b) {
|
||||||
|
return greater(a, array(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns bool array with (a >= b) element-wise. */
|
||||||
|
array greater_equal(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
inline array operator>=(const array& a, const array& b) {
|
||||||
|
return greater_equal(a, b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator>=(T a, const array& b) {
|
||||||
|
return greater_equal(array(a), b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator>=(const array& a, T b) {
|
||||||
|
return greater_equal(a, array(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns bool array with (a < b) element-wise. */
|
||||||
|
array less(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
inline array operator<(const array& a, const array& b) {
|
||||||
|
return less(a, b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator<(T a, const array& b) {
|
||||||
|
return less(array(a), b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator<(const array& a, T b) {
|
||||||
|
return less(a, array(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns bool array with (a <= b) element-wise. */
|
||||||
|
array less_equal(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
inline array operator<=(const array& a, const array& b) {
|
||||||
|
return less_equal(a, b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator<=(T a, const array& b) {
|
||||||
|
return less_equal(array(a), b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator<=(const array& a, T b) {
|
||||||
|
return less_equal(a, array(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** True if two arrays have the same shape and elements. */
|
||||||
|
array array_equal(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
bool equal_nan,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
inline array
|
||||||
|
array_equal(const array& a, const array& b, StreamOrDevice s = {}) {
|
||||||
|
return array_equal(a, b, false, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Select from x or y depending on condition. */
|
||||||
|
array where(
|
||||||
|
const array& condition,
|
||||||
|
const array& x,
|
||||||
|
const array& y,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Reduction operations */
|
||||||
|
|
||||||
|
/** True if all elements in the array are true (or non-zero). **/
|
||||||
|
array all(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||||
|
inline array all(const array& a, StreamOrDevice s = {}) {
|
||||||
|
return all(a, false, to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** True if the two arrays are equal within the specified tolerance. */
|
||||||
|
array allclose(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
double rtol = 1e-5,
|
||||||
|
double atol = 1e-8,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reduces the input along the given axes. An output value is true
|
||||||
|
* if all the corresponding inputs are true.
|
||||||
|
**/
|
||||||
|
array all(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reduces the input along the given axis. An output value is true
|
||||||
|
* if all the corresponding inputs are true.
|
||||||
|
**/
|
||||||
|
array all(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** True if any elements in the array are true (or non-zero). **/
|
||||||
|
array any(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||||
|
inline array any(const array& a, StreamOrDevice s = {}) {
|
||||||
|
return any(a, false, to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reduces the input along the given axes. An output value is true
|
||||||
|
* if any of the corresponding inputs are true.
|
||||||
|
**/
|
||||||
|
array any(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reduces the input along the given axis. An output value is true
|
||||||
|
* if any of the corresponding inputs are true.
|
||||||
|
**/
|
||||||
|
array any(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Sums the elements of an array. */
|
||||||
|
array sum(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||||
|
inline array sum(const array& a, StreamOrDevice s = {}) {
|
||||||
|
return sum(a, false, to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Sums the elements of an array along the given axes. */
|
||||||
|
array sum(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Sums the elements of an array along the given axis. */
|
||||||
|
array sum(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Computes the mean of the elements of an array. */
|
||||||
|
array mean(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||||
|
inline array mean(const array& a, StreamOrDevice s = {}) {
|
||||||
|
return mean(a, false, to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Computes the mean of the elements of an array along the given axes */
|
||||||
|
array mean(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Computes the mean of the elements of an array along the given axis */
|
||||||
|
array mean(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Computes the mean of the elements of an array. */
|
||||||
|
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
|
||||||
|
inline array var(const array& a, StreamOrDevice s = {}) {
|
||||||
|
return var(a, false, 0, to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Computes the var of the elements of an array along the given axes */
|
||||||
|
array var(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
bool keepdims = false,
|
||||||
|
int ddof = 0,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Computes the var of the elements of an array along the given axis */
|
||||||
|
array var(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
int ddof = 0,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** The product of all elements of the array. */
|
||||||
|
array prod(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||||
|
inline array prod(const array& a, StreamOrDevice s = {}) {
|
||||||
|
return prod(a, false, to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** The product of the elements of an array along the given axes. */
|
||||||
|
array prod(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** The product of the elements of an array along the given axis. */
|
||||||
|
array prod(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** The maximum of all elements of the array. */
|
||||||
|
array max(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||||
|
inline array max(const array& a, StreamOrDevice s = {}) {
|
||||||
|
return max(a, false, to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** The maximum of the elements of an array along the given axes. */
|
||||||
|
array max(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** The maximum of the elements of an array along the given axis. */
|
||||||
|
array max(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** The minimum of all elements of the array. */
|
||||||
|
array min(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||||
|
inline array min(const array& a, StreamOrDevice s = {}) {
|
||||||
|
return min(a, false, to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** The minimum of the elements of an array along the given axes. */
|
||||||
|
array min(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** The minimum of the elements of an array along the given axis. */
|
||||||
|
array min(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Returns the index of the minimum value in the array. */
|
||||||
|
array argmin(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||||
|
inline array argmin(const array& a, StreamOrDevice s = {}) {
|
||||||
|
return argmin(a, false, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns the indices of the minimum values along a given axis. */
|
||||||
|
array argmin(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Returns the index of the maximum value in the array. */
|
||||||
|
array argmax(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||||
|
inline array argmax(const array& a, StreamOrDevice s = {}) {
|
||||||
|
return argmax(a, false, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns the indices of the maximum values along a given axis. */
|
||||||
|
array argmax(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Returns a sorted copy of the flattened array. */
|
||||||
|
array sort(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Returns a sorted copy of the array along a given axis. */
|
||||||
|
array sort(const array& a, int axis, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Returns indices that sort the flattened array. */
|
||||||
|
array argsort(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Returns indices that sort the array along a given axis. */
|
||||||
|
array argsort(const array& a, int axis, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a partitioned copy of the flattened array
|
||||||
|
* such that the smaller kth elements are first.
|
||||||
|
**/
|
||||||
|
array partition(const array& a, int kth, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a partitioned copy of the array along a given axis
|
||||||
|
* such that the smaller kth elements are first.
|
||||||
|
**/
|
||||||
|
array partition(const array& a, int kth, int axis, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns indices that partition the flattened array
|
||||||
|
* such that the smaller kth elements are first.
|
||||||
|
**/
|
||||||
|
array argpartition(const array& a, int kth, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns indices that partition the array along a given axis
|
||||||
|
* such that the smaller kth elements are first.
|
||||||
|
**/
|
||||||
|
array argpartition(const array& a, int kth, int axis, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Returns topk elements of the flattened array. */
|
||||||
|
array topk(const array& a, int k, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Returns topk elements of the array along a given axis. */
|
||||||
|
array topk(const array& a, int k, int axis, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** The logsumexp of all elements of the array. */
|
||||||
|
array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||||
|
inline array logsumexp(const array& a, StreamOrDevice s = {}) {
|
||||||
|
return logsumexp(a, false, to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** The logsumexp of the elements of an array along the given axes. */
|
||||||
|
array logsumexp(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** The logsumexp of the elements of an array along the given axis. */
|
||||||
|
array logsumexp(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Simple arithmetic operations */
|
||||||
|
|
||||||
|
/** Absolute value of elements in an array. */
|
||||||
|
array abs(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Negate an array. */
|
||||||
|
array negative(const array& a, StreamOrDevice s = {});
|
||||||
|
array operator-(const array& a);
|
||||||
|
|
||||||
|
/** The sign of the elements in an array. */
|
||||||
|
array sign(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Logical not of an array */
|
||||||
|
array logical_not(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** The reciprocal (1/x) of the elements in an array. */
|
||||||
|
array reciprocal(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Add two arrays. */
|
||||||
|
array add(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
array operator+(const array& a, const array& b);
|
||||||
|
template <typename T>
|
||||||
|
array operator+(T a, const array& b) {
|
||||||
|
return add(array(a), b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator+(const array& a, T b) {
|
||||||
|
return add(a, array(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Subtract two arrays. */
|
||||||
|
array subtract(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
array operator-(const array& a, const array& b);
|
||||||
|
template <typename T>
|
||||||
|
array operator-(T a, const array& b) {
|
||||||
|
return subtract(array(a), b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator-(const array& a, T b) {
|
||||||
|
return subtract(a, array(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Multiply two arrays. */
|
||||||
|
array multiply(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
array operator*(const array& a, const array& b);
|
||||||
|
template <typename T>
|
||||||
|
array operator*(T a, const array& b) {
|
||||||
|
return multiply(array(a), b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator*(const array& a, T b) {
|
||||||
|
return multiply(a, array(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Divide two arrays. */
|
||||||
|
array divide(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
array operator/(const array& a, const array& b);
|
||||||
|
array operator/(double a, const array& b);
|
||||||
|
array operator/(const array& a, double b);
|
||||||
|
|
||||||
|
/** Element-wise maximum between two arrays. */
|
||||||
|
array maximum(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Element-wise minimum between two arrays. */
|
||||||
|
array minimum(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Square the elements of an array. */
|
||||||
|
array square(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Exponential of the elements of an array. */
|
||||||
|
array exp(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Sine of the elements of an array */
|
||||||
|
array sin(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Cosine of the elements of an array */
|
||||||
|
array cos(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Tangent of the elements of an array */
|
||||||
|
array tan(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Arc Sine of the elements of an array */
|
||||||
|
array arcsin(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Arc Cosine of the elements of an array */
|
||||||
|
array arccos(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Arc Tangent of the elements of an array */
|
||||||
|
array arctan(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Hyperbolic Sine of the elements of an array */
|
||||||
|
array sinh(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Hyperbolic Cosine of the elements of an array */
|
||||||
|
array cosh(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Hyperbolic Tangent of the elements of an array */
|
||||||
|
array tanh(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Inverse Hyperbolic Sine of the elements of an array */
|
||||||
|
array arcsinh(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Inverse Hyperbolic Cosine of the elements of an array */
|
||||||
|
array arccosh(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Inverse Hyperbolic Tangent of the elements of an array */
|
||||||
|
array arctanh(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Natural logarithm of the elements of an array. */
|
||||||
|
array log(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Log base 2 of the elements of an array. */
|
||||||
|
array log2(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Log base 10 of the elements of an array. */
|
||||||
|
array log10(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Natural logarithm of one plus elements in the array: `log(1 + a)`. */
|
||||||
|
array log1p(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Log-add-exp of one elements in the array: `log(exp(a) + exp(b))`. */
|
||||||
|
array logaddexp(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Element-wise logistic sigmoid of the array: `1 / (1 + exp(-x)`. */
|
||||||
|
array sigmoid(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Computes the error function of the elements of an array. */
|
||||||
|
array erf(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Computes the inverse error function of the elements of an array. */
|
||||||
|
array erfinv(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Stop the flow of gradients. */
|
||||||
|
array stop_gradient(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Matrix-matrix multiplication. */
|
||||||
|
array matmul(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Gather array entries given indices and slices */
|
||||||
|
array gather(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<array>& indices,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const std::vector<int>& slice_sizes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
inline array gather(
|
||||||
|
const array& a,
|
||||||
|
const array& indices,
|
||||||
|
int axis,
|
||||||
|
const std::vector<int>& slice_sizes,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Take array slices at the given indices of the specified axis. */
|
||||||
|
array take(
|
||||||
|
const array& a,
|
||||||
|
const array& indices,
|
||||||
|
int axis,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Take array entries at the given indices treating the array as flattened. */
|
||||||
|
array take(const array& a, const array& indices, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Take array entries given indices along the axis */
|
||||||
|
array take_along_axis(
|
||||||
|
const array& a,
|
||||||
|
const array& indices,
|
||||||
|
int axis,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Scatter updates to given linear indices */
|
||||||
|
array scatter(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<array>& indices,
|
||||||
|
const array& updates,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
inline array scatter(
|
||||||
|
const array& a,
|
||||||
|
const array& indices,
|
||||||
|
const array& updates,
|
||||||
|
int axis,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return scatter(a, {indices}, updates, std::vector<int>{axis}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Scatter and add updates to given indices */
|
||||||
|
array scatter_add(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<array>& indices,
|
||||||
|
const array& updates,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
inline array scatter_add(
|
||||||
|
const array& a,
|
||||||
|
const array& indices,
|
||||||
|
const array& updates,
|
||||||
|
int axis,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return scatter_add(a, {indices}, updates, std::vector<int>{axis}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Scatter and prod updates to given indices */
|
||||||
|
array scatter_prod(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<array>& indices,
|
||||||
|
const array& updates,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
inline array scatter_prod(
|
||||||
|
const array& a,
|
||||||
|
const array& indices,
|
||||||
|
const array& updates,
|
||||||
|
int axis,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return scatter_prod(a, {indices}, updates, std::vector<int>{axis}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Scatter and max updates to given linear indices */
|
||||||
|
array scatter_max(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<array>& indices,
|
||||||
|
const array& updates,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
inline array scatter_max(
|
||||||
|
const array& a,
|
||||||
|
const array& indices,
|
||||||
|
const array& updates,
|
||||||
|
int axis,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return scatter_max(a, {indices}, updates, std::vector<int>{axis}, s);
|
||||||
|
}
|
||||||
|
/** Scatter and min updates to given linear indices */
|
||||||
|
array scatter_min(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<array>& indices,
|
||||||
|
const array& updates,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
inline array scatter_min(
|
||||||
|
const array& a,
|
||||||
|
const array& indices,
|
||||||
|
const array& updates,
|
||||||
|
int axis,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return scatter_min(a, {indices}, updates, std::vector<int>{axis}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Square root the elements of an array. */
|
||||||
|
array sqrt(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Square root and reciprocal the elements of an array. */
|
||||||
|
array rsqrt(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Softmax of an array. */
|
||||||
|
array softmax(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Softmax of an array. */
|
||||||
|
array softmax(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Softmax of an array. */
|
||||||
|
inline array softmax(const array& a, int axis, StreamOrDevice s = {}) {
|
||||||
|
return softmax(a, std::vector<int>{axis}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Raise elements of a to the power of b element-wise */
|
||||||
|
array power(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
inline array operator^(const array& a, const array& b) {
|
||||||
|
return power(a, b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator^(T a, const array& b) {
|
||||||
|
return power(array(a), b);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
array operator^(const array& a, T b) {
|
||||||
|
return power(a, array(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Cumulative sum of an array. */
|
||||||
|
array cumsum(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool reverse = false,
|
||||||
|
bool inclusive = true,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Cumulative product of an array. */
|
||||||
|
array cumprod(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool reverse = false,
|
||||||
|
bool inclusive = true,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Cumulative max of an array. */
|
||||||
|
array cummax(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool reverse = false,
|
||||||
|
bool inclusive = true,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Cumulative min of an array. */
|
||||||
|
array cummin(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool reverse = false,
|
||||||
|
bool inclusive = true,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Convolution operations */
|
||||||
|
|
||||||
|
/** 1D convolution with a filter */
|
||||||
|
array conv1d(
|
||||||
|
const array& input,
|
||||||
|
const array& weight,
|
||||||
|
int stride = 1,
|
||||||
|
int padding = 0,
|
||||||
|
int dilation = 1,
|
||||||
|
int groups = 1,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** 2D convolution with a filter */
|
||||||
|
array conv2d(
|
||||||
|
const array& input,
|
||||||
|
const array& weight,
|
||||||
|
const std::pair<int, int>& stride = {1, 1},
|
||||||
|
const std::pair<int, int>& padding = {0, 0},
|
||||||
|
const std::pair<int, int>& dilation = {1, 1},
|
||||||
|
int groups = 1,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Serialization operations */
|
||||||
|
|
||||||
|
/** Save array to out stream in .npy format */
|
||||||
|
void save(
|
||||||
|
std::shared_ptr<io::Writer> out_stream,
|
||||||
|
array a,
|
||||||
|
bool retain_graph = true);
|
||||||
|
|
||||||
|
/** Save array to file in .npy format */
|
||||||
|
void save(const std::string& file, array a, bool retain_graph = true);
|
||||||
|
|
||||||
|
/** Load array from reader in .npy format */
|
||||||
|
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Load array from file in .npy format */
|
||||||
|
array load(const std::string& file, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
1527
mlx/primitives.h
Normal file
1527
mlx/primitives.h
Normal file
File diff suppressed because it is too large
Load Diff
300
mlx/random.cpp
Normal file
300
mlx/random.cpp
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
#include <cmath>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/random.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core::random {
|
||||||
|
|
||||||
|
KeySequence::KeySequence(uint64_t seed) : key_(key(seed)) {}
|
||||||
|
|
||||||
|
void KeySequence::seed(uint64_t seed) {
|
||||||
|
key_ = key((seed));
|
||||||
|
}
|
||||||
|
|
||||||
|
array KeySequence::next() {
|
||||||
|
auto out = split(key_);
|
||||||
|
key_ = out.first;
|
||||||
|
return out.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
void seed(uint64_t seed) {
|
||||||
|
KeySequence::default_().seed(seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
array key(uint64_t seed) {
|
||||||
|
uint32_t k1 = static_cast<uint32_t>(seed >> 32);
|
||||||
|
uint32_t k2 = static_cast<uint32_t>(seed);
|
||||||
|
return array({k1, k2});
|
||||||
|
}
|
||||||
|
|
||||||
|
array bits(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
int width /* 4 */,
|
||||||
|
const std::optional<array>& key_ /*= nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto key = key_ ? *key_ : KeySequence::default_().next();
|
||||||
|
if (key.dtype() != uint32) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Expected key type uint32 but received " << key.dtype() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (key.shape() != std::vector<int>{2}) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Expected key shape (2) but received " << key.shape() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_dtype = [width]() {
|
||||||
|
switch (width) {
|
||||||
|
case 4:
|
||||||
|
return uint32;
|
||||||
|
case 2:
|
||||||
|
return uint16;
|
||||||
|
case 1:
|
||||||
|
return uint8;
|
||||||
|
default:
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[bits] Bit width must be in {1, 2, 4} but got " << width << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return array(
|
||||||
|
shape,
|
||||||
|
get_dtype(),
|
||||||
|
std::make_unique<RandomBits>(to_stream(s), shape, width),
|
||||||
|
{key});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<array, array> split(const array& key, StreamOrDevice s /* = {} */) {
|
||||||
|
auto stream = to_stream(s);
|
||||||
|
auto out = mlx::core::split(random::split(key, 2, stream), 2, stream);
|
||||||
|
return {reshape(out[0], {2}, stream), reshape(out[1], {2}, stream)};
|
||||||
|
}
|
||||||
|
|
||||||
|
array split(const array& key, int num, StreamOrDevice s /* = {} */) {
|
||||||
|
return bits({num, 2}, 4, key, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array uniform(
|
||||||
|
const array& low,
|
||||||
|
const array& high,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype /* = float32 */,
|
||||||
|
const std::optional<array>& key /*= nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
if (!is_floating_point(dtype)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Can only generate uniform numbers with floating point type.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto stream = to_stream(s);
|
||||||
|
auto range = subtract(high, low, stream);
|
||||||
|
auto out_shape = broadcast_shapes(shape, range.shape());
|
||||||
|
if (out_shape != shape) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Cannot generate random values of shape " << shape
|
||||||
|
<< " from broadcasted shape " << out_shape << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
// Get random values between [0, nextafter(maxval, 0.0f)] since samples must
|
||||||
|
// be in [low, high)
|
||||||
|
// TODO replace minimum with modulo uint32_t(nextafter(maxval, 0.0f)) to avoid
|
||||||
|
// clipping effects
|
||||||
|
float maxval = std::numeric_limits<uint32_t>::max();
|
||||||
|
auto upper = array(std::nextafter(maxval, 0.0f), dtype);
|
||||||
|
auto out = minimum(bits(shape, size_of(dtype), key, stream), upper, stream);
|
||||||
|
out = divide(out, array(maxval, dtype), stream);
|
||||||
|
return add(multiply(range, out, stream), low, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
array uniform(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype,
|
||||||
|
const std::optional<array>& key /*= nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return uniform(
|
||||||
|
array(0.0, dtype), array(1.0, dtype), shape, dtype, key, to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
array normal(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype,
|
||||||
|
const std::optional<array>& key /*= nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto stream = to_stream(s);
|
||||||
|
auto low = array(std::nextafter(-1.0f, 0.0f), dtype);
|
||||||
|
auto high = array(1.0f, dtype);
|
||||||
|
auto samples = uniform(low, high, shape, dtype, key, stream);
|
||||||
|
return multiply(
|
||||||
|
array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
array randint(
|
||||||
|
const array& low,
|
||||||
|
const array& high,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype /* = int32 */,
|
||||||
|
const std::optional<array>& key /*= nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
if (!is_integral(dtype)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[randint] randint only accepts integer dtypes and bool.");
|
||||||
|
}
|
||||||
|
auto u = uniform(low, high, shape, float32, key, s);
|
||||||
|
return astype(maximum(u, low, s), dtype, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array bernoulli(
|
||||||
|
const array& p,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
const std::optional<array>& key /*= nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
if (!is_floating_point(p.dtype())) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[bernoulli] bernoulli probability `p` must be a float type.");
|
||||||
|
}
|
||||||
|
auto res = uniform(shape, p.dtype(), key, s);
|
||||||
|
res = less(res, p, s);
|
||||||
|
if (res.shape() != shape) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[bernoulli] shape of `p` is incompatible with argument `shape`.");
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
array bernoulli(
|
||||||
|
const array& p,
|
||||||
|
const std::optional<array>& key /*= nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return bernoulli(p, p.shape(), key, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array bernoulli(
|
||||||
|
const std::optional<array>& key /*= nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return bernoulli(array(0.5f), key, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array truncated_normal(
|
||||||
|
const array& lower,
|
||||||
|
const array& upper,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype /* = float32 */,
|
||||||
|
const std::optional<array>& key /*= nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
// Same as
|
||||||
|
// https://jax.readthedocs.io/en/latest/_modules/jax/_src/random.html#truncated_normal
|
||||||
|
|
||||||
|
if (!is_floating_point(dtype)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[trunc_normal] trunc_normal only accepts floating point dtypes.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sqrt2 = array(std::sqrt(2.0), dtype);
|
||||||
|
auto lower_t = astype(lower, dtype, s);
|
||||||
|
auto upper_t = astype(upper, dtype, s);
|
||||||
|
auto a = erf(divide(lower_t, sqrt2, s), s);
|
||||||
|
auto b = erf(divide(upper_t, sqrt2, s), s);
|
||||||
|
auto u = uniform(a, b, shape, dtype, key, s);
|
||||||
|
auto out = multiply(sqrt2, erfinv(u, s), s);
|
||||||
|
|
||||||
|
// Clip in bouds
|
||||||
|
return maximum(minimum(upper_t, out, s), lower_t, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array truncated_normal(
|
||||||
|
const array& lower,
|
||||||
|
const array& upper,
|
||||||
|
Dtype dtype /* = float32 */,
|
||||||
|
const std::optional<array>& key /*= nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto shape = broadcast_shapes(lower.shape(), upper.shape());
|
||||||
|
return truncated_normal(lower, upper, shape, dtype, key, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array gumbel(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype /* = float32 */,
|
||||||
|
const std::optional<array>& key /*= nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
// -log(-log(uniform(shape)))
|
||||||
|
return negative(
|
||||||
|
log(negative(log(uniform(shape, dtype, key, s), s), s), s), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_valid_axis(int axis, int ndim) {
|
||||||
|
int ax = axis < 0 ? axis + ndim : axis;
|
||||||
|
if (ax < 0 || ax >= ndim) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[categorical] Invalid axis " << axis << " for logits with " << ndim
|
||||||
|
<< " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
return ax;
|
||||||
|
}
|
||||||
|
|
||||||
|
array categorical_impl(
|
||||||
|
const array& logits,
|
||||||
|
int axis,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
const std::optional<array>& key /*= nullopt */,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
auto gumbel_shape = shape;
|
||||||
|
auto offset = axis + shape.size() - logits.ndim() + 1;
|
||||||
|
gumbel_shape.insert(gumbel_shape.begin() + offset, logits.shape(axis));
|
||||||
|
auto g = gumbel(gumbel_shape, float32, key, s);
|
||||||
|
return argmax(add(g, logits, s), offset, false, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array categorical(
|
||||||
|
const array& logits,
|
||||||
|
int axis,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
const std::optional<array>& key /*= nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
// Validate and normalize axis
|
||||||
|
axis = get_valid_axis(axis, logits.ndim());
|
||||||
|
|
||||||
|
// Check that shape broadcasts with reduce(logits, axis)
|
||||||
|
auto reduced_shape = logits.shape();
|
||||||
|
reduced_shape.erase(reduced_shape.begin() + axis);
|
||||||
|
if (broadcast_shapes(shape, reduced_shape) != shape) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[categorical] Requested shape " << shape
|
||||||
|
<< " is not broadcast compatable with reduced logits shape"
|
||||||
|
<< reduced_shape << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return categorical_impl(logits, axis, shape, key, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array categorical(
|
||||||
|
const array& logits_,
|
||||||
|
int axis,
|
||||||
|
int num_samples,
|
||||||
|
const std::optional<array>& key /*= nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
axis = get_valid_axis(axis, logits_.ndim());
|
||||||
|
auto logits = expand_dims(logits_, -1);
|
||||||
|
auto shape = logits.shape();
|
||||||
|
shape.erase(shape.begin() + axis);
|
||||||
|
shape.back() = num_samples;
|
||||||
|
return categorical_impl(logits, axis, shape, key, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array categorical(
|
||||||
|
const array& logits,
|
||||||
|
int axis /* = -1 */,
|
||||||
|
const std::optional<array>& key /*= nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
axis = get_valid_axis(axis, logits.ndim());
|
||||||
|
auto shape = logits.shape();
|
||||||
|
shape.erase(shape.begin() + axis);
|
||||||
|
return categorical_impl(logits, axis, shape, key, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::random
|
||||||
43
mlx/scheduler.cpp
Normal file
43
mlx/scheduler.cpp
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
#include "mlx/scheduler.h"
|
||||||
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
Stream default_stream(Device d) {
|
||||||
|
if (!metal::is_available() && d == Device::gpu) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[default_stream] Cannot get gpu stream without gpu backend.");
|
||||||
|
}
|
||||||
|
return scheduler::scheduler().get_default_stream(d);
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_default_stream(Stream s) {
|
||||||
|
if (!metal::is_available() && s.device == Device::gpu) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[set_default_stream] Cannot set gpu stream without gpu backend.");
|
||||||
|
}
|
||||||
|
return scheduler::scheduler().set_default_stream(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
Stream new_stream(Device d) {
|
||||||
|
if (!metal::is_available() && d == Device::gpu) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[new_stream] Cannot make gpu stream without gpu backend.");
|
||||||
|
}
|
||||||
|
return scheduler::scheduler().new_stream(d);
|
||||||
|
}
|
||||||
|
|
||||||
|
Stream new_stream() {
|
||||||
|
return scheduler::scheduler().new_stream(default_device());
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace scheduler {
|
||||||
|
|
||||||
|
/** A singleton scheduler to manage devices, streams, and task execution. */
|
||||||
|
Scheduler& scheduler() {
|
||||||
|
static Scheduler scheduler;
|
||||||
|
return scheduler;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace scheduler
|
||||||
|
} // namespace mlx::core
|
||||||
170
mlx/scheduler.h
Normal file
170
mlx/scheduler.h
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <future>
|
||||||
|
#include <queue>
|
||||||
|
#include <thread>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
#include "mlx/device.h"
|
||||||
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
|
namespace mlx::core::scheduler {
|
||||||
|
|
||||||
|
struct StreamThread {
|
||||||
|
std::mutex mtx;
|
||||||
|
std::queue<std::function<void()>> q;
|
||||||
|
std::condition_variable cond;
|
||||||
|
bool stop;
|
||||||
|
Stream stream;
|
||||||
|
std::thread thread;
|
||||||
|
|
||||||
|
StreamThread(Stream stream)
|
||||||
|
: stop(false), stream(stream), thread(&StreamThread::thread_fn, this) {}
|
||||||
|
|
||||||
|
~StreamThread() {
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lk(mtx);
|
||||||
|
stop = true;
|
||||||
|
}
|
||||||
|
cond.notify_one();
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
void thread_fn() {
|
||||||
|
metal::new_stream(stream);
|
||||||
|
while (true) {
|
||||||
|
std::function<void()> task;
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lk(mtx);
|
||||||
|
cond.wait(lk, [this] { return !this->q.empty() || this->stop; });
|
||||||
|
if (q.empty() && stop) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
task = std::move(q.front());
|
||||||
|
q.pop();
|
||||||
|
}
|
||||||
|
task();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void enqueue(F&& f) {
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lk(mtx);
|
||||||
|
if (stop) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Cannot enqueue work after stream is stopped.");
|
||||||
|
}
|
||||||
|
q.emplace(std::forward<F>(f));
|
||||||
|
}
|
||||||
|
cond.notify_one();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class Scheduler {
|
||||||
|
public:
|
||||||
|
Scheduler() : n_active_tasks_(0) {
|
||||||
|
if (metal::is_available()) {
|
||||||
|
default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
|
||||||
|
}
|
||||||
|
default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not copyable or moveable
|
||||||
|
Scheduler(const Scheduler&) = delete;
|
||||||
|
Scheduler(Scheduler&&) = delete;
|
||||||
|
Scheduler& operator=(const Scheduler&) = delete;
|
||||||
|
Scheduler& operator=(Scheduler&&) = delete;
|
||||||
|
|
||||||
|
Stream new_stream(const Device& d) {
|
||||||
|
auto stream = Stream(streams_.size(), d);
|
||||||
|
streams_.push_back(new StreamThread{stream});
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void enqueue(const Stream& stream, F&& f);
|
||||||
|
|
||||||
|
Stream get_default_stream(const Device& d) {
|
||||||
|
return default_streams_.at(d.type);
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_default_stream(const Stream& s) {
|
||||||
|
default_streams_.at(s.device.type) = s;
|
||||||
|
}
|
||||||
|
|
||||||
|
void notify_new_task(const Stream& stream) {
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lk(mtx);
|
||||||
|
n_active_tasks_++;
|
||||||
|
}
|
||||||
|
completion_cv.notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
void notify_task_completion(const Stream& stream) {
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lk(mtx);
|
||||||
|
n_active_tasks_--;
|
||||||
|
}
|
||||||
|
completion_cv.notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
int n_active_tasks() const {
|
||||||
|
return n_active_tasks_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void wait_for_one() {
|
||||||
|
std::unique_lock<std::mutex> lk(mtx);
|
||||||
|
int n_tasks_old = n_active_tasks();
|
||||||
|
if (n_tasks_old > 1) {
|
||||||
|
completion_cv.wait(lk, [this, n_tasks_old] {
|
||||||
|
return this->n_active_tasks() != n_tasks_old;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~Scheduler() {
|
||||||
|
for (auto s : streams_) {
|
||||||
|
delete s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int n_active_tasks_;
|
||||||
|
std::vector<StreamThread*> streams_;
|
||||||
|
std::unordered_map<Device::DeviceType, Stream> default_streams_;
|
||||||
|
std::condition_variable completion_cv;
|
||||||
|
std::mutex mtx;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void Scheduler::enqueue(const Stream& stream, F&& f) {
|
||||||
|
streams_[stream.index]->enqueue(std::forward<F>(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
Scheduler& scheduler();
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void enqueue(const Stream& stream, F&& f) {
|
||||||
|
scheduler().enqueue(stream, std::forward<F>(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int n_active_tasks() {
|
||||||
|
return scheduler().n_active_tasks();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void notify_new_task(const Stream& stream) {
|
||||||
|
scheduler().notify_new_task(stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void notify_task_completion(const Stream& stream) {
|
||||||
|
scheduler().notify_task_completion(stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void wait_for_one() {
|
||||||
|
scheduler().wait_for_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::scheduler
|
||||||
30
mlx/stream.h
Normal file
30
mlx/stream.h
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/device.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
struct Stream {
|
||||||
|
int index;
|
||||||
|
Device device;
|
||||||
|
explicit Stream(int index, Device device) : index(index), device(device) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
/** Get the default stream for the given device. */
|
||||||
|
Stream default_stream(Device d);
|
||||||
|
|
||||||
|
/** Make the stream the default for its device. */
|
||||||
|
void set_default_stream(Stream s);
|
||||||
|
|
||||||
|
/** Make a new stream on the given device. */
|
||||||
|
Stream new_stream(Device d);
|
||||||
|
|
||||||
|
inline bool operator==(const Stream& lhs, const Stream& rhs) {
|
||||||
|
return lhs.index == rhs.index;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool operator!=(const Stream& lhs, const Stream& rhs) {
|
||||||
|
return !(lhs == rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
54
mlx/types/half_types.h
Normal file
54
mlx/types/half_types.h
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
#pragma once
|
||||||
|
#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
|
||||||
|
|
||||||
|
#include <arm_fp16.h>
|
||||||
|
namespace mlx::core {
|
||||||
|
typedef __fp16 float16_t;
|
||||||
|
} // namespace mlx::core
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
#define ADD_HALF_BINOPS
|
||||||
|
#include "mlx/types/fp16.h"
|
||||||
|
namespace mlx::core {
|
||||||
|
typedef struct _MLX_Float16 float16_t;
|
||||||
|
} // namespace mlx::core
|
||||||
|
|
||||||
|
#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
|
||||||
|
#ifdef __ARM_FEATURE_BF16
|
||||||
|
|
||||||
|
#include <arm_bf16.h>
|
||||||
|
namespace mlx::core {
|
||||||
|
typedef __bf16 bfloat16_t;
|
||||||
|
} // namespace mlx::core
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
#define ADD_HALF_BINOPS
|
||||||
|
#include "mlx/types/bf16.h"
|
||||||
|
namespace mlx::core {
|
||||||
|
typedef struct _MLX_BFloat16 bfloat16_t;
|
||||||
|
} // namespace mlx::core
|
||||||
|
|
||||||
|
#endif // __ARM_FEATURE_BF16
|
||||||
|
|
||||||
|
#ifdef ADD_HALF_BINOPS
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#define fp16_bf16_binop_helper(__op__, __operator__) \
|
||||||
|
inline float __operator__(float16_t lhs, bfloat16_t rhs) { \
|
||||||
|
return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||||
|
} \
|
||||||
|
inline float __operator__(bfloat16_t lhs, float16_t rhs) { \
|
||||||
|
return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||||
|
}
|
||||||
|
|
||||||
|
fp16_bf16_binop_helper(+, operator+)
|
||||||
|
fp16_bf16_binop_helper(-, operator-)
|
||||||
|
fp16_bf16_binop_helper(*, operator*)
|
||||||
|
fp16_bf16_binop_helper(/, operator/)
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
|
#endif
|
||||||
255
mlx/utils.cpp
Normal file
255
mlx/utils.cpp
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
#include <sstream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
Dtype result_type(const std::vector<array>& arrays) {
|
||||||
|
std::vector<Dtype> dtypes(1, bool_);
|
||||||
|
for (auto& arr : arrays) {
|
||||||
|
dtypes.push_back(promote_types(dtypes.back(), arr.dtype()));
|
||||||
|
}
|
||||||
|
return dtypes.back();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> broadcast_shapes(
|
||||||
|
const std::vector<int>& s1,
|
||||||
|
const std::vector<int>& s2) {
|
||||||
|
// Use the same broadcasting rules as numpy
|
||||||
|
// https://numpy.org/doc/1.20/user/theory.broadcasting.html
|
||||||
|
// "The size of the trailing axes for both arrays in an operation must
|
||||||
|
// either be the same size or one of them must be one."
|
||||||
|
int ndim1 = s1.size();
|
||||||
|
int ndim2 = s2.size();
|
||||||
|
int ndim = std::max(ndim1, ndim2);
|
||||||
|
int diff = std::abs(ndim1 - ndim2);
|
||||||
|
const auto& big = ndim1 > ndim2 ? s1 : s2;
|
||||||
|
const auto& small = ndim1 > ndim2 ? s2 : s1;
|
||||||
|
std::vector<int> out_shape(ndim);
|
||||||
|
for (int i = ndim - 1; i >= diff; --i) {
|
||||||
|
int a = big[i];
|
||||||
|
int b = small[i - diff];
|
||||||
|
if (b == a) {
|
||||||
|
out_shape[i] = a;
|
||||||
|
} else if (a == 1 || b == 1) {
|
||||||
|
// 0 if a or b is 0 otherwise max(a, b)
|
||||||
|
out_shape[i] = a * b;
|
||||||
|
} else {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Shapes " << s1 << " and " << s2 << " cannot be broadcast.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = diff - 1; i >= 0; --i) {
|
||||||
|
out_shape[i] = big[i];
|
||||||
|
}
|
||||||
|
return out_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, const Device& d) {
|
||||||
|
os << "Device(";
|
||||||
|
switch (d.type) {
|
||||||
|
case Device::cpu:
|
||||||
|
os << "cpu";
|
||||||
|
break;
|
||||||
|
case Device::gpu:
|
||||||
|
os << "gpu";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
os << ", " << d.index << ")";
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, const Stream& s) {
|
||||||
|
os << "Stream(";
|
||||||
|
os << s.device;
|
||||||
|
os << ", " << s.index << ")";
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, int8_t x) {
|
||||||
|
os << static_cast<int>(x);
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, uint8_t x) {
|
||||||
|
os << static_cast<uint>(x);
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
inline size_t elem_to_loc(
|
||||||
|
int elem,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
const std::vector<size_t>& strides) {
|
||||||
|
size_t loc = 0;
|
||||||
|
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||||
|
auto q_and_r = ldiv(elem, shape[i]);
|
||||||
|
loc += q_and_r.rem * strides[i];
|
||||||
|
elem = q_and_r.quot;
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void print_subarray(std::ostream& os, const array& a, size_t index, int dim) {
|
||||||
|
int num_print = 3;
|
||||||
|
int n = a.shape(dim);
|
||||||
|
size_t s = a.strides()[dim];
|
||||||
|
bool is_last = dim == a.ndim() - 1;
|
||||||
|
auto prefix = is_last ? "" : std::string(7 + dim, ' ');
|
||||||
|
auto postfix = is_last ? ", " : ",\n";
|
||||||
|
os << "[";
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
os << (i == 0 ? "" : prefix);
|
||||||
|
if (i == num_print && n > 2 * num_print) {
|
||||||
|
os << "...";
|
||||||
|
i = n - num_print - 1;
|
||||||
|
index += s * (n - 2 * num_print - 1);
|
||||||
|
} else if (is_last) {
|
||||||
|
os << a.data<T>()[index];
|
||||||
|
} else {
|
||||||
|
print_subarray<T>(os, a, index, dim + 1);
|
||||||
|
}
|
||||||
|
os << (i == n - 1 ? "" : postfix);
|
||||||
|
index += s;
|
||||||
|
}
|
||||||
|
os << "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void print_array(std::ostream& os, const array& a) {
|
||||||
|
std::vector<int> indices(a.ndim(), 0);
|
||||||
|
os << std::boolalpha;
|
||||||
|
os << "array(";
|
||||||
|
if (a.ndim() == 0) {
|
||||||
|
auto data = a.data<T>();
|
||||||
|
os << data[0];
|
||||||
|
} else {
|
||||||
|
print_subarray<T>(os, a, 0, 0);
|
||||||
|
}
|
||||||
|
os << ", dtype=" << a.dtype() << ")";
|
||||||
|
os << std::noboolalpha;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, const Dtype& dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case bool_:
|
||||||
|
return os << "bool";
|
||||||
|
case uint8:
|
||||||
|
return os << "uint8";
|
||||||
|
case uint16:
|
||||||
|
return os << "uint16";
|
||||||
|
case uint32:
|
||||||
|
return os << "uint32";
|
||||||
|
case uint64:
|
||||||
|
return os << "uint64";
|
||||||
|
case int8:
|
||||||
|
return os << "int8";
|
||||||
|
case int16:
|
||||||
|
return os << "int16";
|
||||||
|
case int32:
|
||||||
|
return os << "int32";
|
||||||
|
case int64:
|
||||||
|
return os << "int64";
|
||||||
|
case float16:
|
||||||
|
return os << "float16";
|
||||||
|
case float32:
|
||||||
|
return os << "float32";
|
||||||
|
case bfloat16:
|
||||||
|
return os << "bfloat16";
|
||||||
|
case complex64:
|
||||||
|
return os << "complex64";
|
||||||
|
}
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
|
||||||
|
switch (k) {
|
||||||
|
case Dtype::Kind::b:
|
||||||
|
return os << "b";
|
||||||
|
case Dtype::Kind::i:
|
||||||
|
return os << "i";
|
||||||
|
case Dtype::Kind::u:
|
||||||
|
return os << "u";
|
||||||
|
case Dtype::Kind::f:
|
||||||
|
return os << "f";
|
||||||
|
case Dtype::Kind::c:
|
||||||
|
return os << "c";
|
||||||
|
case Dtype::Kind::V:
|
||||||
|
return os << "V";
|
||||||
|
}
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, array a) {
|
||||||
|
if (!a.is_evaled()) {
|
||||||
|
a.eval();
|
||||||
|
}
|
||||||
|
switch (a.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
print_array<bool>(os, a);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
print_array<uint8_t>(os, a);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
print_array<uint16_t>(os, a);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
print_array<uint32_t>(os, a);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
print_array<uint64_t>(os, a);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
print_array<int8_t>(os, a);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
print_array<int16_t>(os, a);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
print_array<int32_t>(os, a);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
print_array<int64_t>(os, a);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
print_array<float16_t>(os, a);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
print_array<bfloat16_t>(os, a);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
print_array<float>(os, a);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
print_array<complex64_t>(os, a);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
|
||||||
|
os << "(";
|
||||||
|
for (int i = 0; i < v.size(); ++i) {
|
||||||
|
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||||
|
}
|
||||||
|
os << ")";
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v) {
|
||||||
|
os << "(";
|
||||||
|
for (int i = 0; i < v.size(); ++i) {
|
||||||
|
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||||
|
}
|
||||||
|
os << ")";
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
33
mlx/utils.h
Normal file
33
mlx/utils.h
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "array.h"
|
||||||
|
#include "device.h"
|
||||||
|
#include "dtype.h"
|
||||||
|
#include "stream.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
/** The type from promoting the arrays' types with one another. */
|
||||||
|
Dtype result_type(const std::vector<array>& arrays);
|
||||||
|
|
||||||
|
std::vector<int> broadcast_shapes(
|
||||||
|
const std::vector<int>& s1,
|
||||||
|
const std::vector<int>& s2);
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, const Device& d);
|
||||||
|
std::ostream& operator<<(std::ostream& os, const Stream& s);
|
||||||
|
std::ostream& operator<<(std::ostream& os, const Dtype& d);
|
||||||
|
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
|
||||||
|
std::ostream& operator<<(std::ostream& os, array a);
|
||||||
|
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
|
||||||
|
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v);
|
||||||
|
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
|
||||||
|
return os << v.real() << (v.imag() > 0 ? "+" : "") << v.imag() << "j";
|
||||||
|
}
|
||||||
|
inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
|
||||||
|
return os << static_cast<float>(v);
|
||||||
|
}
|
||||||
|
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
|
||||||
|
return os << static_cast<float>(v);
|
||||||
|
}
|
||||||
|
} // namespace mlx::core
|
||||||
37
python/README.md
Normal file
37
python/README.md
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
### Packaging for PyPI
|
||||||
|
|
||||||
|
Install `build` and `twine`:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install --user --upgrade build
|
||||||
|
pip install --user --upgrade twine
|
||||||
|
```
|
||||||
|
|
||||||
|
Generate the source distribution and wheel:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m build
|
||||||
|
```
|
||||||
|
|
||||||
|
*Warning* use a test server first
|
||||||
|
|
||||||
|
#### Test Upload
|
||||||
|
|
||||||
|
Upload to test server:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m twine upload --repository testpypi dist/*
|
||||||
|
```
|
||||||
|
|
||||||
|
Install from test server and check that it works:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Upload
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m twine upload dist/*
|
||||||
|
```
|
||||||
|
|
||||||
18
python/mlx/_reprlib_fix.py
Normal file
18
python/mlx/_reprlib_fix.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
import array
|
||||||
|
import reprlib
|
||||||
|
|
||||||
|
|
||||||
|
class FixedRepr(reprlib.Repr):
|
||||||
|
"""Only route python array instances to repr_array."""
|
||||||
|
|
||||||
|
def repr_array(self, x, maxlevel):
|
||||||
|
if isinstance(x, array.array):
|
||||||
|
return super().repr_array(x, maxlevel)
|
||||||
|
else:
|
||||||
|
return self.repr_instance(x, maxlevel)
|
||||||
|
|
||||||
|
|
||||||
|
# We need to monkey-patch reprlib so that we can use the debugger without
|
||||||
|
# renaming the array to something else
|
||||||
|
fixed_repr = FixedRepr()
|
||||||
|
reprlib.repr = fixed_repr.repr
|
||||||
94
python/mlx/extension.py
Normal file
94
python/mlx/extension.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from setuptools import Extension, setup, find_namespace_packages
|
||||||
|
from setuptools.command.build_ext import build_ext
|
||||||
|
|
||||||
|
import mlx
|
||||||
|
|
||||||
|
_MLX_PATH = str(mlx.__path__[0])
|
||||||
|
|
||||||
|
|
||||||
|
# A CMakeExtension needs a sourcedir instead of a file list.
|
||||||
|
class CMakeExtension(Extension):
|
||||||
|
def __init__(self, name: str, sourcedir: str = "") -> None:
|
||||||
|
super().__init__(name, sources=[])
|
||||||
|
self.sourcedir = os.fspath(Path(sourcedir).resolve())
|
||||||
|
|
||||||
|
|
||||||
|
class CMakeBuild(build_ext):
|
||||||
|
def build_extension(self, ext: CMakeExtension) -> None:
|
||||||
|
# Must be in this form due to bug in .resolve() only fixed in Python 3.10+
|
||||||
|
ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) # type: ignore[no-untyped-call]
|
||||||
|
extdir = ext_fullpath.parent.resolve()
|
||||||
|
|
||||||
|
debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
|
||||||
|
cfg = "Debug" if debug else "Release"
|
||||||
|
|
||||||
|
# CMake lets you override the generator - we need to check this.
|
||||||
|
# Can be set with Conda-Build, for example.
|
||||||
|
cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
|
||||||
|
|
||||||
|
# Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
|
||||||
|
# EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
|
||||||
|
# from Python.
|
||||||
|
cmake_args = [
|
||||||
|
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}",
|
||||||
|
f"-DCMAKE_BUILD_TYPE={cfg}",
|
||||||
|
"-DBUILD_SHARED_LIBS=ON",
|
||||||
|
]
|
||||||
|
build_args = []
|
||||||
|
# Adding CMake arguments set as environment variable
|
||||||
|
# (needed e.g. to build for ARM OSx on conda-forge)
|
||||||
|
if "CMAKE_ARGS" in os.environ:
|
||||||
|
cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
|
||||||
|
|
||||||
|
if sys.platform.startswith("darwin"):
|
||||||
|
# Cross-compile support for macOS - respect ARCHFLAGS if set
|
||||||
|
archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
|
||||||
|
if archs:
|
||||||
|
cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))]
|
||||||
|
|
||||||
|
# Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
|
||||||
|
# across all generators.
|
||||||
|
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
|
||||||
|
# self.parallel is a Python 3 only way to set parallel jobs by hand
|
||||||
|
# using -j in the build_ext call, not supported by pip or PyPA-build.
|
||||||
|
if hasattr(self, "parallel") and self.parallel:
|
||||||
|
# CMake 3.12+ only.
|
||||||
|
build_args += [f"-j{self.parallel}"]
|
||||||
|
|
||||||
|
build_temp = Path(self.build_temp) / ext.name
|
||||||
|
if not build_temp.exists():
|
||||||
|
build_temp.mkdir(parents=True)
|
||||||
|
|
||||||
|
# Make sure cmake can find MLX
|
||||||
|
os.environ["MLX_DIR"] = _MLX_PATH
|
||||||
|
|
||||||
|
subprocess.run(
|
||||||
|
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
|
||||||
|
)
|
||||||
|
subprocess.run(
|
||||||
|
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
super().run()
|
||||||
|
|
||||||
|
# Based on https://github.com/pypa/setuptools/blob/main/setuptools/command/build_ext.py#L102
|
||||||
|
if self.inplace:
|
||||||
|
for ext in self.extensions:
|
||||||
|
if isinstance(ext, CMakeExtension):
|
||||||
|
# Resolve inplace package dir
|
||||||
|
build_py = self.get_finalized_command("build_py")
|
||||||
|
inplace_file, regular_file = self._get_inplace_equivalent(
|
||||||
|
build_py, ext
|
||||||
|
)
|
||||||
|
|
||||||
|
inplace_dir = str(Path(inplace_file).parent.resolve())
|
||||||
|
regular_dir = str(Path(regular_file).parent.resolve())
|
||||||
|
|
||||||
|
self.copy_tree(regular_dir, inplace_dir)
|
||||||
401
python/mlx/nn/layers/base.py
Normal file
401
python/mlx/nn/layers/base.py
Normal file
@@ -0,0 +1,401 @@
|
|||||||
|
import textwrap
|
||||||
|
from typing import Any, Callable, List, Union, Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx.utils import tree_flatten, tree_unflatten
|
||||||
|
|
||||||
|
|
||||||
|
class Module(dict):
|
||||||
|
"""Base class for building neural networks with MLX.
|
||||||
|
|
||||||
|
All the layers provided in :mod:`mlx.nn.layers` subclass this class and
|
||||||
|
your models should do the same.
|
||||||
|
|
||||||
|
A ``Module`` can contain other ``Module`` instances or :class:`mlx.core.array`
|
||||||
|
instances in arbitrary nesting of python lists or dicts. The ``Module``
|
||||||
|
then allows recursively extracting all the :class:`mlx.core.array` instances
|
||||||
|
using :meth:`mlx.nn.Module.parameters`.
|
||||||
|
|
||||||
|
In addition, the ``Module`` has the concept of trainable and non trainable
|
||||||
|
parameters (called "frozen"). When using :func:`mlx.nn.value_and_grad`
|
||||||
|
the gradients are returned only with respect to the trainable parameters.
|
||||||
|
All arrays in a module are trainable unless they are added in the "frozen"
|
||||||
|
set by calling :meth:`freeze`.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class MyMLP(nn.Module):
|
||||||
|
def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.in_proj = nn.Linear(in_dims, hidden_dims)
|
||||||
|
self.out_proj = nn.Linear(hidden_dims, out_dims)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.in_proj(x)
|
||||||
|
x = mx.maximum(x, 0)
|
||||||
|
return self.out_proj(x)
|
||||||
|
|
||||||
|
model = MyMLP(2, 1)
|
||||||
|
|
||||||
|
# All the model parameters are created but since MLX is lazy by
|
||||||
|
# default, they are not evaluated yet. Calling `mx.eval` actually
|
||||||
|
# allocates memory and initializes the parameters.
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
# Setting a parameter to a new value is as simply as accessing that
|
||||||
|
# parameter and assigning a new array to it.
|
||||||
|
model.in_proj.weight = model.in_proj.weight * 2
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Should be called by the subclasses of ``Module``."""
|
||||||
|
self._no_grad = set()
|
||||||
|
self._training = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def training(self):
|
||||||
|
return self._training
|
||||||
|
|
||||||
|
def _extra_repr(self):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
children = tree_flatten(self.children(), is_leaf=self.is_module)
|
||||||
|
value = f"{type(self).__name__}({self._extra_repr()}"
|
||||||
|
for k, v in children:
|
||||||
|
value += "\n"
|
||||||
|
value += textwrap.indent(f"({k}): {repr(v)}", prefix=" ")
|
||||||
|
if children:
|
||||||
|
value += "\n"
|
||||||
|
value += ")"
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
def __getattr__(self, key: str):
|
||||||
|
if key in self:
|
||||||
|
return self[key]
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
|
||||||
|
|
||||||
|
def __setattr__(self, key: str, val: Any):
|
||||||
|
self[key] = val
|
||||||
|
|
||||||
|
def load_weights(self, file: str):
|
||||||
|
"""
|
||||||
|
Load and update the model's weights from a `.npz` file.
|
||||||
|
"""
|
||||||
|
self.update(tree_unflatten(list(mx.load(file).items())))
|
||||||
|
|
||||||
|
def save_weights(self, file: str):
|
||||||
|
"""
|
||||||
|
Save the model's weights to a `.npz` file.
|
||||||
|
"""
|
||||||
|
mx.savez(file, **dict(tree_flatten(self.parameters())))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_module(value):
|
||||||
|
return isinstance(value, Module)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def valid_child_filter(module, key, value):
|
||||||
|
return isinstance(value, (dict, list))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def valid_parameter_filter(module, key, value):
|
||||||
|
return isinstance(value, (dict, list, mx.array)) and not key.startswith("_")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def trainable_parameter_filter(module, key, value):
|
||||||
|
return (
|
||||||
|
Module.valid_parameter_filter(module, key, value)
|
||||||
|
and key not in module._no_grad
|
||||||
|
)
|
||||||
|
|
||||||
|
def filter_and_map(
|
||||||
|
self,
|
||||||
|
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
|
||||||
|
map_fn: Optional[Callable] = None,
|
||||||
|
is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
||||||
|
):
|
||||||
|
"""Recursively filter the contents of the module using ``filter_fn``,
|
||||||
|
namely only select keys and values where ``filter_fn`` returns true.
|
||||||
|
|
||||||
|
This is used to implement :meth:`parameters` and :meth:`trainable_parameters`
|
||||||
|
but it can also be used to extract any subset of the module's parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_fn (Callable): Given a value, the key in which it is found
|
||||||
|
and the containing module, decide whether to keep the value or
|
||||||
|
drop it.
|
||||||
|
map_fn (Callable, optional): Optionally transform the value before
|
||||||
|
returning it.
|
||||||
|
is_leaf_fn (Callable, optional): Given a value, the key in which it
|
||||||
|
is found and the containing module decide if it is a leaf.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing the contents of the module recursively filtered
|
||||||
|
"""
|
||||||
|
|
||||||
|
map_fn = map_fn or (lambda x: x)
|
||||||
|
is_leaf_fn = is_leaf_fn or (
|
||||||
|
lambda m, k, v: not isinstance(v, (Module, dict, list))
|
||||||
|
)
|
||||||
|
|
||||||
|
def unwrap(vk, v):
|
||||||
|
if is_leaf_fn(self, vk, v):
|
||||||
|
return map_fn(v)
|
||||||
|
|
||||||
|
if isinstance(v, Module):
|
||||||
|
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
|
||||||
|
|
||||||
|
if isinstance(v, dict):
|
||||||
|
nd = {}
|
||||||
|
for k, v in v.items():
|
||||||
|
tk = f"{vk}.{k}"
|
||||||
|
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
|
||||||
|
return nd
|
||||||
|
|
||||||
|
if isinstance(v, list):
|
||||||
|
nl = []
|
||||||
|
for i, vi in enumerate(v):
|
||||||
|
tk = f"{vk}.{i}"
|
||||||
|
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
|
||||||
|
return nl
|
||||||
|
|
||||||
|
raise RuntimeError("Unexpected leaf found while traversing the module")
|
||||||
|
|
||||||
|
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
|
||||||
|
|
||||||
|
def parameters(self):
|
||||||
|
"""Recursively return all the :class:`mlx.core.array` members of this Module
|
||||||
|
as a dict of dicts and lists."""
|
||||||
|
return self.filter_and_map(self.valid_parameter_filter)
|
||||||
|
|
||||||
|
def trainable_parameters(self):
|
||||||
|
"""Recursively return all the non frozen :class:`mlx.core.array` members of
|
||||||
|
this Module as a dict of dicts and lists."""
|
||||||
|
return self.filter_and_map(self.trainable_parameter_filter)
|
||||||
|
|
||||||
|
def children(self):
|
||||||
|
"""Return the direct descendants of this Module instance."""
|
||||||
|
return self.filter_and_map(
|
||||||
|
self.valid_child_filter, is_leaf_fn=lambda m, k, v: isinstance(v, Module)
|
||||||
|
)
|
||||||
|
|
||||||
|
def leaf_modules(self):
|
||||||
|
"""Return the submodules that do not contain other modules."""
|
||||||
|
|
||||||
|
def _is_leaf_module(m, k, v):
|
||||||
|
return isinstance(v, Module) and len(tree_flatten(v.children())) == 0
|
||||||
|
|
||||||
|
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
|
||||||
|
|
||||||
|
def update(self, parameters: dict):
|
||||||
|
"""Replace the parameters of this Module with the provided ones in the
|
||||||
|
dict of dicts and lists.
|
||||||
|
|
||||||
|
Commonly used by the optimizer to change the model to the updated
|
||||||
|
(optimized) parameters. Also used by the :meth:`mlx.nn.value_and_grad` to set the
|
||||||
|
tracers in the model in order to compute gradients.
|
||||||
|
|
||||||
|
The passed in parameters dictionary need not be a full dictionary
|
||||||
|
similar to :meth:`parameters`. Only the provided locations will be
|
||||||
|
updated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parameters (dict): A complete or partial dictionary of the modules
|
||||||
|
parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def apply(dst, parameters):
|
||||||
|
if isinstance(parameters, dict):
|
||||||
|
for k in parameters:
|
||||||
|
if k in dst:
|
||||||
|
current_value = dst[k]
|
||||||
|
new_value = parameters[k]
|
||||||
|
if isinstance(current_value, mx.array):
|
||||||
|
dst[k] = new_value
|
||||||
|
elif isinstance(current_value, Module):
|
||||||
|
current_value.update(new_value)
|
||||||
|
elif isinstance(current_value, (dict, list)):
|
||||||
|
apply(current_value, new_value)
|
||||||
|
elif isinstance(parameters, list):
|
||||||
|
for i in range(len(dst)):
|
||||||
|
current_value = dst[i]
|
||||||
|
new_value = parameters[i]
|
||||||
|
if isinstance(current_value, mx.array):
|
||||||
|
dst[i] = new_value
|
||||||
|
elif isinstance(current_value, Module):
|
||||||
|
current_value.update(new_value)
|
||||||
|
elif isinstance(current_value, (dict, list)):
|
||||||
|
apply(current_value, new_value)
|
||||||
|
|
||||||
|
apply(self, parameters)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
map_fn: Callable[[mx.array], mx.array],
|
||||||
|
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
||||||
|
):
|
||||||
|
"""Map all the parameters using the provided ``map_fn`` and immediately
|
||||||
|
update the module with the mapped parameters.
|
||||||
|
|
||||||
|
For instance running ``model.apply(lambda x: x.astype(mx.float16))``
|
||||||
|
casts all parameters to 16 bit floats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
map_fn (Callable): Maps an array to another array
|
||||||
|
filter_fn (Callable, optional): Filter to select which arrays to
|
||||||
|
map (default: :meth:`Module.valid_parameter_filter`).
|
||||||
|
"""
|
||||||
|
filter_fn = filter_fn or Module.valid_parameter_filter
|
||||||
|
self.update(self.filter_and_map(filter_fn, map_fn))
|
||||||
|
|
||||||
|
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
|
||||||
|
"""Apply a function to all the modules in this instance (including this
|
||||||
|
instance).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
apply_fn (Callable): The function to apply to the modules.
|
||||||
|
"""
|
||||||
|
module_stack = [("", self)]
|
||||||
|
while module_stack:
|
||||||
|
prefix, mod = module_stack.pop()
|
||||||
|
apply_fn(prefix, mod)
|
||||||
|
prefix = "." + prefix if prefix else ""
|
||||||
|
module_stack.extend(
|
||||||
|
tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
|
||||||
|
)
|
||||||
|
|
||||||
|
def modules(self):
|
||||||
|
"""Return a list with all the modules in this instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of :class:`mlx.nn.Module` instances.
|
||||||
|
"""
|
||||||
|
modulelist = []
|
||||||
|
self.apply_to_modules(lambda k, m: modulelist.append(m))
|
||||||
|
return modulelist
|
||||||
|
|
||||||
|
def named_modules(self):
|
||||||
|
"""Return a list with all the modules in this instance and their name
|
||||||
|
with dot notation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of tuples (str, :class:`mlx.nn.Module`).
|
||||||
|
"""
|
||||||
|
modulelist = []
|
||||||
|
self.apply_to_modules(lambda k, m: modulelist.append((k, m)))
|
||||||
|
return modulelist
|
||||||
|
|
||||||
|
def _validate_keys(self, keys, strict):
|
||||||
|
keys = keys if isinstance(keys, list) else [keys]
|
||||||
|
if strict:
|
||||||
|
for k in keys:
|
||||||
|
if k not in self:
|
||||||
|
raise KeyError(f"Module doesn't contain member {k}.")
|
||||||
|
return keys
|
||||||
|
|
||||||
|
def freeze(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
recurse: bool = True,
|
||||||
|
keys: Optional[Union[str, List[str]]] = None,
|
||||||
|
strict: bool = False,
|
||||||
|
):
|
||||||
|
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
|
||||||
|
computing gradients for it.
|
||||||
|
|
||||||
|
This function is idempotent ie freezing a frozen model is a noop.
|
||||||
|
|
||||||
|
For instance to only train the attention parameters from a transformer:
|
||||||
|
|
||||||
|
model = ...
|
||||||
|
model.freeze()
|
||||||
|
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recurse (bool, optional): If True then freeze the parameters of the
|
||||||
|
submodules as well (default: True).
|
||||||
|
keys (str or list[str], optional): If provided then only these
|
||||||
|
parameters will be frozen otherwise all the parameters of a
|
||||||
|
module. For instance freeze all biases by calling
|
||||||
|
``module.freeze(keys="bias")``.
|
||||||
|
strict (bool, optional): If set to True validate that the passed keys exist
|
||||||
|
(default: False).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _freeze_impl(_, m):
|
||||||
|
local_keys = keys
|
||||||
|
if local_keys is None:
|
||||||
|
local_keys = tree_flatten(
|
||||||
|
m.filter_and_map(
|
||||||
|
lambda m, k, v: (not isinstance(v, Module))
|
||||||
|
and m.valid_parameter_filter(m, k, v)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
local_keys = [k for (k, v) in local_keys]
|
||||||
|
|
||||||
|
local_keys = m._validate_keys(local_keys, strict)
|
||||||
|
m._no_grad.update(local_keys)
|
||||||
|
|
||||||
|
if recurse:
|
||||||
|
self.apply_to_modules(_freeze_impl)
|
||||||
|
else:
|
||||||
|
_freeze_impl("", self)
|
||||||
|
|
||||||
|
def unfreeze(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
recurse: bool = True,
|
||||||
|
keys: Optional[Union[str, List[str]]] = None,
|
||||||
|
strict: bool = False,
|
||||||
|
):
|
||||||
|
"""Unfreeze the Module's parameters or some of them.
|
||||||
|
|
||||||
|
This function is idempotent ie unfreezing a model that is not frozen is
|
||||||
|
a noop.
|
||||||
|
|
||||||
|
For instance to only train the biases one can do:
|
||||||
|
|
||||||
|
model = ...
|
||||||
|
model.freeze()
|
||||||
|
model.unfreeze(keys="bias")
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recurse (bool, optional): If True then unfreeze the parameters of the
|
||||||
|
submodules as well (default: True).
|
||||||
|
keys (str or list[str], optional): If provided then only these
|
||||||
|
parameters will be unfrozen otherwise all the parameters of a
|
||||||
|
module. For instance unfreeze all biases by calling
|
||||||
|
``module.unfreeze(keys="bias")``.
|
||||||
|
strict (bool, optional): If set to True validate that the passed keys exist
|
||||||
|
(default: False).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _unfreeze_impl(_, m):
|
||||||
|
if keys is None:
|
||||||
|
m._no_grad.clear()
|
||||||
|
|
||||||
|
else:
|
||||||
|
local_keys = m._validate_keys(keys, strict)
|
||||||
|
m._no_grad.difference_update(local_keys)
|
||||||
|
|
||||||
|
if recurse:
|
||||||
|
self.apply_to_modules(_unfreeze_impl)
|
||||||
|
else:
|
||||||
|
_unfreeze_impl("", self)
|
||||||
|
|
||||||
|
def train(self, mode: bool = True):
|
||||||
|
def _set_train(_, m):
|
||||||
|
m._training = mode
|
||||||
|
|
||||||
|
self.apply_to_modules(_set_train)
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
self.train(False)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user