diff --git a/lib/CppInterOp/CppInterOp.cpp b/lib/CppInterOp/CppInterOp.cpp index 9c6f00432..8ba1c84c5 100644 --- a/lib/CppInterOp/CppInterOp.cpp +++ b/lib/CppInterOp/CppInterOp.cpp @@ -1951,6 +1951,22 @@ static QualType findBuiltinType(llvm::StringRef typeName, ASTContext& Context) { */ return QualType(); } +static std::optional GetTypeInternal(Decl* D) { + if (!D) + return {}; + // Even though typedefs derive from TypeDecl, their getTypeForDecl() + // returns a nullptr. + if (const auto* TND = llvm::dyn_cast_or_null(D)) + return TND->getUnderlyingType(); + + if (auto* VD = dyn_cast(D)) + return VD->getType(); + + if (const auto* TD = llvm::dyn_cast_or_null(D)) + return QualType(TD->getTypeForDecl(), 0); + + return {}; +} } // namespace TCppType_t GetType(const std::string& name) { @@ -1958,12 +1974,7 @@ TCppType_t GetType(const std::string& name) { if (!builtin.isNull()) return builtin.getAsOpaquePtr(); - auto* D = (Decl*)GetNamed(name, /* Within= */ 0); - if (auto* TD = llvm::dyn_cast_or_null(D)) { - return QualType(TD->getTypeForDecl(), 0).getAsOpaquePtr(); - } - - return (TCppType_t)0; + return GetTypeFromScope((Decl*)GetNamed(name, /*Within=*/nullptr)); } TCppType_t GetComplexType(TCppType_t type) { @@ -1974,17 +1985,12 @@ TCppType_t GetComplexType(TCppType_t type) { TCppType_t GetTypeFromScope(TCppScope_t klass) { if (!klass) - return 0; - - auto* D = (Decl*)klass; - - if (auto* VD = dyn_cast(D)) - return VD->getType().getAsOpaquePtr(); + return nullptr; - if (auto* TD = dyn_cast(D)) - return getASTContext().getTypeDeclType(TD).getAsOpaquePtr(); + if (auto QT = GetTypeInternal((Decl*)klass)) + return QT->getAsOpaquePtr(); - return (TCppType_t) nullptr; + return nullptr; } // Internal functions that are not needed outside the library are diff --git a/unittests/CppInterOp/ScopeReflectionTest.cpp b/unittests/CppInterOp/ScopeReflectionTest.cpp index 3aece27b7..b68c8e7ca 100644 --- a/unittests/CppInterOp/ScopeReflectionTest.cpp +++ b/unittests/CppInterOp/ScopeReflectionTest.cpp @@ -25,6 +25,19 @@ using namespace TestUtils; using namespace llvm; using namespace clang; +TYPED_TEST(CPPINTEROP_TEST_MODE, ScopeReflection_GetTypeOfTypedef) { + std::string code = R"( + typedef int Type_t; + )"; + + std::vector Decls; + GetAllTopLevelDecls(code, Decls); + auto Ty = Cpp::GetType("Type_t"); + EXPECT_TRUE(Ty); + EXPECT_TRUE(Ty == Cpp::GetTypeFromScope(Decls[0])); + EXPECT_FALSE(Cpp::GetTypeFromScope(nullptr)); +} + TYPED_TEST(CPPINTEROP_TEST_MODE, ScopeReflection_IsEnumScope) { std::vector Decls; std::vector SubDecls;