mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-25 20:11:19 +08:00
* Fixing random.normal for half-precision dtype #642 * Update python/tests/test_random.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
28fcd2b519
commit
a3ee03da01
@ -90,6 +90,16 @@ T below_one() {
|
|||||||
return f;
|
return f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the next representable value above -1.0 for half precision
|
||||||
|
// floating point types (fp16, bf16)
|
||||||
|
template <typename T>
|
||||||
|
T above_minus_one() {
|
||||||
|
T f = T(-1.0);
|
||||||
|
uint16_t* m = (uint16_t*)&f;
|
||||||
|
*m -= 1;
|
||||||
|
return f;
|
||||||
|
}
|
||||||
|
|
||||||
array uniform(
|
array uniform(
|
||||||
const array& low,
|
const array& low,
|
||||||
const array& high,
|
const array& high,
|
||||||
@ -158,7 +168,17 @@ array normal(
|
|||||||
const std::optional<array>& key /*= nullopt */,
|
const std::optional<array>& key /*= nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto stream = to_stream(s);
|
auto stream = to_stream(s);
|
||||||
auto low = array(std::nextafter(-1.0f, 0.0f), dtype);
|
auto get_low = [&dtype]() {
|
||||||
|
switch (dtype) {
|
||||||
|
case float16:
|
||||||
|
return array(above_minus_one<float16_t>(), dtype);
|
||||||
|
case bfloat16:
|
||||||
|
return array(above_minus_one<bfloat16_t>(), dtype);
|
||||||
|
default:
|
||||||
|
return array(std::nextafter(-1.0f, 0.0f), dtype);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto low = get_low();
|
||||||
auto high = array(1.0f, dtype);
|
auto high = array(1.0f, dtype);
|
||||||
auto samples = uniform(low, high, shape, dtype, key, stream);
|
auto samples = uniform(low, high, shape, dtype, key, stream);
|
||||||
samples =
|
samples =
|
||||||
|
@ -96,6 +96,11 @@ class TestRandom(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype)
|
self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype)
|
||||||
|
|
||||||
|
# Test not getting -inf or inf with half precison
|
||||||
|
for hp in [mx.float16, mx.bfloat16]:
|
||||||
|
a = abs(mx.random.normal(shape=(10000,), loc=0, scale=1, dtype=hp))
|
||||||
|
self.assertTrue(mx.all(a < mx.inf))
|
||||||
|
|
||||||
def test_randint(self):
|
def test_randint(self):
|
||||||
a = mx.random.randint(0, 1, [])
|
a = mx.random.randint(0, 1, [])
|
||||||
self.assertEqual(a.shape, ())
|
self.assertEqual(a.shape, ())
|
||||||
|
Loading…
Reference in New Issue
Block a user