mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Skip compile when transforming (#635)
* skip compile when transforming * simplify message
This commit is contained in:
		| @@ -587,3 +587,39 @@ TEST_CASE("test compile repeat input") { | ||||
|   auto expected_out = repeat_input_to_compiled({x})[0]; | ||||
|   CHECK(allclose(out, expected_out).item<bool>()); | ||||
| } | ||||
|  | ||||
| auto compile_unary_inner(const std::vector<array>& inputs) { | ||||
|   auto x = inputs[0]; | ||||
|   return std::vector<array>{exp(exp(x))}; | ||||
| } | ||||
|  | ||||
| auto compile_unary_outer(const std::vector<array>& inputs) { | ||||
|   auto cfun = compile(compile_unary_inner); | ||||
|   return cfun(cfun(inputs)); | ||||
| } | ||||
|  | ||||
| TEST_CASE("test compile compiled function") { | ||||
|   auto cfun = compile(compile_unary_outer); | ||||
|   auto x = array({1.0f}); | ||||
|   auto out = cfun({x})[0]; | ||||
|   auto& p = out.primitive(); | ||||
|   CHECK_EQ(typeid(p), typeid(Compiled)); | ||||
|   CHECK_EQ(out.inputs()[0].id(), x.id()); | ||||
| } | ||||
|  | ||||
| auto grad_unary_compiled(const std::vector<array>& inputs) { | ||||
|   auto gradfn = value_and_grad(compile(compile_unary_inner)); | ||||
|   auto [out, grad] = gradfn(inputs); | ||||
|   return std::vector{out[0], grad[0]}; | ||||
| } | ||||
|  | ||||
| TEST_CASE("test transform compiled function") { | ||||
|   auto cfun = compile(grad_unary_compiled); | ||||
|   auto x = array(1.0f); | ||||
|   auto outs = cfun({x}); | ||||
|   auto& p = outs[0].primitive(); | ||||
|   CHECK_EQ(typeid(p), typeid(Compiled)); | ||||
|   CHECK_EQ(outs[0].siblings()[0].id(), outs[1].id()); | ||||
|   CHECK(!outs[0].inputs()[0].has_primitive()); | ||||
|   CHECK(!outs[0].inputs()[1].has_primitive()); | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun