diff --git a/go/callgraph/cha/cha.go b/go/callgraph/cha/cha.go index 3040f3d8bbc..67a03563602 100644 --- a/go/callgraph/cha/cha.go +++ b/go/callgraph/cha/cha.go @@ -25,12 +25,10 @@ package cha // import "golang.org/x/tools/go/callgraph/cha" // TODO(zpavlinovic): update CHA for how it handles generic function bodies. import ( - "go/types" - "golang.org/x/tools/go/callgraph" + "golang.org/x/tools/go/callgraph/internal/chautil" "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa/ssautil" - "golang.org/x/tools/go/types/typeutil" ) // CallGraph computes the call graph of the specified program using the @@ -53,13 +51,6 @@ func CallGraph(prog *ssa.Program) *callgraph.Graph { // (io.Writer).Write is assumed to call every concrete // Write method in the program, the call graph can // contain a lot of duplication. - // - // TODO(taking): opt: consider making lazyCallees public. - // Using the same benchmarks as callgraph_test.go, removing just - // the explicit callgraph.Graph construction is 4x less memory - // and is 37% faster. - // CHA 86 ms/op 16 MB/op - // lazyCallees 63 ms/op 4 MB/op for _, g := range callees { addEdge(fnode, site, g) } @@ -83,82 +74,4 @@ func CallGraph(prog *ssa.Program) *callgraph.Graph { return cg } -// lazyCallees returns a function that maps a call site (in a function in fns) -// to its callees within fns. -// -// The resulting function is not concurrency safe. -func lazyCallees(fns map[*ssa.Function]bool) func(site ssa.CallInstruction) []*ssa.Function { - // funcsBySig contains all functions, keyed by signature. It is - // the effective set of address-taken functions used to resolve - // a dynamic call of a particular signature. - var funcsBySig typeutil.Map // value is []*ssa.Function - - // methodsByID contains all methods, grouped by ID for efficient - // lookup. - // - // We must key by ID, not name, for correct resolution of interface - // calls to a type with two (unexported) methods spelled the same but - // from different packages. The fact that the concrete type implements - // the interface does not mean the call dispatches to both methods. - methodsByID := make(map[string][]*ssa.Function) - - // An imethod represents an interface method I.m. - // (There's no go/types object for it; - // a *types.Func may be shared by many interfaces due to interface embedding.) - type imethod struct { - I *types.Interface - id string - } - // methodsMemo records, for every abstract method call I.m on - // interface type I, the set of concrete methods C.m of all - // types C that satisfy interface I. - // - // Abstract methods may be shared by several interfaces, - // hence we must pass I explicitly, not guess from m. - // - // methodsMemo is just a cache, so it needn't be a typeutil.Map. - methodsMemo := make(map[imethod][]*ssa.Function) - lookupMethods := func(I *types.Interface, m *types.Func) []*ssa.Function { - id := m.Id() - methods, ok := methodsMemo[imethod{I, id}] - if !ok { - for _, f := range methodsByID[id] { - C := f.Signature.Recv().Type() // named or *named - if types.Implements(C, I) { - methods = append(methods, f) - } - } - methodsMemo[imethod{I, id}] = methods - } - return methods - } - - for f := range fns { - if f.Signature.Recv() == nil { - // Package initializers can never be address-taken. - if f.Name() == "init" && f.Synthetic == "package initializer" { - continue - } - funcs, _ := funcsBySig.At(f.Signature).([]*ssa.Function) - funcs = append(funcs, f) - funcsBySig.Set(f.Signature, funcs) - } else if obj := f.Object(); obj != nil { - id := obj.(*types.Func).Id() - methodsByID[id] = append(methodsByID[id], f) - } - } - - return func(site ssa.CallInstruction) []*ssa.Function { - call := site.Common() - if call.IsInvoke() { - tiface := call.Value.Type().Underlying().(*types.Interface) - return lookupMethods(tiface, call.Method) - } else if g := call.StaticCallee(); g != nil { - return []*ssa.Function{g} - } else if _, ok := call.Value.(*ssa.Builtin); !ok { - fns, _ := funcsBySig.At(call.Signature()).([]*ssa.Function) - return fns - } - return nil - } -} +var lazyCallees = chautil.LazyCallees diff --git a/go/callgraph/internal/chautil/lazy.go b/go/callgraph/internal/chautil/lazy.go new file mode 100644 index 00000000000..430bfea4564 --- /dev/null +++ b/go/callgraph/internal/chautil/lazy.go @@ -0,0 +1,96 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package chautil provides helper functions related to +// class hierarchy analysis (CHA) for use in x/tools. +package chautil + +import ( + "go/types" + + "golang.org/x/tools/go/ssa" + "golang.org/x/tools/go/types/typeutil" +) + +// LazyCallees returns a function that maps a call site (in a function in fns) +// to its callees within fns. The set of callees is computed using the CHA algorithm, +// i.e., on the entire implements relation between interfaces and concrete types +// in fns. Please see golang.org/x/tools/go/callgraph/cha for more information. +// +// The resulting function is not concurrency safe. +func LazyCallees(fns map[*ssa.Function]bool) func(site ssa.CallInstruction) []*ssa.Function { + // funcsBySig contains all functions, keyed by signature. It is + // the effective set of address-taken functions used to resolve + // a dynamic call of a particular signature. + var funcsBySig typeutil.Map // value is []*ssa.Function + + // methodsByID contains all methods, grouped by ID for efficient + // lookup. + // + // We must key by ID, not name, for correct resolution of interface + // calls to a type with two (unexported) methods spelled the same but + // from different packages. The fact that the concrete type implements + // the interface does not mean the call dispatches to both methods. + methodsByID := make(map[string][]*ssa.Function) + + // An imethod represents an interface method I.m. + // (There's no go/types object for it; + // a *types.Func may be shared by many interfaces due to interface embedding.) + type imethod struct { + I *types.Interface + id string + } + // methodsMemo records, for every abstract method call I.m on + // interface type I, the set of concrete methods C.m of all + // types C that satisfy interface I. + // + // Abstract methods may be shared by several interfaces, + // hence we must pass I explicitly, not guess from m. + // + // methodsMemo is just a cache, so it needn't be a typeutil.Map. + methodsMemo := make(map[imethod][]*ssa.Function) + lookupMethods := func(I *types.Interface, m *types.Func) []*ssa.Function { + id := m.Id() + methods, ok := methodsMemo[imethod{I, id}] + if !ok { + for _, f := range methodsByID[id] { + C := f.Signature.Recv().Type() // named or *named + if types.Implements(C, I) { + methods = append(methods, f) + } + } + methodsMemo[imethod{I, id}] = methods + } + return methods + } + + for f := range fns { + if f.Signature.Recv() == nil { + // Package initializers can never be address-taken. + if f.Name() == "init" && f.Synthetic == "package initializer" { + continue + } + funcs, _ := funcsBySig.At(f.Signature).([]*ssa.Function) + funcs = append(funcs, f) + funcsBySig.Set(f.Signature, funcs) + } else if obj := f.Object(); obj != nil { + id := obj.(*types.Func).Id() + methodsByID[id] = append(methodsByID[id], f) + } + } + + return func(site ssa.CallInstruction) []*ssa.Function { + call := site.Common() + if call.IsInvoke() { + tiface := call.Value.Type().Underlying().(*types.Interface) + return lookupMethods(tiface, call.Method) + } else if g := call.StaticCallee(); g != nil { + return []*ssa.Function{g} + } else if _, ok := call.Value.(*ssa.Builtin); !ok { + fns, _ := funcsBySig.At(call.Signature()).([]*ssa.Function) + return fns + } + return nil + } +} diff --git a/go/callgraph/vta/graph.go b/go/callgraph/vta/graph.go index 1eea423999e..1a9ed7cb321 100644 --- a/go/callgraph/vta/graph.go +++ b/go/callgraph/vta/graph.go @@ -9,7 +9,6 @@ import ( "go/token" "go/types" - "golang.org/x/tools/go/callgraph" "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/types/typeutil" "golang.org/x/tools/internal/aliases" @@ -274,8 +273,8 @@ func (g vtaGraph) addEdge(x, y node) { // typePropGraph builds a VTA graph for a set of `funcs` and initial // `callgraph` needed to establish interprocedural edges. Returns the // graph and a map for unique type representatives. -func typePropGraph(funcs map[*ssa.Function]bool, callgraph *callgraph.Graph) (vtaGraph, *typeutil.Map) { - b := builder{graph: make(vtaGraph), callGraph: callgraph} +func typePropGraph(funcs map[*ssa.Function]bool, callees calleesFunc) (vtaGraph, *typeutil.Map) { + b := builder{graph: make(vtaGraph), callees: callees} b.visit(funcs) return b.graph, &b.canon } @@ -283,8 +282,8 @@ func typePropGraph(funcs map[*ssa.Function]bool, callgraph *callgraph.Graph) (vt // Data structure responsible for linearly traversing the // code and building a VTA graph. type builder struct { - graph vtaGraph - callGraph *callgraph.Graph // initial call graph for creating flows at unresolved call sites. + graph vtaGraph + callees calleesFunc // initial call graph for creating flows at unresolved call sites. // Specialized type map for canonicalization of types.Type. // Semantically equivalent types can have different implementations, @@ -598,7 +597,7 @@ func (b *builder) call(c ssa.CallInstruction) { return } - siteCallees(c, b.callGraph)(func(f *ssa.Function) bool { + siteCallees(c, b.callees)(func(f *ssa.Function) bool { addArgumentFlows(b, c, f) site, ok := c.(ssa.Value) diff --git a/go/callgraph/vta/graph_test.go b/go/callgraph/vta/graph_test.go index 8ce4079c693..b32da4f54a6 100644 --- a/go/callgraph/vta/graph_test.go +++ b/go/callgraph/vta/graph_test.go @@ -205,11 +205,21 @@ func TestVTAGraphConstruction(t *testing.T) { t.Fatalf("couldn't find want in `%s`", file) } - g, _ := typePropGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog)) + fs := ssautil.AllFunctions(prog) + + // First test propagation with lazy-CHA initial call graph. + g, _ := typePropGraph(fs, makeCalleesFunc(fs, nil)) got := vtaGraphStr(g) if diff := setdiff(want, got); len(diff) > 0 { t.Errorf("`%s`: want superset of %v;\n got %v\ndiff: %v", file, want, got, diff) } + + // Repeat the test with explicit CHA initial call graph. + g, _ = typePropGraph(fs, makeCalleesFunc(fs, cha.CallGraph(prog))) + got = vtaGraphStr(g) + if diff := setdiff(want, got); len(diff) > 0 { + t.Errorf("`%s`: want superset of %v;\n got %v\ndiff: %v", file, want, got, diff) + } }) } } diff --git a/go/callgraph/vta/initial.go b/go/callgraph/vta/initial.go new file mode 100644 index 00000000000..4dddc4eee6d --- /dev/null +++ b/go/callgraph/vta/initial.go @@ -0,0 +1,37 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package vta + +import ( + "golang.org/x/tools/go/callgraph" + "golang.org/x/tools/go/callgraph/internal/chautil" + "golang.org/x/tools/go/ssa" +) + +// calleesFunc abstracts call graph in one direction, +// from call sites to callees. +type calleesFunc func(ssa.CallInstruction) []*ssa.Function + +// makeCalleesFunc returns an initial call graph for vta as a +// calleesFunc. If c is not nil, returns callees as given by c. +// Otherwise, it returns chautil.LazyCallees over fs. +func makeCalleesFunc(fs map[*ssa.Function]bool, c *callgraph.Graph) calleesFunc { + if c == nil { + return chautil.LazyCallees(fs) + } + return func(call ssa.CallInstruction) []*ssa.Function { + node := c.Nodes[call.Parent()] + if node == nil { + return nil + } + var cs []*ssa.Function + for _, edge := range node.Out { + if edge.Site == call { + cs = append(cs, edge.Callee.Func) + } + } + return cs + } +} diff --git a/go/callgraph/vta/utils.go b/go/callgraph/vta/utils.go index 27923362f1a..141eb077f9c 100644 --- a/go/callgraph/vta/utils.go +++ b/go/callgraph/vta/utils.go @@ -7,7 +7,6 @@ package vta import ( "go/types" - "golang.org/x/tools/go/callgraph" "golang.org/x/tools/go/ssa" "golang.org/x/tools/internal/aliases" "golang.org/x/tools/internal/typeparams" @@ -149,22 +148,14 @@ func sliceArrayElem(t types.Type) types.Type { } } -// siteCallees returns a go1.23 iterator for the callees for call site `c` -// given program `callgraph`. -func siteCallees(c ssa.CallInstruction, callgraph *callgraph.Graph) func(yield func(*ssa.Function) bool) { +// siteCallees returns a go1.23 iterator for the callees for call site `c`. +func siteCallees(c ssa.CallInstruction, callees calleesFunc) func(yield func(*ssa.Function) bool) { // TODO: when x/tools uses go1.23, change callers to use range-over-func // (https://go.dev/issue/65237). - node := callgraph.Nodes[c.Parent()] return func(yield func(*ssa.Function) bool) { - if node == nil { - return - } - - for _, edge := range node.Out { - if edge.Site == c { - if !yield(edge.Callee.Func) { - return - } + for _, callee := range callees(c) { + if !yield(callee) { + return } } } diff --git a/go/callgraph/vta/vta.go b/go/callgraph/vta/vta.go index 226f261d79c..72bd4a4d8b0 100644 --- a/go/callgraph/vta/vta.go +++ b/go/callgraph/vta/vta.go @@ -65,17 +65,20 @@ import ( // CallGraph uses the VTA algorithm to compute call graph for all functions // f:true in funcs. VTA refines the results of initial call graph and uses it -// to establish interprocedural type flow. The resulting graph does not have -// a root node. +// to establish interprocedural type flow. If initial is nil, VTA uses a more +// efficient approach to construct a CHA call graph. +// +// The resulting graph does not have a root node. // // CallGraph does not make any assumptions on initial types global variables // and function/method inputs can have. CallGraph is then sound, modulo use of // reflection and unsafe, if the initial call graph is sound. func CallGraph(funcs map[*ssa.Function]bool, initial *callgraph.Graph) *callgraph.Graph { - vtaG, canon := typePropGraph(funcs, initial) + callees := makeCalleesFunc(funcs, initial) + vtaG, canon := typePropGraph(funcs, callees) types := propagate(vtaG, canon) - c := &constructor{types: types, initial: initial, cache: make(methodCache)} + c := &constructor{types: types, callees: callees, cache: make(methodCache)} return c.construct(funcs) } @@ -85,7 +88,7 @@ func CallGraph(funcs map[*ssa.Function]bool, initial *callgraph.Graph) *callgrap type constructor struct { types propTypeMap cache methodCache - initial *callgraph.Graph + callees calleesFunc } func (c *constructor) construct(funcs map[*ssa.Function]bool) *callgraph.Graph { @@ -101,15 +104,15 @@ func (c *constructor) construct(funcs map[*ssa.Function]bool) *callgraph.Graph { func (c *constructor) constrct(g *callgraph.Graph, f *ssa.Function) { caller := g.CreateNode(f) for _, call := range calls(f) { - for _, c := range c.callees(call) { + for _, c := range c.resolves(call) { callgraph.AddEdge(caller, call, g.CreateNode(c)) } } } -// callees computes the set of functions to which VTA resolves `c`. The resolved -// functions are intersected with functions to which `initial` resolves `c`. -func (c *constructor) callees(call ssa.CallInstruction) []*ssa.Function { +// resolves computes the set of functions to which VTA resolves `c`. The resolved +// functions are intersected with functions to which `c.initial` resolves `c`. +func (c *constructor) resolves(call ssa.CallInstruction) []*ssa.Function { cc := call.Common() if cc.StaticCallee() != nil { return []*ssa.Function{cc.StaticCallee()} @@ -123,7 +126,7 @@ func (c *constructor) callees(call ssa.CallInstruction) []*ssa.Function { // Cover the case of dynamic higher-order and interface calls. var res []*ssa.Function resolved := resolve(call, c.types, c.cache) - siteCallees(call, c.initial)(func(f *ssa.Function) bool { + siteCallees(call, c.callees)(func(f *ssa.Function) bool { if _, ok := resolved[f]; ok { res = append(res, f) } diff --git a/go/callgraph/vta/vta_test.go b/go/callgraph/vta/vta_test.go index 67db1302afd..a6f2dcde03e 100644 --- a/go/callgraph/vta/vta_test.go +++ b/go/callgraph/vta/vta_test.go @@ -19,6 +19,14 @@ import ( ) func TestVTACallGraph(t *testing.T) { + errDiff := func(want, got, missing []string) { + t.Errorf("got:\n%s\n\nwant:\n%s\n\nmissing:\n%s\n\ndiff:\n%s", + strings.Join(got, "\n"), + strings.Join(want, "\n"), + strings.Join(missing, "\n"), + cmp.Diff(got, want)) // to aid debugging + } + for _, file := range []string{ "testdata/src/callgraph_static.go", "testdata/src/callgraph_ho.go", @@ -46,14 +54,18 @@ func TestVTACallGraph(t *testing.T) { t.Fatalf("couldn't find want in `%s`", file) } - g := CallGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog)) + // First test VTA with lazy-CHA initial call graph. + g := CallGraph(ssautil.AllFunctions(prog), nil) got := callGraphStr(g) if missing := setdiff(want, got); len(missing) > 0 { - t.Errorf("got:\n%s\n\nwant:\n%s\n\nmissing:\n%s\n\ndiff:\n%s", - strings.Join(got, "\n"), - strings.Join(want, "\n"), - strings.Join(missing, "\n"), - cmp.Diff(got, want)) // to aid debugging + errDiff(want, got, missing) + } + + // Repeat the test with explicit CHA initial call graph. + g = CallGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog)) + got = callGraphStr(g) + if missing := setdiff(want, got); len(missing) > 0 { + errDiff(want, got, missing) } }) } @@ -168,7 +180,7 @@ func TestVTACallGraphGo117(t *testing.T) { t.Fatalf("couldn't find want in `%s`", file) } - g, _ := typePropGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog)) + g, _ := typePropGraph(ssautil.AllFunctions(prog), makeCalleesFunc(nil, cha.CallGraph(prog))) got := vtaGraphStr(g) if diff := setdiff(want, got); len(diff) != 0 { t.Errorf("`%s`: want superset of %v;\n got %v", file, want, got)