diff --git a/t5/t5.py b/t5/t5.py index 4863ba43..01119a75 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -60,8 +60,8 @@ def _relative_position_bucket( class RelativePositionBias(nn.Module): - def __init__(self, config: ModelArgs, bidirectional: bool): - self.bidirectional = False # bidirectional + def __init__(self, config: T5Config, bidirectional: bool): + self.bidirectional = bidirectional self.num_buckets = config.relative_attention_num_buckets self.max_distance = config.relative_attention_max_distance self.n_heads = config.num_heads