From d2e447d16c68ee8d8090634ae27e02d44e7ee781 Mon Sep 17 00:00:00 2001 From: "Grisha K." Date: Mon, 22 Jul 2024 13:13:50 +0300 Subject: [PATCH] fix(sqlite): allow using fts5 table name in the where clause --- .../virtual_table/sqlite/go/models.go | 8 +- .../virtual_table/sqlite/go/query.sql.go | 81 +++++++++++++++++-- .../testdata/virtual_table/sqlite/query.sql | 8 ++ internal/engine/sqlite/convert.go | 6 ++ 4 files changed, 93 insertions(+), 10 deletions(-) diff --git a/internal/endtoend/testdata/virtual_table/sqlite/go/models.go b/internal/endtoend/testdata/virtual_table/sqlite/go/models.go index b7dfb02d32..fcf1bfa2fe 100644 --- a/internal/endtoend/testdata/virtual_table/sqlite/go/models.go +++ b/internal/endtoend/testdata/virtual_table/sqlite/go/models.go @@ -9,7 +9,8 @@ import ( ) type Ft struct { - B string + B string + Ft string } type Tbl struct { @@ -21,6 +22,7 @@ type Tbl struct { } type TblFt struct { - B string - C string + B string + C string + TblFt string } diff --git a/internal/endtoend/testdata/virtual_table/sqlite/go/query.sql.go b/internal/endtoend/testdata/virtual_table/sqlite/go/query.sql.go index fbcec9a174..b20869ea29 100644 --- a/internal/endtoend/testdata/virtual_table/sqlite/go/query.sql.go +++ b/internal/endtoend/testdata/virtual_table/sqlite/go/query.sql.go @@ -65,16 +65,77 @@ SELECT b, c FROM tbl_ft WHERE b MATCH ? ` -func (q *Queries) SelectAllColsTblFt(ctx context.Context, b string) ([]TblFt, error) { +type SelectAllColsTblFtRow struct { + B string + C string +} + +func (q *Queries) SelectAllColsTblFt(ctx context.Context, b string) ([]SelectAllColsTblFtRow, error) { rows, err := q.db.QueryContext(ctx, selectAllColsTblFt, b) if err != nil { return nil, err } defer rows.Close() + var items []SelectAllColsTblFtRow + for rows.Next() { + var i SelectAllColsTblFtRow + if err := rows.Scan(&i.B, &i.C); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectAllColsTblFtEqualByTableName = `-- name: SelectAllColsTblFtEqualByTableName :many +SELECT b, c, tbl_ft FROM tbl_ft +WHERE tbl_ft = ? +` + +func (q *Queries) SelectAllColsTblFtEqualByTableName(ctx context.Context, tblFt string) ([]TblFt, error) { + rows, err := q.db.QueryContext(ctx, selectAllColsTblFtEqualByTableName, tblFt) + if err != nil { + return nil, err + } + defer rows.Close() var items []TblFt for rows.Next() { var i TblFt - if err := rows.Scan(&i.B, &i.C); err != nil { + if err := rows.Scan(&i.B, &i.C, &i.TblFt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const selectAllColsTblFtMatchByTableName = `-- name: SelectAllColsTblFtMatchByTableName :many +SELECT b, c, tbl_ft FROM tbl_ft +WHERE tbl_ft MATCH ? +` + +func (q *Queries) SelectAllColsTblFtMatchByTableName(ctx context.Context, tblFt string) ([]TblFt, error) { + rows, err := q.db.QueryContext(ctx, selectAllColsTblFtMatchByTableName, tblFt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []TblFt + for rows.Next() { + var i TblFt + if err := rows.Scan(&i.B, &i.C, &i.TblFt); err != nil { return nil, err } items = append(items, i) @@ -89,14 +150,15 @@ func (q *Queries) SelectAllColsTblFt(ctx context.Context, b string) ([]TblFt, er } const selectBm25Func = `-- name: SelectBm25Func :many -SELECT b, c, bm25(tbl_ft, 2.0) FROM tbl_ft +SELECT b, c, tbl_ft, bm25(tbl_ft, 2.0) FROM tbl_ft WHERE b MATCH ? ORDER BY bm25(tbl_ft) ` type SelectBm25FuncRow struct { - B string - C string - Bm25 float64 + B string + C string + TblFt string + Bm25 float64 } func (q *Queries) SelectBm25Func(ctx context.Context, b string) ([]SelectBm25FuncRow, error) { @@ -108,7 +170,12 @@ func (q *Queries) SelectBm25Func(ctx context.Context, b string) ([]SelectBm25Fun var items []SelectBm25FuncRow for rows.Next() { var i SelectBm25FuncRow - if err := rows.Scan(&i.B, &i.C, &i.Bm25); err != nil { + if err := rows.Scan( + &i.B, + &i.C, + &i.TblFt, + &i.Bm25, + ); err != nil { return nil, err } items = append(items, i) diff --git a/internal/endtoend/testdata/virtual_table/sqlite/query.sql b/internal/endtoend/testdata/virtual_table/sqlite/query.sql index ad8eeeae40..3dba327d4e 100644 --- a/internal/endtoend/testdata/virtual_table/sqlite/query.sql +++ b/internal/endtoend/testdata/virtual_table/sqlite/query.sql @@ -14,6 +14,14 @@ WHERE b = ?; SELECT c FROM tbl_ft WHERE b = ?; +-- name: SelectAllColsTblFtEqualByTableName :many +SELECT * FROM tbl_ft +WHERE tbl_ft = ?; + +-- name: SelectAllColsTblFtMatchByTableName :many +SELECT * FROM tbl_ft +WHERE tbl_ft MATCH ?; + -- name: SelectHightlighFunc :many SELECT highlight(tbl_ft, 0, '', '') FROM tbl_ft WHERE b MATCH ?; diff --git a/internal/engine/sqlite/convert.go b/internal/engine/sqlite/convert.go index 02d80bc48c..4d8c357f26 100644 --- a/internal/engine/sqlite/convert.go +++ b/internal/engine/sqlite/convert.go @@ -168,6 +168,12 @@ func (c *cc) convertCreate_virtual_table_fts5(n *parser.Create_virtual_table_stm } } + stmt.Cols = append(stmt.Cols, &ast.ColumnDef{ + Colname: identifier(stmt.Name.Name), + IsNotNull: true, + TypeName: &ast.TypeName{Name: "text"}, + }) + return stmt }