From ff006049743c4e102bb81904c0ddf978cc00380b Mon Sep 17 00:00:00 2001 From: Tyler Southwick Date: Fri, 24 Sep 2021 14:25:06 -0700 Subject: [PATCH 1/2] add InvokeOption for Names --- dig.go | 47 ++++++++++++++++++++++++++++++++++++++++++++++- dig_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 1 deletion(-) diff --git a/dig.go b/dig.go index e99cc9eb..ad3bfe51 100644 --- a/dig.go +++ b/dig.go @@ -282,10 +282,28 @@ func LocationForPC(pc uintptr) ProvideOption { }) } +type invokeOptions struct { + Names []string +} + +func (*invokeOptions) Validate() error { + return nil +} + // An InvokeOption modifies the default behavior of Invoke. It's included for // future functionality; currently, there are no concrete implementations. type InvokeOption interface { - unimplemented() + applyInvokeOption(*invokeOptions) +} + +type invokeOptionFunc func(*invokeOptions) + +func (f invokeOptionFunc) applyInvokeOption(opts *invokeOptions) { f(opts) } + +func Names(names ...string) InvokeOption { + return invokeOptionFunc(func(opts *invokeOptions) { + opts.Names = names + }) } // Container is a directed acyclic graph of types and their dependencies. @@ -566,11 +584,38 @@ func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error { return errf("can't invoke non-function %v (type %v)", function, ftype) } + var options invokeOptions + for _, o := range opts { + o.applyInvokeOption(&options) + } + if err := options.Validate(); err != nil { + return err + } + pl, err := newParamList(ftype) if err != nil { return err } + if len(pl.Params) < len(options.Names) { + return errf("can't invoke function with more names=%s than operands=%s", options.Names, ftype) + } + + updatedParams := make([]param, len(pl.Params)) + for i, p := range pl.Params { + if i < len(options.Names) { + if ps, ok := pl.Params[i].(paramSingle); ok { + ps.Name = options.Names[i] + updatedParams[i] = ps + } else { + return errf("can't have a named param (%s) that is not a paramSingle (%s)", options.Names[i], pl.Params[i]) + } + } else { + updatedParams[i] = p + } + } + pl.Params = updatedParams + if err := shallowCheckDependencies(c, pl); err != nil { return errMissingDependencies{ Func: digreflect.InspectFunc(function), diff --git a/dig_test.go b/dig_test.go index a66ae0c6..7e2ff7ad 100644 --- a/dig_test.go +++ b/dig_test.go @@ -536,6 +536,25 @@ func TestEndToEndSuccess(t *testing.T) { }), "invoke should succeed, pulling out two named instances") }) + t.Run("named instances can be invoked Name option", func(t *testing.T) { + c := New() + + type A struct{ idx int } + + buildConstructor := func(idx int) func() A { + return func() A { return A{idx: idx} } + } + + require.NoError(t, c.Provide(buildConstructor(1), Name("first"))) + require.NoError(t, c.Provide(buildConstructor(2), Name("second"))) + require.NoError(t, c.Provide(buildConstructor(3), Name("third"))) + + require.NoError(t, c.Invoke(func(a1 A, a3 A) { + assert.Equal(t, 1, a1.idx) + assert.Equal(t, 3, a3.idx) + }, Names("first", "third")), "invoke should succeed, using two named instances") + }) + t.Run("named and unnamed instances coexist", func(t *testing.T) { c := New() type A struct{ idx int } @@ -561,6 +580,25 @@ func TestEndToEndSuccess(t *testing.T) { })) }) + t.Run("named and unnamed instances can be invoked with Names option", func(t *testing.T) { + c := New() + + type A struct{ idx int } + + buildConstructor := func(idx int) func() A { + return func() A { return A{idx: idx} } + } + + require.NoError(t, c.Provide(buildConstructor(1), Name("first"))) + require.NoError(t, c.Provide(buildConstructor(2), Name("second"))) + require.NoError(t, c.Provide(buildConstructor(3))) + + require.NoError(t, c.Invoke(func(a1 A, a3 A) { + assert.Equal(t, 1, a1.idx) + assert.Equal(t, 3, a3.idx) + }, Names("first")), "invoke should succeed, using two named instances") + }) + t.Run("named instances recurse", func(t *testing.T) { c := New() type A struct{ idx int } From 1b079e3ea67876cf32625ced4734a5aac466fbe8 Mon Sep 17 00:00:00 2001 From: Tyler Southwick Date: Fri, 24 Sep 2021 16:29:36 -0700 Subject: [PATCH 2/2] add support for Names to ProvideOption --- dig.go | 46 +++++++++++++++++++++------------------------- dig_test.go | 20 ++++++++++++++++++++ param.go | 21 ++++++++++++++++----- param_test.go | 4 ++-- 4 files changed, 59 insertions(+), 32 deletions(-) diff --git a/dig.go b/dig.go index ad3bfe51..e51d7bc9 100644 --- a/dig.go +++ b/dig.go @@ -65,6 +65,7 @@ type provideOptions struct { Info *ProvideInfo As []interface{} Location *digreflect.Func + Names []string } func (o *provideOptions) Validate() error { @@ -300,10 +301,22 @@ type invokeOptionFunc func(*invokeOptions) func (f invokeOptionFunc) applyInvokeOption(opts *invokeOptions) { f(opts) } -func Names(names ...string) InvokeOption { - return invokeOptionFunc(func(opts *invokeOptions) { - opts.Names = names - }) +type InvokeAndProvideOption interface { + InvokeOption + ProvideOption +} + +type namesOption []string + +func (n namesOption) applyInvokeOption(opts *invokeOptions) { + opts.Names = n +} +func (n namesOption) applyProvideOption(opts *provideOptions) { + opts.Names = n +} + +func Names(names ...string) InvokeAndProvideOption { + return namesOption(names) } // Container is a directed acyclic graph of types and their dependencies. @@ -592,30 +605,11 @@ func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error { return err } - pl, err := newParamList(ftype) + pl, err := newParamList(ftype, options.Names) if err != nil { return err } - if len(pl.Params) < len(options.Names) { - return errf("can't invoke function with more names=%s than operands=%s", options.Names, ftype) - } - - updatedParams := make([]param, len(pl.Params)) - for i, p := range pl.Params { - if i < len(options.Names) { - if ps, ok := pl.Params[i].(paramSingle); ok { - ps.Name = options.Names[i] - updatedParams[i] = ps - } else { - return errf("can't have a named param (%s) that is not a paramSingle (%s)", options.Names[i], pl.Params[i]) - } - } else { - updatedParams[i] = p - } - } - pl.Params = updatedParams - if err := shallowCheckDependencies(c, pl); err != nil { return errMissingDependencies{ Func: digreflect.InspectFunc(function), @@ -669,6 +663,7 @@ func (c *Container) provide(ctor interface{}, opts provideOptions) error { ResultGroup: opts.Group, ResultAs: opts.As, Location: opts.Location, + ParamNames: opts.Names, }, ) if err != nil { @@ -887,6 +882,7 @@ type nodeOptions struct { ResultGroup string ResultAs []interface{} Location *digreflect.Func + ParamNames []string } func newNode(ctor interface{}, opts nodeOptions) (*node, error) { @@ -894,7 +890,7 @@ func newNode(ctor interface{}, opts nodeOptions) (*node, error) { ctype := cval.Type() cptr := cval.Pointer() - params, err := newParamList(ctype) + params, err := newParamList(ctype, opts.ParamNames) if err != nil { return nil, err } diff --git a/dig_test.go b/dig_test.go index 7e2ff7ad..74511de5 100644 --- a/dig_test.go +++ b/dig_test.go @@ -536,6 +536,26 @@ func TestEndToEndSuccess(t *testing.T) { }), "invoke should succeed, pulling out two named instances") }) + t.Run("named instances can be used to Provide another instance", func(t *testing.T) { + c := New() + + type A struct{ idx int } + + buildConstructor := func(idx int) func() A { + return func() A { return A{idx: idx} } + } + + require.NoError(t, c.Provide(buildConstructor(1), Name("first"))) + require.NoError(t, c.Provide(buildConstructor(2), Name("second"))) + require.NoError(t, c.Provide(func(a A) int { + return a.idx + 5 + }, Names("first"))) + + require.NoError(t, c.Invoke(func(i int) { + assert.Equal(t, 6, i) + }), "invoke should succeed, pulling out one named instances") + }) + t.Run("named instances can be invoked Name option", func(t *testing.T) { c := New() diff --git a/param.go b/param.go index df7868f3..3b35f344 100644 --- a/param.go +++ b/param.go @@ -62,11 +62,14 @@ var ( // newParam builds a param from the given type. If the provided type is a // dig.In struct, an paramObject will be returned. -func newParam(t reflect.Type) (param, error) { +func newParam(t reflect.Type, paramName string) (param, error) { switch { case IsOut(t) || (t.Kind() == reflect.Ptr && IsOut(t.Elem())) || embedsType(t, _outPtrType): return nil, errf("cannot depend on result objects", "%v embeds a dig.Out", t) case IsIn(t): + if paramName != "" { + return nil, errf("cannot have a paramName (%s) with a struct that has dig.In", paramName) + } return newParamObject(t) case embedsType(t, _inPtrType): return nil, errf( @@ -77,7 +80,7 @@ func newParam(t reflect.Type) (param, error) { "cannot depend on a pointer to a parameter object, use a value instead", "%v is a pointer to a struct that embeds dig.In", t) default: - return paramSingle{Type: t}, nil + return paramSingle{Type: t, Name: paramName}, nil } } @@ -158,7 +161,7 @@ func (pl paramList) DotParam() []*dot.Param { // // Variadic arguments of a constructor are ignored and not included as // dependencies. -func newParamList(ctype reflect.Type) (paramList, error) { +func newParamList(ctype reflect.Type, names []string) (paramList, error) { numArgs := ctype.NumIn() if ctype.IsVariadic() { // NOTE: If the function is variadic, we skip the last argument @@ -171,8 +174,16 @@ func newParamList(ctype reflect.Type) (paramList, error) { Params: make([]param, 0, numArgs), } + if numArgs < len(names) { + return pl, errf("can't create a constructor with more names=%s than args=%s", names, ctype) + } + for i := 0; i < numArgs; i++ { - p, err := newParam(ctype.In(i)) + name := "" + if i < len(names) { + name = names[i] + } + p, err := newParam(ctype.In(i), name) if err != nil { return pl, errf("bad argument %d", i+1, err) } @@ -370,7 +381,7 @@ func newParamObjectField(idx int, f reflect.StructField) (paramObjectField, erro default: var err error - p, err = newParam(f.Type) + p, err = newParam(f.Type, "") if err != nil { return pof, err } diff --git a/param_test.go b/param_test.go index 68f0cde2..e17e3ea5 100644 --- a/param_test.go +++ b/param_test.go @@ -30,7 +30,7 @@ import ( ) func TestParamListBuild(t *testing.T) { - p, err := newParamList(reflect.TypeOf(func() io.Writer { return nil })) + p, err := newParamList(reflect.TypeOf(func() io.Writer { return nil }), []string{}) require.NoError(t, err) assert.Panics(t, func() { p.Build(New()) @@ -238,7 +238,7 @@ func TestParamVisitorChecksEverything(t *testing.T) { pl, err := newParamList(reflect.TypeOf(func(io.Reader, params, io.Writer) { t.Fatalf("this function should not be called") - })) + }), []string{}) require.NoError(t, err) idx := 0