From b386b1dbba8b0b725f29d2c47b45823d62a071d0 Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Thu, 28 Mar 2019 16:56:30 +0100 Subject: [PATCH] rule: parallelize projects calling UAST not under exchanges Fixes #766 Since rows are iterated serially (except for exchange nodes), when a Project node calls UAST, it's processed one at a time, which underutilizes the resources in the machine. When a Project calls UAST or UAST_MODE and is not under an Exchange node, which already parallelizes its children, replaces the project with a special node called parallelProject, which is essentially a project that keeps up to N goroutines processing rows, where N is the parallelism value of gitbase. Signed-off-by: Miguel Molina --- cmd/gitbase/command/server.go | 6 +- integration_test.go | 70 +++++++- internal/rule/parallel_project.go | 239 +++++++++++++++++++++++++ internal/rule/parallel_project_test.go | 52 ++++++ internal/rule/parallelize_uast.go | 76 ++++++++ internal/rule/parallelize_uast_test.go | 76 ++++++++ 6 files changed, 517 insertions(+), 2 deletions(-) create mode 100644 internal/rule/parallel_project.go create mode 100644 internal/rule/parallel_project_test.go create mode 100644 internal/rule/parallelize_uast.go create mode 100644 internal/rule/parallelize_uast_test.go diff --git a/cmd/gitbase/command/server.go b/cmd/gitbase/command/server.go index e49f6b810..401a9538f 100644 --- a/cmd/gitbase/command/server.go +++ b/cmd/gitbase/command/server.go @@ -94,7 +94,11 @@ func NewDatabaseEngine( ab = ab.AddPostAnalyzeRule(rule.SquashJoinsRule, rule.SquashJoins) } - a := ab.Build() + a := ab.AddPostAnalyzeRule( + rule.ParallelizeUASTProjectionsRule, + rule.ParallelizeUASTProjections, + ).Build() + engine := sqle.New(catalog, a, &sqle.Config{ VersionPostfix: version, Auth: userAuth, diff --git a/integration_test.go b/integration_test.go index 7313d3bdb..aa0aae05c 100644 --- a/integration_test.go +++ b/integration_test.go @@ -303,7 +303,8 @@ func TestIntegration(t *testing.T) { FROM files WHERE - language(file_path)="Go" + language(file_path) = "Go" + ORDER BY file_path ASC LIMIT 1`, []sql.Row{ {"go/example.go", int32(1)}, @@ -431,6 +432,73 @@ func TestUastQueries(t *testing.T) { } } +func TestUastParallelQueries(t *testing.T) { + engine, pool, cleanup := setup(t) + defer cleanup() + + pengine := newBaseEngine(pool) + + pengine.Catalog.RegisterFunctions(sqlfunction.Defaults) + pengine.Analyzer = analyzer.NewBuilder(engine.Catalog). + AddPostAnalyzeRule(rule.SquashJoinsRule, rule.SquashJoins). + AddPostAnalyzeRule(rule.ParallelizeUASTProjectionsRule, rule.ParallelizeUASTProjections). + Build() + + testCases := []string{ + `SELECT uast_xpath(uast(blob_content, language(tree_entry_name, blob_content)), '//Identifier') as uast, + tree_entry_name + FROM tree_entries te + INNER JOIN blobs b + ON b.blob_hash = te.blob_hash + WHERE te.tree_entry_name = 'example.go'`, + `SELECT uast_xpath(uast_mode('semantic', blob_content, language(tree_entry_name, blob_content)), '//Identifier') as uast, + tree_entry_name + FROM tree_entries te + INNER JOIN blobs b + ON b.blob_hash = te.blob_hash + WHERE te.tree_entry_name = 'example.go'`, + `SELECT uast_xpath(uast_mode('annotated', blob_content, language(tree_entry_name, blob_content)), '//*[@roleIdentifier]') as uast, + tree_entry_name + FROM tree_entries te + INNER JOIN blobs b + ON b.blob_hash = te.blob_hash + WHERE te.tree_entry_name = 'example.go'`, + `SELECT uast_xpath(uast_mode('native', blob_content, language(tree_entry_name, blob_content)), '//*[@ast_type=\'FunctionDef\']') as uast, + tree_entry_name + FROM tree_entries te + INNER JOIN blobs b + ON b.blob_hash = te.blob_hash + WHERE te.tree_entry_name = 'example.go'`, + } + + _ = testCases + + var pid uint64 + for _, query := range testCases { + pid++ + t.Run(query, func(t *testing.T) { + require := require.New(t) + + session := gitbase.NewSession(pool) + ctx := sql.NewContext(context.TODO(), sql.WithSession(session), sql.WithPid(pid)) + + _, iter, err := engine.Query(ctx, query) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + _, piter, err := pengine.Query(ctx, query) + require.NoError(err) + + prows, err := sql.RowIterToRows(piter) + require.NoError(err) + + require.ElementsMatch(rows, prows) + }) + } +} + func TestSquashCorrectness(t *testing.T) { engine, pool, cleanup := setup(t) defer cleanup() diff --git a/internal/rule/parallel_project.go b/internal/rule/parallel_project.go new file mode 100644 index 000000000..e7744785c --- /dev/null +++ b/internal/rule/parallel_project.go @@ -0,0 +1,239 @@ +package rule + +import ( + "context" + "io" + "strings" + "sync" + + opentracing "github.com/opentracing/opentracing-go" + "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-mysql-server.v0/sql/plan" +) + +type parallelProject struct { + *plan.Project + parallelism int +} + +func newParallelProject( + projection []sql.Expression, + child sql.Node, + parallelism int, +) *parallelProject { + return ¶llelProject{ + plan.NewProject(projection, child), + parallelism, + } +} + +func (p *parallelProject) RowIter(ctx *sql.Context) (sql.RowIter, error) { + span, ctx := ctx.Span( + "plan.Project", + opentracing.Tag{ + Key: "projections", + Value: len(p.Projections), + }, + opentracing.Tag{ + Key: "parallelism", + Value: p.parallelism, + }, + ) + + iter, err := p.Child.RowIter(ctx) + if err != nil { + span.Finish() + return nil, err + } + + return sql.NewSpanIter( + span, + newParallelIter(p.Projections, iter, ctx, p.parallelism), + ), nil +} + +func (p *parallelProject) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { + child, err := p.Child.TransformUp(f) + if err != nil { + return nil, err + } + + return f(newParallelProject(p.Projections, child, p.parallelism)) +} + +func (p *parallelProject) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { + var exprs = make([]sql.Expression, len(p.Projections)) + for i, e := range p.Projections { + expr, err := e.TransformUp(f) + if err != nil { + return nil, err + } + + exprs[i] = expr + } + + child, err := p.Child.TransformExpressionsUp(f) + if err != nil { + return nil, err + } + + return newParallelProject(exprs, child, p.parallelism), nil +} + +func (p *parallelProject) String() string { + pr := sql.NewTreePrinter() + var exprs = make([]string, len(p.Projections)) + for i, expr := range p.Projections { + exprs[i] = expr.String() + } + + _ = pr.WriteNode( + "gitbase.ParallelProject(%s, parallelism=%d)", + strings.Join(exprs, ", "), + p.parallelism, + ) + + _ = pr.WriteChildren(p.Child.String()) + return pr.String() +} + +type parallelIter struct { + projections []sql.Expression + child sql.RowIter + ctx *sql.Context + parallelism int + + cancel context.CancelFunc + rows chan sql.Row + errors chan error + done bool + + mut sync.Mutex + finished bool +} + +func newParallelIter( + projections []sql.Expression, + child sql.RowIter, + ctx *sql.Context, + parallelism int, +) *parallelIter { + var cancel context.CancelFunc + ctx.Context, cancel = context.WithCancel(ctx.Context) + + return ¶llelIter{ + projections: projections, + child: child, + ctx: ctx, + parallelism: parallelism, + cancel: cancel, + errors: make(chan error, parallelism), + } +} + +func (i *parallelIter) Next() (sql.Row, error) { + if i.done { + return nil, io.EOF + } + + if i.rows == nil { + i.rows = make(chan sql.Row, i.parallelism) + go i.start() + } + + select { + case row, ok := <-i.rows: + if !ok { + i.close() + return nil, io.EOF + } + return row, nil + case err := <-i.errors: + i.close() + return nil, err + } +} + +func (i *parallelIter) nextRow() (sql.Row, bool) { + i.mut.Lock() + defer i.mut.Unlock() + + if i.finished { + return nil, true + } + + row, err := i.child.Next() + if err != nil { + if err == io.EOF { + i.finished = true + } else { + i.errors <- err + } + return nil, true + } + + return row, false +} + +func (i *parallelIter) start() { + var wg sync.WaitGroup + wg.Add(i.parallelism) + for j := 0; j < i.parallelism; j++ { + go func() { + defer wg.Done() + + for { + select { + case <-i.ctx.Done(): + i.errors <- context.Canceled + return + default: + } + + row, stop := i.nextRow() + if stop { + return + } + + row, err := project(i.ctx, i.projections, row) + if err != nil { + i.errors <- err + return + } + + i.rows <- row + } + }() + } + + wg.Wait() + close(i.rows) +} + +func (i *parallelIter) close() { + if !i.done { + i.cancel() + i.done = true + } +} + +func (i *parallelIter) Close() error { + i.close() + return i.child.Close() +} + +func project( + s *sql.Context, + projections []sql.Expression, + row sql.Row, +) (sql.Row, error) { + var fields []interface{} + for _, expr := range projections { + f, err := expr.Eval(s, row) + if err != nil { + return nil, err + } + fields = append(fields, f) + } + return sql.NewRow(fields...), nil +} diff --git a/internal/rule/parallel_project_test.go b/internal/rule/parallel_project_test.go new file mode 100644 index 000000000..6bd864a9b --- /dev/null +++ b/internal/rule/parallel_project_test.go @@ -0,0 +1,52 @@ +package rule + +import ( + "fmt" + "io" + "runtime" + "testing" + + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-mysql-server.v0/mem" + "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "gopkg.in/src-d/go-mysql-server.v0/sql/plan" +) + +func TestParallelProject(t *testing.T) { + require := require.New(t) + ctx := sql.NewEmptyContext() + child := mem.NewTable("test", sql.Schema{ + {Name: "col1", Type: sql.Text, Nullable: true}, + {Name: "col2", Type: sql.Text, Nullable: true}, + }) + + var input, expected []sql.Row + for i := 1; i < 500; i++ { + input = append(input, sql.Row{ + fmt.Sprintf("col1_%d", i), fmt.Sprintf("col2_%d", i), + }) + + expected = append(expected, sql.Row{fmt.Sprintf("col2_%d", i)}) + } + + for _, row := range input { + require.NoError(child.Insert(sql.NewEmptyContext(), row)) + } + + p := newParallelProject( + []sql.Expression{expression.NewGetField(1, sql.Text, "col2", true)}, + plan.NewResolvedTable(child), + runtime.NumCPU(), + ) + + iter, err := p.RowIter(ctx) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + require.ElementsMatch(expected, rows) + + _, err = iter.Next() + require.Equal(io.EOF, err) +} diff --git a/internal/rule/parallelize_uast.go b/internal/rule/parallelize_uast.go new file mode 100644 index 000000000..8a6d77fea --- /dev/null +++ b/internal/rule/parallelize_uast.go @@ -0,0 +1,76 @@ +package rule + +import ( + "github.com/src-d/gitbase/internal/function" + "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-mysql-server.v0/sql/analyzer" + "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "gopkg.in/src-d/go-mysql-server.v0/sql/plan" +) + +// ParallelizeUASTProjectionsRule is the name of the rule. +const ParallelizeUASTProjectionsRule = "parallelize_uast_projections" + +// ParallelizeUASTProjections is a rule that whenever it finds a projection +// with a call to any uast function, it replaces it with a parallel version +// of the project node to execute several bblfsh requests in parallel. It +// will only do so if the project is not under an exchange node. +func ParallelizeUASTProjections( + ctx *sql.Context, + a *analyzer.Analyzer, + n sql.Node, +) (sql.Node, error) { + if a.Parallelism <= 1 { + return n, nil + } + + return n.TransformUp(func(n sql.Node) (sql.Node, error) { + switch n := n.(type) { + case *plan.Project: + if callsUAST(n.Projections) { + return newParallelProject(n.Projections, n.Child, a.Parallelism), nil + } + + return n, nil + case *plan.Exchange: + child, err := n.Child.TransformUp(removeParallelProjects) + if err != nil { + return nil, err + } + + return plan.NewExchange(n.Parallelism, child), nil + default: + return n, nil + } + }) +} + +func removeParallelProjects(n sql.Node) (sql.Node, error) { + p, ok := n.(*parallelProject) + if !ok { + return n, nil + } + + return plan.NewProject(p.Projections, p.Child), nil +} + +func callsUAST(exprs []sql.Expression) bool { + var seen bool + for _, e := range exprs { + expression.Inspect(e, func(e sql.Expression) bool { + switch e.(type) { + case *function.UAST, *function.UASTMode: + seen = true + return false + } + + return true + }) + + if seen { + return true + } + } + + return false +} diff --git a/internal/rule/parallelize_uast_test.go b/internal/rule/parallelize_uast_test.go new file mode 100644 index 000000000..05c0b1168 --- /dev/null +++ b/internal/rule/parallelize_uast_test.go @@ -0,0 +1,76 @@ +package rule + +import ( + "testing" + + "github.com/src-d/gitbase" + "github.com/src-d/gitbase/internal/function" + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-mysql-server.v0/sql/analyzer" + "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "gopkg.in/src-d/go-mysql-server.v0/sql/plan" +) + +func TestParallelizeUASTProjections(t *testing.T) { + require := require.New(t) + + tables := gitbase.NewDatabase("foo", gitbase.NewRepositoryPool(0)).Tables() + + uastFn, err := function.NewUAST( + expression.NewGetFieldWithTable(0, sql.Blob, "files", "blob_content", false), + ) + require.NoError(err) + + uastModeFn := function.NewUASTMode( + expression.NewLiteral("semantic", sql.Text), + expression.NewGetFieldWithTable(0, sql.Blob, "files", "blob_content", false), + expression.NewLiteral("Go", sql.Text), + ) + + node := plan.NewProject( + []sql.Expression{ + uastModeFn, + uastFn, + }, + plan.NewExchange( + 5, + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + uastFn, + "foo", + ), + }, + plan.NewResolvedTable(tables[gitbase.FilesTableName]), + ), + ), + ) + + a := analyzer.NewBuilder(nil).WithParallelism(4).Build() + + result, err := ParallelizeUASTProjections(sql.NewEmptyContext(), a, node) + require.NoError(err) + + expected := newParallelProject( + []sql.Expression{ + uastModeFn, + uastFn, + }, + plan.NewExchange( + 5, + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + uastFn, + "foo", + ), + }, + plan.NewResolvedTable(tables[gitbase.FilesTableName]), + ), + ), + 4, + ) + + require.Equal(expected, result) +}