diff --git a/include/CppInterOp/CppInterOpTypes.h b/include/CppInterOp/CppInterOpTypes.h index e7332eecd..45a470e76 100644 --- a/include/CppInterOp/CppInterOpTypes.h +++ b/include/CppInterOp/CppInterOpTypes.h @@ -405,6 +405,16 @@ enum class InterpreterLanguageStandard : unsigned char { hlsl202y, lang_unspecified }; + +enum class AllocType : unsigned char { + None, + New, + NewArr, + Malloc, + Unknown, + CustomAlloc +}; + inline QualKind operator|(QualKind a, QualKind b) { return static_cast(static_cast(a) | static_cast(b)); diff --git a/lib/CppInterOp/CppInterOp.cpp b/lib/CppInterOp/CppInterOp.cpp index 311237812..377e69486 100644 --- a/lib/CppInterOp/CppInterOp.cpp +++ b/lib/CppInterOp/CppInterOp.cpp @@ -47,12 +47,15 @@ #include "clang/AST/GlobalDecl.h" #include "clang/AST/Mangle.h" #include "clang/AST/NestedNameSpecifier.h" +#include "clang/AST/OperationKinds.h" #include "clang/AST/QualTypeNames.h" #include "clang/AST/RawCommentList.h" #include "clang/AST/RecordLayout.h" +#include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/Stmt.h" #include "clang/AST/Type.h" #include "clang/AST/VTableBuilder.h" +#include "clang/Basic/Builtins.h" #include "clang/Basic/Diagnostic.h" #include "clang/Basic/DiagnosticSema.h" #include "clang/Basic/LangStandard.h" @@ -108,6 +111,7 @@ #include #include #include +#include #include #include #include @@ -116,6 +120,7 @@ #ifndef _WIN32 #include #endif +#include #include // Stream redirect. #ifdef _WIN32 @@ -1557,6 +1562,138 @@ bool IsFunctionProtoType(ConstTypeRef TyRef) { return llvm::isa_and_nonnull(T); } +static AllocType handleNew(const clang::CXXNewExpr* CNE) { + if (CNE->getNumPlacementArgs() > 0) + return AllocType::None; + if (CNE->isArray()) + return AllocType::NewArr; + return AllocType::New; +} + +static AllocType AnalyzeAllocType(const clang::FunctionDecl* Fn); + +static AllocType handleCall(const clang::CallExpr* CE) { + if (const auto* FD = CE->getDirectCallee()) { + if (FD->getBuiltinID() == Builtin::ID::BImalloc) + return AllocType::Malloc; + return AnalyzeAllocType(FD); + } + // Function pointer calle + return AllocType::Unknown; +} + +static AllocType +handleExpr(const clang::Expr* expr, + std::unordered_map& varMap) { + const clang::Expr* finExpr = expr->IgnoreParenCasts(); + // Case: return new __type__ + if (const auto* CNE = dyn_cast(finExpr)) + return handleNew(CNE); + + // Case: returns a variable + if (const auto* DRE = dyn_cast(finExpr)) { + if (const auto* VD = dyn_cast(DRE->getDecl())) { + auto it = varMap.find(VD); + if (it != varMap.end()) + return it->second; + } + // FIXME: BindingDecl, NonTypeTemplateParmDecl are not handled + return AllocType::None; + } + + // Case: malloc or another func call + if (const auto* CE = dyn_cast(finExpr)) { + return handleCall(CE); + } + return AllocType::None; +} + +static std::vector +getAllRetStmt(const clang::CompoundStmt* CS) { + struct RetStmtVisitor : RecursiveASTVisitor { + std::vector vec; + bool VisitReturnStmt(ReturnStmt* RS) { + vec.push_back(RS); + return true; + } + }; + RetStmtVisitor Visitor; + Visitor.TraverseStmt(const_cast(CS)); + return Visitor.vec; +} + +static std::unordered_map +getAllVarDecl(const clang::CompoundStmt* CS) { + struct VarVisitor : RecursiveASTVisitor { + std::unordered_map varMap; + bool VisitVarDecl(VarDecl* VD) { + Expr* expr = VD->getInit(); + if (expr) + varMap[VD] = handleExpr(expr, varMap); + else + varMap[VD] = AllocType::None; + return true; + } + + bool VisitBinaryOperator(clang::BinaryOperator* BO) { + if (BO->getOpcode() != BO_Assign) + return true; + Expr* LHS = BO->getLHS(); + LHS = LHS->IgnoreParenCasts(); + if (auto* DRE = dyn_cast(LHS)) { + if (auto* VD = dyn_cast(DRE->getDecl())) { + Expr* RHS = BO->getRHS(); + varMap[VD] = handleExpr(RHS, varMap); + } + } + return true; + } + }; + VarVisitor Visitor; + Visitor.TraverseStmt(const_cast(CS)); + return Visitor.varMap; +} + +static AllocType AnalyzeAllocType(const clang::FunctionDecl* Fn) { + const clang::QualType QT = Fn->getReturnType(); + if (!QT->isPointerType()) + return AllocType::None; + const Stmt* fnBody = Fn->getBody(); + if (!fnBody) + return AllocType::Unknown; + const auto* CmpStmt = dyn_cast(fnBody); + // FIXME:: try catch blocks are not CompoundStmt, only edge case + if (!CmpStmt) + return AllocType::Unknown; + std::unordered_map varMap = getAllVarDecl(CmpStmt); + std::vector allRetStmt = getAllRetStmt(CmpStmt); + std::optional res; + for (const ReturnStmt* retStmt : allRetStmt) { + const clang::Expr* retExpr = retStmt->getRetValue(); + if (retExpr == nullptr) + continue; + AllocType tmp = handleExpr(retExpr, varMap); + if (!res.has_value()) + res = tmp; + // If function's allocation behaviour differs between different cases, + // analyzer returns unknown. + else if (*res != tmp) + return AllocType::Unknown; + } + return res.value_or(AllocType::None); +} + +AllocType GetAllocType(ConstFuncRef Fn) { + INTEROP_TRACE(Fn); + if (Fn) { + const auto* D = unwrap(Fn); + if (const auto* FD = dyn_cast(D)) { + return INTEROP_RETURN(AnalyzeAllocType(FD)); + } + } + return INTEROP_RETURN(AllocType::None); +} + void GetFnTypeSignature(ConstTypeRef fn_type, std::vector& sig) { INTEROP_TRACE(fn_type, INTEROP_OUT(sig)); QualType QT = QualType::getFromOpaquePtr(fn_type.data); diff --git a/lib/CppInterOp/CppInterOp.td b/lib/CppInterOp/CppInterOp.td index e8a57a441..a4144a41e 100644 --- a/lib/CppInterOp/CppInterOp.td +++ b/lib/CppInterOp/CppInterOp.td @@ -763,6 +763,15 @@ def IsFunctionProtoType : CppInterOpAPI { ]; } +def GetAllocType: CppInterOpAPI { + let Doc = "Analyzes function's body to detect memory allocation behaviour"; + + let ReturnType = "AllocType"; + let Args = [ + Arg<"ConstFuncRef", "Fn"> + ]; +} + def GetFnTypeSignature : CppInterOpAPI { let Doc = [{Pushed the signature type of the given function type, where the first item is the return type.}]; diff --git a/unittests/CppInterOp/FunctionReflectionTest.cpp b/unittests/CppInterOp/FunctionReflectionTest.cpp index e1c5fd374..2c982f542 100644 --- a/unittests/CppInterOp/FunctionReflectionTest.cpp +++ b/unittests/CppInterOp/FunctionReflectionTest.cpp @@ -692,6 +692,118 @@ TYPED_TEST(CPPINTEROP_TEST_MODE, FunctionReflection_FunctionTypes) { EXPECT_TRUE(Cpp::IsSameType(typ1, typ2)); } +TYPED_TEST(CPPINTEROP_TEST_MODE, FunctionReflection_GetAllocType) { + std::string code = R"( + #include + #include + + int* func0(int n){ return new int(n); } + + void* func1(int n){ return (void*) new int(n); } + + void* func2(int n){ return static_cast(new int(n)); } + + int* func3(int n){ int* x = new int(n); int* y; y=x; return y; } + + int* func4(int n){int* x = new int(n); int* y = x; return y; } + + void* func5(int n){ return malloc(sizeof(int)); } + + void* func6(int n){ void* x = malloc(sizeof(int)); return x; } + + void** func7(int n){ void** arr = new void*[n]; return arr;} + + void** func8(int n){ return new void*[n]; } + + int* func9(int n){ int* x = new int(n); int* y = x; int* z = y; return z; } + + int* func10(int n){ return static_cast(malloc(sizeof(int(n)))); } + + int* func11(int n){ + int* x = new int(n); + int* y = nullptr; + y = x; + x = nullptr; + return y; + } + + int* func12(int n){ int* x = static_cast(malloc(sizeof(int))); return x;} + + int* func13(int n){ int* x = new int(n); return (((x))); } + + int func14(int n){ return n; } + + int* func15(int n); + + int* func16(int n) try { return new int(n); } catch(...) { return nullptr; } + + int* func17(int* p){ return p; } + + int* func18_t(int n){ return nullptr; } + typedef int* (*FnPtr18)(int); + FnPtr18 func18(int n){ return func18_t; } + + int* func19_helper(int n){ return nullptr; } + int* func19(int n){ return func19_helper(n); } + + int* func20(int* (*fp)(int), int n){ return fp(n); } + + int* func21(int n){ static char buf[16]; return new (&buf) int(n); } + + int* func22(int n){ []{ return; }(); return new int(n); } + + int* func23; + + int* func24(int n){ return func0(n); } + + void** func25(int n){ void** arr = func7(n); return arr; } + + int* func26(bool b){ + if(b) + return (int*)malloc(sizeof(int)); + return new int; + } + )"; + TestFixture::CreateInterpreter(); + Interp->declare(code); + +#define TESTAC(N, EXP) \ + EXPECT_EQ( \ + Cpp::GetAllocType(Cpp::ConstFuncRef { Cpp::GetNamed("func" #N).data }), \ + Cpp::AllocType::EXP) + + TESTAC(0, New); + TESTAC(1, New); + TESTAC(2, New); + TESTAC(3, New); + TESTAC(4, New); + TESTAC(5, Malloc); + TESTAC(6, Malloc); + TESTAC(7, NewArr); + TESTAC(8, NewArr); + TESTAC(9, New); + TESTAC(10, Malloc); + TESTAC(11, New); + TESTAC(12, Malloc); + TESTAC(13, New); + TESTAC(14, None); + TESTAC(15, Unknown); + TESTAC(16, Unknown); + TESTAC(17, None); + TESTAC(18, None); + TESTAC(19, None); + TESTAC(20, Unknown); + TESTAC(21, None); + TESTAC(22, New); + TESTAC(23, None); + TESTAC(24, New); + TESTAC(25, NewArr); + TESTAC(26, Unknown); + +#undef TESTAC + + Cpp::DeleteInterpreter(); +} TYPED_TEST(CPPINTEROP_TEST_MODE, FunctionReflection_GetFunctionSignature) { std::vector Decls; std::string code = R"(