diff --git a/graphql.go b/graphql.go index 9ccc3592..d55d3293 100644 --- a/graphql.go +++ b/graphql.go @@ -65,6 +65,7 @@ type Schema struct { res *resolvable.Schema maxDepth int + complexityEstimators []validation.ComplexityEstimator maxParallelism int tracer trace.Tracer validationTracer trace.ValidationTracer @@ -100,6 +101,23 @@ func MaxDepth(n int) SchemaOpt { } } +// MaxQueryComplexity specifies the complexity of a query. +func MaxQueryComplexity(n int) SchemaOpt { + return ComplexityEstimator(validation.SimpleEstimator{n}) +} + +// MaxQueryRecursion specifies the recursion of a query. +func MaxQueryRecursion(n int) SchemaOpt { + return ComplexityEstimator(validation.RecursionEstimator{n}) +} + +// ComplexityEstimator Add estimator to make estimate max complexity queries. +func ComplexityEstimator(estimator validation.ComplexityEstimator) SchemaOpt { + return func(s *Schema) { + s.complexityEstimators = append(s.complexityEstimators, estimator) + } +} + // MaxParallelism specifies the maximum number of resolvers per request allowed to run in parallel. The default is 10. func MaxParallelism(n int) SchemaOpt { return func(s *Schema) { @@ -151,7 +169,7 @@ func (s *Schema) Validate(queryString string) []*errors.QueryError { return []*errors.QueryError{qErr} } - return validation.Validate(s.schema, doc, nil, s.maxDepth) + return validation.Validate(s.schema, doc, nil, s.maxDepth, s.complexityEstimators) } // Exec executes the given query with the schema's resolver. It panics if the schema was created @@ -171,7 +189,7 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str } validationFinish := s.validationTracer.TraceValidation() - errs := validation.Validate(s.schema, doc, variables, s.maxDepth) + errs := validation.Validate(s.schema, doc, variables, s.maxDepth, s.complexityEstimators) validationFinish(errs) if len(errs) != 0 { return &Response{Errors: errs} diff --git a/internal/validation/complexity.go b/internal/validation/complexity.go new file mode 100644 index 00000000..06463de4 --- /dev/null +++ b/internal/validation/complexity.go @@ -0,0 +1,136 @@ +package validation + +import ( + "github.com/graph-gophers/graphql-go/errors" + "github.com/graph-gophers/graphql-go/internal/query" + "github.com/graph-gophers/graphql-go/internal/schema" +) + +type ComplexityEstimator interface { + DoEstimate(c *opContext, sels []query.Selection) bool +} + +type SimpleEstimator struct { + MaxComplexity int +} + +func (e SimpleEstimator) DoEstimate(c *opContext, sels []query.Selection) bool { + if e.MaxComplexity == 0 { + return false + } + + complexity := e.doSimpleEstimate(c, sels) + if complexity > e.MaxComplexity { + return true + } + + return false +} + +func (e SimpleEstimator) doSimpleEstimate(c *opContext, sels []query.Selection) int { + complexity := 0 + + for _, sel := range sels { + var loc errors.Location + switch sel := sel.(type) { + case *query.Field: + loc = sel.Alias.Loc + complexity += e.doSimpleEstimate(c, sel.Selections) + 1 + case *query.InlineFragment: + loc = sel.Loc + complexity += e.doSimpleEstimate(c, sel.Selections) + case *query.FragmentSpread: + frag := c.doc.Fragments.Get(sel.Name.Name) + if frag == nil { + c.addErr(sel.Loc, "MaxComplexityEvaluationError", "Unknown fragment %q. Unable to evaluate complexity.", sel.Name.Name) + continue + } + loc = frag.Loc + complexity += e.doSimpleEstimate(c, frag.Selections) + } + + if complexity > e.MaxComplexity { + c.addErr(loc, "MaxComplexityExceeded", + "The query exceeds the maximum complexity of %d. Actual complexity is %d.", e.MaxComplexity, complexity) + + return complexity + } + } + + return complexity +} + +type RecursionEstimator struct { + MaxDepth int +} + +func (e RecursionEstimator) DoEstimate(c *opContext, sels []query.Selection) bool { + if e.MaxDepth == 0 { + return false + } + + return e.doRecursivelyVisitSelections(c, sels, map[string]int{}, getEntryPoint(c.schema, c.ops[0])) +} + +type visitedSels map[string]int + +func (s visitedSels) copy() visitedSels { + newSels := visitedSels{} + for index, value := range s { + newSels[index] = value + } + + return newSels +} + +func (e RecursionEstimator) doRecursivelyVisitSelections( + c *opContext, sels []query.Selection, visited visitedSels, t schema.NamedType) bool { + + fields := fields(t) + + exceeded := false + + for _, sel := range sels { + switch sel := sel.(type) { + case *query.Field: + fieldName := sel.Name.Name + switch fieldName { + case "__typename", "__schema", "__type": + continue + default: + if sel.Selections == nil { + continue + } + + if f := fields.Get(fieldName); f != nil { + v := visited.copy() + + if depth, ok := v[f.Type.String()]; ok { + v[f.Type.String()] = depth + 1 + } else { + v[f.Type.String()] = 1 + } + + currentDepth := v[f.Type.String()] + if currentDepth > e.MaxDepth { + c.addErr(sel.Alias.Loc, "MaxDepthRecursionExceeded", + "The query exceeds the maximum depth recursion of %d. Actual is %d.", + e.MaxDepth, currentDepth) + + return true + } + + exceeded = e.doRecursivelyVisitSelections(c, sel.Selections, v, unwrapType(f.Type)) + } + } + case *query.InlineFragment: + exceeded = e.doRecursivelyVisitSelections(c, sel.Selections, visited, unwrapType(t)) + case *query.FragmentSpread: + if frag := c.doc.Fragments.Get(sel.Name.Name); frag != nil { + exceeded = e.doRecursivelyVisitSelections(c, frag.Selections, visited, c.schema.Types[frag.On.Name]) + } + } + } + + return exceeded +} diff --git a/internal/validation/validate_max_complexity_test.go b/internal/validation/validate_max_complexity_test.go new file mode 100644 index 00000000..54938427 --- /dev/null +++ b/internal/validation/validate_max_complexity_test.go @@ -0,0 +1,674 @@ +package validation + +import ( + "testing" + + "github.com/graph-gophers/graphql-go/internal/query" + "github.com/graph-gophers/graphql-go/internal/schema" +) + +type maxComplexityTestCase struct { + name string + query string + estimator ComplexityEstimator + failure bool + expectedErrors []string +} + +func (tc maxComplexityTestCase) Run(t *testing.T, s *schema.Schema) { + t.Run(tc.name, func(t *testing.T) { + doc, qErr := query.Parse(tc.query) + if qErr != nil { + t.Fatal(qErr) + } + + errs := Validate(s, doc, nil, 0, []ComplexityEstimator{tc.estimator}) + if len(tc.expectedErrors) > 0 { + if len(errs) > 0 { + for _, expected := range tc.expectedErrors { + found := false + for _, err := range errs { + if err.Rule == expected { + found = true + break + } + } + if !found { + t.Errorf("expected error %v is missing", expected) + } + } + } else { + t.Errorf("expected errors [%v] are missing", tc.expectedErrors) + } + } + if (len(errs) > 0) != tc.failure { + t.Errorf("expected failure: %t, actual errors (%d): %v", tc.failure, len(errs), errs) + } + }) +} + +func TestMaxComplexity(t *testing.T) { + s := schema.New() + + err := s.Parse(simpleSchema, false) + if err != nil { + t.Fatal(err) + } + + for _, tc := range []maxComplexityTestCase{ + { + name: "off", + query: `query Okay { # complexity 0 + characters { # complexity 1 + id # complexity 2 + name # complexity 3 + friends { # complexity 4 + friends { # complexity 5 + friends { # complexity 6 + id # complexity 7 + name # complexity 8 + } + } + } + } + }`, + estimator: SimpleEstimator{0}, + }, + { + name: "maxComplexity-1", + query: `query Fine { # complexity 0 + characters { # complexity 1 + id # complexity 2 + name # complexity 3 + friends { # complexity 4 + id # complexity 5 + name # complexity 6 + } + } + }`, + estimator: SimpleEstimator{7}, + }, + { + name: "maxComplexity", + query: `query Equals { # complexity 0 + characters { # complexity 1 + id # complexity 2 + name # complexity 3 + friends { # complexity 4 + id # complexity 5 + name # complexity 6 + } + } + }`, + estimator: SimpleEstimator{6}, + }, + { + name: "maxComplexity+1", + query: `query Equals { # complexity 0 + characters { # complexity 1 + id # complexity 2 + name # complexity 3 + friends { # complexity 4 + id # complexity 5 + name # complexity 6 + } + } + }`, + failure: true, + estimator: SimpleEstimator{5}, + }, + } { + tc.Run(t, s) + } +} + +func TestMaxComplexityRecursion(t *testing.T) { + s := schema.New() + + err := s.Parse(simpleSchema, false) + if err != nil { + t.Fatal(err) + } + + for _, tc := range []maxComplexityTestCase{ + { + name: "off", + query: `query Fine { + characters { # complexity 1 + id + name + friends { # complexity 2 + friends { # complexity 3 + friends { # complexity 4 + id + name + } + } + } + } + }`, + estimator: RecursionEstimator{0}, + }, + { + name: "maxComplexity", + query: `query Fine { + characters { # complexity 1 + id + name + friends { # complexity 2 + friends { # complexity 3 + friends { # complexity 4 + id + name + } + } + } + } + }`, + estimator: RecursionEstimator{4}, + }, + { + name: "maxComplexity + 1", + query: `query Fine { + characters { # complexity 1 + id + name + friends { # complexity 2 + friends { # complexity 3 + friends { # complexity 4 + id + name + } + } + } + } + }`, + estimator: RecursionEstimator{5}, + }, + { + name: "number aliases greater then max complexity", + query: `query Fine { + characters { # complexity 1 + id + name + friends { # complexity 2 + friends { # complexity 3 + id + name + } + } + favorite: friends { # complexity 2 + friends { # complexity 3 + id + name + } + } + colleagues: friends { # complexity 2 + friends { # complexity 3 + id + name + } + } + works: friends { # complexity 2 + friends { # complexity 3 + id + name + } + } + } + }`, + estimator: RecursionEstimator{3}, + }, + { + name: "maxComplexity - 1", + query: `query Fine { + characters { # complexity 1 + id + name + friends { # complexity 2 + friends { # complexity 3 + friends { # complexity 4 + id + name + } + } + } + } + }`, + failure: true, + estimator: RecursionEstimator{3}, + }, + } { + tc.Run(t, s) + } +} + +func TestMaxComplexityInlineFragments(t *testing.T) { + s := schema.New() + + err := s.Parse(interfaceSimple, false) + if err != nil { + t.Fatal(err) + } + + for _, tc := range []maxComplexityTestCase{ + { + name: "maxComplexity-1", + query: `query { # complexity 0 + characters { # complexity 1 + name # complexity 2 + ... on Human { # complexity 3 + totalCredits # complexity 4 + } + } + }`, + estimator: SimpleEstimator{5}, + }, + { + name: "maxComplexity", + query: `query { # complexity 0 + characters { # complexity 1 + ... on Droid { # complexity 2 + primaryFunction # complexity 3 + } + } + }`, + estimator: SimpleEstimator{3}, + }, + { + name: "maxComplexity+1", + query: `query { # complexity 0 + characters { # complexity 1 + name # complexity 2 + ... on Human { # complexity 2 + totalCredits # complexity 3 + } + } + }`, + failure: true, + estimator: SimpleEstimator{2}, + }, + } { + tc.Run(t, s) + } +} +func TestMaxComplexityRecursionInlineFragments(t *testing.T) { + s := schema.New() + + err := s.Parse(interfaceSimple, false) + if err != nil { + t.Fatal(err) + } + + for _, tc := range []maxComplexityTestCase{ + { + name: "maxComplexity-1", + query: `query { + characters { # depth 1 + name + ... on Human { + totalCredits + friends { # depth 2 + name + friends { # depth 3 + name + } + } + } + } + }`, + estimator: RecursionEstimator{4}, + }, + { + name: "maxComplexity", + query: `query { + characters { # depth 1 + name + ... on Human { + totalCredits + friends { # depth 2 + name + friends { # depth 3 + name + } + } + } + } + }`, + estimator: RecursionEstimator{3}, + }, + { + name: "maxComplexity + 1", + query: `query { + characters { # depth 1 + name + ... on Human { + totalCredits + friends { # depth 2 + name + friends { # depth 3 + name + } + } + } + } + }`, + failure: true, + estimator: RecursionEstimator{2}, + }, + } { + tc.Run(t, s) + } +} + +func TestMaxComplexityFragmentSpreads(t *testing.T) { + s := schema.New() + + err := s.Parse(interfaceSimple, false) + if err != nil { + t.Fatal(err) + } + + for _, tc := range []maxComplexityTestCase{ + { + name: "maxComplexity-1", + query: `fragment friend on Character { + id # complexity 7 + name # complexity 8 + friends { # complexity 9 + name # complexity 10 + } + } + + query { # complexity 0 + characters { # complexity 1 + id # complexity 2 + name # complexity 3 + friends { # complexity 4 + friends { # complexity 5 + friends { # complexity 6 + ...friend # complexity 6 + } + } + } + } + }`, + estimator: SimpleEstimator{11}, + }, + { + name: "maxComplexity", + query: `fragment friend on Character { + id # complexity 7 + name # complexity 8 + } + query { # depth 0 + characters { # depth 1 + id # depth 2 + name # depth 3 + friends { # depth 4 + friends { # depth 5 + friends { # depth 6 + ...friend # depth 6 + } + } + } + } + }`, + estimator: SimpleEstimator{8}, + }, + { + name: "maxComplexity+1", + query: `fragment friend on Character { + id # complexity 8 + name # complexity 9 + friends { # complexity 10 + name # complexity 11 + } + } + query { # depth 0 + characters { # depth 1 + id # depth 2 + name # depth 3 + friends { # depth 4 + friends { # depth 5 + friends { # depth 6 + friends { # depth 7 + ...friend # depth 7 + } + } + } + } + } + }`, + failure: true, + estimator: SimpleEstimator{10}, + }, + } { + tc.Run(t, s) + } +} + +func TestMaxComplexityRecursionFragmentSpreads(t *testing.T) { + s := schema.New() + + err := s.Parse(interfaceSimple, false) + if err != nil { + t.Fatal(err) + } + + for _, tc := range []maxComplexityTestCase{ + { + name: "maxComplexity-1", + query: `fragment friend on Character { + id + name + friends { # complexity 5 + name + } + } + + query { # + characters { # complexity 1 + id # + name # + friends { # complexity 2 + friends { # complexity 3 + friends { # complexity 4 + ...friend # + } + } + } + } + }`, + estimator: RecursionEstimator{6}, + }, + { + name: "maxComplexity", + query: `fragment friend on Character { + id + name + } + query { # + characters { # depth 1 + id # + name # + friends { # depth 2 + friends { # depth 3 + friends { # depth 4 + ...friend # + } + } + } + } + }`, + estimator: RecursionEstimator{4}, + }, + { + name: "maxComplexity+1", + query: `fragment friend on Character { + id # + name # + friends { # depth 6 + name # + } + } + query { # + characters { # depth 1 + id # + name # + friends { # depth 2 + friends { # depth 3 + friends { # depth 4 + friends { # depth 5 + ...friend # + } + } + } + } + } + }`, + failure: true, + //expectedErrors: []string{"MaxComplexityExceeded"}, + estimator: RecursionEstimator{5}, + }, + } { + tc.Run(t, s) + } +} + +func TestMaxComplexityUnknownFragmentSpreads(t *testing.T) { + s := schema.New() + + err := s.Parse(interfaceSimple, false) + if err != nil { + t.Fatal(err) + } + + for _, tc := range []maxComplexityTestCase{ + { + name: "maxComplexityUnknownFragment", + query: `query { # complexity 0 + characters { # complexity 1 + id # complexity 2 + name # complexity 3 + friends { # complexity 4 + friends { # complexity 5 + friends { # complexity 6 + friends { # complexity 7 + ...unknownFragment # complexity 0 + } + } + } + } + } + }`, + estimator: SimpleEstimator{6}, + failure: true, + expectedErrors: []string{"MaxComplexityEvaluationError"}, + }, + } { + tc.Run(t, s) + } +} + +func TestMaxComplexityValidation(t *testing.T) { + s := schema.New() + + err := s.Parse(interfaceSimple, false) + if err != nil { + t.Fatal(err) + } + + for _, tc := range []struct { + name string + query string + estimator ComplexityEstimator + expected bool + }{ + { + name: "off", + query: `query Fine { # complexity 0 + characters { # complexity 1 + id # complexity 2 + name # complexity 3 + friends { # complexity 4 + id # complexity 5 + name # complexity 6 + } + } + }`, + estimator: SimpleEstimator{}, + }, + { + name: "fields", + query: `query Fine { # complexity 0 + characters { # complexity 1 + id # complexity 2 + name # complexity 3 + friends { # complexity 4 + id # complexity 5 + name # complexity 6 + } + } + }`, + expected: true, + estimator: SimpleEstimator{5}, + }, + { + name: "fragment", + query: `fragment friend on Character { + id # complexity 8 + name # complexity 9 + friends { # complexity 10 + name # complexity 11 + } + } + query { # complexity 0 + characters { # complexity 1 + id # complexity 2 + name # complexity 3 + friends { # complexity 4 + friends { # complexity 5 + friends { # complexity 6 + friends { # complexity 7 + ...friend # complexity 7 + } + } + } + } + } + }`, + expected: true, + estimator: SimpleEstimator{10}, + }, + { + name: "inlinefragment", + query: `query { # complexity 0 + characters { # complexity 1 + ... on Droid { # complexity 1 + primaryFunction # complexity 2 + } + } + }`, + expected: true, + estimator: SimpleEstimator{1}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + doc, err := query.Parse(tc.query) + if err != nil { + t.Fatal(err) + } + + context := newContext(s, doc, 0, []ComplexityEstimator{tc.estimator}) + op := doc.Operations[0] + + opc := &opContext{context: context, ops: doc.Operations} + + actual := validateMaxComplexity(opc, op.Selections) + if actual != tc.expected { + t.Errorf("expected %t, actual %t", tc.expected, actual) + } + }) + } +} diff --git a/internal/validation/validate_max_depth_test.go b/internal/validation/validate_max_depth_test.go index abc337cb..7feaccb5 100644 --- a/internal/validation/validate_max_depth_test.go +++ b/internal/validation/validate_max_depth_test.go @@ -77,7 +77,7 @@ func (tc maxDepthTestCase) Run(t *testing.T, s *schema.Schema) { t.Fatal(qErr) } - errs := Validate(s, doc, nil, tc.depth) + errs := Validate(s, doc, nil, tc.depth, []ComplexityEstimator{}) if len(tc.expectedErrors) > 0 { if len(errs) > 0 { for _, expected := range tc.expectedErrors { @@ -435,7 +435,7 @@ func TestMaxDepthValidation(t *testing.T) { t.Fatal(err) } - context := newContext(s, doc, tc.maxDepth) + context := newContext(s, doc, tc.maxDepth, []ComplexityEstimator{}) op := doc.Operations[0] opc := &opContext{context: context, ops: doc.Operations} diff --git a/internal/validation/validation.go b/internal/validation/validation.go index c8be7354..938930fc 100644 --- a/internal/validation/validation.go +++ b/internal/validation/validation.go @@ -32,6 +32,7 @@ type context struct { fieldMap map[*query.Field]fieldInfo overlapValidated map[selectionPair]struct{} maxDepth int + estimators []ComplexityEstimator } func (c *context) addErr(loc errors.Location, rule string, format string, a ...interface{}) { @@ -51,7 +52,7 @@ type opContext struct { ops []*query.Operation } -func newContext(s *schema.Schema, doc *query.Document, maxDepth int) *context { +func newContext(s *schema.Schema, doc *query.Document, maxDepth int, estimators []ComplexityEstimator) *context { return &context{ schema: s, doc: doc, @@ -60,11 +61,12 @@ func newContext(s *schema.Schema, doc *query.Document, maxDepth int) *context { fieldMap: make(map[*query.Field]fieldInfo), overlapValidated: make(map[selectionPair]struct{}), maxDepth: maxDepth, + estimators: estimators, } } -func Validate(s *schema.Schema, doc *query.Document, variables map[string]interface{}, maxDepth int) []*errors.QueryError { - c := newContext(s, doc, maxDepth) +func Validate(s *schema.Schema, doc *query.Document, variables map[string]interface{}, maxDepth int, estimators []ComplexityEstimator) []*errors.QueryError { + c := newContext(s, doc, maxDepth, estimators) opNames := make(nameSet) fragUsedBy := make(map[*query.FragmentDecl][]*query.Operation) @@ -72,9 +74,12 @@ func Validate(s *schema.Schema, doc *query.Document, variables map[string]interf c.usedVars[op] = make(varSet) opc := &opContext{c, []*query.Operation{op}} - // Check if max depth is exceeded, if it's set. If max depth is exceeded, + var entryPoint = getEntryPoint(s, op) + + // Check if max depth or complexity is exceeded, if it's set. If exceeded, // don't continue to validate the document and exit early. - if validateMaxDepth(opc, op.Selections, 1) { + if validateMaxDepth(opc, op.Selections, 1) || + validateMaxComplexity(opc, op.Selections) { return c.errs } @@ -112,18 +117,6 @@ func Validate(s *schema.Schema, doc *query.Document, variables map[string]interf } } - var entryPoint schema.NamedType - switch op.Type { - case query.Query: - entryPoint = s.EntryPoints["query"] - case query.Mutation: - entryPoint = s.EntryPoints["mutation"] - case query.Subscription: - entryPoint = s.EntryPoints["subscription"] - default: - panic("unreachable") - } - validateSelectionSet(opc, op.Selections, entryPoint) fragUsed := make(map[*query.FragmentDecl]struct{}) @@ -231,6 +224,16 @@ func validateValue(c *opContext, v *common.InputValue, val interface{}, t common } } +func validateMaxComplexity(c *opContext, sels []query.Selection) bool { + for _, estimator := range c.estimators { + if complexityExceeded := estimator.DoEstimate(c, sels); complexityExceeded { + return complexityExceeded + } + } + + return false +} + // validates the query doesn't go deeper than maxDepth (if set). Returns whether // or not query validated max depth to avoid excessive recursion. func validateMaxDepth(c *opContext, sels []query.Selection, depth int) bool { @@ -597,6 +600,22 @@ func argumentsConflict(a, b common.ArgumentList) bool { return false } +func getEntryPoint(s *schema.Schema, op *query.Operation) schema.NamedType { + var entryPoint schema.NamedType + switch op.Type { + case query.Query: + entryPoint = s.EntryPoints["query"] + case query.Mutation: + entryPoint = s.EntryPoints["mutation"] + case query.Subscription: + entryPoint = s.EntryPoints["subscription"] + default: + panic("unreachable") + } + + return entryPoint +} + func fields(t common.Type) schema.FieldList { switch t := t.(type) { case *schema.Object: @@ -849,7 +868,7 @@ func validateBasicLit(v *common.BasicLit, t common.Type) bool { case "ID": return v.Type == scanner.Int || v.Type == scanner.String default: - //TODO: Type-check against expected type by Unmarshalling + // TODO: Type-check against expected type by Unmarshalling return true } diff --git a/internal/validation/validation_test.go b/internal/validation/validation_test.go index e287a526..50150f8d 100644 --- a/internal/validation/validation_test.go +++ b/internal/validation/validation_test.go @@ -51,7 +51,7 @@ func TestValidate(t *testing.T) { if err != nil { t.Fatal(err) } - errs := validation.Validate(schemas[test.Schema], d, test.Vars, 0) + errs := validation.Validate(schemas[test.Schema], d, test.Vars, 0, []validation.ComplexityEstimator{}) got := []*errors.QueryError{} for _, err := range errs { if err.Rule == test.Rule { diff --git a/subscriptions.go b/subscriptions.go index 4199c06d..c5d4bca9 100644 --- a/subscriptions.go +++ b/subscriptions.go @@ -37,7 +37,7 @@ func (s *Schema) subscribe(ctx context.Context, queryString string, operationNam } validationFinish := s.validationTracer.TraceValidation() - errs := validation.Validate(s.schema, doc, variables, s.maxDepth) + errs := validation.Validate(s.schema, doc, variables, s.maxDepth, s.complexityEstimators) validationFinish(errs) if len(errs) != 0 { return sendAndReturnClosed(&Response{Errors: errs})