Skip to content

Commit

Permalink
Merge pull request #636 from roboflow/cache-plan-details
Browse files Browse the repository at this point in the history
Cache plan details
  • Loading branch information
PawelPeczek-Roboflow committed Sep 17, 2024
2 parents 9b84c27 + d9862ab commit 6eb31af
Show file tree
Hide file tree
Showing 7 changed files with 515 additions and 82 deletions.
154 changes: 112 additions & 42 deletions inference/core/utils/sqlite_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ def __init__(
):
self._db_file_path = db_file_path
self._tbl_name = table_name
self._columns = {**columns, **{"id": "INTEGER PRIMARY KEY"}}

self._columns = columns

self._id_col_name = "id"
self._columns[self._id_col_name] = "INTEGER PRIMARY KEY"

if not connection:
os.makedirs(os.path.dirname(db_file_path), exist_ok=True)
Expand Down Expand Up @@ -56,88 +60,141 @@ def _create_table(self, connection: sqlite3.Connection):

def insert(
self,
values: Dict[ColName, ColValue],
row: Dict[ColName, ColValue],
connection: Optional[sqlite3.Connection] = None,
cursor: Optional[sqlite3.Cursor] = None,
with_exclusive: bool = False,
):
if not connection:
if not connection and not cursor:
try:
connection: sqlite3.Connection = sqlite3.connect(
self._db_file_path, timeout=1
)
self._insert(values=values, connection=connection)
self._insert(
row=row, connection=connection, with_exclusive=with_exclusive
)
connection.close()
except Exception as exc:
logger.debug(
"Failed to store '%s' in %s - %s", values, self._tbl_name, exc
"Failed to store '%s' in %s - %s", row, self._tbl_name, exc
)
raise exc
elif connection and not cursor:
self._insert(row=row, connection=connection, with_exclusive=with_exclusive)
elif connection and not with_exclusive:
self._insert(row=row, connection=connection)
elif cursor and not with_exclusive:
self._insert(row=row, cursor=cursor)
else:
self._insert(values=values, connection=connection)
raise RuntimeError("Unsupported mode")

def _insert(self, values: Dict[ColName, ColValue], connection: sqlite3.Connection):
if not set(values.keys()).issubset(self._columns.keys()):
def _insert(
self,
row: Dict[ColName, ColValue],
connection: Optional[sqlite3.Connection] = None,
cursor: Optional[sqlite3.Cursor] = None,
with_exclusive: bool = False,
):
if not set(row.keys()).issubset(self._columns.keys()):
logger.debug(
"Cannot store '%s' in %s, requested column names do not match with table columns",
values,
row,
self._tbl_name,
)
raise ValueError("Columns mismatch")
cursor = connection.cursor()
values = {k: v for k, v in values.items() if k != "id"}

cursor_needs_closing = False
if not cursor:
cursor = connection.cursor()
cursor_needs_closing = True

if with_exclusive:
try:
cursor.execute("BEGIN EXCLUSIVE")
except Exception as exc:
logger.debug(
"Failed to store '%s' in %s - %s", row, self._tbl_name, exc
)
raise exc

values = {k: v for k, v in row.items() if k != "id"}
sql_insert = f"""INSERT INTO {self._tbl_name} ({', '.join(values.keys())})
VALUES ({', '.join(['?'] * len(values))});
"""

try:
cursor.execute("BEGIN EXCLUSIVE")
except Exception as exc:
logger.debug("Failed to store '%s' in %s - %s", values, self._tbl_name, exc)
raise exc

try:
cursor.execute(sql_insert, list(values.values()))
connection.commit()
if with_exclusive:
connection.commit()
except Exception as exc:
logger.debug("Failed to store '%s' in %s - %s", values, self._tbl_name, exc)
connection.rollback()
raise exc
cursor.close()

def count(self, connection: Optional[sqlite3.Connection] = None) -> int:
if not connection:
if cursor_needs_closing:
cursor.close()

def count(
self,
connection: Optional[sqlite3.Connection] = None,
cursor: Optional[sqlite3.Cursor] = None,
with_exclusive: bool = False,
) -> int:
if not connection and not cursor:
try:
connection: sqlite3.Connection = sqlite3.connect(
self._db_file_path, timeout=1
)
count = self._count(connection=connection)
count = self._count(
connection=connection, with_exclusive=with_exclusive
)
connection.close()
except Exception as exc:
logger.debug("Failed to obtain records count - %s", exc)
raise exc
else:
elif connection and not cursor:
count = self._count(connection=connection, with_exclusive=with_exclusive)
elif connection and not with_exclusive:
count = self._count(connection=connection)
elif cursor and not with_exclusive:
count = self._count(cursor=cursor)
else:
raise RuntimeError("Unsupported mode")
return count

def _count(self, connection: sqlite3.Connection) -> int:
cursor = connection.cursor()
sql_select = f"SELECT COUNT(*) FROM {self._tbl_name}"
def _count(
self,
connection: Optional[sqlite3.Connection] = None,
cursor: Optional[sqlite3.Cursor] = None,
with_exclusive: bool = False,
) -> int:
cursor_needs_closing = False
if not cursor:
cursor = connection.cursor()
cursor_needs_closing = True

try:
cursor.execute("BEGIN EXCLUSIVE")
except Exception as exc:
logger.debug("Failed to obtain records count - %s", exc)
raise exc
if with_exclusive:
try:
cursor.execute("BEGIN EXCLUSIVE")
except Exception as exc:
logger.debug("Failed to obtain records count - %s", exc)
raise exc

sql_select = f"SELECT COUNT(*) FROM {self._tbl_name}"

count = 0
try:
cursor.execute(sql_select)
count = int(cursor.fetchone()[0])
connection.commit()
if with_exclusive:
connection.commit()
except Exception as exc:
logger.debug("Failed to obtain records count - %s", exc)
connection.rollback()
raise exc
cursor.close()

if cursor_needs_closing:
cursor.close()

return count

Expand Down Expand Up @@ -179,20 +236,25 @@ def _select(
with_exclusive: bool = False,
limit: int = 0,
) -> List[Dict[str, Any]]:
cursor_needs_closing = False
if not cursor:
cursor = connection.cursor()
cursor_needs_closing = True

if with_exclusive:
try:
cursor.execute("BEGIN EXCLUSIVE")
except Exception as exc:
logger.debug("Failed to obtain records - %s", exc)
raise exc

sql_select = f"""SELECT id, {', '.join(k for k in self._columns.keys() if k != 'id')}
FROM {self._tbl_name}
ORDER BY id ASC
"""
if limit:
sql_select = sql_select + f" LIMIT {limit}"

try:
cursor.execute(sql_select)
sqlite_rows = cursor.fetchall()
Expand All @@ -212,6 +274,9 @@ def _select(
row["id"] = _id
rows.append(row)

if cursor_needs_closing:
cursor.close()

return rows

def flush(
Expand Down Expand Up @@ -296,8 +361,11 @@ def _delete(
logger.debug("No row with 'id' key found in %s", rows)
return []

cursor_needs_closing = False
if not cursor:
cursor = connection.cursor()
cursor_needs_closing = True

if with_exclusive:
try:
cursor.execute("BEGIN EXCLUSIVE")
Expand Down Expand Up @@ -326,12 +394,14 @@ def _delete(
payloads = cursor.fetchall()
if with_exclusive:
connection.commit()
cursor.close()
except Exception as exc:
logger.debug("Failed to delete records - %s", exc)
connection.rollback()
raise exc

if cursor_needs_closing:
cursor.close()

_ids = set()
for _id, *_ in payloads:
_ids.add(_id)
Expand All @@ -340,25 +410,25 @@ def _delete(

def refresh(
self,
values: List[Dict[ColName, ColValue]],
rows: List[Dict[ColName, ColValue]],
connection: Optional[sqlite3.Connection] = None,
) -> List[Dict[str, Any]]:
if not connection:
try:
connection: sqlite3.Connection = sqlite3.connect(
self._db_file_path, timeout=1
)
payloads = self._refresh(values=values, connection=connection)
payloads = self._refresh(rows=rows, connection=connection)
connection.close()
except Exception as exc:
logger.debug("Failed to flush db - %s", exc)
raise exc
else:
payloads = self._refresh(values=values, connection=connection)
payloads = self._refresh(rows=rows, connection=connection)
return payloads

def _refresh(
self, values: List[Dict[ColName, ColValue]], connection: sqlite3.Connection
self, rows: List[Dict[ColName, ColValue]], connection: sqlite3.Connection
) -> List[Dict[str, Any]]:
cursor = connection.cursor()
try:
Expand All @@ -368,23 +438,23 @@ def _refresh(
raise exc

try:
self.delete(values=values, cursor=cursor)
self.delete(rows=rows, cursor=cursor)
except Exception as exc:
logger.debug("Failed to delete records - %s", exc)
connection.rollback()
raise exc

try:
for v in values:
self.insert(values=v, cursor=cursor)
for r in rows:
self.insert(row=r, cursor=cursor)
connection.commit()
except Exception as exc:
logger.debug("Failed to insert records - %s", exc)
connection.rollback()
raise exc

try:
rows = self.select(cursor=cursor)
connection.commit()
cursor.close()
except Exception as exc:
logger.debug("Failed to delete records - %s", exc)
Expand Down
Loading

0 comments on commit 6eb31af

Please sign in to comment.