2929
3030#include " absl/algorithm/container.h"
3131#include " absl/base/casts.h"
32+ #include " absl/base/no_destructor.h"
3233#include " absl/container/btree_set.h"
3334#include " absl/container/flat_hash_map.h"
3435#include " absl/container/flat_hash_set.h"
@@ -245,7 +246,8 @@ void FunctionBase::TakeOwnershipOfNode(std::unique_ptr<Node>&& node) {
245246 FunctionBase* old_owner = node->function_base ();
246247
247248 if (node->Is <StateRead>()) {
248- old_owner->next_values_by_state_read_ .erase (node->As <StateRead>());
249+ old_owner->next_values_by_state_element_ .erase (
250+ node->As <StateRead>()->state_element ());
249251 }
250252
251253 old_owner->node_iterators_ .erase (node.get ());
@@ -442,6 +444,22 @@ absl::StatusOr<Node*> FunctionBase::GetNode(
442444 absl::StrFormat (" GetNode(%s) failed." , standard_node_name));
443445}
444446
447+ const absl::btree_set<Next*, Node::NodeIdLessThan>& FunctionBase::next_values (
448+ StateRead* state_read) const {
449+ return next_values (state_read->state_element ());
450+ }
451+
452+ const absl::btree_set<Next*, Node::NodeIdLessThan>& FunctionBase::next_values (
453+ StateElement* state_element) const {
454+ if (!next_values_by_state_element_.contains (state_element)) {
455+ static const absl::NoDestructor<
456+ absl::btree_set<Next*, Node::NodeIdLessThan>>
457+ kEmptySet ;
458+ return *kEmptySet ;
459+ }
460+ return next_values_by_state_element_.at (state_element);
461+ }
462+
445463absl::Status FunctionBase::RemoveNode (Node* node) {
446464 XLS_RET_CHECK (node->users ().empty ()) << node->GetName ();
447465 XLS_RET_CHECK (!HasImplicitUse (node)) << node->GetName ();
@@ -462,13 +480,12 @@ absl::Status FunctionBase::RemoveNode(Node* node) {
462480 params_.end ());
463481 }
464482 if (node->Is <StateRead>()) {
465- next_values_by_state_read_ .erase (node->As <StateRead>());
483+ next_values_by_state_element_ .erase (node->As <StateRead>()-> state_element ());
466484 }
467485 if (node->Is <Next>()) {
468486 Next* next = node->As <Next>();
469- if (next->state_read ()->Is <StateRead>()) { // Could've been replaced.
470- StateRead* state_read = next->state_read ()->As <StateRead>();
471- next_values_by_state_read_.at (state_read).erase (next);
487+ if (next_values_by_state_element_.contains (next->state_element ())) {
488+ next_values_by_state_element_.at (next->state_element ()).erase (next);
472489 }
473490 std::erase (next_values_, next);
474491 }
@@ -559,13 +576,12 @@ Node* FunctionBase::AddNodeInternal(std::unique_ptr<Node> node) {
559576 params_.push_back (node->As <Param>());
560577 }
561578 if (node->Is <StateRead>()) {
562- next_values_by_state_read_ [node->As <StateRead>()];
579+ next_values_by_state_element_ [node->As <StateRead>()-> state_element ()];
563580 }
564581 if (node->Is <Next>()) {
565582 Next* next = node->As <Next>();
566- StateRead* state_read = next->state_read ()->As <StateRead>();
567- next_values_.push_back (node->As <Next>());
568- next_values_by_state_read_[state_read].insert (next);
583+ next_values_.push_back (next);
584+ next_values_by_state_element_[next->state_element ()].insert (next);
569585 }
570586 Node* ptr = node.get ();
571587 node_iterators_[ptr] = nodes_.insert (nodes_.end (), std::move (node));
@@ -683,7 +699,7 @@ absl::Status FunctionBase::RebuildSideTables() {
683699 // TODO(allight): The fact that there is so much crap in the function_base
684700 // itself is a problem. Having next's and params' in the function base doesn't
685701 // make a ton of sense.
686- // NB Because of above the next-values/next_values_by_state_read_ and params
702+ // NB Because of above the next-values/next_values_by_state_element_ and
687703 // lists are updated in proc and function respectively.
688704 // NB We assume that node_iterators_ never gets invalidated.
689705 XLS_RETURN_IF_ERROR (InternalRebuildSideTables ());
0 commit comments