diff --git a/xls/dslx/bytecode/bytecode.cc b/xls/dslx/bytecode/bytecode.cc index 8a5a797e50..c6a35b3d75 100644 --- a/xls/dslx/bytecode/bytecode.cc +++ b/xls/dslx/bytecode/bytecode.cc @@ -290,6 +290,10 @@ std::string OpToString(Bytecode::Op op) { return "negate"; case Bytecode::Op::kOr: return "or"; + case Bytecode::Op::kPeek: + return "peek"; + case Bytecode::Op::kPeekNonBlocking: + return "peek_non_blocking"; case Bytecode::Op::kPop: return "pop"; case Bytecode::Op::kRange: @@ -502,6 +506,15 @@ DEF_UNARY_BUILDER(Swap); return Bytecode(span, Op::kMatchArm, std::move(item)); } +/* static */ Bytecode Bytecode::MakePeek(Span span, ChannelData channel_data) { + return Bytecode(span, Op::kPeek, std::move(channel_data)); +} + +/* static */ Bytecode Bytecode::MakePeekNonBlocking( + Span span, ChannelData channel_data) { + return Bytecode(span, Op::kPeekNonBlocking, std::move(channel_data)); +} + /* static */ Bytecode Bytecode::MakeRecv(Span span, ChannelData channel_data) { return Bytecode(span, Op::kRecv, std::move(channel_data)); } diff --git a/xls/dslx/bytecode/bytecode.h b/xls/dslx/bytecode/bytecode.h index 044950e467..ff49bd488e 100644 --- a/xls/dslx/bytecode/bytecode.h +++ b/xls/dslx/bytecode/bytecode.h @@ -136,6 +136,22 @@ class Bytecode { kPop, // Creates an array of values [TOS1, TOS0). kRange, + // Peeks a value from the channel at TOS1 if condition at TOS0 is fulfilled. + // If TOS0 is true, then + // peeks a value from the channel or "blocks" + // if empty: terminates execution at the opcode's PC. The interpreter can + // be resumed/retried if/when a value becomes available. + // else + // a tuple containing a tuple and zero value is pushed on the stack. + kPeek, + // Peeks a value off of the channel at TOS0, but does not block if empty. + // A tuple containing + // 0. A token. + // 1. Peeked value (or a zero value if the channel is empty). + // and + // 2. A valid flag (false if the channel is empty). + // is pushed on the stack. + kPeekNonBlocking, // Pulls TOS0 (a condition) and TOS1 (a channel). // If TOS0 is true, then // pulls a value off of the channel or "blocks" @@ -390,6 +406,8 @@ class Bytecode { static Bytecode MakeLoad(Span span, SlotIndex slot_index); static Bytecode MakeLogicalOr(Span span); static Bytecode MakeMatchArm(Span span, MatchArmItem item); + static Bytecode MakePeek(Span span, ChannelData channel_data); + static Bytecode MakePeekNonBlocking(Span span, ChannelData channel_data); static Bytecode MakePop(Span span); static Bytecode MakeRecv(Span span, ChannelData channel_data); static Bytecode MakeRecvNonBlocking(Span span, ChannelData channel_data); diff --git a/xls/dslx/bytecode/bytecode_emitter.cc b/xls/dslx/bytecode/bytecode_emitter.cc index 0ab0c2656c..22c5065524 100644 --- a/xls/dslx/bytecode/bytecode_emitter.cc +++ b/xls/dslx/bytecode/bytecode_emitter.cc @@ -538,6 +538,79 @@ absl::Status BytecodeEmitter::HandleCast(const Cast* node) { return absl::OkStatus(); } +absl::Status BytecodeEmitter::HandleBuiltinPeek(const Invocation* node) { + Expr* token = node->args()[0]; + Expr* channel = node->args()[1]; + + XLS_RETURN_IF_ERROR(token->AcceptExpr(this)); + XLS_RETURN_IF_ERROR(channel->AcceptExpr(this)); + // All receives need a predicate. Set to true for unconditional receive. + Add(Bytecode::MakeLiteral(node->span(), InterpValue::MakeUBits(1, 1))); + XLS_ASSIGN_OR_RETURN( + Bytecode::ChannelData channel_data, + CreateChannelData(channel, type_info_, options_.format_preference)); + // Default value which is unused because the predicate is always + // true. Required because the `peek` bytecode has a predicate and + // a default value operand. + XLS_ASSIGN_OR_RETURN(InterpValue default_value, + CreateZeroValueFromType(channel_data.payload_type())); + Add(Bytecode::MakeLiteral(node->span(), default_value)); + Add(Bytecode::MakePeek(node->span(), std::move(channel_data))); + return absl::OkStatus(); +} + +absl::Status BytecodeEmitter::HandleBuiltinPeekNonBlocking( + const Invocation* node) { + Expr* token = node->args()[0]; + Expr* channel = node->args()[1]; + Expr* default_value = node->args()[2]; + + XLS_RETURN_IF_ERROR(token->AcceptExpr(this)); + XLS_RETURN_IF_ERROR(channel->AcceptExpr(this)); + Add(Bytecode::MakeLiteral(node->span(), InterpValue::MakeUBits(1, 1))); + XLS_RETURN_IF_ERROR(default_value->AcceptExpr(this)); + XLS_ASSIGN_OR_RETURN( + Bytecode::ChannelData channel_data, + CreateChannelData(channel, type_info_, options_.format_preference)); + Add(Bytecode::MakePeekNonBlocking(node->span(), std::move(channel_data))); + return absl::OkStatus(); +} + +absl::Status BytecodeEmitter::HandleBuiltinPeekIf(const Invocation* node) { + Expr* token = node->args()[0]; + Expr* channel = node->args()[1]; + Expr* condition = node->args()[2]; + Expr* default_value = node->args()[3]; + + XLS_RETURN_IF_ERROR(token->AcceptExpr(this)); + XLS_RETURN_IF_ERROR(channel->AcceptExpr(this)); + XLS_RETURN_IF_ERROR(condition->AcceptExpr(this)); + XLS_RETURN_IF_ERROR(default_value->AcceptExpr(this)); + XLS_ASSIGN_OR_RETURN( + Bytecode::ChannelData channel_data, + CreateChannelData(channel, type_info_, options_.format_preference)); + Add(Bytecode::MakePeek(node->span(), std::move(channel_data))); + return absl::OkStatus(); +} + +absl::Status BytecodeEmitter::HandleBuiltinPeekIfNonBlocking( + const Invocation* node) { + Expr* token = node->args()[0]; + Expr* channel = node->args()[1]; + Expr* condition = node->args()[2]; + Expr* default_value = node->args()[3]; + + XLS_RETURN_IF_ERROR(token->AcceptExpr(this)); + XLS_RETURN_IF_ERROR(channel->AcceptExpr(this)); + XLS_RETURN_IF_ERROR(condition->AcceptExpr(this)); + XLS_RETURN_IF_ERROR(default_value->AcceptExpr(this)); + XLS_ASSIGN_OR_RETURN( + Bytecode::ChannelData channel_data, + CreateChannelData(channel, type_info_, options_.format_preference)); + Add(Bytecode::MakePeekNonBlocking(node->span(), std::move(channel_data))); + return absl::OkStatus(); +} + absl::Status BytecodeEmitter::HandleBuiltinRecv(const Invocation* node) { Expr* token = node->args()[0]; Expr* channel = node->args()[1]; @@ -1142,6 +1215,18 @@ absl::Status BytecodeEmitter::HandleInvocation(const Invocation* node) { if (name_ref->identifier() == "join") { return HandleBuiltinJoin(node); } + if (name_ref->identifier() == "peek") { + return HandleBuiltinPeek(node); + } + if (name_ref->identifier() == "peek_non_blocking") { + return HandleBuiltinPeekNonBlocking(node); + } + if (name_ref->identifier() == "peek_if") { + return HandleBuiltinPeekIf(node); + } + if (name_ref->identifier() == "peek_if_non_blocking") { + return HandleBuiltinPeekIfNonBlocking(node); + } if (name_ref->identifier() == "recv") { return HandleBuiltinRecv(node); } diff --git a/xls/dslx/bytecode/bytecode_emitter.h b/xls/dslx/bytecode/bytecode_emitter.h index a93c266a7e..2a4108ae6f 100644 --- a/xls/dslx/bytecode/bytecode_emitter.h +++ b/xls/dslx/bytecode/bytecode_emitter.h @@ -167,6 +167,10 @@ class BytecodeEmitter : public ExprVisitor { absl::Status HandleBuiltinDecode(const Invocation* node); absl::Status HandleBuiltinElementCount(const Invocation* node); absl::Status HandleBuiltinJoin(const Invocation* node); + absl::Status HandleBuiltinPeek(const Invocation* node); + absl::Status HandleBuiltinPeekIf(const Invocation* node); + absl::Status HandleBuiltinPeekIfNonBlocking(const Invocation* node); + absl::Status HandleBuiltinPeekNonBlocking(const Invocation* node); absl::Status HandleBuiltinRecv(const Invocation* node); absl::Status HandleBuiltinRecvIf(const Invocation* node); absl::Status HandleBuiltinRecvIfNonBlocking(const Invocation* node); diff --git a/xls/dslx/bytecode/bytecode_interpreter.cc b/xls/dslx/bytecode/bytecode_interpreter.cc index c680081d0a..3112223d14 100644 --- a/xls/dslx/bytecode/bytecode_interpreter.cc +++ b/xls/dslx/bytecode/bytecode_interpreter.cc @@ -585,6 +585,14 @@ absl::Status BytecodeInterpreter::EvalNextInstruction() { XLS_RETURN_IF_ERROR(EvalOr(bytecode)); break; } + case Bytecode::Op::kPeek: { + XLS_RETURN_IF_ERROR(EvalPeek(bytecode)); + break; + } + case Bytecode::Op::kPeekNonBlocking: { + XLS_RETURN_IF_ERROR(EvalPeekNonBlocking(bytecode)); + break; + } case Bytecode::Op::kPop: { XLS_RETURN_IF_ERROR(EvalPop(bytecode)); break; @@ -1326,6 +1334,83 @@ absl::Status BytecodeInterpreter::EvalOr(const Bytecode& bytecode) { }); } +absl::Status BytecodeInterpreter::EvalPeek(const Bytecode& bytecode) { + XLS_ASSIGN_OR_RETURN(InterpValue default_value, Pop()); + XLS_ASSIGN_OR_RETURN(InterpValue condition, Pop()); + XLS_ASSIGN_OR_RETURN(InterpValue channel_value, Pop()); + XLS_ASSIGN_OR_RETURN(auto channel_reference, + channel_value.GetChannelReference()); + XLS_ASSIGN_OR_RETURN(const Bytecode::ChannelData* channel_data, + bytecode.channel_data()); + + XLS_RET_CHECK(channel_reference.GetChannelId().has_value()); + int64_t channel_id = channel_reference.GetChannelId().value(); + XLS_RET_CHECK(channel_manager_.has_value()); + InterpValueChannel& channel = (*channel_manager_)->GetChannel(channel_id); + + if (condition.IsTrue()) { + if (channel.IsEmpty()) { + stack_.Push(channel_value); + stack_.Push(condition); + stack_.Push(default_value); + blocked_channel_info_ = BlockedChannelInfo{ + .name = FormatChannelNameForTracing(*channel_data), + .span = bytecode.source_span(), + }; + return absl::UnavailableError("Channel is empty."); + } + + XLS_ASSIGN_OR_RETURN(InterpValue token, Pop()); + InterpValue value = channel.Peek(); + if (options_.trace_channels() && events_.has_value()) { + (*events_)->AddTraceChannelMessage( + import_data_->file_table(), bytecode.source_span(), + FormatChannelNameForTracing(*channel_data), value, + ChannelDirection::kIn, channel_data->value_fmt_desc()); + } + stack_.Push(InterpValue::MakeTuple({token, std::move(value)})); + } else { + XLS_ASSIGN_OR_RETURN(InterpValue token, Pop()); + stack_.Push(InterpValue::MakeTuple({token, default_value})); + } + + return absl::OkStatus(); +} + +absl::Status BytecodeInterpreter::EvalPeekNonBlocking( + const Bytecode& bytecode) { + XLS_ASSIGN_OR_RETURN(InterpValue default_value, Pop()); + XLS_ASSIGN_OR_RETURN(InterpValue condition, Pop()); + XLS_ASSIGN_OR_RETURN(InterpValue channel_value, Pop()); + XLS_ASSIGN_OR_RETURN(InterpValue::ChannelReference channel_reference, + channel_value.GetChannelReference()); + XLS_ASSIGN_OR_RETURN(InterpValue token, Pop()); + + XLS_RET_CHECK(channel_reference.GetChannelId().has_value()); + int64_t channel_id = channel_reference.GetChannelId().value(); + XLS_RET_CHECK(channel_manager_.has_value()); + InterpValueChannel& channel = (*channel_manager_)->GetChannel(channel_id); + + XLS_ASSIGN_OR_RETURN(const Bytecode::ChannelData* channel_data, + bytecode.channel_data()); + if (condition.IsTrue() && !channel.IsEmpty()) { + InterpValue value = channel.Peek(); + if (options_.trace_channels() && events_.has_value()) { + (*events_)->AddTraceChannelMessage( + import_data_->file_table(), bytecode.source_span(), + FormatChannelNameForTracing(*channel_data), value, + ChannelDirection::kIn, channel_data->value_fmt_desc()); + } + stack_.Push(InterpValue::MakeTuple( + {token, std::move(value), InterpValue::MakeBool(true)})); + } else { + stack_.Push(InterpValue::MakeTuple( + {token, default_value, InterpValue::MakeBool(false)})); + } + + return absl::OkStatus(); +} + absl::Status BytecodeInterpreter::EvalPop(const Bytecode& bytecode) { return Pop().status(); } @@ -1751,6 +1836,10 @@ absl::Status BytecodeInterpreter::RunBuiltinFn(const Bytecode& bytecode, case Builtin::kToken: case Builtin::kSend: case Builtin::kSendIf: + case Builtin::kPeek: + case Builtin::kPeekNonBlocking: + case Builtin::kPeekIf: + case Builtin::kPeekIfNonBlocking: case Builtin::kRecv: case Builtin::kRecvIf: case Builtin::kRecvNonBlocking: diff --git a/xls/dslx/bytecode/bytecode_interpreter.h b/xls/dslx/bytecode/bytecode_interpreter.h index 8148c9f39e..df6724c459 100644 --- a/xls/dslx/bytecode/bytecode_interpreter.h +++ b/xls/dslx/bytecode/bytecode_interpreter.h @@ -146,6 +146,7 @@ class InterpValueChannel { queue_.pop_front(); return result; } + InterpValue Peek() { return queue_.front(); } void Write(InterpValue v) { queue_.push_back(std::move(v)); } private: @@ -271,6 +272,8 @@ class BytecodeInterpreter { absl::Status EvalNe(const Bytecode& bytecode); absl::Status EvalNegate(const Bytecode& bytecode); absl::Status EvalOr(const Bytecode& bytecode); + absl::Status EvalPeek(const Bytecode& bytecode); + absl::Status EvalPeekNonBlocking(const Bytecode& bytecode); absl::Status EvalPop(const Bytecode& bytecode); absl::Status EvalRange(const Bytecode& bytecode); absl::Status EvalRecv(const Bytecode& bytecode); diff --git a/xls/dslx/dslx_builtins.h b/xls/dslx/dslx_builtins.h index 52c9b09d21..efbb69313e 100644 --- a/xls/dslx/dslx_builtins.h +++ b/xls/dslx/dslx_builtins.h @@ -69,6 +69,10 @@ namespace xls::dslx { X("token", kToken) \ /* send/recv routines */ \ /* keep-sorted start */ \ + X("peek", kPeek) \ + X("peek_if", kPeekIf) \ + X("peek_if_nonblocking", kPeekIfNonBlocking) \ + X("peek_nonblocking", kPeekNonBlocking) \ X("recv", kRecv) \ X("recv_if", kRecvIf) \ X("recv_if_nonblocking", kRecvIfNonBlocking) \ diff --git a/xls/dslx/frontend/ast.cc b/xls/dslx/frontend/ast.cc index cb2467ead1..477dc45608 100644 --- a/xls/dslx/frontend/ast.cc +++ b/xls/dslx/frontend/ast.cc @@ -248,6 +248,14 @@ std::string_view AstNodeKindToString(AstNodeKind kind) { return "index"; case AstNodeKind::kRange: return "range"; + case AstNodeKind::kPeek: + return "peek"; + case AstNodeKind::kPeekNonBlocking: + return "peek-non-blocking"; + case AstNodeKind::kPeekIf: + return "peek-if"; + case AstNodeKind::kPeekIfNonBlocking: + return "peek-if-non-blocking"; case AstNodeKind::kRecv: return "receive"; case AstNodeKind::kRecvNonBlocking: diff --git a/xls/dslx/frontend/ast_node.h b/xls/dslx/frontend/ast_node.h index 168bf09814..027a043c35 100644 --- a/xls/dslx/frontend/ast_node.h +++ b/xls/dslx/frontend/ast_node.h @@ -75,6 +75,10 @@ enum class AstNodeKind : uint8_t { kProcMember, kQuickCheck, kRange, + kPeek, + kPeekIf, + kPeekIfNonBlocking, + kPeekNonBlocking, kRecv, kRecvIf, kRecvIfNonBlocking, diff --git a/xls/dslx/frontend/builtin_stubs.x b/xls/dslx/frontend/builtin_stubs.x index 676dae8de3..ffff924dae 100644 --- a/xls/dslx/frontend/builtin_stubs.x +++ b/xls/dslx/frontend/builtin_stubs.x @@ -75,6 +75,14 @@ fn one_hot_sel(x: uN[N], y: xN[S][M][N]) -> xN[S][M]; fn or_reduce(x: uN[N]) -> u1; +fn peek_if_non_blocking(tok: token, channel: chan in, predicate: bool, value: T) -> (token, T, bool); + +fn peek_if(tok: token, channel: chan in, predicate: bool, value: T) -> (token, T); + +fn peek_non_blocking(tok: token, channel: chan in, value: T) -> (token, T, bool); + +fn peek(tok: token, channel: chan in) -> (token, T); + fn priority_sel(x: uN[N], y: xN[S][M][N], z: xN[S][M]) -> xN[S][M]; fn read(source: State) -> T; diff --git a/xls/dslx/frontend/builtins_metadata.cc b/xls/dslx/frontend/builtins_metadata.cc index 87d153ba60..459ee5ed91 100644 --- a/xls/dslx/frontend/builtins_metadata.cc +++ b/xls/dslx/frontend/builtins_metadata.cc @@ -111,6 +111,10 @@ const absl::flat_hash_map& GetParametricBuiltins() { // proc scope. {"labeled_read", {}}, {"labeled_write", {}}, + {"peek", {}}, + {"peek_if", {}}, + {"peek_non_blocking", {}}, + {"peek_if_non_blocking", {}}, {"send", {}}, {"send_if", {}}, {"recv", {}}, diff --git a/xls/dslx/frontend/parser_test.cc b/xls/dslx/frontend/parser_test.cc index ec39bab1a6..2ea4247a26 100644 --- a/xls/dslx/frontend/parser_test.cc +++ b/xls/dslx/frontend/parser_test.cc @@ -1460,6 +1460,36 @@ proc producer { RoundTrip(std::string(kModule)); } +TEST_F(ParserTest, ParsePeek) { + constexpr std::string_view kModule = R"(struct Packet { + id: u32, + data: u32, +} +proc PeekPacketFiller { + req_r: chan in; + resp_s: chan out; + config(req_r: chan in, resp_s: chan out) { + (req_r, resp_s) + } + init { + u32:0 + } + next(current_id: u32) { + let (tok, packet) = peek(join(), req_r); + let (packet, next_state) = if current_id < packet.id { + (Packet { id: current_id, data: current_id }, current_id + u32:1) + } else { + let (tok, packet) = recv(tok, req_r); + (packet, packet.id + u32:1) + }; + let tok = send(tok, resp_s, packet); + next_state + } +})"; + + RoundTrip(std::string(kModule)); +} + TEST_F(ParserTest, ParseSendIfAndRecvIf) { constexpr std::string_view kModule = R"(proc producer { c: chan in; diff --git a/xls/dslx/ir_convert/function_converter.cc b/xls/dslx/ir_convert/function_converter.cc index f3c3e1e354..b826e4cded 100644 --- a/xls/dslx/ir_convert/function_converter.cc +++ b/xls/dslx/ir_convert/function_converter.cc @@ -2753,6 +2753,18 @@ absl::Status FunctionConverter::HandleInvocation(const Invocation* node) { if (called_name == "map") { return HandleMap(node).status(); } + if (called_name == "peek") { + return HandleBuiltinPeek(node); + } + if (called_name == "peek_if") { + return HandleBuiltinPeekIf(node); + } + if (called_name == "peek_if_non_blocking") { + return HandleBuiltinPeekIfNonBlocking(node); + } + if (called_name == "peek_non_blocking") { + return HandleBuiltinPeekNonBlocking(node); + } if (called_name == "read") { return HandleBuiltinRead(node); } @@ -3058,6 +3070,182 @@ absl::Status FunctionConverter::HandleBuiltinSendIf(const Invocation* node) { return absl::OkStatus(); } +absl::Status FunctionConverter::HandleBuiltinPeek(const Invocation* node) { + XLS_RETURN_IF_ERROR(ValidateProcState("peek", node)); + ProcBuilder* builder_ptr = + dynamic_cast(function_builder_.get()); + + Expr* token = node->args()[0]; + Expr* channel = node->args()[1]; + + XLS_RETURN_IF_ERROR(Visit(token)); + XLS_RETURN_IF_ERROR(Visit(channel)); + IrValue channel_ir_value = node_to_ir_[channel]; + XLS_RETURN_IF_ERROR(CheckValueIsChannel(channel_ir_value)); + + XLS_ASSIGN_OR_RETURN(ReceiveChannelRef channel_ref, + IrValueToReceiveChannelRef(channel_ir_value)); + XLS_ASSIGN_OR_RETURN(BValue token_value, Use(token)); + BValue value; + if (implicit_token_data_.has_value()) { + XLS_RET_CHECK(implicit_token_data_->create_control_predicate != nullptr); + value = builder_ptr->PeekIf( + channel_ref, token_value, + implicit_token_data_->create_control_predicate()); + } else { + value = builder_ptr->Peek(channel_ref, token_value); + } + BValue new_token_value = builder_ptr->TupleIndex(value, 0); + tokens_.push_back(new_token_value); + node_to_ir_[node] = value; + return absl::OkStatus(); +} + +absl::Status FunctionConverter::HandleBuiltinPeekNonBlocking( + const Invocation* node) { + XLS_RETURN_IF_ERROR(ValidateProcState("peek_non_blocking", node)); + ProcBuilder* builder_ptr = + dynamic_cast(function_builder_.get()); + + Expr* token = node->args()[0]; + Expr* channel = node->args()[1]; + Expr* default_expr = node->args()[2]; + + XLS_RETURN_IF_ERROR(Visit(token)); + XLS_RETURN_IF_ERROR(Visit(channel)); + IrValue channel_ir_value = node_to_ir_[channel]; + XLS_RETURN_IF_ERROR(CheckValueIsChannel(channel_ir_value)); + XLS_RETURN_IF_ERROR(Visit(default_expr)); + + XLS_ASSIGN_OR_RETURN(BValue token_value, Use(token)); + XLS_ASSIGN_OR_RETURN(ReceiveChannelRef channel_ref, + IrValueToReceiveChannelRef(channel_ir_value)); + XLS_ASSIGN_OR_RETURN(BValue default_value, Use(default_expr)); + + BValue recv; + if (implicit_token_data_.has_value()) { + XLS_RET_CHECK(implicit_token_data_->create_control_predicate != nullptr); + recv = builder_ptr->PeekIfNonBlocking( + channel_ref, token_value, + implicit_token_data_->create_control_predicate()); + } else { + recv = builder_ptr->PeekNonBlocking(channel_ref, token_value); + } + BValue new_token_value = builder_ptr->TupleIndex(recv, 0); + BValue received_value = builder_ptr->TupleIndex(recv, 1); + BValue receive_activated = builder_ptr->TupleIndex(recv, 2); + + // IR non-blocking receive has a default value of zero. Mux in the + // default_value specified in DSLX. + BValue value = + builder_ptr->Select(receive_activated, {default_value, received_value}); + BValue repackaged_result = + builder_ptr->Tuple({new_token_value, value, receive_activated}); + + tokens_.push_back(new_token_value); + node_to_ir_[node] = repackaged_result; + + return absl::OkStatus(); +} + +absl::Status FunctionConverter::HandleBuiltinPeekIf(const Invocation* node) { + XLS_RETURN_IF_ERROR(ValidateProcState("peek_if", node)); + ProcBuilder* builder_ptr = + dynamic_cast(function_builder_.get()); + + Expr* token = node->args()[0]; + Expr* channel = node->args()[1]; + Expr* predicate = node->args()[2]; + Expr* default_expr = node->args()[3]; + + XLS_RETURN_IF_ERROR(Visit(token)); + XLS_RETURN_IF_ERROR(Visit(channel)); + IrValue channel_ir_value = node_to_ir_[channel]; + XLS_RETURN_IF_ERROR(CheckValueIsChannel(channel_ir_value)); + XLS_RETURN_IF_ERROR(Visit(predicate)); + XLS_RETURN_IF_ERROR(Visit(default_expr)); + + XLS_ASSIGN_OR_RETURN(BValue token_value, Use(token)); + XLS_ASSIGN_OR_RETURN(ReceiveChannelRef channel_ref, + IrValueToReceiveChannelRef(channel_ir_value)); + XLS_ASSIGN_OR_RETURN(BValue predicate_value, Use(predicate)); + XLS_ASSIGN_OR_RETURN(BValue default_value, Use(default_expr)); + + BValue recv; + if (implicit_token_data_.has_value()) { + XLS_RET_CHECK(implicit_token_data_->create_control_predicate != nullptr); + recv = builder_ptr->PeekIf( + channel_ref, token_value, + builder_ptr->And({implicit_token_data_->create_control_predicate(), + predicate_value})); + } else { + recv = builder_ptr->PeekIf(channel_ref, token_value, predicate_value); + } + BValue new_token_value = builder_ptr->TupleIndex(recv, 0); + BValue received_value = builder_ptr->TupleIndex(recv, 1); + + // IR receive-if has a default value of zero. Mux in the + // default_value specified in DSLX. + BValue value = + builder_ptr->Select(predicate_value, {default_value, received_value}); + BValue repackaged_result = builder_ptr->Tuple({new_token_value, value}); + + tokens_.push_back(new_token_value); + node_to_ir_[node] = repackaged_result; + return absl::OkStatus(); +} + +absl::Status FunctionConverter::HandleBuiltinPeekIfNonBlocking( + const Invocation* node) { + XLS_RETURN_IF_ERROR(ValidateProcState("peek_if_non_blocking", node)); + ProcBuilder* builder_ptr = + dynamic_cast(function_builder_.get()); + + Expr* token = node->args()[0]; + Expr* channel = node->args()[1]; + Expr* predicate = node->args()[2]; + Expr* default_expr = node->args()[3]; + + XLS_RETURN_IF_ERROR(Visit(token)); + XLS_RETURN_IF_ERROR(Visit(channel)); + IrValue channel_ir_value = node_to_ir_[channel]; + XLS_RETURN_IF_ERROR(CheckValueIsChannel(channel_ir_value)); + XLS_RETURN_IF_ERROR(Visit(predicate)); + XLS_RETURN_IF_ERROR(Visit(default_expr)); + + XLS_ASSIGN_OR_RETURN(BValue token_value, Use(token)); + XLS_ASSIGN_OR_RETURN(ReceiveChannelRef channel_ref, + IrValueToReceiveChannelRef(channel_ir_value)); + XLS_ASSIGN_OR_RETURN(BValue predicate_value, Use(predicate)); + XLS_ASSIGN_OR_RETURN(BValue default_value, Use(default_expr)); + + BValue recv; + if (implicit_token_data_.has_value()) { + XLS_RET_CHECK(implicit_token_data_->create_control_predicate != nullptr); + recv = builder_ptr->PeekIfNonBlocking( + channel_ref, token_value, + builder_ptr->And({implicit_token_data_->create_control_predicate(), + predicate_value})); + } else { + recv = builder_ptr->PeekIfNonBlocking(channel_ref, token_value, + predicate_value); + } + BValue new_token_value = builder_ptr->TupleIndex(recv, 0); + BValue received_value = builder_ptr->TupleIndex(recv, 1); + BValue receive_activated = builder_ptr->TupleIndex(recv, 2); + + // IR non-blocking receive-if has a default value of zero. Mux in the + // default_value specified in DSLX. + BValue value = + builder_ptr->Select(receive_activated, {default_value, received_value}); + BValue repackaged_result = + builder_ptr->Tuple({new_token_value, value, receive_activated}); + + tokens_.push_back(new_token_value); + node_to_ir_[node] = repackaged_result; + return absl::OkStatus(); +} + absl::Status FunctionConverter::HandleRange(const Range* node) { // Range must be constexpr, since it implicitly defines a structural type // (array of N elements). diff --git a/xls/dslx/ir_convert/function_converter.h b/xls/dslx/ir_convert/function_converter.h index cb95beb4f1..380329f835 100644 --- a/xls/dslx/ir_convert/function_converter.h +++ b/xls/dslx/ir_convert/function_converter.h @@ -497,6 +497,10 @@ class FunctionConverter { absl::Status HandleBuiltinOneHot(const Invocation* node); absl::Status HandleBuiltinOneHotSel(const Invocation* node); absl::Status HandleBuiltinOrReduce(const Invocation* node); + absl::Status HandleBuiltinPeek(const Invocation* node); + absl::Status HandleBuiltinPeekIf(const Invocation* node); + absl::Status HandleBuiltinPeekIfNonBlocking(const Invocation* node); + absl::Status HandleBuiltinPeekNonBlocking(const Invocation* node); absl::Status HandleBuiltinPrioritySel(const Invocation* node); absl::Status HandleBuiltinRead(const Invocation* node); absl::Status HandleBuiltinRecv(const Invocation* node); diff --git a/xls/dslx/ir_convert/ir_converter_test.cc b/xls/dslx/ir_convert/ir_converter_test.cc index e68c5a3e20..dbb58519c2 100644 --- a/xls/dslx/ir_convert/ir_converter_test.cc +++ b/xls/dslx/ir_convert/ir_converter_test.cc @@ -8312,5 +8312,37 @@ proc main { ExpectIr(converted); } +TEST_F(IrConverterTest, PeekChannelOperation) { + constexpr std::string_view program = R"( +struct Packet { + id: u32, + data: u32, +} +proc main { + req_r: chan in; + resp_s: chan out; + + init { } + + config( + req_r: chan in, + resp_s: chan out + ) { + (req_r, resp_s) + } + + next(state: ()) { + let (tok, packet) = peek(join(), req_r); + let handle_packet = packet.id > u32:4; + let (tok, packet) = recv_if(tok, req_r, handle_packet, zero!()); + send(tok, resp_s, packet); + } +} +)"; + XLS_ASSERT_OK_AND_ASSIGN(std::string converted, + ConvertModuleForTest(program)); + ExpectIr(converted); +} + } // namespace } // namespace xls::dslx diff --git a/xls/dslx/ir_convert/testdata/ir_converter_test_PeekChannelOperation.ir b/xls/dslx/ir_convert/testdata/ir_converter_test_PeekChannelOperation.ir new file mode 100644 index 0000000000..09d132647a --- /dev/null +++ b/xls/dslx/ir_convert/testdata/ir_converter_test_PeekChannelOperation.ir @@ -0,0 +1,32 @@ +package test_module + +file_number 0 "test_module.x" + +top proc __test_module__main_0_next<_req_r: (bits[32], bits[32]) in, _resp_s: (bits[32], bits[32]) out>(__state: (), init={()}) { + chan_interface _req_r(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface _resp_s(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + after_all.5: token = after_all(id=5) + literal.3: bits[1] = literal(value=1, id=3) + peek.6: (token, (bits[32], bits[32])) = peek(after_all.5, predicate=literal.3, channel=_req_r, id=6) + packet: (bits[32], bits[32]) = tuple_index(peek.6, index=1, id=9) + packet_id: bits[32] = tuple_index(packet, index=0, id=10) + literal.11: bits[32] = literal(value=4, id=11) + handle_packet: bits[1] = ugt(packet_id, literal.11, id=12) + tok: token = tuple_index(peek.6, index=0, id=8) + and.14: bits[1] = and(literal.3, handle_packet, id=14) + receive.15: (token, (bits[32], bits[32])) = receive(tok, predicate=and.14, channel=_req_r, id=15) + literal.13: (bits[32], bits[32]) = literal(value=(0, 0), id=13) + tuple_index.17: (bits[32], bits[32]) = tuple_index(receive.15, index=1, id=17) + tuple_index.16: token = tuple_index(receive.15, index=0, id=16) + sel.18: (bits[32], bits[32]) = sel(handle_packet, cases=[literal.13, tuple_index.17], id=18) + tuple.19: (token, (bits[32], bits[32])) = tuple(tuple_index.16, sel.18, id=19) + tok__1: token = tuple_index(tuple.19, index=0, id=20) + packet__1: (bits[32], bits[32]) = tuple_index(tuple.19, index=1, id=21) + __state: () = state_read(state_element=__state, id=2) + tuple.23: () = tuple(id=23) + __token: token = literal(value=token, id=1) + tuple.4: () = tuple(id=4) + tuple_index.7: token = tuple_index(peek.6, index=0, id=7) + send.22: token = send(tok__1, packet__1, predicate=literal.3, channel=_resp_s, id=22) + next_value.24: () = next_value(param=__state, value=tuple.23, id=24) +} diff --git a/xls/dslx/type_system/type_info.proto b/xls/dslx/type_system/type_info.proto index 6f3c696908..984192b159 100644 --- a/xls/dslx/type_system/type_info.proto +++ b/xls/dslx/type_system/type_info.proto @@ -97,6 +97,10 @@ enum AstNodeKindProto { AST_NODE_KIND_PROC_ALIAS = 72; AST_NODE_KIND_TRAIT = 73; AST_NODE_KIND_ATTRIBUTE = 74; + AST_NODE_KIND_PEEK = 75; + AST_NODE_KIND_PEEK_IF = 76; + AST_NODE_KIND_PEEK_NON_BLOCKING = 77; + AST_NODE_KIND_PEEK_IF_NON_BLOCKING = 78; } message BitsValueProto { diff --git a/xls/dslx/type_system/type_info_to_proto.cc b/xls/dslx/type_system/type_info_to_proto.cc index 72c3426fc4..4a40c22e1e 100644 --- a/xls/dslx/type_system/type_info_to_proto.cc +++ b/xls/dslx/type_system/type_info_to_proto.cc @@ -108,6 +108,14 @@ AstNodeKindProto ToProto(AstNodeKind kind) { return AST_NODE_KIND_SPLAT_STRUCT_INSTANCE; case AstNodeKind::kIndex: return AST_NODE_KIND_INDEX; + case AstNodeKind::kPeek: + return AST_NODE_KIND_PEEK; + case AstNodeKind::kPeekIf: + return AST_NODE_KIND_PEEK_IF; + case AstNodeKind::kPeekIfNonBlocking: + return AST_NODE_KIND_PEEK_IF_NON_BLOCKING; + case AstNodeKind::kPeekNonBlocking: + return AST_NODE_KIND_PEEK_NON_BLOCKING; case AstNodeKind::kRange: return AST_NODE_KIND_RANGE; case AstNodeKind::kRecv: @@ -722,6 +730,14 @@ absl::StatusOr FromProto(AstNodeKindProto p) { return AstNodeKind::kSplatStructInstance; case AST_NODE_KIND_INDEX: return AstNodeKind::kIndex; + case AST_NODE_KIND_PEEK: + return AstNodeKind::kPeek; + case AST_NODE_KIND_PEEK_NON_BLOCKING: + return AstNodeKind::kPeekNonBlocking; + case AST_NODE_KIND_PEEK_IF: + return AstNodeKind::kPeekIf; + case AST_NODE_KIND_PEEK_IF_NON_BLOCKING: + return AstNodeKind::kPeekIfNonBlocking; case AST_NODE_KIND_RANGE: return AstNodeKind::kRange; case AST_NODE_KIND_RECV: diff --git a/xls/examples/BUILD b/xls/examples/BUILD index 2bd417435f..8d4010d040 100644 --- a/xls/examples/BUILD +++ b/xls/examples/BUILD @@ -1480,3 +1480,47 @@ xls_dslx_verilog( library = "const_for_dslx", verilog_file = "const_for.sv", ) + +xls_dslx_library( + name = "content_based_arbiter_dslx", + srcs = ["content_based_arbiter.x"], +) + +xls_dslx_test( + name = "content_based_arbiter_dslx_test", + srcs = ["content_based_arbiter.x"], +) + +xls_dslx_library( + name = "blocking_peek_dslx", + srcs = [ + "blocking_peek.x", + ], +) + +xls_dslx_test( + name = "blocking_peek_dslx_test", + srcs = ["blocking_peek.x"], + dslx_test_args = {"compare": "jit"}, +) + +xls_dslx_library( + name = "peek_dslx", + srcs = ["peek.x"], +) + +xls_dslx_test( + name = "peek_dslx_test", + size = "small", + srcs = ["peek.x"], +) + +xls_dslx_ir( + name = "peek_ir", + dslx_top = "Peek", + ir_conv_args = { + "lower_to_proc_scoped_channels": "true" + }, + ir_file = "peek.ir", + library = ":peek_dslx", +) diff --git a/xls/examples/blocking_peek.x b/xls/examples/blocking_peek.x new file mode 100644 index 0000000000..acb4c31ccd --- /dev/null +++ b/xls/examples/blocking_peek.x @@ -0,0 +1,127 @@ +// Copyright 2026 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![feature(type_inference_v2)] + +import std; + +struct Packet { + id: u8, + data: uN[1024] +} + +// This proc presents a possible use case for the blocking variant of peek(), +// which may be useful when the decision about which packet to send must be +// made only after data becomes available on the input channel. +// When the input is available, the proc may decide either to forward/use +// that packet or to produce a different output. + +// The proc below implements a packet filler. We assume that packet IDs +// are provided in monotonically increasing order, although some values +// may be missing. When this happens, the filler proc generates artificial +// packets so that all IDs in the sequence are produced. +// +// Implementation that uses recv is provided after the commented-out +// code, so it is possible to verify the expected behaviour of this proc. + +struct PacketFillerState { + packet: Packet, // this is large + packet_valid: bool, + current_id: u8 +} + +struct PeekPacketFillerState { + current_id: u8 +} + +proc PeekPacketFiller { + type State = PeekPacketFillerState; + + req_r: chan in; + resp_s: chan out; + + init { zero!() } + + config( + req_r: chan in, + resp_s: chan out + ) { + (req_r, resp_s) + } + + next(state: State) { + let (tok, packet) = peek(join(), req_r); + let (packet, next_state) = if state.current_id < packet.id { + // we need to generate an artificial packet + ( + Packet { id: state.current_id, ..zero!() }, + State { current_id: state.current_id + u8:1} + ) + } else { + // we can use a packet from the input + let (tok, packet) = recv(tok, req_r); + ( + packet, + State { current_id: packet.id + u8:1 } + ) + }; + + let tok = send(tok, resp_s, packet); + next_state + } +} + +#[test_proc] +proc PacketFillerTest { + terminator: chan out; + + req_s: chan out; + resp_r: chan in; + + init { } + + config(terminator: chan out) { + let (req_s, req_r) = chan("req"); + let (resp_s, resp_r) = chan("resp"); + + spawn PeekPacketFiller(req_r, resp_s); + (terminator, req_s, resp_r) + } + + next(state: ()) { + let tok = join(); + let tok = send(tok, req_s, Packet { id: 3, data: uN[1024]:0xA }); + let tok = send(tok, req_s, Packet { id: 5, data: uN[1024]:0xB }); + let tok = send(tok, req_s, Packet { id: 7, data: uN[1024]:0xC }); + + let (tok, data) = recv(tok, resp_r); + assert_eq(data, Packet { id: 0 , data: uN[1024]:0 }); + let (tok, data) = recv(tok, resp_r); + assert_eq(data, Packet { id: 1 , data: uN[1024]:0 }); + let (tok, data) = recv(tok, resp_r); + assert_eq(data, Packet { id: 2 , data: uN[1024]:0 }); + let (tok, data) = recv(tok, resp_r); + assert_eq(data, Packet { id: 3 , data: uN[1024]:0xA }); + let (tok, data) = recv(tok, resp_r); + assert_eq(data, Packet { id: 4 , data: uN[1024]:0 }); + let (tok, data) = recv(tok, resp_r); + assert_eq(data, Packet { id: 5 , data: uN[1024]:0xB }); + let (tok, data) = recv(tok, resp_r); + assert_eq(data, Packet { id: 6 , data: uN[1024]:0 }); + let (tok, data) = recv(tok, resp_r); + assert_eq(data, Packet { id: 7 , data: uN[1024]:0xC }); + + send(tok, terminator, true); + } +} diff --git a/xls/examples/content_based_arbiter.x b/xls/examples/content_based_arbiter.x new file mode 100644 index 0000000000..41d3d5553d --- /dev/null +++ b/xls/examples/content_based_arbiter.x @@ -0,0 +1,227 @@ +// Copyright 2026 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![feature(type_inference_v2)] + +import std; + +type Priority = u8; + +struct Packet { + priority: Priority, // 0 is the highest prio + data: u32, +} + +struct PeekContentBasedArbiterState { + enabled: bool, +} + +fn highest_priority_packet(storage: Packet[N], storage_valid: bool[N]) -> (Packet, u32) { + let (init_packet, init_idx) = for (i, (packet, idx)): (u32, (Packet, u32)) in 0..N { + let i = N - i - 1; + if storage_valid[i] { (storage[i], i) } else { (packet, idx) } + }((zero!(), u32:0)); + + for (i, (packet, idx)): (u32, (Packet, u32)) in u32:0..N { + let has_priority = storage_valid[i] && (storage[i].priority < packet.priority); + if has_priority { (storage[i], i) } else { (packet, idx) } + }((init_packet, init_idx)) +} + +pub proc PeekContentBasedArbiter { + type State = PeekContentBasedArbiterState; + + enable_r: chan in; + enable_comp_s: chan<()> out; + inputs_r: chan[N] in; + output_s: chan out; + + config( + enable_r: chan in, + enable_comp_s: chan<()> out, + inputs_r: chan[N] in, + output_s: chan out + ) { + (enable_r, enable_comp_s, inputs_r, output_s) + } + + init { zero!() } + + next(state: State) { + let (recv_en_tok, enabled, enabled_valid) = + recv_non_blocking(join(), enable_r, state.enabled); + let send_en_tok = send_if(recv_en_tok, enable_comp_s, enabled_valid, ()); + + if state.enabled { + let (storage, storage_valid, peek_in_tok) = + const for (i, (storage, storage_valid, prev_tok)) in u32:0..N { + let (tok, data, data_valid) = + peek_non_blocking(join(), inputs_r[i], zero!()); + if data_valid { + ( + update(storage, i, data), + update(storage_valid, i, data_valid), + join(prev_tok, tok), + ) + } else { + (storage, storage_valid, join(prev_tok, tok)) + } + }((Packet[N]:[zero!(), ...], bool[N]:[false, ...], join())); + + let (packet, idx) = highest_priority_packet(storage, storage_valid); + + let has_value = or_reduce(std::convert_to_bits_msb0(storage_valid)); + let recv_in_tok = unroll_for!(i, tok): (u32, token) in u32:0..N { + let (tok, _) = + recv_if(peek_in_tok, inputs_r[i], has_value && (idx == i), zero!()); + tok + }(peek_in_tok); + let sent_out_tok = send_if(recv_in_tok, output_s, has_value, packet); + }; + State { enabled } + } +} + +proc PeekContentBasedArbiterInst { + config( + enable_r: chan in, + enable_comp_s: chan<()> out, + inputs_r: chan[3] in, + output_s: chan out + ) { + spawn PeekContentBasedArbiter<3>(enable_r, enable_comp_s, inputs_r, output_s); + } + + init { } + next(state: ()) { } +} + +#[test_proc] +proc PeekContentBasedArbiterTest { + terminator: chan out; + enable_s: chan out; + enable_comp_r: chan<()> in; + inputs_s: chan[3] out; + output_r: chan in; + + config(terminator: chan out) { + let (enable_s, enable_r) = chan("enable"); + let (enable_comp_s, enable_comp_r) = chan<()>("enable_comp"); + let (inputs_s, inputs_r) = chan[3]("inputs"); + let (output_s, output_r) = chan("output"); + + spawn PeekContentBasedArbiter<3>(enable_r, enable_comp_s, inputs_r, output_s); + (terminator, enable_s, enable_comp_r, inputs_s, output_r) + } + + init { } + + next(state: ()) { + let tok = join(); + + let tok = send(tok, enable_s, true); + let (tok, _) = recv(tok, enable_comp_r); + + // Send input data + + let tok = send(tok, inputs_s[0], Packet { priority: 3, data: 2 }); + let tok = send(tok, inputs_s[0], Packet { priority: 5, data: 5 }); + let tok = send(tok, inputs_s[0], Packet { priority: 7, data: 6 }); + + let tok = send(tok, inputs_s[1], Packet { priority: 2, data: 0 }); + let tok = send(tok, inputs_s[1], Packet { priority: 0, data: 1 }); + let tok = send(tok, inputs_s[1], Packet { priority: 3, data: 3 }); + + let tok = send(tok, inputs_s[2], Packet { priority: 4, data: 4 }); + let tok = send(tok, inputs_s[2], Packet { priority: 7, data: 7 }); + let tok = send(tok, inputs_s[2], Packet { priority: 0, data: 8 }); + + // Enable arbiter + // Collect output + + // I0: (p: 3), (p: 5), (p: 7) + // I1: (p: 2), (p: 0), (p: 3) + // I2: (p: 4), (p: 7), (p: 0) + + let (tok, data) = recv(tok, output_r); + trace_fmt!("Received: {}", data); + assert_eq(data, Packet { priority: 2, data: 0 }); + + // I0: (p: 3), (p: 5), (p: 7) + // I1: (p: 0), (p: 3) + // I2: (p: 4), (p: 7), (p: 0) + + let (tok, data) = recv(tok, output_r); + trace_fmt!("Received: {}", data); + assert_eq(data, Packet { priority: 0, data: 1 }); + + // I0: (p: 3), (p: 5), (p: 7) + // I1: (p: 3) + // I2: (p: 4), (p: 7), (p: 0) + + let (tok, data) = recv(tok, output_r); + trace_fmt!("Received: {}", data); + assert_eq(data, Packet { priority: 3, data: 2 }); + + // I0: (p: 5), (p: 7) + // I1: (p: 3) + // I2: (p: 4), (p: 7), (p: 0) + + let (tok, data) = recv(tok, output_r); + trace_fmt!("Received: {}", data); + assert_eq(data, Packet { priority: 3, data: 3 }); + + // I0: (p: 5), (p: 7) + // I1: + // I2: (p: 4), (p: 7), (p: 0) + + let (tok, data) = recv(tok, output_r); + trace_fmt!("Received: {}", data); + assert_eq(data, Packet { priority: 4, data: 4 }); + + // I0: (p: 5), (p: 7) + // I1: + // I2: (p: 7), (p: 0) + + let (tok, data) = recv(tok, output_r); + trace_fmt!("Received: {}", data); + assert_eq(data, Packet { priority: 5, data: 5 }); + + // I0: (p: 7) + // I1: + // I2: (p: 7), (p: 0) + + let (tok, data) = recv(tok, output_r); + trace_fmt!("Received: {}", data); + assert_eq(data, Packet { priority: 7, data: 6 }); + + // I0: + // I1: + // I2: (p: 7), (p: 0) + + let (tok, data) = recv(tok, output_r); + trace_fmt!("Received: {}", data); + assert_eq(data, Packet { priority: 7, data: 7 }); + + // I0: + // I1: + // I2: (p: 0) + + let (tok, data) = recv(tok, output_r); + trace_fmt!("Received: {}", data); + assert_eq(data, Packet { priority: 0, data: 8 }); + + send(tok, terminator, true); + } +} diff --git a/xls/examples/peek.x b/xls/examples/peek.x new file mode 100644 index 0000000000..5b9a81238a --- /dev/null +++ b/xls/examples/peek.x @@ -0,0 +1,264 @@ +// Copyright 2026 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Minimal peek examples for development and testing purposes. + +#![feature(type_inference_v2)] + +struct Packet { + id: u32, + data: u32, +} + +proc Peek { + req_r: chan in; + resp_s: chan out; + + init { } + + config( + req_r: chan in, + resp_s: chan out + ) { + (req_r, resp_s) + } + + next(state: ()) { + let (tok, packet) = peek(join(), req_r); + let handle_packet = packet.id > u32:4; + let (tok, packet) = recv_if(tok, req_r, handle_packet, zero!()); + send(tok, resp_s, packet); + } +} + +#[test_proc] +proc PeekTest { + req_s: chan out; + resp_r: chan in; + terminator: chan out; + + config(terminator: chan out) { + let (req_s, req_r) = chan("req"); + let (resp_s, resp_r) = chan("resp"); + spawn Peek(req_r, resp_s); + + (req_s, resp_r, terminator) + } + + init { } + + next(_: ()) { + // First packet + const FIRST_PACKET_ID = u32:5; + const FIRST_PACKET_DATA = u32:4; + let tok = send(join(), req_s, Packet{ + id: FIRST_PACKET_ID, + data: FIRST_PACKET_DATA + }); + let (tok, packet) = recv(tok, resp_r); + trace_fmt!("Received packet: {}", packet); + + // Second packet + const SECOND_PACKET_ID = u32:3; + const SECOND_PACKET_DATA = u32:16; + let tok = send(tok, req_s, Packet{ + id: SECOND_PACKET_ID, + data: SECOND_PACKET_DATA + }); + let (tok, packet) = recv(tok, resp_r); + trace_fmt!("Received packet: {}", packet); + + send(tok, terminator, true); + } +} + +proc PeekIf { + req_r: chan in; + resp_s: chan out; + enable_r: chan in; + + init { false } + + config( + req_r: chan in, + resp_s: chan out, + enable_r: chan in, + ) { + (req_r, resp_s, enable_r) + } + + next(state: bool) { + let (tok, enabled, valid) = recv_non_blocking(join(), enable_r, state); + let state = if valid { enabled } else { state }; + let (tok, packet) = peek_if(join(), req_r, state, zero!()); + let handle_packet = packet.id > u32:4; + let (tok, packet) = recv_if(tok, req_r, state && handle_packet, zero!()); + send_if(tok, resp_s, state && handle_packet, packet); + state + } +} + +#[test_proc] +proc PeekIfTest { + req_s: chan out; + resp_r: chan in; + enable_s: chan out; + terminator: chan out; + + config(terminator: chan out) { + let (req_s, req_r) = chan("req"); + let (resp_s, resp_r) = chan("resp"); + let (enable_s, enable_r) = chan("enable"); + spawn PeekIf(req_r, resp_s, enable_r); + + (req_s, resp_r, enable_s, terminator) + } + + init { } + + next(_: ()) { + // First packet + const FIRST_PACKET_ID = u32:5; + const FIRST_PACKET_DATA = u32:4; + let tok = send(join(), req_s, Packet{ + id: FIRST_PACKET_ID, + data: FIRST_PACKET_DATA + }); + let tok = send(tok, enable_s, true); + let (tok, packet) = recv(tok, resp_r); + trace_fmt!("Received packet: {}", packet); + + send(tok, terminator, true); + } +} + +proc PeekNonBlocking { + req_r: chan in; + resp_s: chan out; + + init { } + + config( + req_r: chan in, + resp_s: chan out + ) { + (req_r, resp_s) + } + + next(state: ()) { + let (tok, packet, valid) = peek_non_blocking(join(), req_r, zero!()); + let (tok, _) = recv_if(tok, req_r, valid, zero!()); + send_if(tok, resp_s, valid, packet); + } +} + +#[test_proc] +proc PeekNonBlockingTest { + req_s: chan out; + resp_r: chan in; + terminator: chan out; + + config(terminator: chan out) { + let (req_s, req_r) = chan("req"); + let (resp_s, resp_r) = chan("resp"); + spawn PeekNonBlocking(req_r, resp_s); + + (req_s, resp_r, terminator) + } + + init { } + + next(_: ()) { + // First packet + const FIRST_PACKET_ID = u32:5; + const FIRST_PACKET_DATA = u32:4; + let tok = send(join(), req_s, Packet{ + id: FIRST_PACKET_ID, + data: FIRST_PACKET_DATA + }); + let (tok, packet) = recv(tok, resp_r); + trace_fmt!("Received packet: {}", packet); + + // Second packet + const SECOND_PACKET_ID = u32:3; + const SECOND_PACKET_DATA = u32:16; + let tok = send(tok, req_s, Packet{ + id: SECOND_PACKET_ID, + data: SECOND_PACKET_DATA + }); + let (tok, packet) = recv(tok, resp_r); + trace_fmt!("Received packet: {}", packet); + + send(tok, terminator, true); + } +} + +proc PeekIfNonBlocking { + req_r: chan in; + resp_s: chan out; + enable_r: chan in; + + init { false } + + config( + req_r: chan in, + resp_s: chan out, + enable_r: chan in, + ) { + (req_r, resp_s, enable_r) + } + + next(state: bool) { + let (tok, enabled, valid) = recv_non_blocking(join(), enable_r, state); + let state = if valid { enabled } else { state }; + let (tok, packet, packet_valid) = peek_if_non_blocking(join(), req_r, state, zero!()); + let (tok, packet) = recv_if(tok, req_r, state && packet_valid, zero!()); + send_if(tok, resp_s, state && packet_valid, packet); + state + } +} + +#[test_proc] +proc PeekIfNonBlockingTest { + req_s: chan out; + resp_r: chan in; + enable_s: chan out; + terminator: chan out; + + config(terminator: chan out) { + let (req_s, req_r) = chan("req"); + let (resp_s, resp_r) = chan("resp"); + let (enable_s, enable_r) = chan("enable"); + spawn PeekIfNonBlocking(req_r, resp_s, enable_r); + + (req_s, resp_r, enable_s, terminator) + } + + init { } + + next(_: ()) { + // First packet + const FIRST_PACKET_ID = u32:5; + const FIRST_PACKET_DATA = u32:4; + let tok = send(join(), req_s, Packet{ + id: FIRST_PACKET_ID, + data: FIRST_PACKET_DATA + }); + let tok = send(tok, enable_s, true); + let (tok, packet) = recv(tok, resp_r); + trace_fmt!("Received packet: {}", packet); + + send(tok, terminator, true); + } +} diff --git a/xls/interpreter/block_interpreter.cc b/xls/interpreter/block_interpreter.cc index 770e90ef24..6861721b37 100644 --- a/xls/interpreter/block_interpreter.cc +++ b/xls/interpreter/block_interpreter.cc @@ -767,6 +767,10 @@ class ElaboratedBlockInterpreter final : public ElaboratedBlockDfsVisitor { XLS_RETURN_IF_ERROR(SetInstance(instance)); return current_interpreter_->HandleNext(next); } + absl::Status HandlePeek(Peek* peek, BlockInstance* instance) override { + XLS_RETURN_IF_ERROR(SetInstance(instance)); + return current_interpreter_->HandlePeek(peek); + } absl::Status HandleReceive(Receive* receive, BlockInstance* instance) override { XLS_RETURN_IF_ERROR(SetInstance(instance)); diff --git a/xls/interpreter/channel_queue.cc b/xls/interpreter/channel_queue.cc index 22ee074361..1f5f80e2e1 100644 --- a/xls/interpreter/channel_queue.cc +++ b/xls/interpreter/channel_queue.cc @@ -90,6 +90,24 @@ void ChannelQueue::WriteInternal(const Value& value) { queue_.push_back(value); } +std::optional ChannelQueue::Peek() { + absl::MutexLock lock(mutex_); + if (generator_.has_value()) { + // Write/ReadInternal are virtual and may have other side-effects so rather + // than directly returning the generated value, write then read it. + std::optional generated_value = (*generator_)(); + if (generated_value.has_value()) { + WriteInternal(generated_value.value()); + } + } + std::optional value = PeekInternal(); + VLOG(4) << absl::StreamFormat( + "Peeking data from channel instance %s: %s", + channel_instance()->ToString(), + value.has_value() ? value->ToString() : "(none)"); + return value; +} + std::optional ChannelQueue::Read() { absl::MutexLock lock(mutex_); if (generator_.has_value()) { @@ -111,6 +129,13 @@ std::optional ChannelQueue::Read() { int64_t ChannelQueue::GetSizeInternal() const { return queue_.size(); } +std::optional ChannelQueue::PeekInternal() { + if (queue_.empty()) { + return std::nullopt; + } + return queue_.front(); +} + std::optional ChannelQueue::ReadInternal() { if (queue_.empty()) { return std::nullopt; diff --git a/xls/interpreter/channel_queue.h b/xls/interpreter/channel_queue.h index 714e1994ac..58138d2eef 100644 --- a/xls/interpreter/channel_queue.h +++ b/xls/interpreter/channel_queue.h @@ -87,6 +87,10 @@ class ChannelQueue { // Writes the given value on to the channel. absl::Status Write(const Value& value); + // Peeks and returns a value from the channel without dropping it + // from the queue. Returns an std::nullopt if the channel is empty. + std::optional Peek(); + // Reads and returns a value from the channel. Returns an std::nullopt if // the channel is empty. std::optional Read(); @@ -118,6 +122,8 @@ class ChannelQueue { virtual int64_t GetSizeInternal() const ABSL_SHARED_LOCKS_REQUIRED(mutex_); virtual void WriteInternal(const Value& value) ABSL_SHARED_LOCKS_REQUIRED(mutex_); + virtual std::optional PeekInternal() + ABSL_SHARED_LOCKS_REQUIRED(mutex_); virtual std::optional ReadInternal() ABSL_SHARED_LOCKS_REQUIRED(mutex_); ChannelInstance* channel_instance_; diff --git a/xls/interpreter/ir_interpreter.cc b/xls/interpreter/ir_interpreter.cc index 7b6cca54f6..0f1d3a8ee2 100644 --- a/xls/interpreter/ir_interpreter.cc +++ b/xls/interpreter/ir_interpreter.cc @@ -777,6 +777,10 @@ absl::Status IrInterpreter::HandleSendChannelEnd(SendChannelEnd* sce) { "SendChannelEnd value not implemented in IrInterpreter"); } +absl::Status IrInterpreter::HandlePeek(Peek* peek) { + return absl::UnimplementedError("Peek not implemented in IrInterpreter"); +} + absl::Status IrInterpreter::HandleReverse(UnOp* reverse) { return SetBitsResult(reverse, bits_ops::Reverse(ResolveAsBits(reverse->operand(0)))); diff --git a/xls/interpreter/ir_interpreter.h b/xls/interpreter/ir_interpreter.h index ce38a9cd52..ccc2222e66 100644 --- a/xls/interpreter/ir_interpreter.h +++ b/xls/interpreter/ir_interpreter.h @@ -152,6 +152,7 @@ class IrInterpreter : public DfsVisitor { absl::Status HandleNewChannel(NewChannel* new_channel) override; absl::Status HandleRecvChannelEnd(RecvChannelEnd* rce) override; absl::Status HandleSendChannelEnd(SendChannelEnd* sce) override; + absl::Status HandlePeek(Peek* peek) override; absl::Status HandleReceive(Receive* receive) override; absl::Status HandleRegisterRead(RegisterRead* reg_read) override; absl::Status HandleRegisterWrite(RegisterWrite* reg_write) override; diff --git a/xls/interpreter/proc_interpreter.cc b/xls/interpreter/proc_interpreter.cc index 7d7227d79b..9eb371f4fe 100644 --- a/xls/interpreter/proc_interpreter.cc +++ b/xls/interpreter/proc_interpreter.cc @@ -142,6 +142,40 @@ class ProcIrInterpreter : public IrInterpreter { queue_manager_(queue_manager), active_next_values_(active_next_values) {} + absl::Status HandlePeek(Peek* peek) override { + XLS_ASSIGN_OR_RETURN(ChannelQueue * queue, + GetChannelQueue(peek->channel_name())); + + if (peek->predicate().has_value()) { + const Bits& pred = ResolveAsBits(peek->predicate().value()); + if (pred.IsZero()) { + // If the predicate is false, nothing is read from the channel. + // Rather the result of the peek is the zero values of the + // respective type. + return SetValueResult(peek, ZeroOfType(peek->GetType())); + } + } + + std::optional value = queue->Peek(); + if (!value.has_value()) { + if (peek->is_blocking()) { + // Record the channel this peek instruction is blocked on and exit. + blocked_channel_instance_ = queue->channel_instance(); + return absl::OkStatus(); + } + // A non-blocking peek returns a zero data value with a zero valid bit + // if the queue is empty. + return SetValueResult(peek, ZeroOfType(peek->GetType())); + } + + if (peek->is_blocking()) { + return SetValueResult(peek, Value::Tuple({Value::Token(), *value})); + } + + return SetValueResult( + peek, Value::Tuple({Value::Token(), *value, Value(UBits(1, 1))})); + } + absl::Status HandleReceive(Receive* receive) override { XLS_ASSIGN_OR_RETURN(ChannelQueue * queue, GetChannelQueue(receive->channel_name())); diff --git a/xls/ir/block_elaboration.cc b/xls/ir/block_elaboration.cc index ff9859edaf..ae93216070 100644 --- a/xls/ir/block_elaboration.cc +++ b/xls/ir/block_elaboration.cc @@ -742,6 +742,9 @@ absl::Status ElaboratedNode::VisitSingleNode( case Op::kReceive: return visitor.HandleReceive(absl::down_cast(node), instance); + case Op::kPeek: + return visitor.HandlePeek(absl::down_cast(node), instance); + case Op::kSend: return visitor.HandleSend(absl::down_cast(node), instance); diff --git a/xls/ir/dfs_visitor.cc b/xls/ir/dfs_visitor.cc index 9fa7ee3d77..c99757ea49 100644 --- a/xls/ir/dfs_visitor.cc +++ b/xls/ir/dfs_visitor.cc @@ -108,6 +108,10 @@ absl::Status DfsVisitorWithDefault::HandleCover(Cover* cover) { return DefaultHandler(cover); } +absl::Status DfsVisitorWithDefault::HandlePeek(Peek* peek) { + return DefaultHandler(peek); +} + absl::Status DfsVisitorWithDefault::HandleReceive(Receive* receive) { return DefaultHandler(receive); } diff --git a/xls/ir/dfs_visitor.h b/xls/ir/dfs_visitor.h index 9012dd26b6..b5c5f38364 100644 --- a/xls/ir/dfs_visitor.h +++ b/xls/ir/dfs_visitor.h @@ -87,6 +87,7 @@ class DfsVisitor { virtual absl::Status HandleNewChannel(NewChannel* new_channel) = 0; virtual absl::Status HandleRecvChannelEnd(RecvChannelEnd* rce) = 0; virtual absl::Status HandleSendChannelEnd(SendChannelEnd* sce) = 0; + virtual absl::Status HandlePeek(Peek* peek) = 0; virtual absl::Status HandleReceive(Receive* receive) = 0; virtual absl::Status HandleRegisterRead(RegisterRead* reg_read) = 0; virtual absl::Status HandleRegisterWrite(RegisterWrite* reg_write) = 0; @@ -214,6 +215,7 @@ class DfsVisitorWithDefault : public DfsVisitor { absl::Status HandleNewChannel(NewChannel* new_channel) override; absl::Status HandleRecvChannelEnd(RecvChannelEnd* rce) override; absl::Status HandleSendChannelEnd(SendChannelEnd* sce) override; + absl::Status HandlePeek(Peek* peek) override; absl::Status HandleReceive(Receive* receive) override; absl::Status HandleRegisterRead(RegisterRead* reg_read) override; absl::Status HandleRegisterWrite(RegisterWrite* reg_write) override; diff --git a/xls/ir/elaborated_block_dfs_visitor.cc b/xls/ir/elaborated_block_dfs_visitor.cc index d198744dee..ae299d763e 100644 --- a/xls/ir/elaborated_block_dfs_visitor.cc +++ b/xls/ir/elaborated_block_dfs_visitor.cc @@ -133,6 +133,11 @@ absl::Status ElaboratedBlockDfsVisitorWithDefault::HandleCover( return DefaultHandler(ElaboratedNode{.node = cover, .instance = instance}); } +absl::Status ElaboratedBlockDfsVisitorWithDefault::HandlePeek( + Peek* peek, BlockInstance* instance) { + return DefaultHandler(ElaboratedNode{.node = peek, .instance = instance}); +} + absl::Status ElaboratedBlockDfsVisitorWithDefault::HandleReceive( Receive* receive, BlockInstance* instance) { return DefaultHandler(ElaboratedNode{.node = receive, .instance = instance}); diff --git a/xls/ir/elaborated_block_dfs_visitor.h b/xls/ir/elaborated_block_dfs_visitor.h index b3ef22bfdc..8d2a9937b6 100644 --- a/xls/ir/elaborated_block_dfs_visitor.h +++ b/xls/ir/elaborated_block_dfs_visitor.h @@ -106,6 +106,8 @@ class ElaboratedBlockDfsVisitor { virtual absl::Status HandleStateRead(StateRead* state_read, BlockInstance* instance) = 0; virtual absl::Status HandleNext(Next* next, BlockInstance* instance) = 0; + virtual absl::Status HandlePeek(Peek* peek, + BlockInstance* instance) = 0; virtual absl::Status HandleReceive(Receive* receive, BlockInstance* instance) = 0; virtual absl::Status HandleRegisterRead(RegisterRead* reg_read, @@ -264,6 +266,8 @@ class ElaboratedBlockDfsVisitorWithDefault : public ElaboratedBlockDfsVisitor { absl::Status HandleStateRead(StateRead* state_read, BlockInstance* instance) override; absl::Status HandleNext(Next* next, BlockInstance* instance) override; + absl::Status HandlePeek(Peek* peek, + BlockInstance* instance) override; absl::Status HandleReceive(Receive* receive, BlockInstance* instance) override; absl::Status HandleRegisterRead(RegisterRead* reg_read, diff --git a/xls/ir/function_builder.cc b/xls/ir/function_builder.cc index 3f7fd3d64d..28ddb485ea 100644 --- a/xls/ir/function_builder.cc +++ b/xls/ir/function_builder.cc @@ -1046,6 +1046,92 @@ BValue BuilderBase::Concat(absl::Span operands, return AddNode(loc, node_operands, name); } +BValue BuilderBase::Peek(ReceiveChannelRef channel, BValue token, + const SourceInfo& loc, std::string_view name) { + if (ErrorPending()) { + return BValue(); + } + if (!token.GetType()->IsToken()) { + return SetError( + absl::StrFormat( + "Token operand of peek must be of token type; is: %s", + token.GetType()->ToString()), + loc); + } + return AddNode(loc, token.node(), /*predicate=*/std::nullopt, + ChannelRefName(channel), /*is_blocking=*/true, + ChannelRefType(channel), name); +} + +BValue BuilderBase::PeekIf(ReceiveChannelRef channel, BValue token, BValue pred, + const SourceInfo& loc, std::string_view name) { + if (ErrorPending()) { + return BValue(); + } + if (!token.GetType()->IsToken()) { + return SetError( + absl::StrFormat( + "Token operand of peek must be of token type; is: %s", + token.GetType()->ToString()), + loc); + } + if (!pred.GetType()->IsBits() || + pred.GetType()->AsBitsOrDie()->bit_count() != 1) { + return SetError( + absl::StrFormat("Predicate operand of peek_if must be of bits " + "type of width 1; is: %s", + pred.GetType()->ToString()), + loc); + } + return AddNode( + loc, token.node(), pred.node(), ChannelRefName(channel), + /*is_blocking=*/true, ChannelRefType(channel), name); +} + +BValue BuilderBase::PeekIfNonBlocking( + ReceiveChannelRef channel, BValue token, BValue pred, + const SourceInfo& loc, std::string_view name) { + if (ErrorPending()) { + return BValue(); + } + if (!token.GetType()->IsToken()) { + return SetError( + absl::StrFormat( + "Token operand of peek must be of token type; is: %s", + token.GetType()->ToString()), + loc); + } + if (!pred.GetType()->IsBits() || + pred.GetType()->AsBitsOrDie()->bit_count() != 1) { + return SetError( + absl::StrFormat("Predicate operand of peek_if must be of bits " + "type of width 1; is: %s", + pred.GetType()->ToString()), + loc); + } + return AddNode( + loc, token.node(), pred.node(), ChannelRefName(channel), + /*is_blocking=*/false, ChannelRefType(channel), name); +} + +BValue BuilderBase::PeekNonBlocking( + ReceiveChannelRef channel, BValue token, + const SourceInfo& loc, std::string_view name) { + if (ErrorPending()) { + return BValue(); + } + if (!token.GetType()->IsToken()) { + return SetError( + absl::StrFormat( + "Token operand of peek must be of token type; is: %s", + token.GetType()->ToString()), + loc); + } + return AddNode(loc, token.node(), /*predicate=*/std::nullopt, + ChannelRefName(channel), /*is_blocking=*/false, + ChannelRefType(channel), name); +} + BValue BuilderBase::Receive(ReceiveChannelRef channel, BValue token, const SourceInfo& loc, std::string_view name) { if (ErrorPending()) { diff --git a/xls/ir/function_builder.h b/xls/ir/function_builder.h index 99db8b7d33..d36f294889 100644 --- a/xls/ir/function_builder.h +++ b/xls/ir/function_builder.h @@ -653,6 +653,32 @@ class BuilderBase { BValue Gate(BValue condition, BValue data, const SourceInfo& loc = SourceInfo(), std::string_view name = ""); + // Add a peek operation. The type of the peeked data value is + // determined by the channel. + BValue Peek(ReceiveChannelRef channel, BValue token, + const SourceInfo& loc = SourceInfo(), + std::string_view name = ""); + + // Add a conditional peek operation. The peek execution is determined + // by the value of predicate `pred`. The type of the peeked data value is + // determined by the channel. + BValue PeekIf(ReceiveChannelRef channel, BValue token, BValue pred, + const SourceInfo& loc = SourceInfo(), + std::string_view name = ""); + + // Add a conditional non-blocking peek operation. The peek execution is + // determined by the value of predicate `pred`. The type of the peeked data + // value is determined by the channel. + BValue PeekIfNonBlocking(ReceiveChannelRef channel, BValue token, + BValue pred, const SourceInfo& loc = SourceInfo(), + std::string_view name = ""); + + // Add a non-blocking peek operation. The type of the peeked data value is + // determined by the channel. + BValue PeekNonBlocking(ReceiveChannelRef channel, BValue token, + const SourceInfo& loc = SourceInfo(), + std::string_view name = ""); + // Add a receive operation. The type of the data value received is // determined by the channel. BValue Receive(ReceiveChannelRef channel, BValue token, diff --git a/xls/ir/ir_parser.cc b/xls/ir/ir_parser.cc index 20e5c6f2af..50b7b4109c 100644 --- a/xls/ir/ir_parser.cc +++ b/xls/ir/ir_parser.cc @@ -1120,6 +1120,68 @@ absl::StatusOr Parser::ParseNode( *loc, node_name); break; } + case Op::kPeek: { + XLS_ASSIGN_OR_RETURN(Proc * proc, + GetEffectiveProcOrError( + fb, "peek operations only supported in procs", + op_token.pos())); + std::optional* predicate = + arg_parser.AddOptionalKeywordArg("predicate"); + IdentifierString* channel_name = + arg_parser.AddKeywordArg("channel"); + bool* is_blocking = arg_parser.AddOptionalKeywordArg( + "blocking", /*default_value=*/true); + XLS_ASSIGN_OR_RETURN(operands, arg_parser.Run(/*arity=*/1)); + // Get the channel from the package. + if (!HasReceiveChannelRef(proc, channel_name->value)) { + if (!HasSendChannelRef(proc, channel_name->value)) { + return absl::InvalidArgumentError( + absl::StrFormat("No such channel `%s`", channel_name->value)); + } + return absl::InvalidArgumentError(absl::StrFormat( + "Cannot receive on channel `%s`", channel_name->value)); + } + ReceiveChannelRef channel_ref; + Type* channel_type; + if (proc->is_new_style_proc()) { + XLS_ASSIGN_OR_RETURN( + channel_ref, proc->GetReceiveChannelInterface(channel_name->value)); + channel_type = std::get(channel_ref)->type(); + } else { + XLS_ASSIGN_OR_RETURN(channel_ref, + package->GetChannel(channel_name->value)); + channel_type = std::get(channel_ref)->type(); + } + + Type* expected_type = + (*is_blocking) + ? package->GetTupleType({package->GetTokenType(), channel_type}) + : package->GetTupleType({package->GetTokenType(), channel_type, + package->GetBitsType(1)}); + + if (expected_type != type) { + return absl::InvalidArgumentError( + absl::StrFormat("peek op type is type: %s. Expected: %s", + type->ToString(), expected_type->ToString())); + } + if (predicate->has_value()) { + if (*is_blocking) { + bvalue = fb->PeekIf(channel_ref, operands[0], predicate->value(), + *loc, node_name); + } else { + bvalue = fb->PeekIfNonBlocking( + channel_ref, operands[0], predicate->value(), *loc, node_name); + } + } else { + if (*is_blocking) { + bvalue = fb->Peek(channel_ref, operands[0], *loc, node_name); + } else { + bvalue = + fb->PeekNonBlocking(channel_ref, operands[0], *loc, node_name); + } + } + break; + } case Op::kReceive: { XLS_ASSIGN_OR_RETURN(Proc * proc, GetEffectiveProcOrError( diff --git a/xls/ir/node.cc b/xls/ir/node.cc index ecb21e8575..a6d322f107 100644 --- a/xls/ir/node.cc +++ b/xls/ir/node.cc @@ -170,6 +170,10 @@ absl::Status Node::VisitSingleNode(DfsVisitor* visitor) { case Op::kTrace: XLS_RETURN_IF_ERROR(visitor->HandleTrace(absl::down_cast(this))); break; + case Op::kPeek: + XLS_RETURN_IF_ERROR( + visitor->HandlePeek(absl::down_cast(this))); + break; case Op::kReceive: XLS_RETURN_IF_ERROR( visitor->HandleReceive(absl::down_cast(this))); @@ -737,6 +741,20 @@ std::string Node::ToStringInternal(bool include_operand_types) const { args.push_back(absl::StrFormat("channel=%s", send->channel_name())); break; } + case Op::kPeek: { + const Peek* peek = As(); + if (peek->predicate().has_value()) { + args = {operand(0)->GetName()}; + args.push_back(absl::StrFormat( + "predicate=%s", peek->predicate().value()->GetName())); + } + args.push_back(absl::StrFormat("channel=%s", peek->channel_name())); + if (peek->is_blocking() == false) { + // Default blocking=true so we only need to push is !is_blocking(). + args.push_back("blocking=false"); + } + break; + } case Op::kReceive: { const Receive* receive = As(); if (receive->predicate().has_value()) { diff --git a/xls/ir/nodes.cc b/xls/ir/nodes.cc index 763b4bfe03..29a6b604a4 100755 --- a/xls/ir/nodes.cc +++ b/xls/ir/nodes.cc @@ -467,6 +467,21 @@ absl::StatusOr ChannelNode::GetChannelRef() const { return package()->GetChannel(channel_name()); } +absl::StatusOr Peek::GetReceiveChannelRef() const { + FunctionBase* fb = this->function_base(); + if (fb->IsBlock()) { + XLS_RET_CHECK(fb->IsScheduled()); + ScheduledBlock* sb = absl::down_cast(fb); + XLS_RET_CHECK_NE(sb->source(), nullptr); + fb = sb->source(); + } + Proc* proc = fb->AsProcOrDie(); + if (proc->is_new_style_proc()) { + return proc->GetReceiveChannelInterface(channel_name()); + } + return package()->GetChannel(channel_name()); +} + absl::StatusOr Receive::GetReceiveChannelRef() const { FunctionBase* fb = this->function_base(); if (fb->IsBlock()) { @@ -546,6 +561,32 @@ absl::Status ChannelNode::ReplaceChannel(std::string_view new_channel_name) { return absl::OkStatus(); } +Peek::Peek(const SourceInfo& loc, Node* token, + std::optional predicate, std::string_view channel_name, + bool is_blocking, Type* payload_type, std::string_view name, + FunctionBase* function) + : ChannelNode( + loc, Op::kPeek, + GetReceiveType(function->package(), is_blocking, payload_type), + channel_name, ChannelDirection::kReceive, predicate.has_value(), name, + function), + is_blocking_(is_blocking) { + CHECK(IsOpClass(op_)) + << "Op `" << op_ << "` is not a valid op for Node class `Peek`."; + AddOperand(token); + // Predicate is expected to be the last operand. + AddOptionalOperand(predicate); +} + +bool Peek::IsDefinitelyEqualTo(const Node* other) const { + if (this == other) { + return true; + } + return Node::IsDefinitelyEqualTo(other) && + channel_name() == other->As()->channel_name() && + is_blocking() == other->As()->is_blocking(); +} + Receive::Receive(const SourceInfo& loc, Node* token, std::optional predicate, std::string_view channel_name, bool is_blocking, Type* payload_type, std::string_view name, @@ -1424,6 +1465,15 @@ absl::StatusOr Trace::CloneInNewFunction( format(), verbosity(), GetNameView()); } +absl::StatusOr Peek::CloneInNewFunction( + absl::Span new_operands, FunctionBase* new_function) const { + return new_function->MakeNodeWithName( + loc(), new_operands[0], + new_operands.size() > 1 ? std::make_optional(new_operands[1]) + : std::nullopt, + channel_name(), is_blocking(), GetPayloadType(), GetNameView()); +} + absl::StatusOr Receive::CloneInNewFunction( absl::Span new_operands, FunctionBase* new_function) const { // TODO(meheff): Choose an appropriate name for the cloned node. diff --git a/xls/ir/nodes.h b/xls/ir/nodes.h index 7cef2f05bd..3d989a781c 100755 --- a/xls/ir/nodes.h +++ b/xls/ir/nodes.h @@ -1164,6 +1164,29 @@ class ChannelNode : public Node { bool has_predicate_; }; +class Peek final : public ChannelNode { + public: + static constexpr std::array kOps = {Op::kPeek}; + static constexpr int64_t kTokenOperand = 0; + + Peek(const SourceInfo& loc, Node* token, std::optional predicate, + std::string_view channel_name, bool is_blocking, Type* payload_type, + std::string_view name, FunctionBase* function); + + absl::StatusOr CloneInNewFunction( + absl::Span new_operands, + FunctionBase* new_function) const final; + + bool is_blocking() const { return is_blocking_; } + + bool IsDefinitelyEqualTo(const Node* other) const final; + + absl::StatusOr GetReceiveChannelRef() const; + + private: + bool is_blocking_; +}; + class Receive final : public ChannelNode { public: static constexpr std::array kOps = {Op::kReceive}; diff --git a/xls/ir/op.proto b/xls/ir/op.proto index 7faa104a8b..c20852e78e 100644 --- a/xls/ir/op.proto +++ b/xls/ir/op.proto @@ -99,4 +99,5 @@ enum OpProto { OP_NEW_CHANNEL = 77; OP_SEND_CHANNEL_END = 78; OP_RECV_CHANNEL_END = 79; + OP_PEEK = 80; } diff --git a/xls/ir/op_list.h b/xls/ir/op_list.h index d4342de6d6..d0de12f79c 100644 --- a/xls/ir/op_list.h +++ b/xls/ir/op_list.h @@ -81,6 +81,7 @@ inline constexpr uint8_t kSideEffecting = 0b00010000; F(kOutputPort, OP_OUTPUT_PORT, "output_port", op_types::kSideEffecting) \ F(kParam, OP_PARAM, "param", op_types::kSideEffecting) \ F(kPrioritySel, OP_PRIORITY_SEL, "priority_sel", op_types::kStandard) \ + F(kPeek, OP_PEEK, "peek", op_types::kSideEffecting) \ F(kReceive, OP_RECEIVE, "receive", op_types::kSideEffecting) \ F(kRecvChannelEnd, OP_RECV_CHANNEL_END, "recv_channel_end", \ op_types::kSideEffecting) \ diff --git a/xls/ir/proc_conversion.cc b/xls/ir/proc_conversion.cc index a4e513fc0e..2a23a0846f 100644 --- a/xls/ir/proc_conversion.cc +++ b/xls/ir/proc_conversion.cc @@ -83,8 +83,9 @@ absl::StatusOr GetChannelProcMap(Package* package) { channel_map.channel_to_procs[channel].insert(proc.get()); channel_map.proc_to_channels[proc.get()].insert(channel); channel_map.directions[{proc.get(), channel}].insert( - node->Is() ? ChannelDirection::kReceive - : ChannelDirection::kSend); + node->Is() || node->Is() + ? ChannelDirection::kReceive + : ChannelDirection::kSend); } } } diff --git a/xls/ir/verify_node.cc b/xls/ir/verify_node.cc index 7931e869fe..90547fdb46 100644 --- a/xls/ir/verify_node.cc +++ b/xls/ir/verify_node.cc @@ -130,6 +130,62 @@ class NodeChecker : public DfsVisitor { return absl::OkStatus(); } + absl::Status HandlePeek(Peek* peek) override { + XLS_RETURN_IF_ERROR(ExpectOperandCountRange(peek, 1, 2)); + XLS_RETURN_IF_ERROR(ExpectOperandHasTokenType(peek, /*operand_no=*/0)); + if (peek->predicate().has_value()) { + XLS_RETURN_IF_ERROR( + ExpectOperandHasBitsType(peek, 1, /*expected_bit_count=*/1)); + } + if (!peek->function_base()->HasEffectiveProc()) { + return absl::InternalError(absl::StrFormat( + "Peek node %s is not in a proc", peek->GetName())); + } + Proc* proc = peek->function_base()->GetEffectiveProcOrDie(); + Type* channel_type; + if (proc->is_new_style_proc()) { + if (!proc->HasChannelInterface(peek->channel_name(), + ChannelDirection::kReceive)) { + return absl::InternalError( + absl::StrFormat("No receivable channel named `%s`, node %s", + peek->channel_name(), peek->GetName())); + } + XLS_ASSIGN_OR_RETURN( + ChannelInterface * channel_ref, + proc->GetChannelInterface(peek->channel_name(), + ChannelDirection::kReceive)); + channel_type = channel_ref->type(); + } else { + if (!peek->package()->HasChannelWithName(peek->channel_name())) { + return absl::InternalError( + absl::StrFormat("%s refers to channel `%s` which does not exist", + peek->GetName(), peek->channel_name())); + } + XLS_ASSIGN_OR_RETURN(Channel * channel, peek->package()->GetChannel( + peek->channel_name())); + + channel_type = channel->type(); + if (!channel->CanReceive()) { + return absl::InternalError(absl::StrFormat( + "Cannot peek over channel `%s`, peek operation: %s", + peek->channel_name(), peek->GetName())); + } + } + Type* expected_type = + peek->is_blocking() + ? peek->package()->GetTupleType( + {peek->package()->GetTokenType(), channel_type}) + : peek->package()->GetTupleType( + {peek->package()->GetTokenType(), channel_type, + peek->package()->GetBitsType(1)}); + if (peek->GetType() != expected_type) { + return absl::InternalError(absl::StrFormat( + "Expected %s to have type %s, has type %s", peek->GetName(), + expected_type->ToString(), peek->GetType()->ToString())); + } + return absl::OkStatus(); + } + absl::Status HandleReceive(Receive* receive) override { XLS_RETURN_IF_ERROR(ExpectOperandCountRange(receive, 1, 2)); XLS_RETURN_IF_ERROR(ExpectOperandHasTokenType(receive, /*operand_no=*/0)); diff --git a/xls/passes/bdd_query_engine.cc b/xls/passes/bdd_query_engine.cc index c54c57db69..d2a6481172 100644 --- a/xls/passes/bdd_query_engine.cc +++ b/xls/passes/bdd_query_engine.cc @@ -170,6 +170,7 @@ bool ShouldEvaluate(Node* node) { case Op::kParam: case Op::kStateRead: case Op::kNext: + case Op::kPeek: case Op::kReceive: case Op::kRecvChannelEnd: case Op::kRegisterRead: diff --git a/xls/passes/cse_pass.cc b/xls/passes/cse_pass.cc index 2af32d352a..5ad3cdc433 100644 --- a/xls/passes/cse_pass.cc +++ b/xls/passes/cse_pass.cc @@ -308,6 +308,7 @@ class CseNodeArena { case Op::kNext: case Op::kOutputPort: case Op::kParam: + case Op::kPeek: case Op::kReceive: case Op::kRecvChannelEnd: case Op::kRegisterRead: diff --git a/xls/passes/range_query_engine.cc b/xls/passes/range_query_engine.cc index fd9cf54d21..3ab63d57bc 100644 --- a/xls/passes/range_query_engine.cc +++ b/xls/passes/range_query_engine.cc @@ -256,6 +256,7 @@ class RangeQueryVisitor : public DfsVisitor { absl::Status HandleNewChannel(NewChannel* new_channel) override; absl::Status HandleRecvChannelEnd(RecvChannelEnd* rce) override; absl::Status HandleSendChannelEnd(SendChannelEnd* sce) override; + absl::Status HandlePeek(Peek* peek) override; absl::Status HandleReceive(Receive* receive) override; absl::Status HandleRegisterRead(RegisterRead* reg_read) override; absl::Status HandleRegisterWrite(RegisterWrite* reg_write) override; @@ -1179,6 +1180,11 @@ absl::Status RangeQueryVisitor::HandleSendChannelEnd(SendChannelEnd* sce) { return absl::OkStatus(); } +absl::Status RangeQueryVisitor::HandlePeek(Peek* peek) { + INITIALIZE_OR_SKIP(peek); + return absl::OkStatus(); +} + absl::Status RangeQueryVisitor::HandleReceive(Receive* receive) { INITIALIZE_OR_SKIP(receive); return absl::OkStatus(); // TODO(taktoa): implement: interprocedural