mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
102 lines
3.3 KiB
Python
102 lines
3.3 KiB
Python
![]() |
"""
|
||
|
Code to preprocess the WikiSQL dataset adapted from
|
||
|
https://github.com/salesforce/WikiSQL and
|
||
|
https://huggingface.co/sqllama/sqllama-V0/blob/main/wikisql.ipynb .
|
||
|
"""
|
||
|
|
||
|
|
||
|
import json
|
||
|
import os
|
||
|
|
||
|
|
||
|
def load():
|
||
|
"""
|
||
|
Load all three splits of the WikiSQL dataset.
|
||
|
"""
|
||
|
return (WikiSQL(dn) for dn in ["train", "dev", "test"])
|
||
|
|
||
|
|
||
|
class WikiSQL:
|
||
|
def __init__(self, dataset, save_dir="/tmp"):
|
||
|
valid_sets = ("train", "dev", "test")
|
||
|
if dataset not in valid_sets:
|
||
|
raise ValueError(f"Dataset must be in {valid_sets}, got {dataset}")
|
||
|
data_dir = os.path.join(save_dir, "wikisql")
|
||
|
self._maybe_download(data_dir)
|
||
|
|
||
|
self._parse_tables(os.path.join(data_dir, f"data/{dataset}.tables.jsonl"))
|
||
|
self._parse_queries(os.path.join(data_dir, f"data/{dataset}.jsonl"))
|
||
|
|
||
|
def _maybe_download(self, data_dir):
|
||
|
if not os.path.exists(data_dir):
|
||
|
import io
|
||
|
from urllib import request
|
||
|
import tarfile
|
||
|
|
||
|
url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2"
|
||
|
r = request.urlopen(url)
|
||
|
with tarfile.open(fileobj=io.BytesIO(r.read())) as tf:
|
||
|
tf.extractall(data_dir)
|
||
|
|
||
|
def _parse_tables(self, tables):
|
||
|
self._tables = {}
|
||
|
with open(tables) as f:
|
||
|
for line in f:
|
||
|
table = json.loads(line)
|
||
|
self._tables[table["id"]] = {
|
||
|
"columns": table["header"],
|
||
|
"types": table["types"],
|
||
|
"desc": f"table: {table['id']}\ncolumns: {', '.join(table['header'])}",
|
||
|
}
|
||
|
|
||
|
def _parse_queries(self, queries):
|
||
|
self._queries = []
|
||
|
with open(queries) as f:
|
||
|
for line in f:
|
||
|
query = json.loads(line)
|
||
|
table = self._tables[query["table_id"]]
|
||
|
question = query["question"]
|
||
|
answer = self.query_to_text(
|
||
|
query["sql"], query["table_id"], table["columns"], table["types"]
|
||
|
)
|
||
|
self._queries.append(
|
||
|
f"<s>{table['desc']}\nQ: {question}\nA: {answer}</s>"
|
||
|
)
|
||
|
|
||
|
def query_to_text(self, query, table, columns, types):
|
||
|
aggregation_ops = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
|
||
|
condition_ops = ["=", ">", "<", "OP"]
|
||
|
column = columns[query["sel"]]
|
||
|
aggregation = (aggregation_ops[query["agg"]] + " ") if query["agg"] > 0 else ""
|
||
|
sql = f"SELECT {aggregation}{column} FROM {table}"
|
||
|
|
||
|
conditions = query["conds"]
|
||
|
if conditions:
|
||
|
cs = []
|
||
|
for i, o, v in conditions:
|
||
|
column = columns[i]
|
||
|
op = condition_ops[o]
|
||
|
|
||
|
if types[i] == "text":
|
||
|
value = f"'{v}'"
|
||
|
else:
|
||
|
value = v
|
||
|
cs.append(f"{column} {op} {value}")
|
||
|
|
||
|
sql += " WHERE " + " AND ".join(cs)
|
||
|
|
||
|
return sql
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
return self._queries[idx]
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self._queries)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
datanames = ["train", "dev", "test"]
|
||
|
sizes = [56355, 8421, 15878]
|
||
|
for dataname, size in zip(datanames, sizes):
|
||
|
len(WikiSQL(dataname)) == 56355, f"Wrong {dataname} set size."
|