Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions xls/ir/function_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,35 @@ LeafTypeTree<BValue> BuilderBase::MakeLeafTypeTree(BValue v) {
res->AsView(), [&](Node* n) -> BValue { return BValue(n, this); });
}

BValue BuilderBase::Next(class StateElement* state_element, BValue value,
std::optional<BValue> pred,
std::optional<std::string> label,
const SourceInfo& loc, std::string_view name) {
if (ErrorPending()) {
return BValue();
}
if (!value.GetType()->IsEqualTo(state_element->type())) {
return SetError(
absl::StrFormat(
"next value for state element '%s' must be of type %s; is: %s",
state_element->name(), state_element->type()->ToString(),
value.GetType()->ToString()),
loc);
}
if (pred.has_value() && (!pred->GetType()->IsBits() ||
pred->GetType()->AsBitsOrDie()->bit_count() != 1)) {
return SetError(absl::StrFormat("Predicate operand of next must be of bits "
"type of width 1; is: %s",
pred->GetType()->ToString()),
loc);
}
return AddNode<xls::Next>(
loc, /*state_element=*/state_element, /*value=*/value.node(),
/*predicate=*/pred.has_value() ? std::make_optional(pred->node())
: std::nullopt,
/*label=*/label, name);
}

FunctionBuilder::FunctionBuilder(std::string_view name, Package* package,
bool should_verify)
: BuilderBase(std::make_unique<Function>(std::string(name), package),
Expand Down Expand Up @@ -1506,6 +1535,34 @@ BValue ProcBuilder::StateElement(std::string_view name,
loc);
}

absl::StatusOr<class StateElement*> ProcBuilder::UnreadStateElement(
std::string_view name, const Value& initial_value, const SourceInfo& loc) {
if (ErrorPending()) {
return GetError();
}
return proc()->AppendUnreadStateElement(name, initial_value, loc);
}

BValue ProcBuilder::StateRead(class StateElement* state_element,
std::optional<BValue> predicate,
std::optional<std::string> label,
const SourceInfo& loc) {
if (ErrorPending()) {
return BValue();
}
absl::StatusOr<xls::StateRead*> state_read = proc()->AddStateRead(
state_element,
predicate.has_value() ? std::make_optional(predicate->node())
: std::nullopt,
label, loc);
if (!state_read.ok()) {
return SetError(absl::StrFormat("Unable to add state read: %s",
state_read.status().message()),
loc);
}
return CreateBValue(*state_read, loc);
}

BValue ProcBuilder::Param(std::string_view name, Type* type,
const SourceInfo& loc) {
if (ErrorPending()) {
Expand Down
16 changes: 16 additions & 0 deletions xls/ir/function_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Function;
class FunctionBase;
class Node;
class Proc;
class StateElement;
class Type;

// Represents a value for use in the function-definition building process,
Expand Down Expand Up @@ -695,6 +696,10 @@ class BuilderBase {
std::optional<BValue> pred = std::nullopt,
std::optional<std::string> label = std::nullopt,
const SourceInfo& loc = SourceInfo(), std::string_view name = "");
BValue Next(class StateElement* state_element, BValue value,
std::optional<BValue> pred = std::nullopt,
std::optional<std::string> label = std::nullopt,
const SourceInfo& loc = SourceInfo(), std::string_view name = "");

// Converts a BValue to a LeafTypeTree of BValues.
LeafTypeTree<BValue> MakeLeafTypeTree(BValue v);
Expand Down Expand Up @@ -913,6 +918,17 @@ class ProcBuilder : public BuilderBase {
/*read_predicate=*/std::nullopt, loc);
}

// Adds a state element to the proc without creating a state read.
absl::StatusOr<class StateElement*> UnreadStateElement(
std::string_view name, const Value& initial_value,
const SourceInfo& loc = SourceInfo());

// Adds a state read node for an existing state element.
BValue StateRead(class StateElement* state_element,
std::optional<BValue> predicate = std::nullopt,
std::optional<std::string> label = std::nullopt,
const SourceInfo& loc = SourceInfo());

// Overriden Param method is explicitly disabled (returns an error). Use
// StateElement method to add state elements.
BValue Param(std::string_view name, Type* type,
Expand Down
25 changes: 25 additions & 0 deletions xls/ir/function_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,31 @@ TEST(FunctionBuilderTest, BuildTwiceFails) {
HasSubstr("multiple times")));
}

TEST(FunctionBuilderTest, UnreadStateElementAndStateRead) {
Package p("p");
ProcBuilder b("unread_state_test", &p);

XLS_ASSERT_OK_AND_ASSIGN(
StateElement * state_element,
b.UnreadStateElement("my_state", Value(UBits(42, 32))));

BValue cond = b.Literal(UBits(1, 1));
BValue not_cond = b.Not(cond);

BValue read0 = b.StateRead(state_element, cond);
BValue read1 = b.StateRead(state_element, not_cond, "labeled_read");

BValue selected = b.Select(cond, {read1, read0});

b.Next(state_element, selected);
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, b.Build());

EXPECT_THAT(proc->StateElements(),
ElementsAre(m::StateElement("my_state", Value(UBits(42, 32)))));

EXPECT_EQ(proc->GetStateReadsByStateElement(state_element).size(), 2);
}

TEST(FunctionBuilderTest, SendAndReceive) {
Package p("p");
XLS_ASSERT_OK_AND_ASSIGN(
Expand Down
5 changes: 4 additions & 1 deletion xls/ir/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,10 @@ std::string Node::ToStringInternal(bool include_operand_types) const {
}
case Op::kNext: {
const Next* next = As<Next>();
args = {absl::StrFormat("param=%s", next->state_read()->GetName()),
std::string param_name = next->has_state_read()
? next->state_read()->GetName()
: next->state_element()->name();
args = {absl::StrFormat("param=%s", param_name),
absl::StrFormat("value=%s", next->value()->GetName())};
std::optional<Node*> predicate = next->predicate();
if (predicate.has_value()) {
Expand Down
3 changes: 3 additions & 0 deletions xls/ir/nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -834,12 +834,15 @@ class Next final : public Node {
absl::StatusOr<Node*> CloneInNewFunction(
absl::Span<Node* const> new_operands,
FunctionBase* new_function) const final;

Node* state_read() const {
CHECK(state_read_ != nullptr)
<< "state_read() called on a Next node with only StateElement set";
return state_read_;
}

bool has_state_read() const { return state_read_ != nullptr; }

Node* value() const {
if (state_read_ == nullptr) {
return operand(0);
Expand Down
12 changes: 12 additions & 0 deletions xls/ir/proc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,18 @@ absl::StatusOr<StateElement*> Proc::InsertUnreadStateElement(
return state_element;
}

absl::StatusOr<StateRead*> Proc::AddStateRead(StateElement* state_element,
std::optional<Node*> predicate,
std::optional<std::string> label,
const SourceInfo& loc) {
XLS_ASSIGN_OR_RETURN(
StateRead * state_read,
MakeNodeWithName<StateRead>(loc, state_element, predicate, label,
state_element->name()));
state_reads_[state_element].push_back(state_read);
return state_read;
}

absl::StatusOr<StateRead*> Proc::InsertStateElement(
int64_t index, std::string_view requested_state_name,
const Value& init_value, std::optional<Node*> read_predicate,
Expand Down
7 changes: 7 additions & 0 deletions xls/ir/proc.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,13 @@ class Proc : public FunctionBase {
/*next_state=*/std::nullopt, loc);
}

// Adds a state read node for an existing state element.
absl::StatusOr<StateRead*> AddStateRead(
StateElement* state_element,
std::optional<Node*> predicate = std::nullopt,
std::optional<std::string> label = std::nullopt,
const SourceInfo& loc = SourceInfo());

// Add a new state element (at index) without any reads or nexts. These must
// be added separately before verification.
absl::StatusOr<StateElement*> InsertUnreadStateElement(
Expand Down
46 changes: 28 additions & 18 deletions xls/ir/verify_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -978,30 +978,40 @@ class NodeChecker : public DfsVisitor {
}

absl::Status HandleNext(Next* next) override {
XLS_RETURN_IF_ERROR(ExpectOperandCountRange(next, 2, 3));
if (!next->state_read()->Is<StateRead>()) {
return absl::InternalError(
absl::StrFormat("Next node %s expects a state read for param; is: %v",
next->GetName(), *next->state_read()));
if (next->has_state_read()) {
XLS_RETURN_IF_ERROR(ExpectOperandCountRange(next, 2, 3));
if (!next->state_read()->Is<StateRead>()) {
return absl::InternalError(absl::StrFormat(
"Next node %s expects a state read for param; is: %v",
next->GetName(), *next->state_read()));
}
} else {
XLS_RETURN_IF_ERROR(ExpectOperandCountRange(next, 1, 2));
}

if (next->predicate().has_value()) {
XLS_RETURN_IF_ERROR(ExpectOperandHasBitsType(next, /*operand_no=*/2,
/*expected_bit_count=*/1));
XLS_ASSIGN_OR_RETURN(int64_t pred_idx, next->predicate_operand_number());
XLS_RETURN_IF_ERROR(ExpectOperandHasBitsType(
next, /*operand_no=*/pred_idx, /*expected_bit_count=*/1));
}
if (!next->function_base()->HasEffectiveProc()) {
return absl::InternalError(absl::StrFormat(
"Next node %s (for param %s) is not in a proc", next->GetName(),
next->state_read()->As<StateRead>()->state_element()->name()));
return absl::InternalError(
absl::StrFormat("Next node %s (for param %s) is not in a proc",
next->GetName(), next->state_element()->name()));
}
Proc* proc = next->function_base()->GetEffectiveProcOrDie();
XLS_ASSIGN_OR_RETURN(
int64_t index,
proc->GetStateElementIndex(
next->state_read()->As<StateRead>()->state_element()));
XLS_RETURN_IF_ERROR(ExpectOperandHasType(next, /*operand_no=*/0,
proc->GetStateElementType(index)));
return ExpectOperandHasType(next, /*operand_no=*/1, // value is operand 1
proc->GetStateElementType(index));
XLS_ASSIGN_OR_RETURN(int64_t index,
proc->GetStateElementIndex(next->state_element()));

if (next->has_state_read()) {
XLS_RETURN_IF_ERROR(ExpectOperandHasType(
next, /*operand_no=*/0, proc->GetStateElementType(index)));
return ExpectOperandHasType(next, /*operand_no=*/1, // value is operand 1
proc->GetStateElementType(index));
} else {
return ExpectOperandHasType(next, /*operand_no=*/0, // value is operand 0
proc->GetStateElementType(index));
}
}

absl::Status HandleNewChannel(NewChannel* new_channel) override {
Expand Down
Loading