Skip to content

Commit

Permalink
starlark: tweak how UnpackArgs enforces optional args
Browse files Browse the repository at this point in the history
Fixes #544

Signed-off-by: Nick Santos <[email protected]>
  • Loading branch information
nicks committed May 6, 2024
1 parent 9b43f0a commit 3622e9b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 27 deletions.
41 changes: 41 additions & 0 deletions starlark/eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1109,3 +1109,44 @@ f(1)
t.Errorf("env() returned %s, want %s", got, want)
}
}

func TestUnpackArgsOptionalInference(t *testing.T) {
// success
kwargs := []starlark.Tuple{
{starlark.String("x"), starlark.MakeInt(1)},
{starlark.String("y"), starlark.MakeInt(2)},
}
var x, y, z int
if err := starlark.UnpackArgs("unpack", nil, kwargs,
"x", &x, "y?", &y, "z", &z); err != nil {
t.Errorf("UnpackArgs failed: %v", err)
}
if x != 1 {
t.Errorf("for x, got %v, want %v", x, 1)
}
if y != 2 {
t.Errorf("for y, got %v, want %v", y, 2)
}
if z != 0 {
t.Errorf("for z, got %v, want %v", z, 0)
}

// success
args := starlark.Tuple{starlark.MakeInt(1), starlark.MakeInt(2)}
x = 0
y = 0
z = 0
if err := starlark.UnpackArgs("unpack", args, nil,
"x", &x, "y?", &y, "z", &z); err != nil {
t.Errorf("UnpackArgs failed: %v", err)
}
if x != 1 {
t.Errorf("for x, got %v, want %v", x, 1)
}
if y != 2 {
t.Errorf("for y, got %v, want %v", y, 2)
}
if z != 0 {
t.Errorf("for z, got %v, want %v", z, 0)
}
}
56 changes: 29 additions & 27 deletions starlark/unpack.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,23 @@ type Unpacker interface {
//
// Examples:
//
// var (
// a Value
// b = MakeInt(42)
// c Value = starlark.None
// )
// var (
// a Value
// b = MakeInt(42)
// c Value = starlark.None
// )
//
// // 1. mixed parameters, like def f(a, b=42, c=None).
// err := UnpackArgs("f", args, kwargs, "a", &a, "b?", &b, "c?", &c)
// // 1. mixed parameters, like def f(a, b=42, c=None).
// err := UnpackArgs("f", args, kwargs, "a", &a, "b?", &b, "c?", &c)
//
// // 2. keyword parameters only, like def f(*, a, b, c=None).
// if len(args) > 0 {
// return fmt.Errorf("f: unexpected positional arguments")
// }
// err := UnpackArgs("f", args, kwargs, "a", &a, "b?", &b, "c?", &c)
// // 2. keyword parameters only, like def f(*, a, b, c=None).
// if len(args) > 0 {
// return fmt.Errorf("f: unexpected positional arguments")
// }
// err := UnpackArgs("f", args, kwargs, "a", &a, "b?", &b, "c?", &c)
//
// // 3. positional parameters only, like def f(a, b=42, c=None, /) in Python 3.8.
// err := UnpackPositionalArgs("f", args, kwargs, 1, &a, &b, &c)
// // 3. positional parameters only, like def f(a, b=42, c=None, /) in Python 3.8.
// err := UnpackPositionalArgs("f", args, kwargs, 1, &a, &b, &c)
//
// More complex forms such as def f(a, b=42, *args, c, d=123, **kwargs)
// require additional logic, but their need in built-ins is exceedingly rare.
Expand All @@ -79,17 +79,16 @@ type Unpacker interface {
// for the zero values of variables of type *List, *Dict, Callable, or
// Iterable. For example:
//
// // def myfunc(d=None, e=[], f={})
// var (
// d Value
// e *List
// f *Dict
// )
// err := UnpackArgs("myfunc", args, kwargs, "d?", &d, "e?", &e, "f?", &f)
// if d == nil { d = None; }
// if e == nil { e = new(List); }
// if f == nil { f = new(Dict); }
//
// // def myfunc(d=None, e=[], f={})
// var (
// d Value
// e *List
// f *Dict
// )
// err := UnpackArgs("myfunc", args, kwargs, "d?", &d, "e?", &e, "f?", &f)
// if d == nil { d = None; }
// if e == nil { e = new(List); }
// if f == nil { f = new(Dict); }
func UnpackArgs(fnname string, args Tuple, kwargs []Tuple, pairs ...interface{}) error {
nparams := len(pairs) / 2
var defined intset
Expand Down Expand Up @@ -164,12 +163,15 @@ kwloop:
}

// Check that all non-optional parameters are defined.
// (We needn't check the first len(args).)
for i := len(args); i < nparams; i++ {
for i := 0; i < nparams; i++ {
name := pairs[2*i].(string)
if strings.HasSuffix(name, "?") {
break // optional
}
// (We needn't check the first len(args).)
if i < len(args) {
continue
}
if !defined.get(i) {
return fmt.Errorf("%s: missing argument for %s", fnname, name)
}
Expand Down

0 comments on commit 3622e9b

Please sign in to comment.