Skip to content

Commit

Permalink
internal/astutil/cursor: add Cursor.Child(Node) Cursor
Browse files Browse the repository at this point in the history
This method returns the cursor for a direct child,
more efficiently than FindNode.
Also, add edge.Kind.Get method.

+ tests

Change-Id: I1176ac55713ef0c06b02a1e3a9c64f530caa9a09
Reviewed-on: https://go-review.googlesource.com/c/tools/+/642936
LUCI-TryBot-Result: Go LUCI <[email protected]>
Reviewed-by: Robert Findley <[email protected]>
Commit-Queue: Alan Donovan <[email protected]>
Auto-Submit: Alan Donovan <[email protected]>
  • Loading branch information
adonovan authored and gopherbot committed Feb 4, 2025
1 parent f912a4f commit e9f7be9
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 4 deletions.
29 changes: 29 additions & 0 deletions internal/astutil/cursor/cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package cursor

import (
"fmt"
"go/ast"
"go/token"
"iter"
Expand Down Expand Up @@ -227,6 +228,34 @@ func (c Cursor) Edge() (edge.Kind, int) {
return unpackEdgeKindAndIndex(events[pop].parent)
}

// Child returns the cursor for n, which must be a direct child of c's Node.
//
// Child must not be called on the Root node (whose [Cursor.Node] returns nil).
func (c Cursor) Child(n ast.Node) Cursor {
if c.index < 0 {
panic("Cursor.Child called on Root node")
}

if false {
// reference implementation
for child := range c.Children() {
if child.Node() == n {
return child
}
}

} else {
// optimized implementation
events := c.events()
for i := c.index + 1; events[i].index > i; i = events[i].index + 1 {
if events[i].node == n {
return Cursor{c.in, i}
}
}
}
panic(fmt.Sprintf("Child(%T): not a child of %v", n, c))
}

// NextSibling returns the cursor for the next sibling node in the same list
// (for example, of files, decls, specs, statements, fields, or expressions) as
// the current node. It returns (zero, false) if the node is the last node in
Expand Down
11 changes: 11 additions & 0 deletions internal/astutil/cursor/cursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,12 @@ func TestCursor_Edge(t *testing.T) {
e.NodeType(), parent.Node())
}

// Check consistency of c.Edge.Get(c.Parent().Node()) == c.Node().
if got := e.Get(parent.Node(), idx); got != cur.Node() {
t.Errorf("cur=%v@%s: %s.Get(cur.Parent().Node(), %d) = %T@%s, want cur.Node()",
cur, netFset.Position(cur.Node().Pos()), e, idx, got, netFset.Position(got.Pos()))
}

// Check that reflection on the parent finds the current node.
fv := reflect.ValueOf(parent.Node()).Elem().FieldByName(e.FieldName())
if idx >= 0 {
Expand All @@ -373,6 +379,11 @@ func TestCursor_Edge(t *testing.T) {
t.Errorf("%v.Edge = (%v, %d); FieldName/Index reflection gave %T@%s, not original node",
cur, e, idx, got, netFset.Position(got.Pos()))
}

// Check that Cursor.Child is the reverse of Parent.
if cur.Parent().Child(cur.Node()) != cur {
t.Errorf("Cursor.Parent.Child = %v, want %v", cur.Parent().Child(cur.Node()), cur)
}
}
}

Expand Down
25 changes: 21 additions & 4 deletions internal/astutil/edge/edge.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,34 @@ func (k Kind) String() string {
return "<invalid>"
}
info := fieldInfos[k]
return fmt.Sprintf("%v.%s", info.nodeType.Elem().Name(), info.fieldName)
return fmt.Sprintf("%v.%s", info.nodeType.Elem().Name(), info.name)
}

// NodeType returns the pointer-to-struct type of the ast.Node implementation.
func (k Kind) NodeType() reflect.Type { return fieldInfos[k].nodeType }

// FieldName returns the name of the field.
func (k Kind) FieldName() string { return fieldInfos[k].fieldName }
func (k Kind) FieldName() string { return fieldInfos[k].name }

// FieldType returns the declared type of the field.
func (k Kind) FieldType() reflect.Type { return fieldInfos[k].fieldType }

// Get returns the direct child of n identified by (k, idx).
// n's type must match k.NodeType().
// idx must be a valid slice index, or -1 for a non-slice.
func (k Kind) Get(n ast.Node, idx int) ast.Node {
if k.NodeType() != reflect.TypeOf(n) {
panic(fmt.Sprintf("%v.Get(%T): invalid node type", k, n))
}
v := reflect.ValueOf(n).Elem().Field(fieldInfos[k].index)
if idx != -1 {
v = v.Index(idx) // asserts valid index
} else {
// (The type assertion below asserts that v is not a slice.)
}
return v.Interface().(ast.Node) // may be nil
}

const (
Invalid Kind = iota // for nodes at the root of the traversal

Expand Down Expand Up @@ -156,7 +172,8 @@ var _ = [1 << 7]struct{}{}[maxKind]

type fieldInfo struct {
nodeType reflect.Type // pointer-to-struct type of ast.Node implementation
fieldName string
name string
index int
fieldType reflect.Type
}

Expand All @@ -166,7 +183,7 @@ func info[N ast.Node](fieldName string) fieldInfo {
if !ok {
panic(fieldName)
}
return fieldInfo{nodePtrType, fieldName, f.Type}
return fieldInfo{nodePtrType, fieldName, f.Index[0], f.Type}
}

var fieldInfos = [...]fieldInfo{
Expand Down

0 comments on commit e9f7be9

Please sign in to comment.