mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Implement sampling from laplace distribution. (#1279)
This commit is contained in:
		| @@ -640,3 +640,74 @@ TEST_CASE("test categorical") { | ||||
|   CHECK_EQ(categorical(logits, -2, 7).shape(), std::vector<int>{5, 3, 7}); | ||||
|   CHECK_EQ(categorical(logits, -3, 7).shape(), std::vector<int>{4, 3, 7}); | ||||
| } | ||||
|  | ||||
| TEST_CASE("test laplace") { | ||||
|   // Test shapes, types, and sizes | ||||
|   { | ||||
|     auto x = random::laplace({}); | ||||
|     CHECK_EQ(x.size(), 1); | ||||
|     CHECK_EQ(x.dtype(), float32); | ||||
|  | ||||
|     // Non float type throws | ||||
|     CHECK_THROWS_AS(random::laplace({}, int32), std::invalid_argument); | ||||
|  | ||||
|     // Check wrong key type or shape | ||||
|     auto key = array({0, 0}); | ||||
|     CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument); | ||||
|     key = array({0, 0}, {1, 2}); | ||||
|     CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument); | ||||
|     key = array({0u, 0u, 0u}, {3, 1}); | ||||
|     CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument); | ||||
|     key = array({0u, 0u}, {2, 1}); | ||||
|     CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument); | ||||
|   } | ||||
|  | ||||
|   { | ||||
|     constexpr float inf = std::numeric_limits<float>::infinity(); | ||||
|     auto key = random::key(128291); | ||||
|     auto out = random::laplace({1000000}, key); | ||||
|     float sample_mean = mean(out).item<float>(); | ||||
|     float sample_variance = var(out).item<float>(); | ||||
|  | ||||
|     CHECK(all(less(abs(out), array(inf))).item<bool>()); | ||||
|     CHECK(abs(sample_mean) < 0.1); | ||||
|  | ||||
|     // Chebyshev's inequality. | ||||
|     for (int k = 1; k <= 5; ++k) { | ||||
|       float prob_above = | ||||
|           mean(greater_equal(out, array(k * std::sqrt(sample_variance)))) | ||||
|               .item<float>(); | ||||
|       float bound = 1 / std::pow(k, 2); | ||||
|       CHECK(prob_above < bound); | ||||
|     } | ||||
|  | ||||
|     // Expected variance for Laplace distribution is 2*scale^2. | ||||
|     float expected_variance = 2.0; | ||||
|     CHECK(std::abs(sample_variance - expected_variance) < 0.01); | ||||
|  | ||||
|     // Expected kurtosis of Laplace distribution is 3. | ||||
|     array fourth_pows = power(out - sample_mean, {4}); | ||||
|     float sample_kurtosis = | ||||
|         mean(fourth_pows).item<float>() / std::pow(sample_variance, 2) - 3; | ||||
|     float expected_kurtosis = 3.0; | ||||
|     CHECK(std::abs(sample_kurtosis - expected_kurtosis) < 0.1); | ||||
|   } | ||||
|  | ||||
|   { | ||||
|     constexpr float inf = std::numeric_limits<float>::infinity(); | ||||
|     auto key = random::key(128291); | ||||
|     auto out = random::laplace({10000}, float16, key); | ||||
|     CHECK_EQ(out.dtype(), float16); | ||||
|     CHECK(all(less(abs(out), array(inf))).item<bool>()); | ||||
|     CHECK(abs(float(mean(out).item<float16_t>())) < 0.1); | ||||
|   } | ||||
|  | ||||
|   { | ||||
|     constexpr float inf = std::numeric_limits<float>::infinity(); | ||||
|     auto key = random::key(128291); | ||||
|     auto out = random::laplace({10000}, bfloat16, key); | ||||
|     CHECK_EQ(out.dtype(), bfloat16); | ||||
|     CHECK(all(less(abs(out), array(inf))).item<bool>()); | ||||
|     CHECK(abs(float(mean(out).item<bfloat16_t>())) < 0.1); | ||||
|   } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 fgranqvist
					fgranqvist