mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-29 23:15:09 +08:00 
			
		
		
		
	Correct types for vjp + tests (#418)
* correct types for vjp + tests * fix build + comment
This commit is contained in:
		| @@ -1202,3 +1202,29 @@ TEST_CASE("test update state") { | ||||
|   CHECK(state.is_evaled()); | ||||
|   CHECK(array_equal(state, array({1.0, 1.0})).item<bool>()); | ||||
| } | ||||
|  | ||||
| TEST_CASE("test grad types") { | ||||
|   { | ||||
|     auto fn = [](array x) { return sum(x); }; | ||||
|  | ||||
|     for (auto t : {float16, bfloat16, float32}) { | ||||
|       auto x = array(1.0, t); | ||||
|       auto dfdx = grad(fn)(x); | ||||
|       CHECK_EQ(dfdx.dtype(), t); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   { | ||||
|     // Check for multi-input grad | ||||
|     auto fn = [](std::vector<array> inputs) { | ||||
|       return sum(inputs[0] + inputs[1]); | ||||
|     }; | ||||
|  | ||||
|     for (auto t : {float16, bfloat16, float32}) { | ||||
|       auto x = array(1.0, t); | ||||
|       auto y = array(1.0, t); | ||||
|       auto out = grad(fn)({x, y}); | ||||
|       CHECK_EQ(out[0].dtype(), t); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun