Skip to content

Commit

Permalink
code working
Browse files Browse the repository at this point in the history
  • Loading branch information
coilysiren committed Oct 22, 2023
1 parent 0b47c1c commit c158a9e
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 83 deletions.
21 changes: 21 additions & 0 deletions data/sql_input_3.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
-- https://cratedb.com/docs/sql-99/en/latest/chapters/01.html
-- https://www.postgresql.org/docs/16/sql-createtable.html
-- https://www.postgresql.org/docs/16/sql-insert.html
-- https://www.postgresql.org/docs/16/sql-select.html
CREATE TABLE city (
name VARCHAR,
population INT,
timezone INT
);

INSERT INTO city (name, timezone)
VALUES ('San Francisco', -8);

INSERT INTO city (name, population)
VALUES ('New York', 8405837);

SELECT
name,
population,
timezone
FROM city;
4 changes: 1 addition & 3 deletions data/sql_output_0.json
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
{
"table_name": ["city"]
}
[{ "table_name": "city" }]
4 changes: 1 addition & 3 deletions data/sql_output_1.json
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
{
"table_name": ["city", "town"]
}
[{ "table_name": "city" }, { "table_name": "town" }]
9 changes: 4 additions & 5 deletions data/sql_output_2.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
{
"name": ["San Francisco", "New York"],
"population": [852469, 8405837],
"timezone": [-8, -5]
}
[
{ "name": "San Francisco", "population": 852469, "timezone": -8 },
{ "name": "New York", "population": 8405837, "timezone": -5 }
]
4 changes: 4 additions & 0 deletions data/sql_output_3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[
{ "name": "San Francisco", "population": null, "timezone": -8 },
{ "name": "New York", "population": 8405837, "timezone": null }
]
191 changes: 120 additions & 71 deletions src/python/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,68 @@
########################


import dataclasses
import json
import typing


class SQL:
__data: dict = {}
@dataclasses.dataclass(frozen=True)
class SQLState:
state: dict

def __init__(self) -> None:
self.clear_data()
def read_table_meta(self, table_name: str) -> dict:
return self.state.get(table_name, {}).get("metadata", {})

def clear_data(self):
self.__data = {}
def read_table_rows(self, table_name: str) -> dict:
return self.state.get(table_name, {}).get("rows", {})

def read_data_table(self, table_name: str) -> dict:
return self.__data.get(table_name, {})

def read_information_schema_tables(self) -> list[dict]:
return [data["metadata"] for data in self.__data.values()]
def read_information_schema(self) -> list[dict]:
return [data["metadata"] for data in self.state.values()]

def write_table_meta(self, table_name: str, data: dict):
self.__data[table_name] = data

def write_table_data(self, table_name: str, data: dict):
self.__data[table_name]["data"] = data

def create_table(self, *args, table_schema="public") -> dict:
state = self.state
table = state.get(table_name, {})
metadata = table.get("metadata", {})
metadata.update(data)
table["metadata"] = metadata
state[table_name] = table
return self.__class__(state)

def write_table_rows(self, table_name: str, data: dict):
state = self.state
table = state.get(table_name, {})
rows = table.get("rows", [])
rows.append(data)
table["rows"] = rows
state[table_name] = table
return self.__class__(state)


class SQLType:
@staticmethod
def varchar(data) -> str:
data_str = str(data).strip()
if data_str.startswith("'") or data_str.startswith('"'):
data_str = data_str[1:]
if data_str.endswith("'") or data_str.endswith('"'):
data_str = data_str[:-1]
return data_str

@staticmethod
def int(data) -> int:
return int(data.strip())


sql_type_map = {
"VARCHAR": SQLType.varchar,
"INT": SQLType.int,
}


class SQLFunctions:
@staticmethod
def create_table(state: SQLState, *args, table_schema="public") -> typing.Tuple[list, SQLState]:
output: list[dict] = []
table_name = args[2]

# get columns
Expand All @@ -47,29 +84,44 @@ def create_table(self, *args, table_schema="public") -> dict:
}
# fmt: on

if self.read_data_table(table_name):
self.write_table_meta(
if not state.read_table_meta(table_name):
state = state.write_table_meta(
table_name,
{
"metadata": {
"table_name": table_name,
"table_schema": table_schema,
"colums": columns,
}
"table_name": table_name,
"table_schema": table_schema,
"colums": columns,
},
)
return {}
return (output, state)

@staticmethod
def insert_into(state: SQLState, *args) -> typing.Tuple[list, SQLState]:
output: list[dict] = []
table_name = args[2]

create_table.sql = "CREATE TABLE"
values_index = None
for i, arg in enumerate(args):
if arg == "VALUES":
values_index = i

keys = " ".join(args[3:values_index]).replace("(", "").replace(")", "").split(",")
keys = [key.strip() for key in keys]
values = " ".join(args[values_index + 1 :]).replace("(", "").replace(")", "").split(",")
values = [value.strip() for value in values]
key_value_map = dict(zip(keys, values))

def insert_into(self, *args) -> dict:
print(f"args: {args}")
pass
data = {}
if metadata := state.read_table_meta(table_name):
for key, value in key_value_map.items():
data[key] = sql_type_map[metadata["colums"][key]](value)
state = state.write_table_rows(table_name, data)

insert_into.sql = "INSERT INTO"
return (output, state)

def select(self, *args) -> dict:
output = {}
@staticmethod
def select(state: SQLState, *args) -> typing.Tuple[list, SQLState]:
output: list[dict] = []

from_index = None
where_index = None
Expand All @@ -81,64 +133,61 @@ def select(self, *args) -> dict:

# get select keys by getting the slice of args before FROM
select_keys = " ".join(args[1:from_index]).split(",")
select_keys = [key.strip() for key in select_keys]

# get where keys by getting the slice of args after WHERE
from_value = args[from_index + 1]

# consider "information_schema.tables" a special case until
# we figure out why its so different from the others
# `information_schema.tables` is a special case
if from_value == "information_schema.tables":
target = self.read_information_schema_tables()
data = state.read_information_schema()
else:
target = self.read_data_table(from_value)
data = state.read_table_rows(from_value)

output = []
for datum in data:
# fmt: off
output.append({
key: datum.get(key)
for key in select_keys
})
# fmt: on

# fmt: off
output = {
key: [
value for data in target
for key, value in data.items()
if key in select_keys
]
for key in select_keys
}
# fmt: on
return (output, state)

return output

select.sql = "SELECT"
sql_function_map: dict[str, typing.Callable] = {
"CREATE TABLE": SQLFunctions.create_table,
"SELECT": SQLFunctions.select,
"INSERT INTO": SQLFunctions.insert_into,
}

sql_map = {
create_table.sql: create_table,
select.sql: select,
insert_into.sql: insert_into,
}

def run(self, input_sql: list[str]) -> list[str]:
output = {}
def run_sql(input_sql: list[str]) -> list[str]:
output = []
state = SQLState(state={})

# remove comments
input_sql = [line.strip() for line in input_sql if not line.startswith("--")]
# remove comments
input_sql = [line.strip() for line in input_sql if not line.startswith("--")]

# re-split on semi-colons
input_sql = " ".join(input_sql).split(";")
# re-split on semi-colons
input_sql = " ".join(input_sql).split(";")

# iterate over each line of sql
for line in input_sql:
words = line.split(" ")
for i in reversed(range(len(words) + 1)):
key = " ".join(words[:i]).strip()
# print(f'key: "{key}"')
if func := self.sql_map.get(key):
# print(f'running "{func.__name__}" with {words}')
output = func(self, *[word for word in words if word])
break
# iterate over each line of sql
for line in input_sql:
words = line.split(" ")
for i in reversed(range(len(words) + 1)):
key = " ".join(words[:i]).strip()
if func := sql_function_map.get(key):
output, state = func(state, *[word for word in words if word])
break

return [json.dumps(output)]
return [json.dumps(output)]


######################
# business logic end #
######################

if __name__ == "__main__":
helpers.run(SQL().run)
helpers.run(run_sql)
2 changes: 1 addition & 1 deletion tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def run_tests(self, input_script):
prepared_file_data = json.load(reader)
with open(ctx.script_output_file_path, "r", encoding="utf-8") as reader:
script_output_file_data = json.load(reader)
unittest.TestCase().assertDictEqual(prepared_file_data, script_output_file_data)
unittest.TestCase().assertListEqual(prepared_file_data, script_output_file_data)
self.set_success_status(True)
print(f"\t🟢 {ctx.script_relative_path} on {ctx.input_file_path} succeeded")
continue
Expand Down

0 comments on commit c158a9e

Please sign in to comment.