Skip to content

Commit 8c24a14

Browse files
NL02copybara-github
authored andcommitted
[Explicit State Access] Allow next_values to use state_element instead of state_read.
PiperOrigin-RevId: 885757168
1 parent 5cc4900 commit 8c24a14

9 files changed

Lines changed: 136 additions & 79 deletions

File tree

xls/ir/function.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ absl::Status Function::InternalRebuildSideTables() {
329329
// only held in the side table. We can still check for correctness at least.
330330
// TODO(allight): We should ideally be able to do this.
331331
XLS_RET_CHECK(next_values_.empty());
332-
XLS_RET_CHECK(next_values_by_state_read_.empty());
332+
XLS_RET_CHECK(next_values_by_state_element_.empty());
333333

334334
for (Param* p : params_) {
335335
XLS_RET_CHECK(p->function_base() == this)

xls/ir/function_base.cc

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
#include "absl/algorithm/container.h"
3131
#include "absl/base/casts.h"
32+
#include "absl/base/no_destructor.h"
3233
#include "absl/container/btree_set.h"
3334
#include "absl/container/flat_hash_map.h"
3435
#include "absl/container/flat_hash_set.h"
@@ -245,7 +246,8 @@ void FunctionBase::TakeOwnershipOfNode(std::unique_ptr<Node>&& node) {
245246
FunctionBase* old_owner = node->function_base();
246247

247248
if (node->Is<StateRead>()) {
248-
old_owner->next_values_by_state_read_.erase(node->As<StateRead>());
249+
old_owner->next_values_by_state_element_.erase(
250+
node->As<StateRead>()->state_element());
249251
}
250252

251253
old_owner->node_iterators_.erase(node.get());
@@ -442,6 +444,22 @@ absl::StatusOr<Node*> FunctionBase::GetNode(
442444
absl::StrFormat("GetNode(%s) failed.", standard_node_name));
443445
}
444446

447+
const absl::btree_set<Next*, Node::NodeIdLessThan>& FunctionBase::next_values(
448+
StateRead* state_read) const {
449+
return next_values(state_read->state_element());
450+
}
451+
452+
const absl::btree_set<Next*, Node::NodeIdLessThan>& FunctionBase::next_values(
453+
StateElement* state_element) const {
454+
if (!next_values_by_state_element_.contains(state_element)) {
455+
static const absl::NoDestructor<
456+
absl::btree_set<Next*, Node::NodeIdLessThan>>
457+
kEmptySet;
458+
return *kEmptySet;
459+
}
460+
return next_values_by_state_element_.at(state_element);
461+
}
462+
445463
absl::Status FunctionBase::RemoveNode(Node* node) {
446464
XLS_RET_CHECK(node->users().empty()) << node->GetName();
447465
XLS_RET_CHECK(!HasImplicitUse(node)) << node->GetName();
@@ -462,13 +480,12 @@ absl::Status FunctionBase::RemoveNode(Node* node) {
462480
params_.end());
463481
}
464482
if (node->Is<StateRead>()) {
465-
next_values_by_state_read_.erase(node->As<StateRead>());
483+
next_values_by_state_element_.erase(node->As<StateRead>()->state_element());
466484
}
467485
if (node->Is<Next>()) {
468486
Next* next = node->As<Next>();
469-
if (next->state_read()->Is<StateRead>()) { // Could've been replaced.
470-
StateRead* state_read = next->state_read()->As<StateRead>();
471-
next_values_by_state_read_.at(state_read).erase(next);
487+
if (next_values_by_state_element_.contains(next->state_element())) {
488+
next_values_by_state_element_.at(next->state_element()).erase(next);
472489
}
473490
std::erase(next_values_, next);
474491
}
@@ -559,13 +576,12 @@ Node* FunctionBase::AddNodeInternal(std::unique_ptr<Node> node) {
559576
params_.push_back(node->As<Param>());
560577
}
561578
if (node->Is<StateRead>()) {
562-
next_values_by_state_read_[node->As<StateRead>()];
579+
next_values_by_state_element_[node->As<StateRead>()->state_element()];
563580
}
564581
if (node->Is<Next>()) {
565582
Next* next = node->As<Next>();
566-
StateRead* state_read = next->state_read()->As<StateRead>();
567-
next_values_.push_back(node->As<Next>());
568-
next_values_by_state_read_[state_read].insert(next);
583+
next_values_.push_back(next);
584+
next_values_by_state_element_[next->state_element()].insert(next);
569585
}
570586
Node* ptr = node.get();
571587
node_iterators_[ptr] = nodes_.insert(nodes_.end(), std::move(node));
@@ -683,7 +699,7 @@ absl::Status FunctionBase::RebuildSideTables() {
683699
// TODO(allight): The fact that there is so much crap in the function_base
684700
// itself is a problem. Having next's and params' in the function base doesn't
685701
// make a ton of sense.
686-
// NB Because of above the next-values/next_values_by_state_read_ and params
702+
// NB Because of above the next-values/next_values_by_state_element_ and
687703
// lists are updated in proc and function respectively.
688704
// NB We assume that node_iterators_ never gets invalidated.
689705
XLS_RETURN_IF_ERROR(InternalRebuildSideTables());

xls/ir/function_base.h

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
#include <utility>
2727
#include <vector>
2828

29-
#include "absl/algorithm/container.h"
30-
#include "absl/base/no_destructor.h"
3129
#include "absl/base/optimization.h"
3230
#include "absl/container/btree_set.h"
3331
#include "absl/container/flat_hash_map.h"
@@ -249,26 +247,10 @@ class FunctionBase {
249247
absl::Span<Next* const> next_values() const { return next_values_; }
250248

251249
const absl::btree_set<Next*, Node::NodeIdLessThan>& next_values(
252-
StateRead* state_read) const {
253-
if (!next_values_by_state_read_.contains(state_read)) {
254-
// This should be pretty rare. Basically this should only happen in the
255-
// short time before the non-updated state element is replaced. Just
256-
// returning what is actually there is nicer than crashing however. Do
257-
// check that this is not just some sort of weird corruption however.
258-
static const absl::NoDestructor<
259-
absl::btree_set<Next*, Node::NodeIdLessThan>>
260-
kEmptySet;
261-
CHECK(absl::c_none_of(nodes(),
262-
[state_read](Node* n) {
263-
return n->Is<Next>() &&
264-
n->As<Next>()->state_read() == state_read;
265-
}))
266-
<< "Invalid side table for next values. Missing " << state_read
267-
<< " in " << this;
268-
return *kEmptySet;
269-
}
270-
return next_values_by_state_read_.at(state_read);
271-
}
250+
StateRead* state_read) const;
251+
252+
const absl::btree_set<Next*, Node::NodeIdLessThan>& next_values(
253+
StateElement* state_element) const;
272254

273255
// Moves the given param to the given index in the parameter list.
274256
absl::Status MoveParamToIndex(Param* param, int64_t index);
@@ -497,7 +479,7 @@ class FunctionBase {
497479
// together.
498480
return !n->Is<Param>() && !n->Is<StateRead>();
499481
});
500-
other.next_values_by_state_read_.clear();
482+
other.next_values_by_state_element_.clear();
501483
other.next_values_.clear();
502484
}
503485

@@ -553,8 +535,9 @@ class FunctionBase {
553535

554536
std::vector<Param*> params_;
555537
std::vector<Next*> next_values_;
556-
absl::flat_hash_map<StateRead*, absl::btree_set<Next*, Node::NodeIdLessThan>>
557-
next_values_by_state_read_;
538+
absl::flat_hash_map<StateElement*,
539+
absl::btree_set<Next*, Node::NodeIdLessThan>>
540+
next_values_by_state_element_;
558541

559542
NameUniquer node_name_uniquer_ =
560543
NameUniquer(/*separator=*/"__", GetIrReservedWords());

xls/ir/node.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -640,8 +640,14 @@ std::string Node::ToStringInternal(bool include_operand_types) const {
640640
}
641641
case Op::kNext: {
642642
const Next* next = As<Next>();
643-
args = {absl::StrFormat("param=%s", next->state_read()->GetName()),
644-
absl::StrFormat("value=%s", next->value()->GetName())};
643+
if (next->has_state_read_operand()) {
644+
args = {absl::StrFormat("param=%s", next->state_read()->GetName()),
645+
absl::StrFormat("value=%s", next->value()->GetName())};
646+
} else {
647+
args = {
648+
absl::StrFormat("state_element=%s", next->state_element()->name()),
649+
absl::StrFormat("value=%s", next->value()->GetName())};
650+
}
645651
std::optional<Node*> predicate = next->predicate();
646652
if (predicate.has_value()) {
647653
args.push_back(

xls/ir/nodes.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,7 @@ Next::Next(const SourceInfo& loc, StateElement* state_element, Node* value,
988988
: Node(Op::kNext, function->package()->GetTupleType({}), loc, name,
989989
function),
990990
state_element_(state_element),
991+
has_state_read_operand_(false),
991992
has_predicate_(predicate.has_value()),
992993
predicate_operand_index_(1),
993994
label_(std::move(label)) {
@@ -1002,6 +1003,8 @@ Next::Next(const SourceInfo& loc, Node* state_read, Node* value,
10021003
std::string_view name, FunctionBase* function)
10031004
: Node(Op::kNext, function->package()->GetTupleType({}), loc, name,
10041005
function),
1006+
state_element_(state_read->As<StateRead>()->state_element()),
1007+
has_state_read_operand_(true),
10051008
has_predicate_(predicate.has_value()),
10061009
predicate_operand_index_(2),
10071010
label_(std::move(label)) {
@@ -1466,17 +1469,17 @@ absl::StatusOr<Node*> StateRead::CloneInNewFunction(
14661469

14671470
absl::StatusOr<Node*> Next::CloneInNewFunction(
14681471
absl::Span<Node* const> new_operands, FunctionBase* new_function) const {
1469-
if (state_element_ != nullptr) {
1472+
if (has_state_read_operand_) {
14701473
return new_function->MakeNodeWithName<Next>(
1471-
loc(), state_element_, new_operands[0],
1472-
new_operands.size() > 1 ? std::make_optional(new_operands[1])
1474+
loc(), new_operands[0], new_operands[1],
1475+
new_operands.size() > 2 ? std::make_optional(new_operands[2])
14731476
: std::nullopt,
14741477
label(), GetNameView());
14751478
}
14761479
// TODO(meheff): Choose an appropriate name for the cloned node.
14771480
return new_function->MakeNodeWithName<Next>(
1478-
loc(), new_operands[0], new_operands[1],
1479-
new_operands.size() > 2 ? std::make_optional(new_operands[2])
1481+
loc(), state_element_, new_operands[0],
1482+
new_operands.size() > 1 ? std::make_optional(new_operands[1])
14801483
: std::nullopt,
14811484
label(), GetNameView());
14821485
}

xls/ir/nodes.h

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -831,17 +831,12 @@ class Next final : public Node {
831831
absl::Span<Node* const> new_operands,
832832
FunctionBase* new_function) const final;
833833
Node* state_read() const {
834-
CHECK(state_element_ == nullptr) << "StateElement is set";
834+
CHECK(has_state_read_operand_)
835+
<< "Next node does not have a state_read operand";
835836
return operand(0);
836837
}
837838

838-
Node* value() const {
839-
if (state_element_ != nullptr) {
840-
return operand(0);
841-
} else {
842-
return operand(1);
843-
}
844-
}
839+
Node* value() const { return operand(has_state_read_operand_ ? 1 : 0); }
845840

846841
const std::optional<std::string>& label() const { return label_; }
847842

@@ -885,15 +880,13 @@ class Next final : public Node {
885880

886881
bool IsDefinitelyEqualTo(const Node* other) const final;
887882

888-
StateElement* state_element() const {
889-
if (state_element_ != nullptr) {
890-
return state_element_;
891-
}
892-
return state_read()->As<StateRead>()->state_element();
893-
}
883+
bool has_state_read_operand() const { return has_state_read_operand_; }
884+
885+
StateElement* state_element() const { return state_element_; }
894886

895887
private:
896-
StateElement* state_element_ = nullptr;
888+
StateElement* state_element_;
889+
bool has_state_read_operand_;
897890
bool has_predicate_;
898891
const int64_t predicate_operand_index_;
899892
std::optional<std::string> label_;

xls/ir/proc.cc

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -972,14 +972,20 @@ absl::StatusOr<StateRead*> Proc::TransformStateElement(
972972
nt.old_next->GetName()));
973973
to_replace.push_back({nt.old_next, nxt});
974974
// Identity-ify the old next.
975-
XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber(
976-
Next::kValueOperand, nt.old_next->state_read()));
975+
if (nt.old_next->has_state_read_operand()) {
976+
XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber(
977+
Next::kValueOperand, nt.old_next->state_read()));
978+
} else {
979+
XLS_RETURN_IF_ERROR(
980+
nt.old_next->ReplaceOperandNumber(/*operand_no=*/0, old_state_read));
981+
}
977982
}
978983
for (const auto& [old_n, new_n] : to_replace) {
979984
XLS_RETURN_IF_ERROR(old_n->ReplaceUsesWith(
980985
new_n,
981986
[&](Node* n) {
982-
if (n->Is<Next>() && n->As<Next>()->state_read() == old_n) {
987+
if (n->Is<Next>() && n->As<Next>()->has_state_read_operand() &&
988+
n->As<Next>()->state_read() == old_n) {
983989
return false;
984990
}
985991
return true;
@@ -993,7 +999,7 @@ absl::Status Proc::InternalRebuildSideTables() {
993999
XLS_RET_CHECK(params_.empty());
9941000
// Why is next-values in base but not elements?
9951001
next_values_.clear();
996-
next_values_by_state_read_.clear();
1002+
next_values_by_state_element_.clear();
9971003
state_reads_.clear();
9981004
for (Node* n : nodes()) {
9991005
if (n->Is<StateRead>()) {
@@ -1003,8 +1009,8 @@ absl::Status Proc::InternalRebuildSideTables() {
10031009
state_reads_[n->As<StateRead>()->state_element()] = n->As<StateRead>();
10041010
} else if (n->Is<Next>()) {
10051011
next_values_.push_back(n->As<Next>());
1006-
next_values_by_state_read_[n->As<Next>()->state_read()->As<StateRead>()]
1007-
.insert(n->As<Next>());
1012+
next_values_by_state_element_[n->As<Next>()->state_element()].insert(
1013+
n->As<Next>());
10081014
}
10091015
}
10101016
// TODO(allight): We should make it so we can recover channel/proc-inst things

xls/ir/proc_test.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,45 @@ TEST_F(ProcTest, StatelessProc) {
205205
EXPECT_EQ(proc->DumpIr(), "proc p() {\n}\n");
206206
}
207207

208+
TEST_F(ProcTest, NextValuesByStateElement) {
209+
auto p = CreatePackage();
210+
ProcBuilder pb("p", p.get());
211+
BValue state = pb.StateElement("st", Value(UBits(42, 32)));
212+
BValue add = pb.Add(pb.Literal(UBits(1, 32)), state);
213+
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build({add}));
214+
215+
StateElement* st_elt = proc->GetStateElement(0);
216+
StateRead* st_read = proc->GetStateRead(st_elt);
217+
218+
EXPECT_THAT(proc->next_values(st_elt),
219+
ElementsAre(m::Next(m::StateRead("st"), m::Add())));
220+
EXPECT_THAT(proc->next_values(st_read),
221+
ElementsAre(m::Next(m::StateRead("st"), m::Add())));
222+
223+
// Add another next value for the same state element using the StateElement
224+
// constructor.
225+
XLS_ASSERT_OK_AND_ASSIGN(
226+
Node * literal_10,
227+
proc->MakeNode<Literal>(SourceInfo(), Value(UBits(10, 32))));
228+
XLS_ASSERT_OK_AND_ASSIGN(
229+
Next * next2,
230+
proc->MakeNode<Next>(SourceInfo(), st_elt, literal_10,
231+
/*predicate=*/std::nullopt, /*label=*/std::nullopt));
232+
233+
EXPECT_THAT(
234+
proc->next_values(st_elt),
235+
UnorderedElementsAre(m::Next(m::StateRead("st"), m::Add()), next2));
236+
EXPECT_THAT(
237+
proc->next_values(st_read),
238+
UnorderedElementsAre(m::Next(m::StateRead("st"), m::Add()), next2));
239+
240+
XLS_ASSERT_OK(proc->RemoveNode(next2));
241+
EXPECT_THAT(proc->next_values(st_elt),
242+
ElementsAre(m::Next(m::StateRead("st"), m::Add())));
243+
EXPECT_THAT(proc->next_values(st_read),
244+
ElementsAre(m::Next(m::StateRead("st"), m::Add())));
245+
}
246+
208247
TEST_F(ProcTest, RemoveStateThatStillHasUse) {
209248
// Don't call CreatePackage which creates a VerifiedPackage because we
210249
// intentionally create a malformed proc.

xls/ir/verify_node.cc

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -978,29 +978,40 @@ class NodeChecker : public DfsVisitor {
978978
}
979979

980980
absl::Status HandleNext(Next* next) override {
981-
XLS_RETURN_IF_ERROR(ExpectOperandCountRange(next, 2, 3));
982-
if (!next->state_read()->Is<StateRead>()) {
983-
return absl::InternalError(
984-
absl::StrFormat("Next node %s expects a state read for param; is: %v",
985-
next->GetName(), *next->state_read()));
981+
if (next->has_state_read_operand()) {
982+
XLS_RETURN_IF_ERROR(ExpectOperandCountRange(next, 2, 3));
983+
if (!next->state_read()->Is<StateRead>()) {
984+
return absl::InternalError(absl::StrFormat(
985+
"Next node %s expects a state read for param; is: %v",
986+
next->GetName(), *next->state_read()));
987+
}
988+
} else {
989+
XLS_RETURN_IF_ERROR(ExpectOperandCountRange(next, 1, 2));
986990
}
991+
987992
if (next->predicate().has_value()) {
988-
XLS_RETURN_IF_ERROR(ExpectOperandHasBitsType(next, /*operand_no=*/2,
993+
XLS_ASSIGN_OR_RETURN(int64_t pred_idx, next->predicate_operand_number());
994+
XLS_RETURN_IF_ERROR(ExpectOperandHasBitsType(next, pred_idx,
989995
/*expected_bit_count=*/1));
990996
}
997+
991998
if (!next->function_base()->HasEffectiveProc()) {
992999
return absl::InternalError(absl::StrFormat(
993-
"Next node %s (for param %s) is not in a proc", next->GetName(),
994-
next->state_read()->As<StateRead>()->state_element()->name()));
1000+
"Next node %s (for state element %s) is not in a proc",
1001+
next->GetName(), next->state_element()->name()));
9951002
}
1003+
9961004
Proc* proc = next->function_base()->GetEffectiveProcOrDie();
997-
XLS_ASSIGN_OR_RETURN(
998-
int64_t index,
999-
proc->GetStateElementIndex(
1000-
next->state_read()->As<StateRead>()->state_element()));
1001-
XLS_RETURN_IF_ERROR(ExpectOperandHasType(next, /*operand_no=*/0,
1002-
proc->GetStateElementType(index)));
1003-
return ExpectOperandHasType(next, /*operand_no=*/1, // value is operand 1
1005+
XLS_ASSIGN_OR_RETURN(int64_t index,
1006+
proc->GetStateElementIndex(next->state_element()));
1007+
1008+
if (next->has_state_read_operand()) {
1009+
XLS_RETURN_IF_ERROR(ExpectOperandHasType(
1010+
next, /*operand_no=*/0, proc->GetStateElementType(index)));
1011+
}
1012+
1013+
int64_t value_idx = next->has_state_read_operand() ? 1 : 0;
1014+
return ExpectOperandHasType(next, value_idx,
10041015
proc->GetStateElementType(index));
10051016
}
10061017

0 commit comments

Comments
 (0)