mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Shapeless compilation for some graphs (#687)
* shapeless compilation for some graphs * update compile benchmark * default compile a few activations * buffer donation * bugfix * shapeless fix * update tests to work for cpu and gpu fusion * test kwargs * add kwargs to compile * Recompile when python arguments change * no compile for tanh * some constant tests --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		| @@ -624,31 +624,23 @@ TEST_CASE("test transform compiled function") { | ||||
|   CHECK(!outs[0].inputs()[1].has_primitive()); | ||||
| } | ||||
|  | ||||
| TEST_CASE("test metal fusion kernel reuse") { | ||||
|   if (default_device() != Device::gpu) { | ||||
|     return; | ||||
|   } | ||||
|  | ||||
| TEST_CASE("test fusion kernel reuse") { | ||||
|   auto cfun = compile(gelu_1); | ||||
|   auto x = array({2.0f, -2.0f}); | ||||
|   auto y = cfun({x})[0]; | ||||
|   auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr()); | ||||
|   eval(y); | ||||
|  | ||||
|   std::string lib_name = p->metal_lib_name(); | ||||
|   std::string lib_source = p->metal_lib_source(); | ||||
|   std::string lib_name = p->lib_name(); | ||||
|   CHECK(!lib_name.empty()); | ||||
|   CHECK(!lib_source.empty()); | ||||
|  | ||||
|   x = astype(reshape(arange(10), {2, 5}), float32); | ||||
|   auto z = cfun({x})[0]; | ||||
|   auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr()); | ||||
|   eval(z); | ||||
|  | ||||
|   std::string lib_name_z = pz->metal_lib_name(); | ||||
|   std::string lib_source_z = pz->metal_lib_source(); | ||||
|   std::string lib_name_z = pz->lib_name(); | ||||
|   CHECK(!lib_name_z.empty()); | ||||
|   CHECK(lib_source_z.empty()); | ||||
|  | ||||
|   CHECK_EQ(lib_name, lib_name_z); | ||||
| } | ||||
| @@ -657,29 +649,57 @@ auto add3(const std::vector<array>& xs) { | ||||
|   return std::vector<array>{xs[0] + xs[0] + xs[0]}; | ||||
| } | ||||
|  | ||||
| TEST_CASE("test metal fusion types") { | ||||
|   if (default_device() != Device::gpu) { | ||||
|     return; | ||||
|   } | ||||
|  | ||||
| TEST_CASE("test fusion types") { | ||||
|   auto cfun = compile(add3); | ||||
|   auto x = array({2.0f, -2.0f}); | ||||
|   auto y = cfun({x})[0]; | ||||
|   auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr()); | ||||
|   eval(y); | ||||
|  | ||||
|   std::string lib_name = p->metal_lib_name(); | ||||
|   std::string lib_source = p->metal_lib_source(); | ||||
|   std::string lib_name = p->lib_name(); | ||||
|   CHECK(!lib_name.empty()); | ||||
|   CHECK(!lib_source.empty()); | ||||
|  | ||||
|   x = array({2, -2}, int32); | ||||
|   auto z = cfun({x})[0]; | ||||
|   auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr()); | ||||
|   eval(z); | ||||
|  | ||||
|   std::string lib_name_z = pz->metal_lib_name(); | ||||
|   std::string lib_source_z = pz->metal_lib_source(); | ||||
|   std::string lib_name_z = pz->lib_name(); | ||||
|   CHECK(!lib_name_z.empty()); | ||||
|   CHECK(!lib_source_z.empty()); | ||||
| } | ||||
|  | ||||
| auto compile_shapeless_not_ok(const std::vector<array>& inputs) { | ||||
|   auto x = reshape(inputs[0], {2, 2}); | ||||
|   return std::vector<array>{x}; | ||||
| } | ||||
|  | ||||
| auto compile_shapeless_ok(const std::vector<array>& inputs) { | ||||
|   auto x = inputs[0] + array({2}); | ||||
|   return std::vector<array>{x}; | ||||
| } | ||||
|  | ||||
| TEST_CASE("test shapeless compile") { | ||||
|   { | ||||
|     auto cfun = compile(compile_shapeless_not_ok, /* shapeless */ true); | ||||
|     CHECK_THROWS(cfun({array({1, 2, 3, 4})})); | ||||
|   } | ||||
|  | ||||
|   { | ||||
|     auto cfun = compile(compile_shapeless_ok, /* shapeless */ true); | ||||
|     auto out = cfun({array({1, 2})})[0]; | ||||
|     auto out2 = cfun({array({1, 2, 3, 4})})[0]; | ||||
|  | ||||
|     // Not making a new constant array since no recompile, | ||||
|     // hence the ids should be the same | ||||
|     CHECK_EQ(out.inputs()[1].id(), out2.inputs()[1].id()); | ||||
|     CHECK(array_equal(out2, array({3, 4, 5, 6})).item<bool>()); | ||||
|  | ||||
|     // Recompile since type changes | ||||
|     out2 = cfun({array({1.0, 2.0})})[0]; | ||||
|     CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id()); | ||||
|  | ||||
|     // Recompile since ndim changes | ||||
|     out2 = cfun({array({1.0, 2.0}, {1, 2})})[0]; | ||||
|     CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id()); | ||||
|   } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun