diff --git a/accumulation/analyzer.go b/accumulation/analyzer.go index e112b04c..96237006 100644 --- a/accumulation/analyzer.go +++ b/accumulation/analyzer.go @@ -183,7 +183,16 @@ func checkErrors(triggers []annotation.FullTrigger, annMap annotation.Map, diagn }, ) + // Delete all "always safe" special handlers, since they are not meant to be tested for the no infer case + finalTriggers := make([]annotation.FullTrigger, 0, len(filteredTriggers)) for _, trigger := range filteredTriggers { + if c, ok := trigger.Consumer.Annotation.(*annotation.UseAsReturn); ok && c.IsTrackingAlwaysSafe { + continue + } + finalTriggers = append(finalTriggers, trigger) + } + + for _, trigger := range finalTriggers { // Skip checking any full triggers we created by duplicating from contracted functions // to the caller function. if !trigger.CreatedFromDuplication && trigger.Check(annMap) { diff --git a/annotation/consume_trigger.go b/annotation/consume_trigger.go index 9794c316..680ef897 100644 --- a/annotation/consume_trigger.go +++ b/annotation/consume_trigger.go @@ -1092,8 +1092,9 @@ func DuplicateReturnConsumer(t *ConsumeTrigger, location token.Position) *Consum // used for functions with contracts since we need to duplicate the sites for context sensitivity. type UseAsReturn struct { *TriggerIfNonNil - IsNamedReturn bool - RetStmt *ast.ReturnStmt + IsNamedReturn bool + IsTrackingAlwaysSafe bool + RetStmt *ast.ReturnStmt } // equals returns true if the passed ConsumingAnnotationTrigger is equal to this one @@ -1101,6 +1102,7 @@ func (u *UseAsReturn) equals(other ConsumingAnnotationTrigger) bool { if other, ok := other.(*UseAsReturn); ok { return u.TriggerIfNonNil.equals(other.TriggerIfNonNil) && u.IsNamedReturn == other.IsNamedReturn && + u.IsTrackingAlwaysSafe == other.IsTrackingAlwaysSafe && u.RetStmt == other.RetStmt } return false diff --git a/annotation/produce_trigger.go b/annotation/produce_trigger.go index 9d19e4b4..754e2425 100644 --- a/annotation/produce_trigger.go +++ b/annotation/produce_trigger.go @@ -755,12 +755,15 @@ func (f FldReturnPrestring) String() string { // context sensitivity. type FuncReturn struct { *TriggerIfNilable + + IsFromRichCheckEffectFunc bool } // equals returns true if the passed ProducingAnnotationTrigger is equal to this one func (f *FuncReturn) equals(other ProducingAnnotationTrigger) bool { if other, ok := other.(*FuncReturn); ok { - return f.TriggerIfNilable.equals(other.TriggerIfNilable) + return f.TriggerIfNilable.equals(other.TriggerIfNilable) && + f.IsFromRichCheckEffectFunc == other.IsFromRichCheckEffectFunc } return false } diff --git a/assertion/function/analyzer.go b/assertion/function/analyzer.go index 348b59ef..133035db 100644 --- a/assertion/function/analyzer.go +++ b/assertion/function/analyzer.go @@ -281,9 +281,9 @@ func duplicateFullTriggersFromContractedFunctionsToCallers( for ctrtFunc, calls := range callsByCtrtFunc { r := funcResults[ctrtFunc] if r == nil { - // should not happen since funcResults should contain all the functions including any - // contracted functions. - panic(fmt.Sprintf("Did not find the contracted function %s in funcResults", ctrtFunc.Id())) + // The contracted function is imported from upstream, and the local package analysis + // does not involve it. + continue } for _, trigger := range r.triggers { // If the full trigger has a FuncParam producer or a UseAsReturn consumer, then create @@ -312,9 +312,9 @@ func duplicateFullTriggersFromContractedFunctionsToCallers( for funcObj, triggers := range dupTriggers { r := funcResults[funcObj] if r == nil { - // should not happen since funcResults should contain all the functions including any - // contracted functions. - panic(fmt.Sprintf("Did not find the contracted function %s in funcResults", funcObj.Id())) + // Should not happen since we would not have created the duplicated triggers if the + // contracted function is not involved in the analysis of local package. + panic(fmt.Sprintf("did not find the contracted function %s in funcResults", funcObj.Id())) } funcTriggers[r.index] = append(funcTriggers[r.index], triggers...) } diff --git a/assertion/function/assertiontree/backprop.go b/assertion/function/assertiontree/backprop.go index a2918539..9460fe56 100644 --- a/assertion/function/assertiontree/backprop.go +++ b/assertion/function/assertiontree/backprop.go @@ -207,7 +207,7 @@ func backpropAcrossReturn(rootNode *RootAssertionNode, node *ast.ReturnStmt) err isErrReturning := util.FuncIsErrReturning(funcObj) isOkReturning := util.FuncIsOkReturning(funcObj) - rootNode.AddNewTriggers(annotation.FullTrigger{ + trigger := annotation.FullTrigger{ Producer: &annotation.ProduceTrigger{ // since the value is being returned directly, only its shallow nilability // matters (but deep would matter if we were enforcing correct variance) @@ -230,7 +230,29 @@ func backpropAcrossReturn(rootNode *RootAssertionNode, node *ast.ReturnStmt) err // interpreted as guarded GuardMatched: isErrReturning || isOkReturning, }, - }) + } + + // This is a duplicate trigger for tracking "always safe" paths. The analysis of these triggers + // will be processed at the inference stage. + triggerAlwaysSafe := annotation.FullTrigger{ + Producer: trigger.Producer, + Consumer: &annotation.ConsumeTrigger{ + Annotation: &annotation.UseAsReturn{ + TriggerIfNonNil: &annotation.TriggerIfNonNil{ + Ann: annotation.RetKeyFromRetNum( + rootNode.ObjectOf(rootNode.FuncNameIdent()).(*types.Func), + i, + )}, + RetStmt: node, + IsTrackingAlwaysSafe: true, + }, + Expr: trigger.Consumer.Expr, + Guards: trigger.Consumer.Guards, + GuardMatched: trigger.Consumer.GuardMatched, + }, + } + + rootNode.AddNewTriggers(trigger, triggerAlwaysSafe) } } rootNode.AddComputation(call) diff --git a/assertion/function/assertiontree/backprop_util.go b/assertion/function/assertiontree/backprop_util.go index b3b7617c..01afa882 100644 --- a/assertion/function/assertiontree/backprop_util.go +++ b/assertion/function/assertiontree/backprop_util.go @@ -268,6 +268,9 @@ func handleErrorReturns(rootNode *RootAssertionNode, retStmt *ast.ReturnStmt, re errRetExpr := results[errRetIndex] // n-th expression nonErrRetExpr := results[:errRetIndex] // n-1 expressions + // default tracking to support potential "always safe" cases + createReturnConsumersForAlwaysSafe(rootNode, nonErrRetExpr, retStmt, isNamedReturn) + // check if the error return is at all guarding any nilable returns, such as pointers, maps, and slices if isErrorReturnNil(rootNode, errRetExpr) { // if error is the only return expression in the statement, then create a consumer for it, else create consumers for the non-error return expressions @@ -329,6 +332,9 @@ func handleBooleanReturns(rootNode *RootAssertionNode, retStmt *ast.ReturnStmt, return false } + // default tracking to support potential "always safe" cases + createReturnConsumersForAlwaysSafe(rootNode, nMinusOneRetExpr, retStmt, isNamedReturn) + // If return is "true", then track its n-1 returns. Create return consume triggers for all n-1 return expressions. // If return is "false", then do nothing, since we don't track boolean values. if val { @@ -371,6 +377,32 @@ func createGeneralReturnConsumers(rootNode *RootAssertionNode, results []ast.Exp } } +// createReturnConsumersForAlwaysSafe creates return consumers for the non-return expressions in the return statement +// for tracking potential "always safe" cases +func createReturnConsumersForAlwaysSafe(rootNode *RootAssertionNode, nonErrResults []ast.Expr, retStmt *ast.ReturnStmt, isNamedReturn bool) { + for i := range nonErrResults { + // don't do anything if the expression is a blank identifier ("_") + if util.IsEmptyExpr(nonErrResults[i]) { + continue + } + + rootNode.AddConsumption(&annotation.ConsumeTrigger{ + Annotation: &annotation.UseAsReturn{ + TriggerIfNonNil: &annotation.TriggerIfNonNil{ + Ann: &annotation.RetAnnotationKey{ + FuncDecl: rootNode.FuncObj(), + RetNum: i, + }, + }, + IsNamedReturn: isNamedReturn, + IsTrackingAlwaysSafe: true, + RetStmt: retStmt}, + Expr: nonErrResults[i], + Guards: util.NoGuards(), + }) + } +} + // createSpecialConsumersForAllReturns conservatively creates specially designed consumers for all return expressions, error and non-error func createSpecialConsumersForAllReturns(rootNode *RootAssertionNode, nonErrRetExpr []ast.Expr, errRetExpr ast.Expr, errRetIndex int, retStmt *ast.ReturnStmt, isNamedReturn bool) { for i := range nonErrRetExpr { diff --git a/assertion/function/assertiontree/parse_expr_producer.go b/assertion/function/assertiontree/parse_expr_producer.go index 7f76ff24..c55cabf9 100644 --- a/assertion/function/assertiontree/parse_expr_producer.go +++ b/assertion/function/assertiontree/parse_expr_producer.go @@ -481,6 +481,7 @@ func (r *RootAssertionNode) getFuncReturnProducers(ident *ast.Ident, expr *ast.C // such as "error-nonnil" or "always-nonnil" NeedsGuard: (isErrReturning || isOkReturning) && i != numResults-1, }, + IsFromRichCheckEffectFunc: isErrReturning || isOkReturning, }, Expr: expr, }, diff --git a/assertion/function/assertiontree/root_assertion_node.go b/assertion/function/assertiontree/root_assertion_node.go index ea10f826..3fb61423 100644 --- a/assertion/function/assertiontree/root_assertion_node.go +++ b/assertion/function/assertiontree/root_assertion_node.go @@ -258,7 +258,18 @@ func (r *RootAssertionNode) AddConsumption(consumer *annotation.ConsumeTrigger) path, producers := r.ParseExprAsProducer(consumer.Expr, false) if path == nil { // expr is not trackable if producers == nil { - return // expr is not trackable, but cannot be nil, so do nothing + // Here we can infer that the expression is non-nil by definition. Instead of ignoring creation of a trigger, + // particularly for always safe tracking, we create a trigger with ProduceTriggerNever. + if c, ok := consumer.Annotation.(*annotation.UseAsReturn); ok && c.IsTrackingAlwaysSafe { + r.AddNewTriggers(annotation.FullTrigger{ + Producer: &annotation.ProduceTrigger{ + Annotation: &annotation.ProduceTriggerNever{}, + Expr: consumer.Expr, + }, + Consumer: consumer, + }) + } + return } if len(producers) != 1 { panic("multiply-returning function call was passed to AddConsumption") diff --git a/assertion/function/functioncontracts/analyzer.go b/assertion/function/functioncontracts/analyzer.go index 1aa1e067..f6a7c6db 100644 --- a/assertion/function/functioncontracts/analyzer.go +++ b/assertion/function/functioncontracts/analyzer.go @@ -43,27 +43,78 @@ var Analyzer = &analysis.Analyzer{ Doc: _doc, Run: analysishelper.WrapRun(run), ResultType: reflect.TypeOf((*analysishelper.Result[Map])(nil)), + FactTypes: []analysis.Fact{new(Contracts)}, Requires: []*analysis.Analyzer{config.Analyzer, buildssa.Analyzer}, } +// Contracts represents the list of contracts for a function. +type Contracts []Contract + +// AFact enables use of the facts passing mechanism in Go's analysis framework. +func (*Contracts) AFact() {} + +// Map stores the mappings from *types.Func to associated function contracts. +type Map map[*types.Func]Contracts + func run(pass *analysis.Pass) (Map, error) { conf := pass.ResultOf[config.Analyzer].(*config.Config) - if !conf.IsPkgInScope(pass.Pkg) { - return Map{}, nil + return make(Map), nil } + // Collect contracts from the current package. contracts, err := collectFunctionContracts(pass) if err != nil { return nil, err } + + // The fact mechanism only allows exporting pointer types. However, internally we are using + // `Contract` as a value type because it is an underlying slice type (such that making it a + // pointer type will make the rest of the logic more complicated). Therefore, we strictly + // only convert it from/to a pointer type _here_ during the fact import/exports. Everywhere + // else in NilAway (this sub-analyzer, as well as the other analyzers) we treat `Contract` + // simply as a value type. + + // Import contracts from upstream packages and merge it with the local contract map. + for _, fact := range pass.AllObjectFacts() { + fn, ok := fact.Object.(*types.Func) + if !ok { + continue + } + ctrts, ok := fact.Fact.(*Contracts) + if !ok || ctrts == nil { + continue + } + // The existing contracts are imported from upstream packages about upstream functions, + // therefore there should not be any conflicts with contracts collected from the current package. + if _, ok := contracts[fn]; ok { + return nil, fmt.Errorf("function %s has multiple contracts", fn.Name()) + } + contracts[fn] = *ctrts + } + + // Now, export the contracts for the _exported_ functions in the current package only. + for fn, ctrts := range contracts { + // Check if the function is (1) exported by name (i.e., starts with a capital letter), (2) + // it is directly inside the package scope (such that it is really visible in downstream + // packages). + if fn.Exported() && + // fn.Scope() -> the scope of the function body. + fn.Scope() != nil && + // fn.Scope().Parent() -> the scope of the file. + fn.Scope().Parent() != nil && + // fn.Scope().Parent().Parent() -> the scope of the package. + fn.Scope().Parent().Parent() == pass.Pkg.Scope() { + pass.ExportObjectFact(fn, &ctrts) + } + } return contracts, nil } // functionResult is the struct that is received from the channel for each function. type functionResult struct { funcObj *types.Func - contracts []*FunctionContract + contracts Contracts err error } @@ -72,8 +123,9 @@ type functionResult struct { // the comments at the top of each function. Only when there are no handwritten contracts there, // do we try to automatically infer contracts. func collectFunctionContracts(pass *analysis.Pass) (Map, error) { - // Collect ssa for every function. conf := pass.ResultOf[config.Analyzer].(*config.Config) + + // Collect ssa for every function. ssaInput := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA) ssaOfFunc := make(map[*types.Func]*ssa.Function, len(ssaInput.SrcFuncs)) for _, fnssa := range ssaInput.SrcFuncs { @@ -155,7 +207,7 @@ func collectFunctionContracts(pass *analysis.Pass) (Map, error) { defer func() { if r := recover(); r != nil { e := fmt.Errorf("INTERNAL PANIC: %s\n%s", r, string(debug.Stack())) - funcChan <- functionResult{err: e, funcObj: funcObj, contracts: []*FunctionContract{}} + funcChan <- functionResult{err: e, funcObj: funcObj} } }() diff --git a/assertion/function/functioncontracts/analyzer_test.go b/assertion/function/functioncontracts/analyzer_test.go index 3e97c87e..c9b67976 100644 --- a/assertion/function/functioncontracts/analyzer_test.go +++ b/assertion/function/functioncontracts/analyzer_test.go @@ -17,6 +17,7 @@ package functioncontracts import ( "fmt" "go/types" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -42,7 +43,7 @@ func TestContractCollection(t *testing.T) { testdata := analysistest.TestData() - r := analysistest.Run(t, testdata, Analyzer, "go.uber.org/functioncontracts/parse") + r := analysistest.Run(t, testdata, Analyzer, "go.uber.org/parse") require.Equal(t, 1, len(r)) require.NotNil(t, r[0]) @@ -53,31 +54,31 @@ func TestContractCollection(t *testing.T) { require.NotNil(t, funcContractsMap) - actualNameToContracts := map[*types.Func][]*FunctionContract{} + actual := make(Map) for funcObj, contracts := range funcContractsMap { - actualNameToContracts[funcObj] = contracts + actual[funcObj] = contracts } - expectedNameToContracts := map[*types.Func][]*FunctionContract{ + expected := Map{ getFuncObj(pass, "f1"): { - &FunctionContract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, + Contract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, }, getFuncObj(pass, "f2"): { - &FunctionContract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{True}}, + Contract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{True}}, }, getFuncObj(pass, "f3"): { - &FunctionContract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{False}}, + Contract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{False}}, }, getFuncObj(pass, "multipleValues"): { - &FunctionContract{Ins: []ContractVal{Any, NonNil}, Outs: []ContractVal{NonNil, True}}, + Contract{Ins: []ContractVal{Any, NonNil}, Outs: []ContractVal{NonNil, True}}, }, getFuncObj(pass, "multipleContracts"): { - &FunctionContract{Ins: []ContractVal{Any, NonNil}, Outs: []ContractVal{NonNil, True}}, - &FunctionContract{Ins: []ContractVal{NonNil, Any}, Outs: []ContractVal{NonNil, True}}, + Contract{Ins: []ContractVal{Any, NonNil}, Outs: []ContractVal{NonNil, True}}, + Contract{Ins: []ContractVal{NonNil, Any}, Outs: []ContractVal{NonNil, True}}, }, // function contractCommentInOtherLine should not exist in the map as it has no contract. } - if diff := cmp.Diff(expectedNameToContracts, actualNameToContracts); diff != "" { + if diff := cmp.Diff(expected, actual); diff != "" { require.Fail(t, fmt.Sprintf("parsed contracts mismatch (-want +got):\n%s", diff)) } } @@ -85,7 +86,7 @@ func TestInfer(t *testing.T) { t.Parallel() testdata := analysistest.TestData() - r := analysistest.Run(t, testdata, Analyzer, "go.uber.org/functioncontracts/infer") + r := analysistest.Run(t, testdata, Analyzer, "go.uber.org/infer") require.Equal(t, 1, len(r)) require.NotNil(t, r[0]) @@ -97,49 +98,101 @@ func TestInfer(t *testing.T) { require.NotNil(t, funcContractsMap) - actualNameToContracts := map[*types.Func][]*FunctionContract{} + actual := make(Map) for funcObj, contracts := range funcContractsMap { - actualNameToContracts[funcObj] = contracts + actual[funcObj] = contracts } - expectedNameToContracts := map[*types.Func][]*FunctionContract{ + expected := Map{ getFuncObj(pass, "onlyLocalVar"): { - &FunctionContract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, + Contract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, }, getFuncObj(pass, "unknownCondition"): { - &FunctionContract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, + Contract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, }, getFuncObj(pass, "noLocalVar"): { - &FunctionContract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, + Contract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, }, getFuncObj(pass, "learnUnderlyingFromOuterMakeInterface"): { - &FunctionContract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, + Contract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, }, getFuncObj(pass, "twoCondsMerge"): { - &FunctionContract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, + Contract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, }, getFuncObj(pass, "unknownToUnknownButSameValue"): { - &FunctionContract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, + Contract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, }, // other functions should not exist in the map as the contract nonnil->nonnil does not hold // for them. // TODO: uncomment this when we support field access when inferring contracts. // getFuncObj(pass, "field"): { - // &FunctionContract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, + // Contract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, // }, // TODO: uncomment this when we support nonempty slice to nonnil. // getFuncObj(pass, "nonEmptySliceToNonnil"): { - // &FunctionContract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, + // Contract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, // }, } - if diff := cmp.Diff(expectedNameToContracts, actualNameToContracts); diff != "" { + if diff := cmp.Diff(expected, actual); diff != "" { + require.Fail(t, fmt.Sprintf("inferred contracts mismatch (-want +got):\n%s", diff)) + } +} + +func TestFactExport(t *testing.T) { + t.Parallel() + + testdata := analysistest.TestData() + // The exported facts are asserted in the testdata file themselves in "want" strings. + analysistest.Run(t, testdata, Analyzer, "go.uber.org/factexport/upstream") +} + +func TestFactImport(t *testing.T) { + t.Parallel() + + // Now we test the import of the contract facts. The downstream package has a dependency on + // the upstream package (which contains several contracted functions). It should be able to + // import those facts, combine them with its own contracts, and return the combined map. + + testdata := analysistest.TestData() + r := analysistest.Run(t, testdata, Analyzer, "go.uber.org/factexport/downstream") + require.Len(t, r, 1) + pass, result := r[0].Pass, r[0].Result + require.IsType(t, &analysishelper.Result[Map]{}, result) + require.NoError(t, result.(*analysishelper.Result[Map]).Err) + actual := result.(*analysishelper.Result[Map]).Res + + expected := Map{ + getFuncObj(pass, "localManual"): { + Contract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, + }, + getFuncObj(pass, "upstream.ExportedManual"): { + Contract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, + }, + getFuncObj(pass, "upstream.ExportedInferred"): { + Contract{Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, + }, + } + if diff := cmp.Diff(expected, actual); diff != "" { require.Fail(t, fmt.Sprintf("inferred contracts mismatch (-want +got):\n%s", diff)) } } func getFuncObj(pass *analysis.Pass, name string) *types.Func { - return pass.Pkg.Scope().Lookup(name).(*types.Func) + parts := strings.Split(name, ".") + if len(parts) == 1 { + return pass.Pkg.Scope().Lookup(parts[0]).(*types.Func) + } + if len(parts) > 2 { + panic(fmt.Sprintf("invalid function name to look up, expected name or pkg.name, got %q", name)) + } + for _, imported := range pass.Pkg.Imports() { + if imported.Name() == parts[0] { + return imported.Scope().Lookup(parts[1]).(*types.Func) + } + } + + panic(fmt.Sprintf("cannot find function %q", name)) } func TestMain(m *testing.M) { diff --git a/assertion/function/functioncontracts/function_contracts_map.go b/assertion/function/functioncontracts/contract.go similarity index 78% rename from assertion/function/functioncontracts/function_contracts_map.go rename to assertion/function/functioncontracts/contract.go index 6445ab93..f381228a 100644 --- a/assertion/function/functioncontracts/function_contracts_map.go +++ b/assertion/function/functioncontracts/contract.go @@ -14,10 +14,6 @@ package functioncontracts -import ( - "go/types" -) - // ContractVal represents the possible value appearing in a function contract. type ContractVal string @@ -32,8 +28,8 @@ const ( Any ContractVal = "_" ) -// stringToContractVal converts a keyword string into the corresponding function ContractVal. -func stringToContractVal(keyword string) ContractVal { +// newContractVal converts a keyword string into the corresponding function ContractVal. +func newContractVal(keyword string) ContractVal { switch keyword { case "nonnil": return NonNil @@ -51,11 +47,10 @@ func stringToContractVal(keyword string) ContractVal { } } -// FunctionContract represents a function contract. -type FunctionContract struct { - Ins []ContractVal +// Contract represents a function contract. +type Contract struct { + // Ins is the list of input contract values, where the index is the index of the parameter. + Ins []ContractVal + // Outs is the list of output contract values, where the index is the index of the return. Outs []ContractVal } - -// Map stores the mappings from *types.Func to associated function contracts. -type Map map[*types.Func][]*FunctionContract diff --git a/assertion/function/functioncontracts/infer.go b/assertion/function/functioncontracts/infer.go index c84e139a..036f6f59 100644 --- a/assertion/function/functioncontracts/infer.go +++ b/assertion/function/functioncontracts/infer.go @@ -30,7 +30,7 @@ const _maxNumTablesPerBlock = 1024 // inferContracts infers function contracts for a function if it has no contracts written. It // returns a list of inferred contracts, which may be empty if no contract is inferred but is never // nil. -func inferContracts(fn *ssa.Function) []*FunctionContract { +func inferContracts(fn *ssa.Function) Contracts { nilnessTableSetByBB := make(map[*ssa.BasicBlock]nilnessTableSet) retInstrs := getReturnInstrs(fn) // TODO: Consider *ssa.Panic // No need of an expensive dataflow analysis if we can derive contracts from the return @@ -140,7 +140,7 @@ func inferContracts(fn *ssa.Function) []*FunctionContract { // TODO: nicely handle exponential explosion of tables. if len(nilnessTableSetByBB[b]) >= _maxNumTablesPerBlock { // Too many tables, we should give up inferring contracts for this function. - return []*FunctionContract{} + return nil } // Add successors to queue since the nilness table set of this block has been updated. @@ -204,7 +204,8 @@ func learnNilness(succ *ssa.BasicBlock, pred *ssa.BasicBlock, table nilnessTable func deriveContracts( retInstrs []*ssa.Return, fn *ssa.Function, - nilnessTableSetByBB map[*ssa.BasicBlock]nilnessTableSet) []*FunctionContract { + nilnessTableSetByBB map[*ssa.BasicBlock]nilnessTableSet, +) Contracts { // TODO: verify other or multiple param/return contracts in the future; for now we consider // contract(nonnil->nonnil) only. param := fn.Params[0] @@ -258,7 +259,7 @@ func deriveContracts( continue } // All the remaining cases are counterexamples to contract(nonnil->nonnil) - return []*FunctionContract{} + return nil } } @@ -274,14 +275,14 @@ func deriveContracts( // infer nonnil->nonnil. if (nilParamChoices == totalChoices && nonnilOrUnknownParamChoices == 0) || nonnilRetChoices == totalChoices { - return []*FunctionContract{} + return nil } // totalChoices > nilParamChoices >= 0 && totalChoices >= nonnilOrUnknownParamChoices > 0 && // nonnilRetChoices < totalChoices // nonnil->nonnil is valid at all exit blocks - return []*FunctionContract{ + return Contracts{ {Ins: []ContractVal{NonNil}, Outs: []ContractVal{NonNil}}, } } diff --git a/assertion/function/functioncontracts/parse.go b/assertion/function/functioncontracts/parse.go index 5de07995..d4fd3814 100644 --- a/assertion/function/functioncontracts/parse.go +++ b/assertion/function/functioncontracts/parse.go @@ -37,26 +37,22 @@ var _contractRE = regexp.MustCompile( // parseContracts parses a slice of function contracts from a singe comment group. If no contract // is found from the comment group, an empty slice is returned. -func parseContracts(doc *ast.CommentGroup) []*FunctionContract { - contracts := make([]*FunctionContract, 0) +func parseContracts(doc *ast.CommentGroup) Contracts { if doc == nil { - return contracts + return nil } + + var contracts Contracts for _, lineComment := range doc.List { - res := _contractRE.FindAllStringSubmatch(lineComment.Text, -1) - if res == nil { - continue - } - for _, matching := range res { + for _, matching := range _contractRE.FindAllStringSubmatch(lineComment.Text, -1) { // matching is a slice of three elements; the first is the whole matched string and the // next two are the captured groups of contract values before and after `->`. ins := parseListOfContractValues(matching[1]) outs := parseListOfContractValues(matching[2]) - ctrt := &FunctionContract{ + contracts = append(contracts, Contract{ Ins: ins, Outs: outs, - } - contracts = append(contracts, ctrt) + }) } } return contracts @@ -68,7 +64,7 @@ func parseListOfContractValues(wholeStr string) []ContractVal { valKeywords := strings.Split(wholeStr, _sep) contractVals := make([]ContractVal, len(valKeywords)) for i, v := range valKeywords { - contractVals[i] = stringToContractVal(strings.TrimSpace(v)) + contractVals[i] = newContractVal(strings.TrimSpace(v)) } return contractVals } diff --git a/assertion/function/functioncontracts/testdata/src/go.uber.org/factexport/downstream/downstream.go b/assertion/function/functioncontracts/testdata/src/go.uber.org/factexport/downstream/downstream.go new file mode 100644 index 00000000..100f839b --- /dev/null +++ b/assertion/function/functioncontracts/testdata/src/go.uber.org/factexport/downstream/downstream.go @@ -0,0 +1,37 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package downstream + +import ( + "go.uber.org/factexport/upstream" +) + +func foo() { + // The contract sub-analyzer does not really report potential nil panics, the following + // calls are just to ensure we add the upstream dependency and the sub-analyzer is able to + // import facts about it. + upstream.ExportedManual(nil) + upstream.ExportedInferred(nil) +} + +// This is a local function that has a contract that should be combined with the imported facts. +// contract(nonnil -> nonnil) +func localManual(p *int) *int { + if p != nil { + a := 1 + return &a + } + return nil +} diff --git a/assertion/function/functioncontracts/testdata/src/go.uber.org/factexport/upstream/upstream.go b/assertion/function/functioncontracts/testdata/src/go.uber.org/factexport/upstream/upstream.go new file mode 100644 index 00000000..4d3dd4b9 --- /dev/null +++ b/assertion/function/functioncontracts/testdata/src/go.uber.org/factexport/upstream/upstream.go @@ -0,0 +1,51 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package upstream + +// This tests the export of contracts from the upstream package. + +//contract(nonnil -> nonnil) +func ExportedManual(p *int) *int { //want ExportedManual:"&\\[{\\[nonnil\\] \\[nonnil\\]}\\]" + if p != nil { + a := 1 + return &a + } + return nil +} + +func ExportedInferred(p *int) *int { //want ExportedInferred:"&\\[{\\[nonnil\\] \\[nonnil\\]}\\]" + if p != nil { + a := 1 + return &a + } + return nil +} + +//contract(nonnil -> nonnil) +func unexportedManual(p *int) *int { // Notice here we do not want to export the contracts for it. + if p != nil { + a := 1 + return &a + } + return nil +} + +func unexportedInferred(p *int) *int { // Notice here we do not want to export the contracts for it. + if p != nil { + a := 1 + return &a + } + return nil +} diff --git a/assertion/function/functioncontracts/testdata/src/go.uber.org/functioncontracts/infer/external.go b/assertion/function/functioncontracts/testdata/src/go.uber.org/infer/external.go similarity index 100% rename from assertion/function/functioncontracts/testdata/src/go.uber.org/functioncontracts/infer/external.go rename to assertion/function/functioncontracts/testdata/src/go.uber.org/infer/external.go diff --git a/assertion/function/functioncontracts/testdata/src/go.uber.org/functioncontracts/infer/main.go b/assertion/function/functioncontracts/testdata/src/go.uber.org/infer/main.go similarity index 100% rename from assertion/function/functioncontracts/testdata/src/go.uber.org/functioncontracts/infer/main.go rename to assertion/function/functioncontracts/testdata/src/go.uber.org/infer/main.go diff --git a/assertion/function/functioncontracts/testdata/src/go.uber.org/functioncontracts/parse/main.go b/assertion/function/functioncontracts/testdata/src/go.uber.org/parse/main.go similarity index 100% rename from assertion/function/functioncontracts/testdata/src/go.uber.org/functioncontracts/parse/main.go rename to assertion/function/functioncontracts/testdata/src/go.uber.org/parse/main.go diff --git a/inference/engine.go b/inference/engine.go index 01e7d753..31940506 100644 --- a/inference/engine.go +++ b/inference/engine.go @@ -145,6 +145,32 @@ func (e *Engine) ObserveAnnotations(pkgAnnotations *annotation.ObservedMap, mode }, mode != NoInfer) } +// mapGuardMissingAndReturnToFuncSite returns two maps: +// 1. A map with key being the function return site and value being the list of indices of guard-missing triggers matching the site. +// 2. A map with key being the function return site and value being the list of indices of return triggers matching the site. +func (e *Engine) mapGuardMissingAndReturnToFuncSite(triggers []annotation.FullTrigger) (map[primitiveSite][]int, map[primitiveSite][]int) { + mapSiteGuardMissing := make(map[primitiveSite][]int) + mapSiteReturn := make(map[primitiveSite][]int) + + for i, trigger := range triggers { + if p, ok := trigger.Producer.Annotation.(*annotation.GuardMissing); ok { + if o, ok := p.OldAnnotation.(*annotation.FuncReturn); ok && o.IsFromRichCheckEffectFunc { + site := e.primitive.site(o.UnderlyingSite(), p.Kind() == annotation.DeepConditional) + mapSiteGuardMissing[site] = append(mapSiteGuardMissing[site], i) + } + } + } + + for i, trigger := range triggers { + if c, ok := trigger.Consumer.Annotation.(*annotation.UseAsReturn); ok && c.IsTrackingAlwaysSafe { + site := e.primitive.site(c.UnderlyingSite(), c.Kind() == annotation.DeepConditional) + mapSiteReturn[site] = append(mapSiteReturn[site], i) + } + } + + return mapSiteGuardMissing, mapSiteReturn +} + // ObservePackage observes all the annotations and assertions computed locally about the current // package. The assertions are sorted based on whether they are already known to trigger without // reliance on annotation sites, such as `x` in `x = nil; x.f`, which will generate @@ -154,18 +180,61 @@ func (e *Engine) ObserveAnnotations(pkgAnnotations *annotation.ObservedMap, mode // observeImplication. Before all assertions are sorted and handled thus, the annotations read for // the package are iterated over and observed via calls to observeSiteExplanation as a BecauseAnnotation. func (e *Engine) ObservePackage(pkgFullTriggers []annotation.FullTrigger) { + // As Step 1, we do a pre-analysis of "guard missing" triggers to verify if their dereferences are always nil-safe, + // and hence can be deleted to not report a false positive error. Specifically, this analyis of "always safe" paths + // is focussed on the rich check effect functions, namely error returning functions and ok-returning functions. + // The process is to find all guard missing triggers reaching a function return site, and then check if all the return triggers + // to that function site are non-nil. If so, we can safely delete all the guard-missing triggers for this function site. + triggersToBeDeleted := make(map[int]bool) + mapSiteGuardMissing, mapSiteReturn := e.mapGuardMissingAndReturnToFuncSite(pkgFullTriggers) + for site, guardMissingIndices := range mapSiteGuardMissing { + if returnIndices, ok := mapSiteReturn[site]; ok { + // Check if all the return triggers to this function site are non-nil. + nonnilCnt := 0 + for _, index := range returnIndices { + returnTrigger := pkgFullTriggers[index] + if returnTrigger.Producer.Annotation.Kind() != annotation.Never { + // break early if we find a potentially nilable trigger + break + } + nonnilCnt++ + } + + if nonnilCnt == len(returnIndices) { + // If all return triggers are non-nil, then we can safely delete all the guard-missing triggers + // for this function site. + for _, index := range guardMissingIndices { + triggersToBeDeleted[index] = true + } + } + } + } + // Add all placeholder UseAsReturnForAlwaysSafePath triggers to triggersToBeDeleted + for _, indices := range mapSiteReturn { + for _, index := range indices { + triggersToBeDeleted[index] = true + } + } + + // Filter out the triggers that are to be deleted. + pkgFullTriggers = slices.DeleteFunc(pkgFullTriggers, func(t annotation.FullTrigger) bool { + index := slices.Index(pkgFullTriggers, t) + return triggersToBeDeleted[index] + }) + // Separate out triggers with UseAsNonErrorRetDependentOnErrorRetNilability consumer from other triggers. // This is needed since whether UseAsNonErrorRetDependentOnErrorRetNilability triggers should be fired // is dependent on their corresponding UseAsErrorRetWithNilabilityUnknown triggers. By this separation, // we can process all other triggers, including UseAsErrorRetWithNilabilityUnknown, first, and once // their nilability status is known, then filter out the unnecessary UseAsNonErrorRetDependentOnErrorRetNilability // triggers, and run the pkg inference process again only for the remainder triggers. - // Steps 1--3 below depict this approach in more detail. + // Steps 2--4 below depict this approach in more detail. var ( nonErrRetTriggers []annotation.FullTrigger // In most cases all triggers will be stored in otherTriggers, so we set a proper capacity. otherTriggers = make([]annotation.FullTrigger, 0, len(pkgFullTriggers)) ) + for _, t := range pkgFullTriggers { if _, ok := t.Consumer.Annotation.(*annotation.UseAsNonErrorRetDependentOnErrorRetNilability); ok { nonErrRetTriggers = append(nonErrRetTriggers, t) @@ -174,10 +243,10 @@ func (e *Engine) ObservePackage(pkgFullTriggers []annotation.FullTrigger) { } } - // Step 1: build the inference map based on `otherTriggers` and incorporate those assertions into the `inferredAnnotationMap` + // Step 2: build the inference map based on `otherTriggers` and incorporate those assertions into the `inferredAnnotationMap` e.buildPkgInferenceMap(otherTriggers) - // Step 2: run error return handling procedure to filter out redundant triggers based on the error contract, and + // Step 3: run error return handling procedure to filter out redundant triggers based on the error contract, and // keep only those UseAsNonErrorRetDependentOnErrorRetNilability triggers that are not deleted. // Call FilterTriggersForErrorReturn to filter triggers for error return handling -- inter-procedural and full-inference mode _, delTriggers := assertiontree.FilterTriggersForErrorReturn( @@ -193,22 +262,24 @@ func (e *Engine) ObservePackage(pkgFullTriggers []annotation.FullTrigger) { isDeep := kind == annotation.DeepConditional primitive := e.primitive.site(site, isDeep) if val, ok := e.inferredMap.Load(primitive); ok { - switch vType := val.(type) { - case *DeterminedVal: + if vType, ok := val.(*DeterminedVal); ok { if !vType.Bool.Val() { return assertiontree.ProducerIsNonNil } - case *UndeterminedVal: - // Consider the producer site as non-nil, if the determined value is non-nil, i.e., - // `!vType.Bool.Val()`, or the site is undetermined. Undetermined sites are assumed "non-nil" here based on the following: - // (a) the inference algorithm does not propagate non-nil values forward, and - // (b) the processing of the sites under question, error return sites, are allowed to be processed first in step 1 above - // - // This above assumption works favorably in most of the cases, except the one demonstrated in - // `testdata/errorreturn/inference/downstream.go`, for instance, where it leads to a false negative. - return assertiontree.ProducerIsNonNil + return assertiontree.ProducerIsNil } } + // We reach here if `primitive` site is + // - present in `inferredMap` but UndeterminedVal, or + // - not present in `inferredMap`, implying undetermined. + // + // At this point we consider undetermined sites producer site as non-nil, based on the following: + // (a) the inference algorithm does not propagate non-nil values forward + // (b) the processing of the sites under question (i.e., error return sites) are allowed to be processed first in step 1 above + // + // This above assumption works favorably in most of the cases, except the one demonstrated in + // `testdata/errorreturn/inference/downstream.go`, for instance, where it leads to a false negative. + return assertiontree.ProducerIsNonNil } // In all other cases, return ProducerNilabilityUnknown to indicate that all we @@ -220,7 +291,7 @@ func (e *Engine) ObservePackage(pkgFullTriggers []annotation.FullTrigger) { filteredTriggers := nonErrRetTriggers // Remove deleted triggers from nonErrRetTriggers (if needed). if len(delTriggers) != 0 { - filteredTriggers = make([]annotation.FullTrigger, 0, len(nonErrRetTriggers)-len(delTriggers)) + filteredTriggers = make([]annotation.FullTrigger, 0, len(nonErrRetTriggers)) for _, t := range nonErrRetTriggers { if !delTriggers[t] { filteredTriggers = append(filteredTriggers, t) @@ -228,7 +299,7 @@ func (e *Engine) ObservePackage(pkgFullTriggers []annotation.FullTrigger) { } } - // Step 3: run the inference building process for only the remaining UseAsNonErrorRetDependentOnErrorRetNilability triggers, and collect assertions + // Step 4: run the inference building process for only the remaining UseAsNonErrorRetDependentOnErrorRetNilability triggers, and collect assertions e.buildPkgInferenceMap(filteredTriggers) } diff --git a/nilaway_test.go b/nilaway_test.go index 9d20bc69..33522441 100644 --- a/nilaway_test.go +++ b/nilaway_test.go @@ -38,7 +38,7 @@ func TestNilAway(t *testing.T) { patterns []string }{ {name: "Inference", patterns: []string{"go.uber.org/inference"}}, - {name: "Contracts", patterns: []string{"go.uber.org/contracts", "go.uber.org/contracts/namedtypes"}}, + {name: "Contracts", patterns: []string{"go.uber.org/contracts", "go.uber.org/contracts/namedtypes", "go.uber.org/contracts/inference"}}, {name: "Testing", patterns: []string{"go.uber.org/testing"}}, {name: "ErrorReturn", patterns: []string{"go.uber.org/errorreturn", "go.uber.org/errorreturn/inference"}}, {name: "Maps", patterns: []string{"go.uber.org/maps"}}, @@ -66,6 +66,7 @@ func TestNilAway(t *testing.T) { {name: "Constants", patterns: []string{"go.uber.org/consts"}}, {name: "ErrorMessage", patterns: []string{"go.uber.org/errormessage", "go.uber.org/errormessage/inference"}}, {name: "LoopRange", patterns: []string{"go.uber.org/looprange"}}, + {name: "AbnormalFlow", patterns: []string{"go.uber.org/abnormalflow"}}, } for _, tt := range tests { diff --git a/testdata/integration/contracts/downstream/downstream.go b/testdata/integration/contracts/downstream/downstream.go index 11b089de..16b8beb9 100644 --- a/testdata/integration/contracts/downstream/downstream.go +++ b/testdata/integration/contracts/downstream/downstream.go @@ -28,7 +28,5 @@ func GiveUpstreamNil() { func GiveUpstreamNonnil() { a := 1 r := upstream.NonnilToNonnil(&a) - // TODO: FP: this should be safe due to the contract. We should remove this once contract - // support is fixed in NilAway. - print(*r) //want "result 0 of `NonnilToNonnil\(\)` dereferenced" -} \ No newline at end of file + print(*r) // Safe due to the contract! +} diff --git a/testdata/src/go.uber.org/abnormalflow/abnormalflow.go b/testdata/src/go.uber.org/abnormalflow/abnormalflow.go new file mode 100644 index 00000000..6d618a1a --- /dev/null +++ b/testdata/src/go.uber.org/abnormalflow/abnormalflow.go @@ -0,0 +1,260 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package abnormalflow checks code patterns with abnormal control flows (e.g., panic, log.Fatal +// etc.) that may lead to program terminations (such that any subsequent potential nil panics +// will not happen). +package abnormalflow + +import ( + "errors" + "log" + "os" + "runtime" + "testing" +) + +func testDirectDereference(msg string, t *testing.T, b *testing.B, f *testing.F, tb testing.TB) { + var nilable *int + switch msg { + case "print": + print("123") + print(*nilable) //want "unassigned variable `nilable` dereferenced" + case "panic": + panic("foo") + print(*nilable) + case "log.Fatal": + log.Fatal("foo") + print(*nilable) + case "log.Fatalf": + log.Fatalf("foo") + print(*nilable) + case "os.Exit": + os.Exit(1) + print(*nilable) + case "runtime.Goexit": + runtime.Goexit() + print(*nilable) + case "testing.T.Fatal": + t.Fatal("foo") + print(*nilable) + case "testing.T.Fatalf": + t.Fatalf("foo") + print(*nilable) + case "testing.T.SkipNow": + t.SkipNow() + print(*nilable) + case "testing.T.Skip": + t.Skip() + print(*nilable) + case "testing.T.Skipf": + t.Skipf("msg") + print(*nilable) + case "testing.B.Fatal": + b.Fatal("foo") + print(*nilable) + case "testing.B.Fatalf": + b.Fatalf("foo") + print(*nilable) + case "testing.B.SkipNow": + b.SkipNow() + print(*nilable) + case "testing.B.Skip": + b.Skip() + print(*nilable) + case "testing.B.Skipf": + b.Skipf("msg") + print(*nilable) + case "testing.TB.Fatal": + tb.Fatal("foo") + print(*nilable) + case "testing.TB.Fatalf": + tb.Fatalf("foo") + print(*nilable) + case "testing.TB.SkipNow": + tb.SkipNow() + print(*nilable) + case "testing.TB.Skip": + tb.Skip() + print(*nilable) + case "testing.TB.Skipf": + tb.Skipf("msg") + print(*nilable) + case "testing.F.Fatal": + f.Fatal("foo") + print(*nilable) + case "testing.F.Fatalf": + f.Fatalf("foo") + print(*nilable) + case "testing.F.SkipNow": + f.SkipNow() + print(*nilable) + case "testing.F.Skip": + f.Skip() + print(*nilable) + case "testing.F.Skipf": + f.Skipf("msg") + print(*nilable) + } +} + +func errReturn(a bool) (*int, error) { + i := 42 + if a { + return &i, nil + } + return nil, errors.New("some error") +} + +func testErrReturn(msg string, val bool, t *testing.T, b *testing.B, f *testing.F, tb testing.TB) { + ptr, err := errReturn(val) + switch msg { + case "print": + if err != nil { + print(err) + } + print(*ptr) //want "dereferenced" + case "print_and_return": + if err != nil { + print(err) + return + } + print(*ptr) + case "panic": + if err != nil { + panic(err) + } + print(*ptr) + case "log.Fatal": + if err != nil { + log.Fatal(err) + } + print(*ptr) + case "log.Fatalf": + if err != nil { + log.Fatalf("msg %s", err) + } + print(*ptr) + case "os.Exit": + if err != nil { + os.Exit(1) + } + print(*ptr) + case "runtime.Goexit": + if err != nil { + runtime.Goexit() + } + print(*ptr) + case "testing.T.Fatal": + if err != nil { + t.Fatal(err) + } + print(*ptr) + case "testing.T.Fatalf": + if err != nil { + t.Fatalf("msg %s", err) + } + print(*ptr) + case "testing.T.SkipNow": + if err != nil { + t.SkipNow() + } + print(*ptr) + case "testing.T.Skip": + if err != nil { + t.Skip(err) + } + print(*ptr) + case "testing.T.Skipf": + if err != nil { + t.Skipf("msg %s", err) + } + print(*ptr) + case "testing.B.Fatal": + if err != nil { + b.Fatal(err) + } + print(*ptr) + case "testing.B.Fatalf": + if err != nil { + b.Fatalf("msg %s", err) + } + print(*ptr) + case "testing.B.SkipNow": + if err != nil { + b.SkipNow() + } + print(*ptr) + case "testing.B.Skip": + if err != nil { + b.Skip(err) + } + print(*ptr) + case "testing.B.Skipf": + if err != nil { + b.Skipf("msg %s", err) + } + print(*ptr) + case "testing.F.Fatal": + if err != nil { + f.Fatal(err) + } + print(*ptr) + case "testing.F.Fatalf": + if err != nil { + f.Fatalf("msg %s", err) + } + print(*ptr) + case "testing.F.SkipNow": + if err != nil { + f.SkipNow() + } + print(*ptr) + case "testing.F.Skip": + if err != nil { + f.Skip(err) + } + print(*ptr) + case "testing.F.Skipf": + if err != nil { + f.Skipf("msg %s", err) + } + print(*ptr) + case "testing.TB.Fatal": + if err != nil { + tb.Fatal(err) + } + print(*ptr) + case "testing.TB.Fatalf": + if err != nil { + tb.Fatalf("msg %s", err) + } + print(*ptr) + case "testing.TB.SkipNow": + if err != nil { + tb.SkipNow() + } + print(*ptr) + case "testing.TB.Skip": + if err != nil { + tb.Skip(err) + } + print(*ptr) + case "testing.TB.Skipf": + if err != nil { + tb.Skipf("msg %s", err) + } + print(*ptr) + } +} diff --git a/testdata/src/go.uber.org/contracts/inference/userdefinedfunctions-with-inference.go b/testdata/src/go.uber.org/contracts/inference/userdefinedfunctions-with-inference.go new file mode 100644 index 00000000..8afe0833 --- /dev/null +++ b/testdata/src/go.uber.org/contracts/inference/userdefinedfunctions-with-inference.go @@ -0,0 +1,167 @@ +// Copyright (c) 2023 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package contracts: This file tests the contract of "ok" form for user defined and standard library functions in +// full inference mode. + +package inference + +var dummy bool + +const falseVal = false +const trueVal = true + +// ***** below tests check the handling for "always safe" cases and their variants ***** + +func retAlwaysNonnilPtrBool(i int) (*int, bool) { + switch i { + case 0: + return new(int), false + case 1: + return &i, trueVal + case 2: + return new(int), falseVal + } + return new(int), true +} + +func retAlwaysNilPtrBool(i int) (*int, bool) { + switch i { + case 0: + return nil, false + case 1: + return nil, trueVal + case 2: + return nil, falseVal + } + return nil, true +} + +func retSometimesNilPtrBool(i int) (*int, bool) { + switch i { + case 0: + return nil, false + case 1: + return nil, falseVal + case 2: + return new(int), trueVal + } + return new(int), true +} + +func testAlwaysSafe(i int) { + switch i { + // always safe + case 0: + x, _ := retAlwaysNonnilPtrBool(i) + print(*x) + case 1: + if x, ok := retAlwaysNonnilPtrBool(i); ok { + print(*x) + } + case 2: + if x, ok := retAlwaysNonnilPtrBool(i); ok { + print(*x) + } + case 3: + x, _ := retAlwaysNonnilPtrBool(i) + y, _ := retAlwaysNonnilPtrBool(i) + print(*x) + print(*y) + case 4: + x, okx := retAlwaysNonnilPtrBool(i) + y, oky := retAlwaysNonnilPtrBool(i) + + if oky { + print(*x) + } + if okx { + print(*y) + } + + // always unsafe + case 5: + x, _ := retAlwaysNilPtrBool(i) + print(*x) //want "dereferenced" + case 6: + if x, ok := retAlwaysNilPtrBool(i); ok { + print(*x) //want "dereferenced" + } + + // conditionally safe + case 7: + x, _ := retSometimesNilPtrBool(i) + print(*x) //want "dereferenced" + case 8: + if x, ok := retSometimesNilPtrBool(i); ok { + print(*x) + } + } +} + +// Test always safe through multiple hops. Currently, we support only immediate function call for "always safe" tracking. +// Hence, the below cases are expected to report errors. +// TODO: add support for multiple hops to address the false positives + +func m1() (*int, bool) { + return m2() +} + +func m2() (*int, bool) { + v, ok := m3() + if !ok { + // makes non-error return always non-nil + return new(int), false + } + y := *v + 1 + return &y, true +} + +func m3() (*int, bool) { + if dummy { + return nil, false + } + return new(int), true +} + +type S struct { + f *int +} + +func f1(i int) (*int, bool) { + switch i { + case 0: + // direct non-nil non-error return value + return new(int), false + case 1: + s := &S{f: new(int)} + // indirect non-nil non-error return value via a field read + return s.f, true + case 2: + } + // indirect non-nil non-error return value via a function return + return retAlwaysNonnilPtrBool(i) +} + +func testAlwaysSafeMultipleHops() { + // TODO: call to m1() should be reported as always safe. This is a false positive since currently we are limiting the + // "always safe" tracking to only immediate function call, not chained error returning function calls. + v1, _ := m1() + print(*v1) //want "dereferenced" + + // TODO: call to f1() should be reported as always safe. This is a false positive since currently we are limiting the + // analysis of "return statements" to only the directly determinable cases (e.g., new(int), &S{}, NegativeNilCheck), not through multiple hops. + v2, _ := f1(0) + print(*v2) //want "dereferenced" +} diff --git a/testdata/src/go.uber.org/errorreturn/inference/errorreturn-with-inference.go b/testdata/src/go.uber.org/errorreturn/inference/errorreturn-with-inference.go index f134fdb4..1442f058 100644 --- a/testdata/src/go.uber.org/errorreturn/inference/errorreturn-with-inference.go +++ b/testdata/src/go.uber.org/errorreturn/inference/errorreturn-with-inference.go @@ -38,7 +38,7 @@ func retNonNilErr2() error { return &myErr2{} } -// ***** the below test case checks error return via a function and assigned to a vairable ***** +// ***** the below test case checks error return via a function and assigned to a variable ***** func retPtrAndErr2(i int) (*int, error) { if dummy2 { return nil, retNonNilErr2() @@ -46,8 +46,21 @@ func retPtrAndErr2(i int) (*int, error) { return &i, retNilErr2() } -func testFuncRet2(i int) (*int, error) { +// same as retPtrAndErr2 but with the return statements swapped. This is to check that the order of return statements +// does not affect the error return analysis +func retPtrAndErr3() (*int, error) { + if dummy2 { + return new(int), retNilErr2() + } + return nil, retNonNilErr3() +} + +// duplicated from retNonNilErr2 to make a fresh instance of the function for supporting the testing of retPtrAndErr3 +func retNonNilErr3() error { + return &myErr2{} +} +func testFuncRet2(i int) (*int, error) { var errNil = retNilErr2() var errNonNil = retNonNilErr2() switch i { @@ -69,6 +82,8 @@ func testFuncRet2(i int) (*int, error) { return &i, errNonNil case 8: return &i, retNonNilErr2() + case 9: + return retPtrAndErr3() } return &i, nil } @@ -226,3 +241,147 @@ func testAliasedMixedReturns() { } } + +// ***** below tests check the handling for "always safe" cases and their variants ***** + +func retAlwaysNonnilPtrErr(i int) (*int, error) { + switch i { + case 0: + return new(int), &myErr2{} + case 1: + return &i, retNonNilErr2() + case 2: + return new(int), retNilErr2() + } + return new(int), nil +} + +func retAlwaysNilPtrErr(i int) (*int, error) { + switch i { + case 0: + return nil, &myErr2{} + case 1: + return nil, retNonNilErr2() + case 2: + return nil, retNilErr2() + } + return nil, nil +} + +func retSometimesNilPtrErr(i int) (*int, error) { + switch i { + case 0: + return nil, &myErr2{} + case 1: + return nil, retNonNilErr2() + case 2: + return new(int), retNilErr2() + } + return new(int), nil +} + +func testAlwaysSafe(i int) { + switch i { + // always safe + case 0: + x, _ := retAlwaysNonnilPtrErr(i) + print(*x) + case 1: + if x, err := retAlwaysNonnilPtrErr(i); err != nil { + print(*x) + } + case 2: + if x, err := retAlwaysNonnilPtrErr(i); err == nil { + print(*x) + } + case 3: + x, _ := retAlwaysNonnilPtrErr(i) + y, _ := retAlwaysNonnilPtrErr(i) + print(*x) + print(*y) + case 4: + x, errx := retAlwaysNonnilPtrErr(i) + y, erry := retAlwaysNonnilPtrErr(i) + + if erry == nil { + print(*x) + } + if errx == nil { + print(*y) + } + + // always unsafe + case 5: + x, _ := retAlwaysNilPtrErr(i) + print(*x) //want "dereferenced" + case 6: + if x, err := retAlwaysNilPtrErr(i); err == nil { + print(*x) //want "dereferenced" + } + + // conditionally safe + case 7: + x, _ := retSometimesNilPtrErr(i) + print(*x) //want "dereferenced" + case 8: + if x, err := retSometimesNilPtrErr(i); err == nil { + print(*x) + } + } +} + +// Test always safe through multiple hops. Currently, we support only immediate function call for "always safe" tracking. +// Hence, the below cases are expected to report errors. +// TODO: add support for multiple hops to address the false positives + +func m1() (*int, error) { + return m2() +} + +func m2() (*int, error) { + v, err := m3() + if err != nil { + // makes non-error return always non-nil + return new(int), err + } + y := *v + 1 + return &y, nil +} + +func m3() (*int, error) { + if dummy2 { + return nil, &myErr2{} + } + return new(int), nil +} + +type S struct { + f *int +} + +func f1(i int) (*int, error) { + switch i { + case 0: + // direct non-nil non-error return value + return new(int), &myErr2{} + case 1: + s := &S{f: new(int)} + // indirect non-nil non-error return value via a field read + return s.f, nil + case 2: + } + // indirect non-nil non-error return value via a function return + return retAlwaysNonnilPtrErr(i) +} + +func testAlwaysSafeMultipleHops() { + // TODO: call to m1() should be reported as always safe. This is a false positive since currently we are limiting the + // "always safe" tracking to only immediate function call, not chained error returning function calls. + v1, _ := m1() + print(*v1) //want "dereferenced" + + // TODO: call to f1() should be reported as always safe. This is a false positive since currently we are limiting the + // analysis of "return statements" to only the directly determinable cases (e.g., new(int), &S{}, NegativeNilCheck), not through multiple hops. + v2, _ := f1(0) + print(*v2) //want "dereferenced" +}