diff --git a/xls/ir/function_builder.cc b/xls/ir/function_builder.cc index 6f12a80710..cbcbdbfe12 100644 --- a/xls/ir/function_builder.cc +++ b/xls/ir/function_builder.cc @@ -1227,6 +1227,35 @@ LeafTypeTree BuilderBase::MakeLeafTypeTree(BValue v) { res->AsView(), [&](Node* n) -> BValue { return BValue(n, this); }); } +BValue BuilderBase::Next(class StateElement* state_element, BValue value, + std::optional pred, + std::optional label, + const SourceInfo& loc, std::string_view name) { + if (ErrorPending()) { + return BValue(); + } + if (!value.GetType()->IsEqualTo(state_element->type())) { + return SetError( + absl::StrFormat( + "next value for state element '%s' must be of type %s; is: %s", + state_element->name(), state_element->type()->ToString(), + value.GetType()->ToString()), + loc); + } + if (pred.has_value() && (!pred->GetType()->IsBits() || + pred->GetType()->AsBitsOrDie()->bit_count() != 1)) { + return SetError(absl::StrFormat("Predicate operand of next must be of bits " + "type of width 1; is: %s", + pred->GetType()->ToString()), + loc); + } + return AddNode( + loc, /*state_element=*/state_element, /*value=*/value.node(), + /*predicate=*/pred.has_value() ? std::make_optional(pred->node()) + : std::nullopt, + /*label=*/label, name); +} + FunctionBuilder::FunctionBuilder(std::string_view name, Package* package, bool should_verify) : BuilderBase(std::make_unique(std::string(name), package), @@ -1506,6 +1535,34 @@ BValue ProcBuilder::StateElement(std::string_view name, loc); } +absl::StatusOr ProcBuilder::UnreadStateElement( + std::string_view name, const Value& initial_value, const SourceInfo& loc) { + if (ErrorPending()) { + return GetError(); + } + return proc()->AppendUnreadStateElement(name, initial_value, loc); +} + +BValue ProcBuilder::StateRead(class StateElement* state_element, + std::optional predicate, + std::optional label, + const SourceInfo& loc) { + if (ErrorPending()) { + return BValue(); + } + absl::StatusOr state_read = proc()->AddStateRead( + state_element, + predicate.has_value() ? std::make_optional(predicate->node()) + : std::nullopt, + label, loc); + if (!state_read.ok()) { + return SetError(absl::StrFormat("Unable to add state read: %s", + state_read.status().message()), + loc); + } + return CreateBValue(*state_read, loc); +} + BValue ProcBuilder::Param(std::string_view name, Type* type, const SourceInfo& loc) { if (ErrorPending()) { diff --git a/xls/ir/function_builder.h b/xls/ir/function_builder.h index 00092c5ae6..f6737a6408 100644 --- a/xls/ir/function_builder.h +++ b/xls/ir/function_builder.h @@ -61,6 +61,7 @@ class Function; class FunctionBase; class Node; class Proc; +class StateElement; class Type; // Represents a value for use in the function-definition building process, @@ -695,6 +696,10 @@ class BuilderBase { std::optional pred = std::nullopt, std::optional label = std::nullopt, const SourceInfo& loc = SourceInfo(), std::string_view name = ""); + BValue Next(class StateElement* state_element, BValue value, + std::optional pred = std::nullopt, + std::optional label = std::nullopt, + const SourceInfo& loc = SourceInfo(), std::string_view name = ""); // Converts a BValue to a LeafTypeTree of BValues. LeafTypeTree MakeLeafTypeTree(BValue v); @@ -913,6 +918,17 @@ class ProcBuilder : public BuilderBase { /*read_predicate=*/std::nullopt, loc); } + // Adds a state element to the proc without creating a state read. + absl::StatusOr UnreadStateElement( + std::string_view name, const Value& initial_value, + const SourceInfo& loc = SourceInfo()); + + // Adds a state read node for an existing state element. + BValue StateRead(class StateElement* state_element, + std::optional predicate = std::nullopt, + std::optional label = std::nullopt, + const SourceInfo& loc = SourceInfo()); + // Overriden Param method is explicitly disabled (returns an error). Use // StateElement method to add state elements. BValue Param(std::string_view name, Type* type, diff --git a/xls/ir/function_builder_test.cc b/xls/ir/function_builder_test.cc index 26ac3e87af..aa7af16b4a 100644 --- a/xls/ir/function_builder_test.cc +++ b/xls/ir/function_builder_test.cc @@ -504,6 +504,31 @@ TEST(FunctionBuilderTest, BuildTwiceFails) { HasSubstr("multiple times"))); } +TEST(FunctionBuilderTest, UnreadStateElementAndStateRead) { + Package p("p"); + ProcBuilder b("unread_state_test", &p); + + XLS_ASSERT_OK_AND_ASSIGN( + StateElement * state_element, + b.UnreadStateElement("my_state", Value(UBits(42, 32)))); + + BValue cond = b.Literal(UBits(1, 1)); + BValue not_cond = b.Not(cond); + + BValue read0 = b.StateRead(state_element, cond); + BValue read1 = b.StateRead(state_element, not_cond, "labeled_read"); + + BValue selected = b.Select(cond, {read1, read0}); + + b.Next(state_element, selected); + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, b.Build()); + + EXPECT_THAT(proc->StateElements(), + ElementsAre(m::StateElement("my_state", Value(UBits(42, 32))))); + + EXPECT_EQ(proc->GetStateReadsByStateElement(state_element).size(), 2); +} + TEST(FunctionBuilderTest, SendAndReceive) { Package p("p"); XLS_ASSERT_OK_AND_ASSIGN( diff --git a/xls/ir/node.cc b/xls/ir/node.cc index 95c2c31b82..f652fe1a5e 100644 --- a/xls/ir/node.cc +++ b/xls/ir/node.cc @@ -641,7 +641,10 @@ std::string Node::ToStringInternal(bool include_operand_types) const { } case Op::kNext: { const Next* next = As(); - args = {absl::StrFormat("param=%s", next->state_read()->GetName()), + std::string param_name = next->has_state_read() + ? next->state_read()->GetName() + : next->state_element()->name(); + args = {absl::StrFormat("param=%s", param_name), absl::StrFormat("value=%s", next->value()->GetName())}; std::optional predicate = next->predicate(); if (predicate.has_value()) { diff --git a/xls/ir/nodes.h b/xls/ir/nodes.h index 6790f970dd..a2151db377 100755 --- a/xls/ir/nodes.h +++ b/xls/ir/nodes.h @@ -834,12 +834,15 @@ class Next final : public Node { absl::StatusOr CloneInNewFunction( absl::Span new_operands, FunctionBase* new_function) const final; + Node* state_read() const { CHECK(state_read_ != nullptr) << "state_read() called on a Next node with only StateElement set"; return state_read_; } + bool has_state_read() const { return state_read_ != nullptr; } + Node* value() const { if (state_read_ == nullptr) { return operand(0); diff --git a/xls/ir/proc.cc b/xls/ir/proc.cc index 8f8dabd8f8..78c5d8970f 100644 --- a/xls/ir/proc.cc +++ b/xls/ir/proc.cc @@ -272,6 +272,18 @@ absl::StatusOr Proc::InsertUnreadStateElement( return state_element; } +absl::StatusOr Proc::AddStateRead(StateElement* state_element, + std::optional predicate, + std::optional label, + const SourceInfo& loc) { + XLS_ASSIGN_OR_RETURN( + StateRead * state_read, + MakeNodeWithName(loc, state_element, predicate, label, + state_element->name())); + state_reads_[state_element].push_back(state_read); + return state_read; +} + absl::StatusOr Proc::InsertStateElement( int64_t index, std::string_view requested_state_name, const Value& init_value, std::optional read_predicate, diff --git a/xls/ir/proc.h b/xls/ir/proc.h index 8385d96758..a6593cba42 100644 --- a/xls/ir/proc.h +++ b/xls/ir/proc.h @@ -228,6 +228,13 @@ class Proc : public FunctionBase { /*next_state=*/std::nullopt, loc); } + // Adds a state read node for an existing state element. + absl::StatusOr AddStateRead( + StateElement* state_element, + std::optional predicate = std::nullopt, + std::optional label = std::nullopt, + const SourceInfo& loc = SourceInfo()); + // Add a new state element (at index) without any reads or nexts. These must // be added separately before verification. absl::StatusOr InsertUnreadStateElement( diff --git a/xls/ir/verify_node.cc b/xls/ir/verify_node.cc index 7931e869fe..91720dc735 100644 --- a/xls/ir/verify_node.cc +++ b/xls/ir/verify_node.cc @@ -978,30 +978,40 @@ class NodeChecker : public DfsVisitor { } absl::Status HandleNext(Next* next) override { - XLS_RETURN_IF_ERROR(ExpectOperandCountRange(next, 2, 3)); - if (!next->state_read()->Is()) { - return absl::InternalError( - absl::StrFormat("Next node %s expects a state read for param; is: %v", - next->GetName(), *next->state_read())); + if (next->has_state_read()) { + XLS_RETURN_IF_ERROR(ExpectOperandCountRange(next, 2, 3)); + if (!next->state_read()->Is()) { + return absl::InternalError(absl::StrFormat( + "Next node %s expects a state read for param; is: %v", + next->GetName(), *next->state_read())); + } + } else { + XLS_RETURN_IF_ERROR(ExpectOperandCountRange(next, 1, 2)); } + if (next->predicate().has_value()) { - XLS_RETURN_IF_ERROR(ExpectOperandHasBitsType(next, /*operand_no=*/2, - /*expected_bit_count=*/1)); + XLS_ASSIGN_OR_RETURN(int64_t pred_idx, next->predicate_operand_number()); + XLS_RETURN_IF_ERROR(ExpectOperandHasBitsType( + next, /*operand_no=*/pred_idx, /*expected_bit_count=*/1)); } if (!next->function_base()->HasEffectiveProc()) { - return absl::InternalError(absl::StrFormat( - "Next node %s (for param %s) is not in a proc", next->GetName(), - next->state_read()->As()->state_element()->name())); + return absl::InternalError( + absl::StrFormat("Next node %s (for param %s) is not in a proc", + next->GetName(), next->state_element()->name())); } Proc* proc = next->function_base()->GetEffectiveProcOrDie(); - XLS_ASSIGN_OR_RETURN( - int64_t index, - proc->GetStateElementIndex( - next->state_read()->As()->state_element())); - XLS_RETURN_IF_ERROR(ExpectOperandHasType(next, /*operand_no=*/0, - proc->GetStateElementType(index))); - return ExpectOperandHasType(next, /*operand_no=*/1, // value is operand 1 - proc->GetStateElementType(index)); + XLS_ASSIGN_OR_RETURN(int64_t index, + proc->GetStateElementIndex(next->state_element())); + + if (next->has_state_read()) { + XLS_RETURN_IF_ERROR(ExpectOperandHasType( + next, /*operand_no=*/0, proc->GetStateElementType(index))); + return ExpectOperandHasType(next, /*operand_no=*/1, // value is operand 1 + proc->GetStateElementType(index)); + } else { + return ExpectOperandHasType(next, /*operand_no=*/0, // value is operand 0 + proc->GetStateElementType(index)); + } } absl::Status HandleNewChannel(NewChannel* new_channel) override {