Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement query complexity analysis #324 #401

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ type Schema struct {
res *resolvable.Schema

maxDepth int
complexityEstimators []validation.ComplexityEstimator
maxParallelism int
tracer trace.Tracer
validationTracer trace.ValidationTracer
Expand Down Expand Up @@ -100,6 +101,23 @@ func MaxDepth(n int) SchemaOpt {
}
}

// MaxQueryComplexity specifies the complexity of a query.
func MaxQueryComplexity(n int) SchemaOpt {
return ComplexityEstimator(validation.SimpleEstimator{n})
}

// MaxQueryRecursion specifies the recursion of a query.
func MaxQueryRecursion(n int) SchemaOpt {
return ComplexityEstimator(validation.RecursionEstimator{n})
}

// ComplexityEstimator Add estimator to make estimate max complexity queries.
func ComplexityEstimator(estimator validation.ComplexityEstimator) SchemaOpt {
return func(s *Schema) {
s.complexityEstimators = append(s.complexityEstimators, estimator)
}
}

// MaxParallelism specifies the maximum number of resolvers per request allowed to run in parallel. The default is 10.
func MaxParallelism(n int) SchemaOpt {
return func(s *Schema) {
Expand Down Expand Up @@ -151,7 +169,7 @@ func (s *Schema) Validate(queryString string) []*errors.QueryError {
return []*errors.QueryError{qErr}
}

return validation.Validate(s.schema, doc, nil, s.maxDepth)
return validation.Validate(s.schema, doc, nil, s.maxDepth, s.complexityEstimators)
}

// Exec executes the given query with the schema's resolver. It panics if the schema was created
Expand All @@ -171,7 +189,7 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str
}

validationFinish := s.validationTracer.TraceValidation()
errs := validation.Validate(s.schema, doc, variables, s.maxDepth)
errs := validation.Validate(s.schema, doc, variables, s.maxDepth, s.complexityEstimators)
validationFinish(errs)
if len(errs) != 0 {
return &Response{Errors: errs}
Expand Down
136 changes: 136 additions & 0 deletions internal/validation/complexity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package validation

import (
"github.com/graph-gophers/graphql-go/errors"
"github.com/graph-gophers/graphql-go/internal/query"
"github.com/graph-gophers/graphql-go/internal/schema"
)

type ComplexityEstimator interface {
DoEstimate(c *opContext, sels []query.Selection) bool
}

type SimpleEstimator struct {
MaxComplexity int
}

func (e SimpleEstimator) DoEstimate(c *opContext, sels []query.Selection) bool {
if e.MaxComplexity == 0 {
return false
}

complexity := e.doSimpleEstimate(c, sels)
if complexity > e.MaxComplexity {
return true
}

return false
}

func (e SimpleEstimator) doSimpleEstimate(c *opContext, sels []query.Selection) int {
complexity := 0

for _, sel := range sels {
var loc errors.Location
switch sel := sel.(type) {
case *query.Field:
loc = sel.Alias.Loc
complexity += e.doSimpleEstimate(c, sel.Selections) + 1
case *query.InlineFragment:
loc = sel.Loc
complexity += e.doSimpleEstimate(c, sel.Selections)
case *query.FragmentSpread:
frag := c.doc.Fragments.Get(sel.Name.Name)
if frag == nil {
c.addErr(sel.Loc, "MaxComplexityEvaluationError", "Unknown fragment %q. Unable to evaluate complexity.", sel.Name.Name)
continue
}
loc = frag.Loc
complexity += e.doSimpleEstimate(c, frag.Selections)
}

if complexity > e.MaxComplexity {
c.addErr(loc, "MaxComplexityExceeded",
"The query exceeds the maximum complexity of %d. Actual complexity is %d.", e.MaxComplexity, complexity)

return complexity
}
}

return complexity
}

type RecursionEstimator struct {
MaxDepth int
}

func (e RecursionEstimator) DoEstimate(c *opContext, sels []query.Selection) bool {
if e.MaxDepth == 0 {
return false
}

return e.doRecursivelyVisitSelections(c, sels, map[string]int{}, getEntryPoint(c.schema, c.ops[0]))
}

type visitedSels map[string]int

func (s visitedSels) copy() visitedSels {
newSels := visitedSels{}
for index, value := range s {
newSels[index] = value
}

return newSels
}

func (e RecursionEstimator) doRecursivelyVisitSelections(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this different than the normal MaxDepth check?

c *opContext, sels []query.Selection, visited visitedSels, t schema.NamedType) bool {

fields := fields(t)

exceeded := false

for _, sel := range sels {
switch sel := sel.(type) {
case *query.Field:
fieldName := sel.Name.Name
switch fieldName {
case "__typename", "__schema", "__type":
continue
default:
if sel.Selections == nil {
continue
}

if f := fields.Get(fieldName); f != nil {
v := visited.copy()

if depth, ok := v[f.Type.String()]; ok {
v[f.Type.String()] = depth + 1
} else {
v[f.Type.String()] = 1
}

currentDepth := v[f.Type.String()]
if currentDepth > e.MaxDepth {
c.addErr(sel.Alias.Loc, "MaxDepthRecursionExceeded",
"The query exceeds the maximum depth recursion of %d. Actual is %d.",
e.MaxDepth, currentDepth)

return true
}

exceeded = e.doRecursivelyVisitSelections(c, sel.Selections, v, unwrapType(f.Type))
}
}
case *query.InlineFragment:
exceeded = e.doRecursivelyVisitSelections(c, sel.Selections, visited, unwrapType(t))
case *query.FragmentSpread:
if frag := c.doc.Fragments.Get(sel.Name.Name); frag != nil {
exceeded = e.doRecursivelyVisitSelections(c, frag.Selections, visited, c.schema.Types[frag.On.Name])
}
}
}

return exceeded
}
Loading