Add optional headers to `mx.fast.metal_kernel` (#1358)

This commit is contained in:
Alex Barron 2024-08-26 21:45:45 -07:00 committed by GitHub
parent 5f7d19d1f5
commit 1d94ac3f90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 50 additions and 10 deletions

View File

@ -1112,7 +1112,6 @@ std::map<std::string, array> MetalKernel::operator()(
"[metal_kernel] MetalKernel only works on GPU."); "[metal_kernel] MetalKernel only works on GPU.");
} }
std::ostringstream kernel_source;
std::ostringstream func_name; std::ostringstream func_name;
std::string template_def = ""; std::string template_def = "";
@ -1128,6 +1127,9 @@ std::map<std::string, array> MetalKernel::operator()(
func_name << "custom_kernel_" << name_ << hash_key; func_name << "custom_kernel_" << name_ << hash_key;
std::string kernel_name = func_name.str(); std::string kernel_name = func_name.str();
std::ostringstream kernel_source;
kernel_source << header_ << std::endl;
std::vector<CustomKernelShapeInfo> shape_infos; std::vector<CustomKernelShapeInfo> shape_infos;
write_signature( write_signature(
func_name.str(), func_name.str(),

View File

@ -71,10 +71,12 @@ class MetalKernel {
MetalKernel( MetalKernel(
const std::string& name, const std::string& name,
const std::string& source, const std::string& source,
bool ensure_row_contiguous, const std::string& header = "",
bool atomic_outputs) bool ensure_row_contiguous = true,
bool atomic_outputs = false)
: name_(name), : name_(name),
source_(source), source_(source),
header_(header),
ensure_row_contiguous_(ensure_row_contiguous), ensure_row_contiguous_(ensure_row_contiguous),
atomic_outputs_(atomic_outputs) {} atomic_outputs_(atomic_outputs) {}
@ -93,7 +95,8 @@ class MetalKernel {
private: private:
std::string name_; std::string name_;
std::string source_; std::string source_;
bool ensure_row_contiguous_ = true; std::string header_;
bool atomic_outputs_ = false; bool ensure_row_contiguous_;
bool atomic_outputs_;
}; };
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@ -200,9 +200,15 @@ void init_fast(nb::module_& parent_module) {
A jit-compiled custom Metal kernel defined from a source string. A jit-compiled custom Metal kernel defined from a source string.
)pbdoc") )pbdoc")
.def( .def(
nb::init<const std::string&, const std::string&, bool, bool>(), nb::init<
const std::string&,
const std::string&,
const std::string&,
bool,
bool>(),
"name"_a, "name"_a,
"source"_a, "source"_a,
"header"_a = "",
"ensure_row_contiguous"_a = true, "ensure_row_contiguous"_a = true,
"atomic_outputs"_a = false, "atomic_outputs"_a = false,
R"pbdoc( R"pbdoc(
@ -214,6 +220,8 @@ void init_fast(nb::module_& parent_module) {
the function signature will be generated for you. The names of the inputs/outputs the function signature will be generated for you. The names of the inputs/outputs
are determined by the ``inputs`` and ``output_shapes``/``output_dtypes`` are determined by the ``inputs`` and ``output_shapes``/``output_dtypes``
used when the kernel is called. used when the kernel is called.
header (str): Header source code to include before the main function.
Useful for helper functions or includes that should live outside of the main function body.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``True``. before the kernel runs. Default: ``True``.
atomic_outputs (bool): Whether to use atomic outputs in the function signature atomic_outputs (bool): Whether to use atomic outputs in the function signature

View File

@ -551,7 +551,7 @@ class TestFast(mlx_tests.MLXTestCase):
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_basic(self): def test_custom_kernel_basic(self):
mx.random.seed(7) mx.random.seed(7)
a = mx.random.normal(shape=(3, 6)) a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="basic", name="basic",
source=""" source="""
@ -567,7 +567,7 @@ class TestFast(mlx_tests.MLXTestCase):
output_dtypes={"out1": mx.float32}, output_dtypes={"out1": mx.float32},
stream=mx.gpu, stream=mx.gpu,
) )
mx.allclose(out["out1"], a[:2, :2]) self.assertTrue(mx.allclose(out["out1"], a))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_args(self): def test_custom_kernel_args(self):
@ -618,12 +618,12 @@ class TestFast(mlx_tests.MLXTestCase):
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc]; T tmp = inp[loc];
out[elem] = metal::exp(tmp) * threads_per_simdgroup; out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
""" """
source_contig = """ source_contig = """
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
T tmp = inp[elem]; T tmp = inp[elem];
out[elem] = metal::exp(tmp) * threads_per_simdgroup; out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
""" """
# non contiguous # non contiguous
@ -646,6 +646,33 @@ class TestFast(mlx_tests.MLXTestCase):
) )
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"])) self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"]))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_helper(self):
mx.random.seed(7)
a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel(
name="helper",
header="""
template <typename T>
T do_exp(T x) {
return metal::precise::exp(x);
}
""",
source="""
uint elem = thread_position_in_grid.x;
out1[elem] = do_exp(a[elem]);
""",
)
out = kernel(
inputs={"a": a},
grid=(4, 1, 1),
threadgroup=(2, 1, 1),
output_shapes={"out1": (2, 2)},
output_dtypes={"out1": mx.float32},
stream=mx.gpu,
)
self.assertTrue(mx.allclose(out["out1"], mx.exp(a)))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()