-
-
Notifications
You must be signed in to change notification settings - Fork 1k
compiler: implement method-set based AssignableTo and Implements #5304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
9dd4964
a9d5c29
1cb67a2
1b13c9f
9fb3536
d73f9b8
2767f51
a4eb3f8
bcd81c0
39f7e8e
40ac70e
60e1f18
06c884c
84579fc
9929018
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ import ( | |
| "fmt" | ||
| "go/token" | ||
| "go/types" | ||
| "sort" | ||
| "strconv" | ||
| "strings" | ||
|
|
||
|
|
@@ -183,6 +184,16 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { | |
| typeFieldTypes := []*types.Var{ | ||
| types.NewVar(token.NoPos, nil, "kind", types.Typ[types.Int8]), | ||
| } | ||
| // Compute the method set value for types that support methods. | ||
| var methods []*types.Func | ||
| for i := 0; i < ms.Len(); i++ { | ||
| methods = append(methods, ms.At(i).Obj().(*types.Func)) | ||
| } | ||
| methodSetType := types.NewStruct([]*types.Var{ | ||
| types.NewVar(token.NoPos, nil, "length", types.Typ[types.Uintptr]), | ||
| types.NewVar(token.NoPos, nil, "methods", types.NewArray(types.Typ[types.UnsafePointer], int64(len(methods)))), | ||
| }, nil) | ||
| methodSetValue := c.getMethodSetValue(methods) | ||
| switch typ := typ.(type) { | ||
| case *types.Basic: | ||
| typeFieldTypes = append(typeFieldTypes, | ||
|
|
@@ -199,6 +210,13 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { | |
| types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]), | ||
| types.NewVar(token.NoPos, nil, "underlying", types.Typ[types.UnsafePointer]), | ||
| types.NewVar(token.NoPos, nil, "pkgpath", types.Typ[types.UnsafePointer]), | ||
| ) | ||
| if len(methods) > 0 { | ||
| typeFieldTypes = append(typeFieldTypes, | ||
| types.NewVar(token.NoPos, nil, "methods", methodSetType), | ||
| ) | ||
| } | ||
| typeFieldTypes = append(typeFieldTypes, | ||
| types.NewVar(token.NoPos, nil, "name", types.NewArray(types.Typ[types.Int8], int64(len(pkgname)+1+len(name)+1))), | ||
| ) | ||
| case *types.Chan: | ||
|
|
@@ -218,6 +236,11 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { | |
| types.NewVar(token.NoPos, nil, "numMethods", types.Typ[types.Uint16]), | ||
| types.NewVar(token.NoPos, nil, "elementType", types.Typ[types.UnsafePointer]), | ||
| ) | ||
| if len(methods) > 0 { | ||
| typeFieldTypes = append(typeFieldTypes, | ||
| types.NewVar(token.NoPos, nil, "methods", methodSetType), | ||
| ) | ||
| } | ||
| case *types.Array: | ||
| typeFieldTypes = append(typeFieldTypes, | ||
| types.NewVar(token.NoPos, nil, "numMethods", types.Typ[types.Uint16]), | ||
|
|
@@ -242,11 +265,16 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { | |
| types.NewVar(token.NoPos, nil, "numFields", types.Typ[types.Uint16]), | ||
| types.NewVar(token.NoPos, nil, "fields", types.NewArray(c.getRuntimeType("structField"), int64(typ.NumFields()))), | ||
| ) | ||
| if len(methods) > 0 { | ||
| typeFieldTypes = append(typeFieldTypes, | ||
| types.NewVar(token.NoPos, nil, "methods", methodSetType), | ||
| ) | ||
| } | ||
| case *types.Interface: | ||
| typeFieldTypes = append(typeFieldTypes, | ||
| types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]), | ||
| types.NewVar(token.NoPos, nil, "methods", methodSetType), | ||
| ) | ||
| // TODO: methods | ||
| case *types.Signature: | ||
| typeFieldTypes = append(typeFieldTypes, | ||
| types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]), | ||
|
|
@@ -292,14 +320,21 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { | |
| pkgname = pkg.Name() | ||
| } | ||
| pkgPathPtr := c.pkgPathPtr(pkgpath) | ||
| namedNumMethods := uint64(numMethods) | ||
| if len(methods) > 0 { | ||
| namedNumMethods |= 0x8000 // numMethodHasMethodSet flag | ||
|
deadprogram marked this conversation as resolved.
Outdated
|
||
| } | ||
| typeFields = []llvm.Value{ | ||
| llvm.ConstInt(c.ctx.Int16Type(), uint64(numMethods), false), // numMethods | ||
| c.getTypeCode(types.NewPointer(typ)), // ptrTo | ||
| c.getTypeCode(typ.Underlying()), // underlying | ||
| pkgPathPtr, // pkgpath pointer | ||
| c.ctx.ConstString(pkgname+"."+name+"\x00", false), // name | ||
| llvm.ConstInt(c.ctx.Int16Type(), namedNumMethods, false), // numMethods | ||
| c.getTypeCode(types.NewPointer(typ)), // ptrTo | ||
| c.getTypeCode(typ.Underlying()), // underlying | ||
| pkgPathPtr, // pkgpath pointer | ||
| } | ||
| if len(methods) > 0 { | ||
| typeFields = append(typeFields, methodSetValue) // methods | ||
| } | ||
| metabyte |= 1 << 5 // "named" flag | ||
| typeFields = append(typeFields, c.ctx.ConstString(pkgname+"."+name+"\x00", false)) // name | ||
| metabyte |= 1 << 5 // "named" flag | ||
| case *types.Chan: | ||
| var dir reflectChanDir | ||
| switch typ.Dir() { | ||
|
|
@@ -323,10 +358,17 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { | |
| c.getTypeCode(typ.Elem()), // elementType | ||
| } | ||
| case *types.Pointer: | ||
| ptrNumMethods := uint64(numMethods) | ||
| if len(methods) > 0 { | ||
| ptrNumMethods |= 0x8000 // numMethodHasMethodSet flag | ||
| } | ||
| typeFields = []llvm.Value{ | ||
| llvm.ConstInt(c.ctx.Int16Type(), uint64(numMethods), false), // numMethods | ||
| llvm.ConstInt(c.ctx.Int16Type(), ptrNumMethods, false), // numMethods | ||
| c.getTypeCode(typ.Elem()), | ||
| } | ||
| if len(methods) > 0 { | ||
| typeFields = append(typeFields, methodSetValue) | ||
| } | ||
| case *types.Array: | ||
| typeFields = []llvm.Value{ | ||
| llvm.ConstInt(c.ctx.Int16Type(), 0, false), // numMethods | ||
|
|
@@ -353,9 +395,13 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { | |
|
|
||
| llvmStructType := c.getLLVMType(typ) | ||
| size := c.targetData.TypeStoreSize(llvmStructType) | ||
| structNumMethods := uint64(numMethods) | ||
| if len(methods) > 0 { | ||
| structNumMethods |= 0x8000 // numMethodHasMethodSet flag | ||
| } | ||
| typeFields = []llvm.Value{ | ||
| llvm.ConstInt(c.ctx.Int16Type(), uint64(numMethods), false), // numMethods | ||
| c.getTypeCode(types.NewPointer(typ)), // ptrTo | ||
| llvm.ConstInt(c.ctx.Int16Type(), structNumMethods, false), // numMethods | ||
| c.getTypeCode(types.NewPointer(typ)), // ptrTo | ||
| pkgPathPtr, | ||
| llvm.ConstInt(c.ctx.Int32Type(), uint64(size), false), // size | ||
| llvm.ConstInt(c.ctx.Int16Type(), uint64(typ.NumFields()), false), // numFields | ||
|
|
@@ -407,9 +453,14 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { | |
| })) | ||
| } | ||
| typeFields = append(typeFields, llvm.ConstArray(structFieldType, fields)) | ||
| if len(methods) > 0 { | ||
| typeFields = append(typeFields, methodSetValue) | ||
| } | ||
| case *types.Interface: | ||
| typeFields = []llvm.Value{c.getTypeCode(types.NewPointer(typ))} | ||
| // TODO: methods | ||
| typeFields = []llvm.Value{ | ||
| c.getTypeCode(types.NewPointer(typ)), | ||
| methodSetValue, | ||
| } | ||
| case *types.Signature: | ||
| typeFields = []llvm.Value{c.getTypeCode(types.NewPointer(typ))} | ||
| // TODO: params, return values, etc | ||
|
|
@@ -696,17 +747,11 @@ func (b *builder) createTypeAssert(expr *ssa.TypeAssert) llvm.Value { | |
| // This type assertion always succeeds, so we can just set commaOk to true. | ||
| commaOk = llvm.ConstInt(b.ctx.Int1Type(), 1, true) | ||
| } else { | ||
| // Type assert on interface type with methods. | ||
| // This is a call to an interface type assert function. | ||
| // The interface lowering pass will define this function by filling it | ||
| // with a type switch over all concrete types that implement this | ||
| // interface, and returning whether it's one of the matched types. | ||
| // This is very different from how interface asserts are implemented in | ||
| // the main Go compiler, where the runtime checks whether the type | ||
| // implements each method of the interface. See: | ||
| // https://research.swtch.com/interfaces | ||
| fn := b.getInterfaceImplementsFunc(expr.AssertedType) | ||
| commaOk = b.CreateCall(fn.GlobalValueType(), fn, []llvm.Value{actualTypeNum}, "") | ||
| // Type assert on an interface type with methods. | ||
| // Create a call to a declared-but-not-defined function that will | ||
| // be lowered by the interface lowering pass into a type-ID | ||
| // comparison chain. | ||
| commaOk = b.createInterfaceTypeAssert(intf, actualTypeNum) | ||
| } | ||
| } else { | ||
| name, _ := getTypeCodeName(expr.AssertedType) | ||
|
|
@@ -783,20 +828,74 @@ func (c *compilerContext) getMethodsString(itf *types.Interface) string { | |
| return strings.Join(methods, "; ") | ||
| } | ||
|
|
||
| // getInterfaceImplementsFunc returns a declared function that works as a type | ||
| // switch. The interface lowering pass will define this function. | ||
| func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) llvm.Value { | ||
| s, _ := getTypeCodeName(assertedType.Underlying()) | ||
| fnName := s + ".$typeassert" | ||
| llvmFn := c.mod.NamedFunction(fnName) | ||
| if llvmFn.IsNil() { | ||
| llvmFnType := llvm.FunctionType(c.ctx.Int1Type(), []llvm.Type{c.dataPtrType}, false) | ||
| llvmFn = llvm.AddFunction(c.mod, fnName, llvmFnType) | ||
| c.addStandardDeclaredAttributes(llvmFn) | ||
| methods := c.getMethodsString(assertedType.Underlying().(*types.Interface)) | ||
| llvmFn.AddFunctionAttr(c.ctx.CreateStringAttribute("tinygo-methods", methods)) | ||
| // getInterfaceMethodSet returns a global that contains the method set for an | ||
| // interface type, creating it if needed. | ||
| func (c *compilerContext) getInterfaceMethodSet(t *types.Interface) llvm.Value { | ||
| s, _ := getTypeCodeName(t) | ||
| methodSetName := s + "$itfmethods" | ||
| methodSet := c.mod.NamedGlobal(methodSetName) | ||
| if !methodSet.IsNil() { | ||
| return methodSet | ||
| } | ||
| return llvmFn | ||
|
|
||
| var methods []*types.Func | ||
| for i := 0; i < t.NumMethods(); i++ { | ||
| methods = append(methods, t.Method(i)) | ||
| } | ||
| if len(methods) == 0 { | ||
| panic("unreachable: getInterfaceMethodSet called on empty interface") | ||
| } | ||
|
|
||
| methodSetValue := c.getMethodSetValue(methods) | ||
| methodSet = llvm.AddGlobal(c.mod, methodSetValue.Type(), methodSetName) | ||
| methodSet.SetInitializer(methodSetValue) | ||
| methodSet.SetGlobalConstant(true) | ||
| methodSet.SetLinkage(llvm.LinkOnceODRLinkage) | ||
| methodSet.SetAlignment(c.targetData.ABITypeAlignment(methodSetValue.Type())) | ||
| methodSet.SetUnnamedAddr(true) | ||
|
|
||
| return methodSet | ||
| } | ||
|
|
||
| // getMethodSetValue creates the method set struct value for a list of methods. | ||
| // The struct contains a length and a sorted array of method signature pointers. | ||
| func (c *compilerContext) getMethodSetValue(methods []*types.Func) llvm.Value { | ||
| // Create a sorted list of method signature global names. | ||
| type methodRef struct { | ||
| name string | ||
| value llvm.Value | ||
| } | ||
| var refs []methodRef | ||
| for _, method := range methods { | ||
| name := method.Name() | ||
| if !token.IsExported(name) { | ||
| name = method.Pkg().Path() + "." + name | ||
| } | ||
| s, _ := getTypeCodeName(method.Type()) | ||
| globalName := "reflect/types.signature:" + name + ":" + s | ||
| value := c.mod.NamedGlobal(globalName) | ||
| if value.IsNil() { | ||
| value = llvm.AddGlobal(c.mod, c.ctx.Int8Type(), globalName) | ||
| value.SetInitializer(llvm.ConstNull(c.ctx.Int8Type())) | ||
| value.SetGlobalConstant(true) | ||
| value.SetLinkage(llvm.LinkOnceODRLinkage) | ||
| value.SetAlignment(1) | ||
|
Comment on lines
+864
to
+868
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There might be ways to optimize this, since all it really needs is unique IDs. Anyway, just ideas for the future it looks good enough for now.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, so I did think about using IDs, but there were some challenges that made me not do that. Namely that I wasn't sure that it would be easy to DCE because with the pointers, at least I think the LLVM stack knows when something is unused, but with the IDs, not so much? For the debug info, I think I could try and do that quick, but if you don't mind it later I'm happy to wait (not sure if there's any rebasing or something needed for this PR or if it's going to just get squashed).
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, adding the debug info is acutally very easy, it's effectively just copy-paste from
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you! |
||
| } | ||
| refs = append(refs, methodRef{globalName, value}) | ||
| } | ||
| sort.Slice(refs, func(i, j int) bool { | ||
| return refs[i].name < refs[j].name | ||
| }) | ||
|
|
||
| var values []llvm.Value | ||
| for _, ref := range refs { | ||
| values = append(values, ref.value) | ||
| } | ||
|
|
||
| return c.ctx.ConstStruct([]llvm.Value{ | ||
| llvm.ConstInt(c.uintptrType, uint64(len(values)), false), | ||
| llvm.ConstArray(c.dataPtrType, values), | ||
| }, false) | ||
| } | ||
|
|
||
| // getInvokeFunction returns the thunk to call the given interface method. The | ||
|
|
@@ -823,6 +922,24 @@ func (c *compilerContext) getInvokeFunction(instr *ssa.CallCommon) llvm.Value { | |
| return llvmFn | ||
| } | ||
|
|
||
| // createInterfaceTypeAssert creates a call to a declared-but-not-defined | ||
| // $typeassert function for the given interface. This function will be defined | ||
| // by the interface lowering pass as a type-ID comparison chain, avoiding the | ||
| // need for runtime.typeImplementsMethodSet at compile time. | ||
| func (b *builder) createInterfaceTypeAssert(intf *types.Interface, actualType llvm.Value) llvm.Value { | ||
| s, _ := getTypeCodeName(intf) | ||
| fnName := s + ".$typeassert" | ||
| llvmFn := b.mod.NamedFunction(fnName) | ||
| if llvmFn.IsNil() { | ||
| llvmFnType := llvm.FunctionType(b.ctx.Int1Type(), []llvm.Type{b.dataPtrType}, false) | ||
| llvmFn = llvm.AddFunction(b.mod, fnName, llvmFnType) | ||
| b.addStandardDeclaredAttributes(llvmFn) | ||
| methods := b.getMethodsString(intf) | ||
| llvmFn.AddFunctionAttr(b.ctx.CreateStringAttribute("tinygo-methods", methods)) | ||
| } | ||
| return b.CreateCall(llvmFn.GlobalValueType(), llvmFn, []llvm.Value{actualType}, "") | ||
| } | ||
|
|
||
| // getInterfaceInvokeWrapper returns a wrapper for the given method so it can be | ||
| // invoked from an interface. The wrapper takes in a pointer to the underlying | ||
| // value, dereferences or unpacks it if necessary, and calls the real method. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.