Skip to content

Commit

Permalink
Fix convergence problem due to unbounded growing of assertion tree (#244
Browse files Browse the repository at this point in the history
)

This PR fixes the convergence issues causes by unbounded growing of
assertion trees with assignments in a loop, such as in the simple
example shown below.
```
func test() {
         a := &A{}
	for {
		a = a.f
	}
}
```
Here, the assertion tree should look like `root -> varAssertionNode (a)
-> fldAssertionNode (f)`. However, the assertion before was growing
unboundedly (`root -> a -> f -> f -> f -> ...`) until the
`stableRoundLimit` was hit. This simple function was taking 7 iterations
to converge, while after the code change it now takes only 3 iterations
to converge.

Taking this opportunity, I have also added tests for testing fixpoint
convergence. For this, I refactored existing infrastructure for
anonymous functions and generalized it.

Note: the goal of this PR is to improve performance only, and should not
have any effect on the reported errors.
  • Loading branch information
sonalmahajan15 authored May 21, 2024
1 parent fe712a8 commit 482b433
Show file tree
Hide file tree
Showing 9 changed files with 220 additions and 74 deletions.
66 changes: 14 additions & 52 deletions assertion/anonymousfunc/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ import (
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"go.uber.org/nilaway/config"
"go.uber.org/nilaway/nilawaytest"
"go.uber.org/nilaway/util/analysishelper"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/analysistest"
)

Expand Down Expand Up @@ -61,10 +61,20 @@ func TestClosureCollection(t *testing.T) {
require.NotZero(t, len(funcLitMap))

// Get the expected closure vars from comments written for each function literal in the test file.
expectedClosure := findExpectedClosure(pass)
require.Equal(t, len(expectedClosure), len(funcLitMap))
// FindExpectedValues inspects test files and gathers comment strings at the same line of the
// *ast.FuncLit nodes, so that we know which *ast.FuncLit node corresponds to which anonymous
// function comment in the source.
expectedValues := nilawaytest.FindExpectedValues(pass, _wantClosurePrefix)
require.Equal(t, len(expectedValues), len(funcLitMap))

funcLitExpectedClosure := make(map[*ast.FuncLit][]string)
for node, closureVars := range expectedValues {
if funcLit, ok := node.(*ast.FuncLit); ok {
funcLitExpectedClosure[funcLit] = closureVars
}
}

for funcLit, expectedVars := range expectedClosure {
for funcLit, expectedVars := range funcLitExpectedClosure {
info, ok := funcLitMap[funcLit]
require.True(t, ok)

Expand Down Expand Up @@ -120,54 +130,6 @@ func TestClosureCollection(t *testing.T) {
}
}

// findExpectedClosure inspects the files and gather the comment strings at the same line of the
// *ast.FuncLit nodes, so that we know which *ast.FuncLit node corresponds to which anonymous
// function comment in the source.
func findExpectedClosure(pass *analysis.Pass) map[*ast.FuncLit][]string {
results := make(map[*ast.FuncLit][]string)

for _, file := range pass.Files {

// Store a mapping between single comment's line number to its text.
comments := make(map[int]string)
for _, group := range file.Comments {
if len(group.List) != 1 {
continue
}
comment := group.List[0]
comments[pass.Fset.Position(comment.Pos()).Line] = comment.Text
}

// Now, find all *ast.FuncLit nodes and find their comment.
ast.Inspect(file, func(node ast.Node) bool {
n, ok := node.(*ast.FuncLit)
if !ok {
return true
}
text, ok := comments[pass.Fset.Position(n.Pos()).Line]
if !ok {
// It is ok to not leave annotations for a func lit node - it simply does not use
// any closure variables. We still need to traverse further since there could be
// comments for nested func lit nodes.
results[n] = nil
return true
}

// Trim the trailing slashes and extra spaces and extract the set of expected values.
text = strings.TrimSpace(strings.TrimPrefix(text, "//"))
text = strings.TrimSpace(strings.TrimPrefix(text, _wantClosurePrefix))
// If no closure variables are written after _wantClosurePrefix, we simply ignore it.
results[n] = nil
if len(text) != 0 {
results[n] = strings.Split(text, " ")
}
return true
})
}

return results
}

func TestMain(m *testing.M) {
// Enable anonymous function flag for tests. It is OK to not unset this flag since Go builds
// tests for each package into separate binaries and execute them in parallel [1]. So the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (a *A) foo() {
func noClosure() {
// For function literals that do not use closure variables, either write no comments or leave
// the list of closure variables empty after "expect_closure".
func() {
func() { // expect_closure:
func() { // expect_closure:
print("test")
}()
Expand Down
2 changes: 1 addition & 1 deletion assertion/function/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ func analyzeFunc(
}()

// Do the actual backpropagation.
funcTriggers, err := assertiontree.BackpropAcrossFunc(ctx, pass, funcDecl, funcContext, graph)
funcTriggers, _, _, err := assertiontree.BackpropAcrossFunc(ctx, pass, funcDecl, funcContext, graph)

// If any error occurs in back-propagating the function, we wrap the error with more information.
if err != nil {
Expand Down
63 changes: 63 additions & 0 deletions assertion/function/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ package function

import (
"context"
"fmt"
"go/ast"
"go/types"
"strconv"
"sync"
"testing"

Expand All @@ -27,13 +29,18 @@ import (
"go.uber.org/nilaway/assertion/anonymousfunc"
"go.uber.org/nilaway/assertion/function/assertiontree"
"go.uber.org/nilaway/assertion/function/functioncontracts"
"go.uber.org/nilaway/nilawaytest"
"go.uber.org/nilaway/util/analysishelper"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/analysistest"
"golang.org/x/tools/go/analysis/passes/ctrlflow"
"golang.org/x/tools/go/cfg"
)

// _wantFixpointPrefix is a prefix that we use in the test file to specify the expected fixpoint from BackpropAcrossFunc().
// Format: expect_fixpoint: <roundCount>,<stableRoundCount>,<number of triggers>
const _wantFixpointPrefix = "expect_fixpoint:"

func TestAnalyzer(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -137,6 +144,62 @@ func TestAnalyzeFuncPanic(t *testing.T) {
require.ErrorContains(t, res.err, "panic")
}

func TestBackpropFixpointConvergence(t *testing.T) {
t.Parallel()

testdata := analysistest.TestData()

// First do an analysis test run just to get the pass variable.
r := analysistest.Run(t, testdata, Analyzer, "go.uber.org/backprop")
pass := r[0].Pass

// Gather function declaration nodes from test.
var funcs []*ast.FuncDecl
for _, file := range pass.Files {
for _, decl := range file.Decls {
if f, ok := decl.(*ast.FuncDecl); ok {
funcs = append(funcs, f)
}
}
}
require.NotZero(t, len(funcs), "Cannot find any function declaration in test code")

for _, funcDecl := range funcs {
// Prepare the input variables for passing to BackpropAcrossFunc():
funcConfig := assertiontree.FunctionConfig{
EnableStructInitCheck: true,
EnableAnonymousFunc: true,
}
emptyFuncLitMap := make(map[*ast.FuncLit]*anonymousfunc.FuncLitInfo)
emptyPkgFakeIdentMap := make(map[*ast.Ident]types.Object)
emptyFuncContracts := make(functioncontracts.Map)
funcContext := assertiontree.NewFunctionContext(pass, funcDecl, nil, /* funcLit */
funcConfig, emptyFuncLitMap, emptyPkgFakeIdentMap, emptyFuncContracts)
ctrlflowResult := pass.ResultOf[ctrlflow.Analyzer].(*ctrlflow.CFGs)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Run the backpropagation algorithm and collect the results.
funcTriggers, roundCount, stableRoundCount, err := assertiontree.BackpropAcrossFunc(ctx, pass, funcDecl, funcContext, ctrlflowResult.FuncDecl(funcDecl))
require.NoError(t, err, "Backpropagation algorithm should not return an error")

expectedValues := nilawaytest.FindExpectedValues(pass, _wantFixpointPrefix)
expectedVals, ok := expectedValues[funcDecl]
if !ok {
// No expected values written in the test file, so we skip the comparison.
continue
}

require.Equal(t, len(expectedVals), 3, "Expected fixpoint values must have 3 elements: roundCount, stableRoundCount, numTriggers")

// Compare the expected fixpoint values with the actual results.
actualVals := []string{strconv.Itoa(roundCount), strconv.Itoa(stableRoundCount), strconv.Itoa(len(funcTriggers))}
require.EqualValues(t, expectedVals, actualVals, fmt.Sprintf("Fixpoint values mismatch for round count, "+
"stable round count, or number of triggers for func `%s`", funcDecl.Name.Name))
}
}

func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
14 changes: 7 additions & 7 deletions assertion/function/assertiontree/backprop.go
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ func computePostOrder(blocks []*cfg.Block) []int {
// with accompanying CFG, and back-propagates a tree of assertions across it to generate, at entry
// to the function, the set of assertions that must hold to avoid possible nil flow errors.
func BackpropAcrossFunc(ctx context.Context, pass *analysis.Pass, decl *ast.FuncDecl,
functionContext FunctionContext, graph *cfg.CFG) ([]annotation.FullTrigger, error) {
functionContext FunctionContext, graph *cfg.CFG) ([]annotation.FullTrigger, int, int, error) {
// We transform the CFG to have it reflect the implicit control flow that happens
// inside short-circuiting boolean expressions.
graph, richCheckBlocks, exprNonceMap := preprocess(graph, functionContext)
Expand Down Expand Up @@ -845,7 +845,7 @@ func BackpropAcrossFunc(ctx context.Context, pass *analysis.Pass, decl *ast.Func

select {
case <-ctx.Done():
return nil, fmt.Errorf("backprop early stop due to context: %w", ctx.Err())
return nil, roundCount, stableRoundCount, fmt.Errorf("backprop early stop due to context: %w", ctx.Err())
default:
}

Expand All @@ -858,7 +858,7 @@ func BackpropAcrossFunc(ctx context.Context, pass *analysis.Pass, decl *ast.Func
}

if len(block.Succs) > 2 {
return nil, errors.New("assumptions about CFG shape violated - a block has >2 successors")
return nil, roundCount, stableRoundCount, errors.New("assumptions about CFG shape violated - a block has >2 successors")
}

// No need to re-process the assertion node for the current block if it does not have
Expand Down Expand Up @@ -909,7 +909,7 @@ func BackpropAcrossFunc(ctx context.Context, pass *analysis.Pass, decl *ast.Func
// No assertion nodes attached with any successors, this should never happen since we
// will only reach here if any of the successors were updated in the current or last round.
if len(succs) == 0 {
return nil, fmt.Errorf("no assertion nodes for successors of block %d", block.Index)
return nil, roundCount, stableRoundCount, fmt.Errorf("no assertion nodes for successors of block %d", block.Index)
}

// Merge the branch successors if they are both available.
Expand All @@ -921,7 +921,7 @@ func BackpropAcrossFunc(ctx context.Context, pass *analysis.Pass, decl *ast.Func
nextAssertions[i] = succs[0]
err := backpropAcrossBlock(nextAssertions[i], blocks[i])
if err != nil {
return nil, err
return nil, roundCount, stableRoundCount, err
}

// Monotonize updates updatedThisRound to reflect whether the assertions changed at a given index.
Expand Down Expand Up @@ -982,7 +982,7 @@ func BackpropAcrossFunc(ctx context.Context, pass *analysis.Pass, decl *ast.Func

// Return the generated full triggers at the entry block; we're done!
if currRootAssertionNode == nil {
return nil, nil
return nil, roundCount, stableRoundCount, nil
}
return currRootAssertionNode.triggers, nil
return currRootAssertionNode.triggers, roundCount, stableRoundCount, nil
}
15 changes: 12 additions & 3 deletions assertion/function/assertiontree/root_assertion_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -963,9 +963,18 @@ func (r *RootAssertionNode) LiftFromPath(path TrackableExpr) (AssertionNode, boo
func (r *RootAssertionNode) LandAtPath(path TrackableExpr, node AssertionNode) {
if path != nil {
newRoot := r.linkPath(path)
newNode := path[len(path)-1]
newNode.SetConsumeTriggers(node.ConsumeTriggers())
newNode.SetChildren(node.Children())
lastNode := path[len(path)-1]
lastNode.SetConsumeTriggers(node.ConsumeTriggers())

// To restrict the assertion tree from growing unboundedly, we add node.children to `newNode` iff
// they are not equal to `newNode` itself.
var childrenToAdd []AssertionNode
for _, child := range node.Children() {
if !r.eqNodes(child, lastNode) {
childrenToAdd = append(childrenToAdd, child)
}
}
lastNode.SetChildren(childrenToAdd)

r.mergeInto(r, newRoot)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,34 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package optimization
// This package tests fixpoint convergence of the backpropagation algorithm.
// The format for specifying expected values is as follows:
// expect_fixpoint: <roundCount> <stableRoundCount> <number of triggers>

// This is a simple test to check the effectiveness of the optimization added via the `struct field analyzer` that enables NilAway to
// only create triggers for those fields of the struct that are being actively assigned (implying a potential side effect) in the function.
// This approach creates fewer number of triggers allowing NilAway to converge quicker without losing precision.
package backprop

// Without `struct field analyzer`, m23() in this simple test creates 670 triggers and converges in 31 iterations
// With `struct field analyzer` for assigned fields only, m23() in this simple test creates 70 triggers and converges in 18 iterations (as of Aug 29, 2022)
// With `struct field analyzer` for assigned and accessed fields, m23() in this simple test creates 40 triggers and converges in 18 iterations (as of Aug 31, 2022)
func testSimple() { // expect_fixpoint: 2 1 1
var x *int
_ = *x
}

func testEmptyBody() { // expect_fixpoint: 2 1 0
}

// (NOTE: above numbers are subject to change as NilAway evolves)
func testPanic() { // expect_fixpoint: 1 1 0
panic("some error")
}

type A struct {
ptr *int
aptr *A
newPtr *A
}

func m23() {
// This is a simple test to check the effectiveness of the optimization added via the `struct field analyzer` that enables NilAway to
// only create triggers for those fields of the struct that are being actively assigned (implying a potential side effect) in the function.
// This approach creates fewer number of triggers allowing NilAway to converge quicker without losing precision.
func testStructFieldAnalyzerEffect() { // expect_fixpoint: 4 2 40
a := &A{}
for dummy() {
switch dummy() {
Expand Down Expand Up @@ -72,3 +81,33 @@ func (*A) f10() {}
func dummy() bool {
return true
}

// test assignment in infinite loop

func testInfiniteLoop() { // expect_fixpoint: 3 1 2
a := &A{}
for {
a = a.aptr
}
}

// test repeated map assignment in a finite loop

type mapType map[string]interface{}

func Get(m mapType, key string) interface{} {
return m[key]
}

func testAssignmentInLoop(m mapType, key string) { // expect_fixpoint: 6 2 4
var value interface{}
value = m
for len(key) > 0 {
switch v := value.(type) {
case mapType:
value = Get(v, key)
case map[string]interface{}:
value = v[key]
}
}
}
2 changes: 1 addition & 1 deletion nilaway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func TestStructInit(t *testing.T) { //nolint:paralleltest
}()

testdata := analysistest.TestData()
analysistest.Run(t, testdata, Analyzer, "go.uber.org/structinit/funcreturnfields", "go.uber.org/structinit/local", "go.uber.org/structinit/global", "go.uber.org/structinit/paramfield", "go.uber.org/structinit/paramsideeffect", "go.uber.org/structinit/defaultfield", "go.uber.org/structinit/optimization")
analysistest.Run(t, testdata, Analyzer, "go.uber.org/structinit/funcreturnfields", "go.uber.org/structinit/local", "go.uber.org/structinit/global", "go.uber.org/structinit/paramfield", "go.uber.org/structinit/paramsideeffect", "go.uber.org/structinit/defaultfield")
}

func TestAnonymousFunction(t *testing.T) { //nolint:paralleltest
Expand Down
Loading

0 comments on commit 482b433

Please sign in to comment.