diff --git a/xls/dslx/ir_convert/BUILD b/xls/dslx/ir_convert/BUILD index d90be8b7a2..65181e0eaf 100644 --- a/xls/dslx/ir_convert/BUILD +++ b/xls/dslx/ir_convert/BUILD @@ -158,10 +158,13 @@ cc_library( srcs = ["ir_conversion_utils.cc"], hdrs = ["ir_conversion_utils.h"], deps = [ + "//xls/common/status:ret_check", "//xls/common/status:status_macros", "//xls/dslx:interp_value", + "//xls/dslx/frontend:ast", "//xls/dslx/type_system:parametric_env", "//xls/dslx/type_system:type", + "//xls/dslx/type_system:type_info", "//xls/ir", "//xls/ir:type", "@com_google_absl//absl/log", @@ -352,6 +355,7 @@ cc_library( "//xls/dslx/type_system:parametric_env", "//xls/dslx/type_system:type", "//xls/dslx/type_system:type_info", + "//xls/dslx/type_system_v2:import_utils", "//xls/ir", "//xls/ir:bits", "//xls/ir:channel", @@ -415,6 +419,7 @@ cc_library( ":extract_conversion_order", ":function_converter", ":get_conversion_records", + ":ir_conversion_utils", ":proc_config_ir_converter", "//xls/common/status:ret_check", "//xls/common/status:status_macros", @@ -437,6 +442,7 @@ cc_library( "//xls/dslx/frontend:scanner", "//xls/dslx/type_system:parametric_env", "//xls/dslx/type_system:type_info", + "//xls/dslx/type_system_v2:import_utils", "//xls/ir", "//xls/ir:function_builder", "//xls/ir:ir_scanner", @@ -444,6 +450,7 @@ cc_library( "//xls/ir:verifier", "//xls/ir:xls_ir_interface_cc_proto", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -515,6 +522,7 @@ cc_library( hdrs = ["get_conversion_records.h"], deps = [ ":conversion_record", + ":ir_conversion_utils", "//xls/common/status:ret_check", "//xls/common/status:status_macros", "//xls/dslx:interp_value", diff --git a/xls/dslx/ir_convert/extract_conversion_order.cc b/xls/dslx/ir_convert/extract_conversion_order.cc index db244cda7d..2046da4bf8 100644 --- a/xls/dslx/ir_convert/extract_conversion_order.cc +++ b/xls/dslx/ir_convert/extract_conversion_order.cc @@ -819,7 +819,7 @@ absl::StatusOr> GetOrder(Module* module, } absl::StatusOr> GetOrderForEntry( - std::variant entry, TypeInfo* type_info, + std::variant entry, TypeInfo* type_info, std::optional resolved_proc_alias) { std::vector ready; if (std::holds_alternative(entry)) { @@ -835,6 +835,11 @@ absl::StatusOr> GetOrderForEntry( return ready; } + if (std::holds_alternative(entry)) { + return absl::UnimplementedError( + "Impl-style procs can only be compiled with proc-scoped channels."); + } + Proc* p = std::get(entry); XLS_ASSIGN_OR_RETURN(TypeInfo * new_ti, type_info->GetTopLevelProcTypeInfo(p)); diff --git a/xls/dslx/ir_convert/extract_conversion_order.h b/xls/dslx/ir_convert/extract_conversion_order.h index 5ef0470d30..0e1f27dd42 100644 --- a/xls/dslx/ir_convert/extract_conversion_order.h +++ b/xls/dslx/ir_convert/extract_conversion_order.h @@ -89,7 +89,7 @@ absl::StatusOr> GetOrder( // f: The top level function. // type_info: Mapping from node to type. absl::StatusOr> GetOrderForEntry( - std::variant entry, TypeInfo* type_info, + std::variant entry, TypeInfo* type_info, std::optional resolved_proc_alias = std::nullopt); // Top level procs are procs where their config or next function is not invoked diff --git a/xls/dslx/ir_convert/function_converter.cc b/xls/dslx/ir_convert/function_converter.cc index bd90e1d694..8e3268fd6f 100644 --- a/xls/dslx/ir_convert/function_converter.cc +++ b/xls/dslx/ir_convert/function_converter.cc @@ -66,6 +66,7 @@ #include "xls/dslx/type_system/parametric_env.h" #include "xls/dslx/type_system/type.h" #include "xls/dslx/type_system/type_info.h" +#include "xls/dslx/type_system_v2/import_utils.h" #include "xls/dslx/warning_collector.h" #include "xls/ir/bits.h" #include "xls/ir/channel.h" @@ -126,6 +127,16 @@ absl::StatusOr ToTypeDefinition( return type_ref_type_annotation->type_ref()->type_definition(); } +// TODO: https://github.com/google/xls/issues/4125 - This is brittle but +// preserves the pre-existing behavior of explicit state access. We need to +// replace string keys with `StateElement` keys. +std::string GetStateElementName(const Expr* expr) { + if (expr->kind() == AstNodeKind::kAttr) { + return std::string(absl::down_cast(expr)->attr()); + } + return expr->ToString(); +} + } // namespace absl::StatusOr EmitImplicitTokenEntryWrapper( @@ -494,6 +505,10 @@ absl::Status FunctionConverter::Visit(const AstNode* node) { ci); }, [](ChannelArray* ca) { return absl::StrFormat("%p", ca); }, + [](ProcDefInstance* instance) { + return absl::StrCat("instance of ", + instance->proc_def->identifier()); + }, }, value); } @@ -2659,8 +2674,8 @@ absl::Status FunctionConverter::HandleBuiltinRead(const Invocation* node) { Expr* source = node->args()[0]; BValue active = implicit_token_data_->create_control_predicate(); - BValue state_read_called = - state_read_called_by_state_name_[source->ToString()]; + std::string state_name = GetStateElementName(source); + BValue state_read_called = state_read_called_by_state_name_[state_name]; if (options_.emit_assert) { // Assert multiple reads don't happen in same activation. implicit_token_data_->entry_token = function_builder_->Assert( @@ -2671,7 +2686,7 @@ absl::Status FunctionConverter::HandleBuiltinRead(const Invocation* node) { implicit_token_data_->control_tokens.push_back( implicit_token_data_->entry_token); } - state_read_called_by_state_name_[source->ToString()] = + state_read_called_by_state_name_[state_name] = function_builder_->Or(state_read_called, active); XLS_RETURN_IF_ERROR(Visit(source)); XLS_ASSIGN_OR_RETURN(BValue state_read, Use(source)); @@ -2692,10 +2707,9 @@ absl::Status FunctionConverter::HandleBuiltinWrite(const Invocation* node) { Expr* target = node->args()[0]; BValue active = implicit_token_data_->create_control_predicate(); - BValue state_read_called = - state_read_called_by_state_name_[target->ToString()]; - BValue state_write_called = - state_write_called_by_state_name_[target->ToString()]; + std::string state_name = GetStateElementName(target); + BValue state_read_called = state_read_called_by_state_name_[state_name]; + BValue state_write_called = state_write_called_by_state_name_[state_name]; if (options_.emit_assert) { // Assert write doesn't happen before a read implicit_token_data_->entry_token = function_builder_->Assert( @@ -2714,7 +2728,7 @@ absl::Status FunctionConverter::HandleBuiltinWrite(const Invocation* node) { implicit_token_data_->control_tokens.push_back( implicit_token_data_->entry_token); } - state_write_called_by_state_name_[target->ToString()] = + state_write_called_by_state_name_[state_name] = function_builder_->Or(state_write_called, active); XLS_RETURN_IF_ERROR(Visit(target)); XLS_ASSIGN_OR_RETURN(BValue state_read, Use(target)); @@ -2906,6 +2920,10 @@ absl::StatusOr FunctionConverter::IrValueToChannelRef( return absl::InvalidArgumentError( "Unexpected ChannelArray in IrValue."); }, + [](ProcDefInstance* proc) -> absl::StatusOr { + return absl::InvalidArgumentError( + "Unexpected proc instance in IrValue."); + }, }, ir_value); } @@ -3653,6 +3671,194 @@ absl::Status FunctionConverter::HandleChannelDecl(const ChannelDecl* node) { return absl::OkStatus(); } +absl::Status FunctionConverter::HandleProcDefConstructor( + const ProcDef& proc, const Function& constructor, + const ParametricEnv& bindings, ProcBuilder* builder) { + // Generate channel interfaces. + for (const Param* param : constructor.params()) { + VLOG(5) << "Generating channel interface: " << param->ToString(); + XLS_ASSIGN_OR_RETURN(Type * type, + current_type_info_->GetItemOrError(param)); + ChannelOrArray channel_or_array; + const ChannelType* channel_type = nullptr; + + if (ArrayType* array_type = dynamic_cast(type); + array_type != nullptr) { + const Type& innermost_type = + array_type->GetInnermostElementType().element_type; + VLOG(5) << "Lowering to PSC, innermost type: " + << innermost_type.ToString(); + channel_type = dynamic_cast(&innermost_type); + } else { + channel_type = dynamic_cast(type); + } + + if (channel_type != nullptr) { + VLOG(10) << "Param " << param->ToString() << " has channel type " + << channel_type->ToString(); + XLS_ASSIGN_OR_RETURN(channel_or_array, + channel_scope_->DefineBoundaryChannelOrArray( + param, current_type_info_)); + SetNodeToChannelOrArray(param->name_def(), channel_or_array); + XLS_RETURN_IF_ERROR(channel_scope_->AssociateWithExistingChannelOrArray( + *proc_id_, param->name_def(), channel_or_array)); + } + } + + VLOG(5) << "Visiting body of constructor: " << constructor.identifier(); + FunctionTag original_tag = current_fn_tag_; + current_fn_tag_ = FunctionTag::kProcConfig; + last_tuple_.clear(); + XLS_RETURN_IF_ERROR(Visit(constructor.body())); + current_fn_tag_ = original_tag; + + XLS_RET_CHECK_EQ(last_tuple_.size(), proc.members().size()); + for (int i = 0; i < proc.members().size(); i++) { + const StructMemberNode* member = proc.members()[i]; + VLOG(5) << "Handling proc member " << member->ToString(); + + IrValue tuple_entry = last_tuple_[i]; + SetNodeToIr(member->name_def(), tuple_entry); + std::optional channel_or_array = absl::visit( + Visitor{ + [](Channel* chan) -> std::optional { return chan; }, + [](ChannelArray* ca) -> std::optional { + return ca; + }, + [](ChannelInterface* ci) -> std::optional { + return ci; + }, + [](auto v) -> std::optional { + return std::nullopt; + }, + }, + tuple_entry); + + if (channel_or_array.has_value()) { + proc_data_->id_to_members.at( + *proc_id_)[member->name_def()->identifier()] = + ChannelOrArrayToProcConfigValue(*channel_or_array); + XLS_RETURN_IF_ERROR(channel_scope_->AssociateWithExistingChannelOrArray( + *proc_id_, member->name_def(), *channel_or_array)); + + // TODO: https://github.com/google/xls/issues/4125 - Apply channel + // strictness. Currently it is not captured by the parser on a ProcDef + // member. + } + } + + return absl::OkStatus(); +} + +absl::Status FunctionConverter::ConvertProcDef(const ProcDef* proc_def, + const Function* constructor, + ProcId proc_id, + TypeInfo* type_info) { + ScopedTypeInfoSwap stis(this, type_info); + proc_id_ = proc_id; + proc_data_->id_to_members[proc_id] = {}; + + // TODO: https://github.com/google/xls/issues/4125 - Consider the parametrics + // when invoking `MangleDslxName` here, using `HandleProcNextFunction` as a + // rough guide. ProcDef support is a WIP and we don't yet support parametrics. + ParametricEnv env; + XLS_ASSIGN_OR_RETURN(std::string mangled_name, + MangleDslxName(module_->name(), proc_def->identifier(), + CallingConvention::kProcNext, + /*free_keys=*/{}, &env)); + auto unique_builder = + std::make_unique(NewStyleProc{}, mangled_name, package()); + ProcBuilder* builder = unique_builder.get(); + SetFunctionBuilder(std::move(unique_builder)); + if (is_top_) { + XLS_RETURN_IF_ERROR(builder->SetAsTop()); + } + + // We make an implicit token in case any downstream functions need it; if it's + // unused, it'll be optimized out later. + BValue implicit_token = + builder->Literal(Value::Token(), SourceInfo(), "__token"); + implicit_token_data_ = ImplicitTokenData{ + .entry_token = implicit_token, + .activated = builder->Literal(Value::Bool(true)), + .create_control_predicate = + [this]() { return implicit_token_data_->activated; }, + }; + tokens_.push_back(implicit_token); + + XLS_ASSIGN_OR_RETURN(InterpValue init_interp_value, + current_type_info_->GetConstExpr(constructor->body())); + XLS_ASSIGN_OR_RETURN(Value init_value, InterpValueToValue(init_interp_value)); + + XLS_ASSIGN_OR_RETURN( + std::vector state_elements, + GetProcDefStateMembers(proc_def, *import_data_, *current_type_info_)); + VLOG(5) << "Proc has " << state_elements.size() << " state elements."; + + for (int i = 0; i < state_elements.size(); i++) { + StructMemberNode* state_element = state_elements[i]; + VLOG(5) << "Configuring state element: " << state_element->name(); + + state_read_called_by_state_name_[state_element->name()] = + builder->Literal(UBits(0, 1)); + state_write_called_by_state_name_[state_element->name()] = + builder->Literal(UBits(0, 1)); + + PackageInterfaceProto::Proc::StateValue* state_value_proto = + proc_proto_.value()->add_state_values(); + PackageInterfaceProto::NamedValue* state_name_proto = + state_value_proto->mutable_name(); + std::string state_name = absl::StrCat("__", state_element->name()); + Value init = init_value.elements()[i]; + BValue state_element_value = builder->StateElement(state_name, init); + state_name_proto->set_name(state_name); + XLS_ASSIGN_OR_RETURN(auto type, ResolveTypeToIr(state_element->type())); + *state_name_proto->mutable_type() = type->ToProto(); + SetNodeToIr(state_element->name_def(), state_element_value); + } + + // TODO: https://github.com/google/xls/issues/4125 - Deal with the parametric + // bindings here, using `HandleProcNextFunction` as a rough guide. + + VLOG(3) << "Proc has " << constant_deps_.size() << " constant deps"; + for (ConstantDef* dep : constant_deps_) { + VLOG(5) << "Visiting constant dep: " << dep->ToString(); + + // The constant dep may be from a different module than the module for the + // function we're currently converting. + XLS_ASSIGN_OR_RETURN(std::optional stis, + ScopedTypeInfoSwap::ForNode(this, dep)); + XLS_RETURN_IF_ERROR(Visit(dep)); + } + + auto proc_scoped_channel_scope = std::make_unique( + package_data_.conversion_info, import_data_, options_, builder); + proc_scoped_channel_scope->EnterFunctionContext(current_type_info_, env); + channel_scope_ = proc_scoped_channel_scope.get(); + + XLS_RETURN_IF_ERROR( + HandleProcDefConstructor(*proc_def, *constructor, env, builder)); + + VLOG(5) << "Visiting proc next body."; + std::optional next_fn = GetProcNextFunction(proc_def); + XLS_RET_CHECK(next_fn.has_value()); + + // TODO: https://github.com/google/xls/issues/4125 - Tell the ChannelScope to + // enter the function context, when next() has a dedicated + // TypeInfo/ParametricEnv. + + auto proc_def_instance = + std::make_unique(proc_def, last_tuple_); + SetNodeToIr((*next_fn)->params()[0]->name_def(), proc_def_instance.get()); + proc_def_instances_.push_back(std::move(proc_def_instance)); + XLS_RETURN_IF_ERROR(Visit((*next_fn)->body())); + XLS_ASSIGN_OR_RETURN(xls::Proc * p, builder->Build()); + + package_data_.ir_to_dslx[p] = *next_fn; + package_data_.callee_to_ir_proc[{*next_fn, env}] = p; + return absl::OkStatus(); +} + // TODO: davidplass - break this method up. It's too big. absl::Status FunctionConverter::HandleProcNextFunction( const ConversionRecord& record, ImportData* import_data, @@ -4033,10 +4239,10 @@ absl::Status FunctionConverter::HandleSplatStructInstance( XLS_ASSIGN_OR_RETURN(TypeDefinition type_definition, ToTypeDefinition(node->struct_ref())); - XLS_ASSIGN_OR_RETURN(StructDef * struct_def, DerefStruct(type_definition)); + XLS_ASSIGN_OR_RETURN(StructDefBase * def, DerefStructOrProc(type_definition)); std::vector members; - for (int64_t i = 0; i < struct_def->members().size(); ++i) { - const std::string& k = struct_def->GetMemberName(i); + for (int64_t i = 0; i < def->members().size(); ++i) { + const std::string& k = def->GetMemberName(i); if (auto it = updates.find(k); it != updates.end()) { members.push_back(it->second); } else { @@ -4055,9 +4261,50 @@ absl::Status FunctionConverter::HandleStructInstance( std::vector operands; XLS_ASSIGN_OR_RETURN(TypeDefinition type_definition, ToTypeDefinition(node->struct_ref())); - XLS_ASSIGN_OR_RETURN(StructDef * struct_def, DerefStruct(type_definition)); + XLS_ASSIGN_OR_RETURN(StructDefBase * def, DerefStructOrProc(type_definition)); + if (def->kind() == AstNodeKind::kProcDef) { + // Mimic the `HandleXlsTuple` handling for the final tuple in a legacy proc + // `config`, although the mechanics are a bit different. The tuple we create + // here is the IR version of the `ProcDef` instance, which includes both + // channels and state elements. + std::vector ir_operands; + std::vector b_operands; + for (const auto& [name, value] : node->GetOrderedMembers(def)) { + std::optional member = def->GetMemberByName(name); + XLS_RET_CHECK(member.has_value()); + std::optional member_type = + current_type_info_->GetItem(*member); + XLS_RET_CHECK(member_type.has_value()); + std::optional v; + if (IsProcDefStateType(**member_type, *import_data_)) { + // Get the state element by member `NameDef`, i.e. don't treat it like + // it's a value supplied by the struct instance node. + v = GetNodeToIr((*member)->name_def()); + } else { + // If it's not a state element (e.g. it's a channel) then it's a + // "normal" IR value. + XLS_RETURN_IF_ERROR(Visit(value)); + v = GetNodeToIr(value); + } + + XLS_RET_CHECK(v.has_value()) + << "Expected a value for proc member " << name; + if (std::holds_alternative(*v)) { + b_operands.push_back(std::get(*v)); + } + ir_operands.push_back(*v); + } + last_tuple_ = ir_operands; + + Def(node, [this, &b_operands](const SourceInfo& loc) { + return function_builder_->Tuple(b_operands, loc); + }); + return absl::OkStatus(); + } + std::vector const_operands; - for (const auto& [_, member_expr] : node->GetOrderedMembers(struct_def)) { + for (const auto& [_, member_expr] : node->GetOrderedMembers(def)) { + VLOG(10) << "Visiting member expr " << member_expr->ToString(); XLS_RETURN_IF_ERROR(Visit(member_expr)); XLS_ASSIGN_OR_RETURN(BValue operand, Use(member_expr)); operands.push_back(operand); @@ -4276,9 +4523,23 @@ absl::Status FunctionConverter::HandleAttr(const Attr* node) { VLOG(5) << "FunctionConverter::HandleAttr: " << node->ToString() << " @ " << node->span().ToString(file_table()); XLS_RETURN_IF_ERROR(Visit(node->lhs())); + std::optional lhs_type = current_type_info_->GetItem(node->lhs()); XLS_RET_CHECK(lhs_type.has_value()); + + // If it's a reference to `self.something` where `self` is a proc, we need + // `IrValue`-based logic, because it may be a channel or other thing that a + // real IR tuple can't contain. + std::optional value = GetNodeToIr(node->lhs()); + if (value.has_value() && std::holds_alternative(*value)) { + XLS_ASSIGN_OR_RETURN(int64_t index, + (*lhs_type)->AsProc().GetMemberIndex(node->attr())); + SetNodeToIr(node, + std::get(*value)->member_values.at(index)); + return absl::OkStatus(); + } + auto* struct_type = dynamic_cast(lhs_type.value()); std::string_view identifier = node->attr(); XLS_ASSIGN_OR_RETURN(int64_t index, struct_type->GetMemberIndex(identifier)); @@ -4811,6 +5072,10 @@ FunctionConverter::DerefStructOrEnum(TypeDefinition node) { return std::get(node); } + if (std::holds_alternative(node)) { + return std::get(node); + } + XLS_RET_CHECK(std::holds_alternative(node)); auto* colon_ref = std::get(node); std::optional import = colon_ref->ResolveImportSubject(); @@ -4825,10 +5090,12 @@ FunctionConverter::DerefStructOrEnum(TypeDefinition node) { return DerefStructOrEnum(td); } -absl::StatusOr FunctionConverter::DerefStruct(TypeDefinition node) { +absl::StatusOr FunctionConverter::DerefStructOrProc( + TypeDefinition node) { XLS_ASSIGN_OR_RETURN(DerefVariant v, DerefStructOrEnum(node)); - XLS_RET_CHECK(std::holds_alternative(v)); - return std::get(v); + XLS_RET_CHECK(std::holds_alternative(v) || + std::holds_alternative(v)); + return absl::down_cast(ToAstNode(v)); } absl::StatusOr FunctionConverter::DerefEnum(TypeDefinition node) { diff --git a/xls/dslx/ir_convert/function_converter.h b/xls/dslx/ir_convert/function_converter.h index b2313ba2dd..90b329f397 100644 --- a/xls/dslx/ir_convert/function_converter.h +++ b/xls/dslx/ir_convert/function_converter.h @@ -157,6 +157,10 @@ class FunctionConverter { ImportData* import_data, ProcConversionData* proc_data); + absl::Status ConvertProcDef(const ProcDef* proc_def, + const Function* constructor, ProcId proc_id, + TypeInfo* type_info); + // Notes a constant-definition dependency for the function (so it can // participate in the IR conversion). void AddConstantDep(ConstantDef* constant_def); @@ -238,11 +242,21 @@ class FunctionConverter { BValue value; }; + struct ProcDefInstance; + // Every AST node has an "IR value" that is either a function builder value // (BValue) or its IR-conversion-time-constant-decorated cousin (CValue), or // an inter-proc Channel. - using IrValue = - std::variant; + using IrValue = std::variant; + + // The `IrValue` for an instance of an impl-based proc. This is basically a + // tuple, except that the members are typically non-BValues (e.g. channels). + // When these are supported by BValues, we can probably use a native tuple. + struct ProcDefInstance { + const ProcDef* proc_def; + std::vector member_values; + }; // Helper for converting an IR value to its BValue pointer for use in // debugging. @@ -518,8 +532,16 @@ class FunctionConverter { absl::Status HandleBuiltinZip(const Invocation* node); // keep-sorted end - // Derefences the type definition to a struct definition. - absl::StatusOr DerefStruct(TypeDefinition node); + absl::Status HandleProcDef(const ProcDef* proc_def, + const Function* constructor); + + absl::Status HandleProcDefConstructor(const ProcDef& proc, + const Function& constructor, + const ParametricEnv& bindings, + ProcBuilder* builder); + + // Dereferences the type definition to a struct or impl-style proc definition. + absl::StatusOr DerefStructOrProc(TypeDefinition node); // Derefences the type definition to a enum definition. absl::StatusOr DerefEnum(TypeDefinition node); @@ -548,7 +570,7 @@ class FunctionConverter { // Dereferences a type definition to either a struct definition or enum // definition. - using DerefVariant = std::variant; + using DerefVariant = std::variant; absl::StatusOr DerefStructOrEnum(TypeDefinition node); SourceInfo ToSourceInfo(const std::optional& span) { @@ -668,6 +690,8 @@ class FunctionConverter { // every time we emit a state read/write. absl::flat_hash_map state_read_called_by_state_name_; absl::flat_hash_map state_write_called_by_state_name_; + + std::vector> proc_def_instances_; }; } // namespace xls::dslx diff --git a/xls/dslx/ir_convert/get_conversion_records.cc b/xls/dslx/ir_convert/get_conversion_records.cc index 420a4fc67a..43f715c494 100644 --- a/xls/dslx/ir_convert/get_conversion_records.cc +++ b/xls/dslx/ir_convert/get_conversion_records.cc @@ -40,6 +40,7 @@ #include "xls/dslx/frontend/proc_id.h" #include "xls/dslx/interp_value.h" #include "xls/dslx/ir_convert/conversion_record.h" +#include "xls/dslx/ir_convert/ir_conversion_utils.h" #include "xls/dslx/type_system/parametric_env.h" #include "xls/dslx/type_system/type_info.h" #include "xls/public/status_macros.h" @@ -339,6 +340,30 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault { return DefaultHandler(qc); } + absl::Status HandleProcDef(const ProcDef* p) override { + VLOG(5) << "HandleProcDef " << p->ToString(); + + XLS_ASSIGN_OR_RETURN(Function * constructor, + GetTopProcConstructor(p, type_info_)); + XLS_ASSIGN_OR_RETURN(InterpValue initial_state, + type_info_->GetConstExpr(constructor->body())); + + VLOG(5) << "Initial state: " << initial_state.ToHumanString(); + + std::optional next_fn = GetProcNextFunction(p); + XLS_RET_CHECK(next_fn.has_value()); + + XLS_ASSIGN_OR_RETURN( + ConversionRecord cr, + MakeConversionRecord(*next_fn, p->owner(), type_info_, + /*bindings=*/ParametricEnv(), + /*proc_id=*/proc_id_factory_.CreateProcId(p), + /*is_top=*/top_ == next_fn, + /*config_record=*/nullptr, initial_state)); + records_.push_back(std::move(cr)); + return absl::OkStatus(); + } + absl::Status HandleProc(const Proc* p) override { VLOG(5) << "HandleProc " << p->ToString(); const Function* next_fn = &p->next(); @@ -456,7 +481,6 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault { OK_HANDLER(ConstAssert) OK_HANDLER(EnumDef) OK_HANDLER(ParametricBinding) - OK_HANDLER(ProcDef) OK_HANDLER(StructDef) OK_HANDLER(TypeAlias) // keep-sorted end @@ -536,38 +560,45 @@ absl::StatusOr> GetConversionRecords( } absl::StatusOr> GetConversionRecordsForEntry( - std::variant entry, TypeInfo* type_info, + std::variant entry, TypeInfo* type_info, std::optional resolved_proc_alias) { ProcIdFactory proc_id_factory; + std::vector records; + absl::flat_hash_set processed_invocations; + AstNode* visit_target = ToAstNode(entry); + Module* m = visit_target->owner(); + Function* top_fn = nullptr; + TypeInfo* visitor_ti = type_info; + if (std::holds_alternative(entry)) { XLS_RET_CHECK(!resolved_proc_alias.has_value()); - Function* f = std::get(entry); - Module* m = f->owner(); - std::vector records; - absl::flat_hash_set processed_invocations; - // We are only ever called for tests, so we set include_tests to - // true, and make sure that this function is top. - ConversionRecordVisitor visitor( - m, type_info, /*include_tests=*/true, proc_id_factory, f, - /*resolved_proc_alias=*/std::nullopt, records, processed_invocations); - XLS_RETURN_IF_ERROR(f->Accept(&visitor)); - - return RemoveFunctionDuplicates(records); + top_fn = std::get(entry); + } else if (std::holds_alternative(entry)) { + Proc* p = std::get(entry); + XLS_ASSIGN_OR_RETURN(TypeInfo * new_ti, + type_info->GetTopLevelProcTypeInfo(p)); + visitor_ti = new_ti; + top_fn = &p->next(); + visit_target = p; + } else { + XLS_RET_CHECK(!resolved_proc_alias.has_value()); + ProcDef* p = std::get(entry); + std::optional next_fn = GetProcNextFunction(p); + if (!next_fn.has_value()) { + return absl::InvalidArgumentError( + "A proc with no 'next' function cannot be top."); + } + + top_fn = *next_fn; + visit_target = p; } - Proc* p = std::get(entry); - Module* m = p->owner(); - XLS_ASSIGN_OR_RETURN(TypeInfo * new_ti, - type_info->GetTopLevelProcTypeInfo(p)); - std::vector records; - absl::flat_hash_set processed_invocations; // We are only ever called for tests, so we set include_tests to true, // and make sure that this proc's next function is top. - ConversionRecordVisitor visitor( - m, new_ti, /*include_tests=*/true, proc_id_factory, &p->next(), - resolved_proc_alias, records, processed_invocations); - XLS_RETURN_IF_ERROR(p->Accept(&visitor)); - + ConversionRecordVisitor visitor(m, visitor_ti, /*include_tests=*/true, + proc_id_factory, top_fn, resolved_proc_alias, + records, processed_invocations); + XLS_RETURN_IF_ERROR(visit_target->Accept(&visitor)); return RemoveFunctionDuplicates(records); } } // namespace xls::dslx diff --git a/xls/dslx/ir_convert/get_conversion_records.h b/xls/dslx/ir_convert/get_conversion_records.h index c7343d2ab5..0b62cd8095 100644 --- a/xls/dslx/ir_convert/get_conversion_records.h +++ b/xls/dslx/ir_convert/get_conversion_records.h @@ -40,7 +40,7 @@ absl::StatusOr> GetConversionRecords( // entry: Proc or Function to start from (the top) // type_info: Mapping from node to type. absl::StatusOr> GetConversionRecordsForEntry( - std::variant entry, TypeInfo* type_info, + std::variant entry, TypeInfo* type_info, std::optional resolved_proc_alias); } // namespace xls::dslx diff --git a/xls/dslx/ir_convert/ir_conversion_utils.cc b/xls/dslx/ir_convert/ir_conversion_utils.cc index a1d8494bdd..edd912c102 100644 --- a/xls/dslx/ir_convert/ir_conversion_utils.cc +++ b/xls/dslx/ir_convert/ir_conversion_utils.cc @@ -15,16 +15,22 @@ #include #include +#include +#include #include #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/substitute.h" +#include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" +#include "xls/dslx/frontend/ast.h" #include "xls/dslx/interp_value.h" #include "xls/dslx/type_system/parametric_env.h" #include "xls/dslx/type_system/type.h" +#include "xls/dslx/type_system/type_info.h" #include "xls/ir/package.h" #include "xls/ir/type.h" @@ -143,4 +149,71 @@ absl::StatusOr TypeToIr(Package* package, const Type& type, return v.retval(); } +std::optional GetProcNextFunction(const ProcDef* proc) { + for (ImplMember member : (*proc->impl())->members()) { + if (!std::holds_alternative(member)) { + continue; + } + + Function* fn = std::get(member); + if (fn->identifier() == "next") { + return fn; + } + } + + return std::nullopt; +} + +absl::StatusOr> GetProcConstructors(const ProcDef* p, + const TypeInfo* ti) { + XLS_RET_CHECK(p->impl().has_value()); + const Impl* impl = *p->impl(); + std::vector result; + for (ImplMember member : impl->members()) { + if (std::holds_alternative(member)) { + Function* function = std::get(member); + XLS_ASSIGN_OR_RETURN(const Type* fn_type, ti->GetItemOrError(function)); + XLS_RET_CHECK(fn_type->IsFunction()); + + if (!fn_type->AsFunction().params().empty()) { + const Type& first_param_type = *fn_type->AsFunction().params().front(); + if (first_param_type.IsProc() && + &first_param_type.AsProc().struct_def_base() == p) { + continue; + } + } + + // It's only a constructor if it returns effectively `Self`. + const Type& return_type = fn_type->AsFunction().return_type(); + if (return_type.IsProc() && + &return_type.AsProc().struct_def_base() == p) { + result.push_back(function); + } + } + } + return result; +} + +absl::StatusOr GetTopProcConstructor(const ProcDef* proc, + const TypeInfo* ti) { + XLS_ASSIGN_OR_RETURN(std::vector constructors, + GetProcConstructors(proc, ti)); + if (constructors.empty()) { + return absl::InvalidArgumentError(absl::Substitute( + "Proc '$0' does not have a constructor, i.e., a static function " + "returning Self, so it cannot be used as a top proc.", + proc->identifier())); + } + + if (constructors.size() > 1) { + return absl::InvalidArgumentError(absl::Substitute( + "Proc '$0' has $1 possible constructors, i.e. static functions " + "returning Self. In order to be used as a top proc, there must only be " + "one constructor.", + proc->identifier(), constructors.size())); + } + + return constructors.front(); +} + } // namespace xls::dslx diff --git a/xls/dslx/ir_convert/ir_conversion_utils.h b/xls/dslx/ir_convert/ir_conversion_utils.h index 9050fa6d4c..3e064f5033 100644 --- a/xls/dslx/ir_convert/ir_conversion_utils.h +++ b/xls/dslx/ir_convert/ir_conversion_utils.h @@ -18,10 +18,14 @@ // tree traversal (which is the main concern of ir_converter.h/cc. #include +#include +#include #include "absl/status/statusor.h" +#include "xls/dslx/frontend/ast.h" #include "xls/dslx/type_system/parametric_env.h" #include "xls/dslx/type_system/type.h" +#include "xls/dslx/type_system/type_info.h" #include "xls/ir/package.h" #include "xls/ir/type.h" @@ -39,6 +43,20 @@ absl::StatusOr ResolveDimToInt(const TypeDim& dim, absl::StatusOr TypeToIr(Package* package, const Type& type, const ParametricEnv& bindings); +// Returns the `next` function of the given proc, if it has one. +std::optional GetProcNextFunction(const ProcDef* proc); + +// Returns all functions in `proc` that are constructors by signature, i.e. +// static functions returning `Self`. +absl::StatusOr> GetProcConstructors(const ProcDef* proc, + const TypeInfo* ti); + +// Returns the one function in `proc` that is a constructor by signature. This +// errors unless there is one and only one constructor, because currently there +// is no option or attribute to select one of many. +absl::StatusOr GetTopProcConstructor(const ProcDef* proc, + const TypeInfo* ti); + } // namespace xls::dslx #endif // XLS_DSLX_IR_CONVERT_IR_CONVERSION_UTILS_H_ diff --git a/xls/dslx/ir_convert/ir_converter.cc b/xls/dslx/ir_convert/ir_converter.cc index 0b7b503e17..130198636e 100644 --- a/xls/dslx/ir_convert/ir_converter.cc +++ b/xls/dslx/ir_convert/ir_converter.cc @@ -33,6 +33,7 @@ #include #include "absl/algorithm/container.h" +#include "absl/base/casts.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" @@ -41,6 +42,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/substitute.h" #include "absl/types/span.h" #include "cppitertools/filter.hpp" #include "cppitertools/imap.hpp" @@ -66,10 +68,12 @@ #include "xls/dslx/ir_convert/extract_conversion_order.h" #include "xls/dslx/ir_convert/function_converter.h" #include "xls/dslx/ir_convert/get_conversion_records.h" +#include "xls/dslx/ir_convert/ir_conversion_utils.h" #include "xls/dslx/ir_convert/proc_config_ir_converter.h" #include "xls/dslx/parse_and_typecheck.h" #include "xls/dslx/type_system/parametric_env.h" #include "xls/dslx/type_system/type_info.h" +#include "xls/dslx/type_system_v2/import_utils.h" #include "xls/dslx/virtualizable_file_system.h" #include "xls/dslx/warning_collector.h" #include "xls/dslx/warning_kind.h" @@ -242,6 +246,22 @@ absl::Status ConvertOneFunctionInternal(PackageData& package_data, return converter.HandleProcNextFunction(record, import_data, proc_data); } + XLS_ASSIGN_OR_RETURN(bool is_proc_def_next, + IsProcDefNextFunction(f, *import_data)); + if (is_proc_def_next) { + XLS_ASSIGN_OR_RETURN(std::optional def, + GetStructOrProcDef(f, *import_data)); + XLS_RET_CHECK(def.has_value()); + const ProcDef* proc_def = absl::down_cast(*def); + // TODO: https://github.com/google/xls/issues/4125 - Specify the intended + // constructor in the params to `ConvertOneFunctionInternal`. For now we + // assume the proc is top, because we don't yet support spawns. + XLS_ASSIGN_OR_RETURN(Function * constructor, + GetTopProcConstructor(proc_def, record.type_info())); + return converter.ConvertProcDef(proc_def, constructor, *record.proc_id(), + record.type_info()); + } + return converter.HandleFunction(f, record.type_info(), &record.parametric_env()); } @@ -449,6 +469,18 @@ absl::Status CheckAcceptableTopProc(Proc* proc) { return absl::OkStatus(); } +absl::Status CheckAcceptableTopProcDef(const ProcDef* proc, + const TypeInfo* ti) { + if (!proc->impl().has_value()) { + return absl::InvalidArgumentError( + absl::Substitute("Cannot convert proc '$0' because it does not have an " + "impl with a constructor.", + proc->identifier())); + } + + return GetTopProcConstructor(proc, ti).status(); +} + template absl::Status ConvertOneFunctionIntoPackageInternal( BlockT* block, ImportData* import_data, const ConvertOptions& options, @@ -526,6 +558,15 @@ absl::Status ConvertOneFunctionIntoPackage(Module* module, conv); } + absl::StatusOr proc_def = + module->GetMemberOrError(entry_function_name); + if (proc_def.ok()) { + XLS_ASSIGN_OR_RETURN(TypeInfo * ti, import_data->GetRootTypeInfo(module)); + XLS_RETURN_IF_ERROR(CheckAcceptableTopProcDef(*proc_def, ti)); + return ConvertOneFunctionIntoPackageInternal(*proc_def, import_data, + options, conv); + } + absl::StatusOr proc_alias = module->GetMemberOrError(entry_function_name); if (proc_alias.ok()) { diff --git a/xls/dslx/ir_convert/ir_converter_test.cc b/xls/dslx/ir_convert/ir_converter_test.cc index 8d3686078e..454bf5b502 100644 --- a/xls/dslx/ir_convert/ir_converter_test.cc +++ b/xls/dslx/ir_convert/ir_converter_test.cc @@ -3023,6 +3023,117 @@ proc A { ExpectIr(converted); } +TEST_F(IrConverterTest, TopProcDefWithNoSpawns) { + constexpr std::string_view program = R"( +#![feature(explicit_state_access)] + +proc Main { + c_in: chan in, + c_out: chan out, + i: u32, +} + +impl Main { + fn new(c_in: chan in, c_out: chan out) -> Self { + Main { c_in: c_in, c_out: c_out, i: 1 } + } + + fn next(self) { + let i_val = read(self.i); + let (tok, j) = recv(join(), self.c_in); + let tok = send(tok, self.c_out, i_val + j); + write(self.i, i_val + j); + } +} +)"; + + auto import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + std::string converted, + ConvertOneFunctionForTest(program, "Main", import_data)); + ExpectIr(converted); +} + +TEST_F(IrConverterTest, TopProcDefWithNoConstructorFails) { + constexpr std::string_view program = R"( +#![feature(explicit_state_access)] + +proc Main { + c_in: chan in, + c_out: chan out, + i: u32, +} + +impl Main { + fn next(self) { + let i_val = read(self.i); + let (tok, j) = recv(join(), self.c_in); + let tok = send(tok, self.c_out, i_val + j); + write(self.i, i_val + j); + } +} +)"; + + auto import_data = CreateImportDataForTest(); + EXPECT_THAT(ConvertOneFunctionForTest(program, "Main", import_data).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Proc 'Main' does not have a constructor"))); +} + +TEST_F(IrConverterTest, TopProcDefWithNonConstructorNewFails) { + constexpr std::string_view program = R"( +#![feature(explicit_state_access)] + +proc Main { + c_in: chan in, + c_out: chan out, + i: u32, +} + +impl Main { + fn new(c_in: chan in, c_out: chan out) { + let x = Main { c_in: c_in, c_out: c_out, i: 1 }; + } + + fn next(self) { + let i_val = read(self.i); + let (tok, j) = recv(join(), self.c_in); + let tok = send(tok, self.c_out, i_val + j); + write(self.i, i_val + j); + } +} +)"; + + auto import_data = CreateImportDataForTest(); + EXPECT_THAT(ConvertOneFunctionForTest(program, "Main", import_data).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Proc 'Main' does not have a constructor"))); +} + +TEST_F(IrConverterTest, TopProcDefWithNoNextFails) { + constexpr std::string_view program = R"( +#![feature(explicit_state_access)] + +proc Main { + c_in: chan in, + c_out: chan out, + i: u32, +} + +impl Main { + fn new(c_in: chan in, c_out: chan out) -> Self { + Main { c_in: c_in, c_out: c_out, i: 1 } + } +} +)"; + + auto import_data = CreateImportDataForTest(); + EXPECT_THAT( + ConvertOneFunctionForTest(program, "Main", import_data).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("A proc with no 'next' function cannot be top."))); +} + TEST_F(IrConverterTest, SendIfRecvIf) { constexpr std::string_view program = R"(proc producer { c: chan out; diff --git a/xls/dslx/ir_convert/testdata/ir_converter_test_TopProcDefWithNoSpawns.ir b/xls/dslx/ir_convert/testdata/ir_converter_test_TopProcDefWithNoSpawns.ir new file mode 100644 index 0000000000..6506978c62 --- /dev/null +++ b/xls/dslx/ir_convert/testdata/ir_converter_test_TopProcDefWithNoSpawns.ir @@ -0,0 +1,39 @@ +package test_module + +file_number 0 "test_module.x" + +top proc __test_module__Main_next<_c_in: bits[32] in, _c_out: bits[32] out>(__i: bits[32], init={1}) { + chan_interface _c_in(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface _c_out(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + literal.2: bits[1] = literal(value=1, id=2) + literal.3: bits[1] = literal(value=0, id=3) + after_all.13: token = after_all(id=13) + not.7: bits[1] = not(literal.2, id=7) + not.8: bits[1] = not(literal.3, id=8) + __i: bits[32] = state_read(state_element=__i, id=5) + receive.14: (token, bits[32]) = receive(after_all.13, predicate=literal.2, channel=_c_in, id=14) + __token: token = literal(value=token, id=1) + or.9: bits[1] = or(not.7, not.8, id=9) + not.21: bits[1] = not(literal.2, id=21) + or.11: bits[1] = or(literal.3, literal.2, id=11) + literal.4: bits[1] = literal(value=0, id=4) + i_val: bits[32] = identity(__i, id=12) + j: bits[32] = tuple_index(receive.14, index=1, id=17) + assert.10: token = assert(__token, or.9, message="State element read after read in same activation.", id=10) + or.22: bits[1] = or(not.21, or.11, id=22) + not.24: bits[1] = not(literal.2, id=24) + not.25: bits[1] = not(literal.4, id=25) + tok: token = tuple_index(receive.14, index=0, id=16) + add.18: bits[32] = add(i_val, j, id=18) + assert.23: token = assert(assert.10, or.22, message="State element written before read in same activation.", id=23) + or.26: bits[1] = or(not.24, not.25, id=26) + add.20: bits[32] = add(i_val, j, id=20) + tuple.6: (bits[32]) = tuple(__i, id=6) + tuple_index.15: token = tuple_index(receive.14, index=0, id=15) + tok__1: token = send(tok, add.18, predicate=literal.2, channel=_c_out, id=19) + assert.27: token = assert(assert.23, or.26, message="State element written after write in same activation.", id=27) + or.28: bits[1] = or(literal.4, literal.2, id=28) + next_value.29: () = next_value(param=__i, value=add.20, predicate=literal.2, id=29) + tuple.30: () = tuple(id=30) + tuple.31: () = tuple(id=31) +} diff --git a/xls/dslx/type_system/type.cc b/xls/dslx/type_system/type.cc index 21d50b2904..3f97d0fc5a 100644 --- a/xls/dslx/type_system/type.cc +++ b/xls/dslx/type_system/type.cc @@ -246,6 +246,10 @@ bool Type::CompatibleWith(const Type& other) const { return false; } +bool Type::IsChannel() const { + return dynamic_cast(this) != nullptr; +} + bool Type::IsUnit() const { if (auto* t = dynamic_cast(this)) { return t->empty(); diff --git a/xls/dslx/type_system/type.h b/xls/dslx/type_system/type.h index 82bbf70978..d7f96002cf 100644 --- a/xls/dslx/type_system/type.h +++ b/xls/dslx/type_system/type.h @@ -312,6 +312,7 @@ class Type { // Type equality, but ignores tuple member naming discrepancies. bool CompatibleWith(const Type& other) const; + bool IsChannel() const; bool IsUnit() const; bool IsToken() const; bool IsStruct() const;