From f2b617de68a30dda538d4607009f537191d46f0b Mon Sep 17 00:00:00 2001 From: Rich McKeever Date: Mon, 20 Apr 2026 15:27:40 -0700 Subject: [PATCH] Type inference precursor to IR conversion support for impl-style procs. * Do things for ProcDefs that have thus far only been done for StructDefs. * Add utils for dealing with ProcDefs. * Compute proc init tuples in CollectConstants and disallow the use of non-constant initializers. PiperOrigin-RevId: 902850513 --- xls/dslx/frontend/ast.cc | 5 +- xls/dslx/frontend/ast.h | 4 +- .../type_info_to_proto_test_ProcWithImpl.txt | 3 + xls/dslx/type_system_v2/BUILD | 5 ++ xls/dslx/type_system_v2/constant_collector.cc | 62 +++++++++++++++++++ .../type_system_v2/flatten_in_type_order.cc | 38 +++++++----- xls/dslx/type_system_v2/import_utils.cc | 29 +++++++++ xls/dslx/type_system_v2/import_utils.h | 14 +++++ .../type_system_v2/populate_table_visitor.cc | 28 ++++++--- .../typecheck_module_v2_test.cc | 45 ++++++++++---- .../type_system_v2/validate_concrete_type.cc | 28 ++++++--- 11 files changed, 214 insertions(+), 47 deletions(-) diff --git a/xls/dslx/frontend/ast.cc b/xls/dslx/frontend/ast.cc index e14ee1c355..8adcb8e584 100644 --- a/xls/dslx/frontend/ast.cc +++ b/xls/dslx/frontend/ast.cc @@ -1956,9 +1956,10 @@ StructInstanceBase::StructInstanceBase( members_(std::move(members)) {} std::vector> -StructInstanceBase::GetOrderedMembers(const StructDef* struct_def) const { +StructInstanceBase::GetOrderedMembers( + const StructDefBase* struct_or_proc_def) const { std::vector> result; - for (const std::string& name : struct_def->GetMemberNames()) { + for (const std::string& name : struct_or_proc_def->GetMemberNames()) { absl::StatusOr expr = GetExpr(name); if (absl::IsNotFound(expr.status()) && !requires_all_members()) { continue; diff --git a/xls/dslx/frontend/ast.h b/xls/dslx/frontend/ast.h index 929bad0fdd..b2287d0e28 100644 --- a/xls/dslx/frontend/ast.h +++ b/xls/dslx/frontend/ast.h @@ -3618,9 +3618,9 @@ class StructInstanceBase : public Expr { } // Returns the members for the struct instance, ordered by the (resolved) - // struct definition "struct_def". + // struct definition "struct_or_proc_def". std::vector> GetOrderedMembers( - const StructDef* struct_def) const; + const StructDefBase* struct_or_proc_def) const; const std::vector>& members() const { return members_; diff --git a/xls/dslx/type_system/testdata/type_info_to_proto_test_ProcWithImpl.txt b/xls/dslx/type_system/testdata/type_info_to_proto_test_ProcWithImpl.txt index d09fd39a88..562b569f6e 100644 --- a/xls/dslx/type_system/testdata/type_info_to_proto_test_ProcWithImpl.txt +++ b/xls/dslx/type_system/testdata/type_info_to_proto_test_ProcWithImpl.txt @@ -1 +1,4 @@ +1:0-1:19: PROC_DEF :: `proc Foo { + a: u32, +}` :: typeof(Foo { a: uN[32] }) 1:14-1:17: TYPE_ANNOTATION :: `u32` :: typeof(uN[32]) \ No newline at end of file diff --git a/xls/dslx/type_system_v2/BUILD b/xls/dslx/type_system_v2/BUILD index 42be4dae57..16baaa5397 100644 --- a/xls/dslx/type_system_v2/BUILD +++ b/xls/dslx/type_system_v2/BUILD @@ -60,6 +60,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", ], ) @@ -524,6 +525,7 @@ cc_library( ":inference_table", ":inference_table_utils", ":type_annotation_utils", + "//xls/common:attribute_data", "//xls/common:visitor", "//xls/common/status:ret_check", "//xls/common/status:status_macros", @@ -695,7 +697,10 @@ cc_library( "//xls/dslx/frontend:ast_node_visitor_with_default", "//xls/dslx/frontend:module", "//xls/dslx/frontend:pos", + "//xls/dslx/type_system:type", + "//xls/dslx/type_system:type_info", "@com_google_absl//absl/base", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/xls/dslx/type_system_v2/constant_collector.cc b/xls/dslx/type_system_v2/constant_collector.cc index 6e97b69927..c30e9f5e74 100644 --- a/xls/dslx/type_system_v2/constant_collector.cc +++ b/xls/dslx/type_system_v2/constant_collector.cc @@ -29,6 +29,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/substitute.h" #include "absl/types/span.h" #include "xls/common/status/ret_check.h" @@ -682,6 +683,67 @@ class Visitor : public AstNodeVisitorWithDefault { return absl::OkStatus(); } + absl::Status HandleStructInstance(const StructInstance* node) { + if (!type_.IsProc()) { + return absl::OkStatus(); + } + + // For a StructInstance node that is creating an impl-style proc, we store a + // tuple of the initial state values as the constexpr value in TypeInfo. + // This is equivalent to the result of the 'init' block in legacy procs. + const ProcType& type = type_.AsProc(); + std::vector state_init_values; + for (const auto& [member_name, initializer] : + node->GetOrderedMembers(&type.AsProc().struct_def_base())) { + std::optional maybe_member_type = + type.GetMemberTypeByName(member_name); + XLS_RET_CHECK(maybe_member_type.has_value()); + + // The inferred types of the state members are State where T is the + // type specified in the DSLX source. These are the only members we need + // to evaluate. + const Type& member_type = **maybe_member_type; + if (!IsProcDefStateType(member_type, import_data_)) { + continue; + } + + absl::StatusOr value = ConstexprEvaluator::EvaluateToValue( + &import_data_, ti_, &warning_collector_, + table_.GetParametricEnv(parametric_context_), initializer); + if (!value.ok()) { + return TypeInferenceErrorStatus( + initializer->span(), &type_, + absl::Substitute("Initializer for member `$0` of proc `$1` must be " + "possible to evaluate at compile time.", + member_name, + type.AsProc().struct_def_base().identifier()), + file_table_); + } + + state_init_values.push_back(std::move(*value)); + } + + InterpValue value = InterpValue::MakeTuple(std::move(state_init_values)); + ti_->NoteConstExpr(node, value); + VLOG(6) << "Storing value " << value.ToHumanString() + << " for proc initializer " << node->ToString(); + + // Propagate the proc "value" through the statement and/or statement block + // that yields it. This way IR conversion can easily say "give me the + // constant value for the body of the proc constructor." + for (AstNode* parent = node->parent(); + parent != nullptr && (parent->kind() == AstNodeKind::kStatement || + parent->kind() == AstNodeKind::kStatementBlock); + parent = parent->parent()) { + VLOG(6) << "Propagating proc initializer value for expr `" + << node->ToString() << "` to ancestor of kind " + << AstNodeKindToString(parent->kind()); + ti_->NoteConstExpr(parent, value); + } + + return absl::OkStatus(); + } + private: InferenceTable& table_; Module& module_; diff --git a/xls/dslx/type_system_v2/flatten_in_type_order.cc b/xls/dslx/type_system_v2/flatten_in_type_order.cc index 2dffaa501b..8fb7f4202c 100644 --- a/xls/dslx/type_system_v2/flatten_in_type_order.cc +++ b/xls/dslx/type_system_v2/flatten_in_type_order.cc @@ -191,21 +191,11 @@ class Flattener : public AstNodeVisitorWithDefault { } absl::Status HandleStructDef(const StructDef* node) override { - if (node->IsParametric() && node != root_) { - return absl::OkStatus(); - } - // StructDefBase::GetChildren does not return StructMemberNodes, this is - // blocked by https://github.com/google/xls/issues/1756. - nodes_.push_back(node->name_def()); - for (const ParametricBinding* parametric_binding : - node->parametric_bindings()) { - XLS_RETURN_IF_ERROR(parametric_binding->Accept(this)); - } - for (const StructMemberNode* member : node->members()) { - XLS_RETURN_IF_ERROR(member->Accept(this)); - } - nodes_.push_back(node); - return absl::OkStatus(); + return HandleStructDefBaseInternal(node); + } + + absl::Status HandleProcDef(const ProcDef* node) override { + return HandleStructDefBaseInternal(node); } absl::Status HandleTypeRef(const TypeRef* node) override { @@ -257,6 +247,24 @@ class Flattener : public AstNodeVisitorWithDefault { const std::vector& nodes() const { return nodes_; } private: + absl::Status HandleStructDefBaseInternal(const StructDefBase* node) { + if (node->IsParametric() && node != root_) { + return absl::OkStatus(); + } + // StructDefBase::GetChildren does not return StructMemberNodes, this is + // blocked by https://github.com/google/xls/issues/1756. + nodes_.push_back(node->name_def()); + for (const ParametricBinding* parametric_binding : + node->parametric_bindings()) { + XLS_RETURN_IF_ERROR(parametric_binding->Accept(this)); + } + for (const StructMemberNode* member : node->members()) { + XLS_RETURN_IF_ERROR(member->Accept(this)); + } + nodes_.push_back(node); + return absl::OkStatus(); + } + const ImportData& import_data_; const AstNode* const root_; const bool include_parametric_entities_; diff --git a/xls/dslx/type_system_v2/import_utils.cc b/xls/dslx/type_system_v2/import_utils.cc index 0d4e23dbe3..7f86c1918c 100644 --- a/xls/dslx/type_system_v2/import_utils.cc +++ b/xls/dslx/type_system_v2/import_utils.cc @@ -30,6 +30,8 @@ #include "xls/dslx/frontend/module.h" #include "xls/dslx/frontend/pos.h" #include "xls/dslx/import_data.h" +#include "xls/dslx/type_system/type.h" +#include "xls/dslx/type_system/type_info.h" #include "xls/dslx/type_system_v2/type_annotation_utils.h" namespace xls::dslx { @@ -247,4 +249,31 @@ absl::StatusOr IsProcDefNextFunction(const Function* f, return (*def)->kind() == AstNodeKind::kProcDef; } +bool IsProcDefStateType(const Type& type, const ImportData& import_data) { + // A state element in a `ProcDef` always uses explicit state access, so should + // have been made into a State by semantics analysis, i.e. it is always a + // struct in type inference. + if (!type.IsStruct()) { + return false; + } + + Module* builtin_stubs = *import_data.GetBuiltinStubsModule(); + const StructDef* state_struct_def = + *builtin_stubs->GetMember("State"); + return &type.AsStruct().struct_def_base() == state_struct_def; +} + +absl::StatusOr> GetProcDefStateMembers( + const ProcDef* proc_def, const ImportData& import_data, + const TypeInfo& type_info) { + std::vector result; + for (StructMemberNode* member : proc_def->members()) { + XLS_ASSIGN_OR_RETURN(const Type* type, type_info.GetItemOrError(member)); + if (IsProcDefStateType(*type, import_data)) { + result.push_back(member); + } + } + return result; +} + } // namespace xls::dslx diff --git a/xls/dslx/type_system_v2/import_utils.h b/xls/dslx/type_system_v2/import_utils.h index 6c5819c9df..d91a833f38 100644 --- a/xls/dslx/type_system_v2/import_utils.h +++ b/xls/dslx/type_system_v2/import_utils.h @@ -16,12 +16,15 @@ #define XLS_DSLX_TYPE_SYSTEM_V2_IMPORT_UTILS_H_ #include +#include #include "absl/status/statusor.h" #include "xls/dslx/frontend/ast.h" #include "xls/dslx/frontend/module.h" #include "xls/dslx/frontend/pos.h" #include "xls/dslx/import_data.h" +#include "xls/dslx/type_system/type.h" +#include "xls/dslx/type_system/type_info.h" #include "xls/dslx/type_system_v2/type_annotation_utils.h" namespace xls::dslx { @@ -74,6 +77,17 @@ absl::StatusOr IsProcDefNextFunction(const Function* f, // Returns whether `colon_ref` is imported from a different module. bool IsImport(const ColonRef* colon_ref); +// Returns whether the given `type` of a member of a `ProcDef` indicates that +// the member is a state element. All state members of a `ProcDef` have the +// inferred type `State` where `T` is the type written by the programmer, so +// this amounts to checking if `type` is a "State-wrapped" type. +bool IsProcDefStateType(const Type& type, const ImportData& import_data); + +// Returns all the members of `proc_def` that are state elements. +absl::StatusOr> GetProcDefStateMembers( + const ProcDef* proc_def, const ImportData& import_data, + const TypeInfo& type_info); + } // namespace xls::dslx #endif // XLS_DSLX_TYPE_SYSTEM_V2_IMPORT_UTILS_H_ diff --git a/xls/dslx/type_system_v2/populate_table_visitor.cc b/xls/dslx/type_system_v2/populate_table_visitor.cc index c768f44fa4..8406dcc494 100644 --- a/xls/dslx/type_system_v2/populate_table_visitor.cc +++ b/xls/dslx/type_system_v2/populate_table_visitor.cc @@ -39,6 +39,7 @@ #include "absl/strings/str_join.h" #include "absl/strings/substitute.h" #include "absl/types/variant.h" +#include "xls/common/attribute_data.h" #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" #include "xls/common/visitor.h" @@ -1413,15 +1414,11 @@ class PopulateInferenceTableVisitor : public PopulateTableVisitor, // annotation for each member and cache it for use by type unification of // instances of this struct. absl::Status HandleStructDef(const StructDef* node) override { - if (!node->IsParametric()) { - for (const StructMemberNode* member : node->members()) { - XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation(member, member->type())); - } - XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation( - node, CreateStructOrProcAnnotation( - module_, const_cast(node), {}, std::nullopt))); - } - return DefaultHandler(node); + return HandleStructDefBaseInternal(node); + } + + absl::Status HandleProcDef(const ProcDef* node) override { + return HandleStructDefBaseInternal(node); } absl::Status HandleTupleIndex(const TupleIndex* node) override { @@ -2042,6 +2039,19 @@ class PopulateInferenceTableVisitor : public PopulateTableVisitor, return true; } + absl::Status HandleStructDefBaseInternal(const StructDefBase* node) { + if (!node->IsParametric()) { + for (const StructMemberNode* member : node->members()) { + XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation(member, member->type())); + } + XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation( + node, + CreateStructOrProcAnnotation( + module_, const_cast(node), {}, std::nullopt))); + } + return DefaultHandler(node); + } + // Helper that creates an internal type variable for a `ConstantDef`, `Param`, // or similar type of node that contains a `NameDef` and optional // `TypeAnnotation`. diff --git a/xls/dslx/type_system_v2/typecheck_module_v2_test.cc b/xls/dslx/type_system_v2/typecheck_module_v2_test.cc index f160090db6..9505866a0d 100644 --- a/xls/dslx/type_system_v2/typecheck_module_v2_test.cc +++ b/xls/dslx/type_system_v2/typecheck_module_v2_test.cc @@ -7789,8 +7789,8 @@ proc P { } impl P { - fn new(input: chan in, output: chan out, init_val: u32) -> Self { - P { input, output, state: init_val } + fn new(input: chan in, output: chan out) -> Self { + P { input, output, state: 5 } } fn next(self) { @@ -7805,10 +7805,9 @@ impl P { TypecheckSucceeds(AllOf( HasNodeWithType("P", absl::Substitute("typeof($0)", kProcType)), HasNodeWithType( - "new", - absl::Substitute("(chan(uN[32], dir=in), chan(uN[32], dir=out), " - "uN[32]) -> $0", - kProcType)), + "new", absl::Substitute( + "(chan(uN[32], dir=in), chan(uN[32], dir=out)) -> $0", + kProcType)), HasNodeWithType("next", absl::Substitute("($0) -> ()", kProcType))))); } @@ -7822,8 +7821,8 @@ proc P { } impl P { - fn new(init_val: u32) -> Self { - P { state: init_val } + fn new() -> Self { + P { state: 5 } } fn next(self, a: u32) {} @@ -7843,8 +7842,8 @@ proc P { } impl P { - fn new(init_val: u32) -> Self { - P { state: init_val } + fn new() -> Self { + P { state: 5 } } fn next() {} @@ -7864,8 +7863,8 @@ proc P { } impl P { - fn new(init_val: u32) -> Self { - P { state: init_val } + fn new() -> Self { + P { state: 5 } } fn next() -> u32 { 0 } @@ -7875,6 +7874,28 @@ impl P { "must not return anything"))); } +TEST(TypecheckV2Test, ProcWithImplIntegerParamInTopProcNewFails) { + EXPECT_THAT( + R"( +#![feature(explicit_state_access)] + +proc P { + state: u32, +} + +impl P { + fn new(init_val: u32) -> Self { + P { state: init_val } + } + + fn next() -> u32 { 0 } +} +)", + TypecheckFails( + HasSubstr("Initializer for member `state` of proc `P` must be " + "possible to evaluate at compile time."))); +} + TEST(TypecheckV2Test, SpawnProcWithImpl) { std::string_view kProgram = R"( #![feature(explicit_state_access)] diff --git a/xls/dslx/type_system_v2/validate_concrete_type.cc b/xls/dslx/type_system_v2/validate_concrete_type.cc index 2c3c57d303..82a59ef362 100644 --- a/xls/dslx/type_system_v2/validate_concrete_type.cc +++ b/xls/dslx/type_system_v2/validate_concrete_type.cc @@ -246,13 +246,12 @@ class TypeValidator : public AstNodeVisitorWithDefault { return DefaultHandler(invocation); } - absl::Status HandleStructMemberNode(const StructMemberNode* member) override { - if (type_->IsProc()) { - return TypeInferenceErrorStatus( - member->span(), type_, "Structs cannot contain procs as members.", - file_table_); - } - return absl::OkStatus(); + absl::Status HandleStructDef(const StructDef* node) override { + return HandleStructDefBaseInternal(node); + } + + absl::Status HandleProcDef(const ProcDef* node) override { + return HandleStructDefBaseInternal(node); } absl::Status HandleFormatMacro(const FormatMacro* macro) override { @@ -530,6 +529,21 @@ class TypeValidator : public AstNodeVisitorWithDefault { domain->ToString(), param->ToString()); } + absl::Status HandleStructDefBaseInternal(const StructDefBase* def) { + if (type_->IsProc()) { + return absl::OkStatus(); + } + for (const StructMemberNode* member : def->members()) { + XLS_ASSIGN_OR_RETURN(const Type* member_type, ti_.GetItemOrError(member)); + if (member_type->IsProc()) { + return TypeInferenceErrorStatus( + member->span(), type_, "Structs cannot contain procs as members.", + file_table_); + } + } + return absl::OkStatus(); + } + absl::Status ValidateBinopShift(const Binop& binop) { XLS_ASSIGN_OR_RETURN(Type * rhs_type, ti_.GetItemOrError(binop.rhs())); XLS_ASSIGN_OR_RETURN(