diff --git a/transformer_lm/jax_main.py b/transformer_lm/jax_main.py index 947543f2..cd14a55f 100644 --- a/transformer_lm/jax_main.py +++ b/transformer_lm/jax_main.py @@ -3,7 +3,7 @@ import functools import jax import jax.numpy as jnp -import math +import mathdddd import numpy as np import time from collections import namedtuple