2023-12-01 03:08:53 +08:00
|
|
|
# Copyright © 2023 Apple Inc.
|
|
|
|
|
2023-11-30 06:14:11 +08:00
|
|
|
"""
|
|
|
|
Code to preprocess the WikiSQL dataset adapted from
|
|
|
|
https://github.com/salesforce/WikiSQL and
|
|
|
|
https://huggingface.co/sqllama/sqllama-V0/blob/main/wikisql.ipynb .
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import json
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
def load():
|
|
|
|
"""
|
|
|
|
Load all three splits of the WikiSQL dataset.
|
|
|
|
"""
|
|
|
|
return (WikiSQL(dn) for dn in ["train", "dev", "test"])
|
|
|
|
|
|
|
|
|
|
|
|
class WikiSQL:
|
|
|
|
def __init__(self, dataset, save_dir="/tmp"):
|
|
|
|
valid_sets = ("train", "dev", "test")
|
|
|
|
if dataset not in valid_sets:
|
|
|
|
raise ValueError(f"Dataset must be in {valid_sets}, got {dataset}")
|
|
|
|
data_dir = os.path.join(save_dir, "wikisql")
|
|
|
|
self._maybe_download(data_dir)
|
|
|
|
|
|
|
|
self._parse_tables(os.path.join(data_dir, f"data/{dataset}.tables.jsonl"))
|
|
|
|
self._parse_queries(os.path.join(data_dir, f"data/{dataset}.jsonl"))
|
|
|
|
|
|
|
|
def _maybe_download(self, data_dir):
|
|
|
|
if not os.path.exists(data_dir):
|
|
|
|
import io
|
|
|
|
from urllib import request
|
|
|
|
import tarfile
|
|
|
|
|
|
|
|
url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2"
|
|
|
|
r = request.urlopen(url)
|
|
|
|
with tarfile.open(fileobj=io.BytesIO(r.read())) as tf:
|
|
|
|
tf.extractall(data_dir)
|
|
|
|
|
|
|
|
def _parse_tables(self, tables):
|
|
|
|
self._tables = {}
|
|
|
|
with open(tables) as f:
|
|
|
|
for line in f:
|
|
|
|
table = json.loads(line)
|
|
|
|
self._tables[table["id"]] = {
|
|
|
|
"columns": table["header"],
|
|
|
|
"types": table["types"],
|
|
|
|
"desc": f"table: {table['id']}\ncolumns: {', '.join(table['header'])}",
|
|
|
|
}
|
|
|
|
|
|
|
|
def _parse_queries(self, queries):
|
|
|
|
self._queries = []
|
|
|
|
with open(queries) as f:
|
|
|
|
for line in f:
|
|
|
|
query = json.loads(line)
|
|
|
|
table = self._tables[query["table_id"]]
|
|
|
|
question = query["question"]
|
|
|
|
answer = self.query_to_text(
|
|
|
|
query["sql"], query["table_id"], table["columns"], table["types"]
|
|
|
|
)
|
|
|
|
self._queries.append(
|
|
|
|
f"<s>{table['desc']}\nQ: {question}\nA: {answer}</s>"
|
|
|
|
)
|
|
|
|
|
|
|
|
def query_to_text(self, query, table, columns, types):
|
|
|
|
aggregation_ops = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
|
|
|
|
condition_ops = ["=", ">", "<", "OP"]
|
|
|
|
column = columns[query["sel"]]
|
|
|
|
aggregation = (aggregation_ops[query["agg"]] + " ") if query["agg"] > 0 else ""
|
|
|
|
sql = f"SELECT {aggregation}{column} FROM {table}"
|
|
|
|
|
|
|
|
conditions = query["conds"]
|
|
|
|
if conditions:
|
|
|
|
cs = []
|
|
|
|
for i, o, v in conditions:
|
|
|
|
column = columns[i]
|
|
|
|
op = condition_ops[o]
|
|
|
|
|
|
|
|
if types[i] == "text":
|
|
|
|
value = f"'{v}'"
|
|
|
|
else:
|
|
|
|
value = v
|
|
|
|
cs.append(f"{column} {op} {value}")
|
|
|
|
|
|
|
|
sql += " WHERE " + " AND ".join(cs)
|
|
|
|
|
|
|
|
return sql
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
return self._queries[idx]
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self._queries)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
datanames = ["train", "dev", "test"]
|
|
|
|
sizes = [56355, 8421, 15878]
|
|
|
|
for dataname, size in zip(datanames, sizes):
|
2023-12-16 01:56:10 +08:00
|
|
|
len(WikiSQL(dataname)) == size, f"Wrong {dataname} set size."
|
|
|
|
|
|
|
|
# Write the sets to jsonl
|
|
|
|
import json
|
|
|
|
|
|
|
|
train, dev, test = load()
|
|
|
|
datasets = [
|
|
|
|
(train, "train", 1000),
|
|
|
|
(dev, "valid", 100),
|
|
|
|
(test, "test", 100),
|
|
|
|
]
|
|
|
|
for dataset, name, size in datasets:
|
|
|
|
with open(f"data/{name}.jsonl", "w") as fid:
|
|
|
|
for e, t in zip(range(size), dataset):
|
|
|
|
# Strip the <s>, </s> since the tokenizer adds them
|
|
|
|
json.dump({"text": t[3:-4]}, fid)
|
|
|
|
fid.write("\n")
|