From e9f7be9f13468661413bba805228434c534e6dbc Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 15 Jan 2025 12:29:12 -0500 Subject: [PATCH] internal/astutil/cursor: add Cursor.Child(Node) Cursor This method returns the cursor for a direct child, more efficiently than FindNode. Also, add edge.Kind.Get method. + tests Change-Id: I1176ac55713ef0c06b02a1e3a9c64f530caa9a09 Reviewed-on: https://go-review.googlesource.com/c/tools/+/642936 LUCI-TryBot-Result: Go LUCI Reviewed-by: Robert Findley Commit-Queue: Alan Donovan Auto-Submit: Alan Donovan --- internal/astutil/cursor/cursor.go | 29 ++++++++++++++++++++++++++ internal/astutil/cursor/cursor_test.go | 11 ++++++++++ internal/astutil/edge/edge.go | 25 ++++++++++++++++++---- 3 files changed, 61 insertions(+), 4 deletions(-) diff --git a/internal/astutil/cursor/cursor.go b/internal/astutil/cursor/cursor.go index 38a35f64ce0..1052f65acfb 100644 --- a/internal/astutil/cursor/cursor.go +++ b/internal/astutil/cursor/cursor.go @@ -15,6 +15,7 @@ package cursor import ( + "fmt" "go/ast" "go/token" "iter" @@ -227,6 +228,34 @@ func (c Cursor) Edge() (edge.Kind, int) { return unpackEdgeKindAndIndex(events[pop].parent) } +// Child returns the cursor for n, which must be a direct child of c's Node. +// +// Child must not be called on the Root node (whose [Cursor.Node] returns nil). +func (c Cursor) Child(n ast.Node) Cursor { + if c.index < 0 { + panic("Cursor.Child called on Root node") + } + + if false { + // reference implementation + for child := range c.Children() { + if child.Node() == n { + return child + } + } + + } else { + // optimized implementation + events := c.events() + for i := c.index + 1; events[i].index > i; i = events[i].index + 1 { + if events[i].node == n { + return Cursor{c.in, i} + } + } + } + panic(fmt.Sprintf("Child(%T): not a child of %v", n, c)) +} + // NextSibling returns the cursor for the next sibling node in the same list // (for example, of files, decls, specs, statements, fields, or expressions) as // the current node. It returns (zero, false) if the node is the last node in diff --git a/internal/astutil/cursor/cursor_test.go b/internal/astutil/cursor/cursor_test.go index e04f8c24b89..01f791f2833 100644 --- a/internal/astutil/cursor/cursor_test.go +++ b/internal/astutil/cursor/cursor_test.go @@ -360,6 +360,12 @@ func TestCursor_Edge(t *testing.T) { e.NodeType(), parent.Node()) } + // Check consistency of c.Edge.Get(c.Parent().Node()) == c.Node(). + if got := e.Get(parent.Node(), idx); got != cur.Node() { + t.Errorf("cur=%v@%s: %s.Get(cur.Parent().Node(), %d) = %T@%s, want cur.Node()", + cur, netFset.Position(cur.Node().Pos()), e, idx, got, netFset.Position(got.Pos())) + } + // Check that reflection on the parent finds the current node. fv := reflect.ValueOf(parent.Node()).Elem().FieldByName(e.FieldName()) if idx >= 0 { @@ -373,6 +379,11 @@ func TestCursor_Edge(t *testing.T) { t.Errorf("%v.Edge = (%v, %d); FieldName/Index reflection gave %T@%s, not original node", cur, e, idx, got, netFset.Position(got.Pos())) } + + // Check that Cursor.Child is the reverse of Parent. + if cur.Parent().Child(cur.Node()) != cur { + t.Errorf("Cursor.Parent.Child = %v, want %v", cur.Parent().Child(cur.Node()), cur) + } } } diff --git a/internal/astutil/edge/edge.go b/internal/astutil/edge/edge.go index bf945a8f632..4f6ccfd6e5e 100644 --- a/internal/astutil/edge/edge.go +++ b/internal/astutil/edge/edge.go @@ -21,18 +21,34 @@ func (k Kind) String() string { return "" } info := fieldInfos[k] - return fmt.Sprintf("%v.%s", info.nodeType.Elem().Name(), info.fieldName) + return fmt.Sprintf("%v.%s", info.nodeType.Elem().Name(), info.name) } // NodeType returns the pointer-to-struct type of the ast.Node implementation. func (k Kind) NodeType() reflect.Type { return fieldInfos[k].nodeType } // FieldName returns the name of the field. -func (k Kind) FieldName() string { return fieldInfos[k].fieldName } +func (k Kind) FieldName() string { return fieldInfos[k].name } // FieldType returns the declared type of the field. func (k Kind) FieldType() reflect.Type { return fieldInfos[k].fieldType } +// Get returns the direct child of n identified by (k, idx). +// n's type must match k.NodeType(). +// idx must be a valid slice index, or -1 for a non-slice. +func (k Kind) Get(n ast.Node, idx int) ast.Node { + if k.NodeType() != reflect.TypeOf(n) { + panic(fmt.Sprintf("%v.Get(%T): invalid node type", k, n)) + } + v := reflect.ValueOf(n).Elem().Field(fieldInfos[k].index) + if idx != -1 { + v = v.Index(idx) // asserts valid index + } else { + // (The type assertion below asserts that v is not a slice.) + } + return v.Interface().(ast.Node) // may be nil +} + const ( Invalid Kind = iota // for nodes at the root of the traversal @@ -156,7 +172,8 @@ var _ = [1 << 7]struct{}{}[maxKind] type fieldInfo struct { nodeType reflect.Type // pointer-to-struct type of ast.Node implementation - fieldName string + name string + index int fieldType reflect.Type } @@ -166,7 +183,7 @@ func info[N ast.Node](fieldName string) fieldInfo { if !ok { panic(fieldName) } - return fieldInfo{nodePtrType, fieldName, f.Type} + return fieldInfo{nodePtrType, fieldName, f.Index[0], f.Type} } var fieldInfos = [...]fieldInfo{