Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions xls/dslx/frontend/ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1956,9 +1956,10 @@ StructInstanceBase::StructInstanceBase(
members_(std::move(members)) {}

std::vector<std::pair<std::string, Expr*>>
StructInstanceBase::GetOrderedMembers(const StructDef* struct_def) const {
StructInstanceBase::GetOrderedMembers(
const StructDefBase* struct_or_proc_def) const {
std::vector<std::pair<std::string, Expr*>> result;
for (const std::string& name : struct_def->GetMemberNames()) {
for (const std::string& name : struct_or_proc_def->GetMemberNames()) {
absl::StatusOr<Expr*> expr = GetExpr(name);
if (absl::IsNotFound(expr.status()) && !requires_all_members()) {
continue;
Expand Down
4 changes: 2 additions & 2 deletions xls/dslx/frontend/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -3618,9 +3618,9 @@ class StructInstanceBase : public Expr {
}

// Returns the members for the struct instance, ordered by the (resolved)
// struct definition "struct_def".
// struct definition "struct_or_proc_def".
std::vector<std::pair<std::string, Expr*>> GetOrderedMembers(
const StructDef* struct_def) const;
const StructDefBase* struct_or_proc_def) const;

const std::vector<std::pair<std::string, Expr*>>& members() const {
return members_;
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
1:0-1:19: PROC_DEF :: `proc Foo {
a: u32,
}` :: typeof(Foo { a: uN[32] })
1:14-1:17: TYPE_ANNOTATION :: `u32` :: typeof(uN[32])
5 changes: 5 additions & 0 deletions xls/dslx/type_system_v2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
],
)
Expand Down Expand Up @@ -524,6 +525,7 @@ cc_library(
":inference_table",
":inference_table_utils",
":type_annotation_utils",
"//xls/common:attribute_data",
"//xls/common:visitor",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
Expand Down Expand Up @@ -695,7 +697,10 @@ cc_library(
"//xls/dslx/frontend:ast_node_visitor_with_default",
"//xls/dslx/frontend:module",
"//xls/dslx/frontend:pos",
"//xls/dslx/type_system:type",
"//xls/dslx/type_system:type_info",
"@com_google_absl//absl/base",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down
62 changes: 62 additions & 0 deletions xls/dslx/type_system_v2/constant_collector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/substitute.h"
#include "absl/types/span.h"
#include "xls/common/status/ret_check.h"
Expand Down Expand Up @@ -682,6 +683,67 @@ class Visitor : public AstNodeVisitorWithDefault {
return absl::OkStatus();
}

absl::Status HandleStructInstance(const StructInstance* node) {
if (!type_.IsProc()) {
return absl::OkStatus();
}

// For a StructInstance node that is creating an impl-style proc, we store a
// tuple of the initial state values as the constexpr value in TypeInfo.
// This is equivalent to the result of the 'init' block in legacy procs.
const ProcType& type = type_.AsProc();
std::vector<InterpValue> state_init_values;
for (const auto& [member_name, initializer] :
node->GetOrderedMembers(&type.AsProc().struct_def_base())) {
std::optional<const Type*> maybe_member_type =
type.GetMemberTypeByName(member_name);
XLS_RET_CHECK(maybe_member_type.has_value());

// The inferred types of the state members are State<T> where T is the
// type specified in the DSLX source. These are the only members we need
// to evaluate.
const Type& member_type = **maybe_member_type;
if (!IsProcDefStateType(member_type, import_data_)) {
continue;
}

absl::StatusOr<InterpValue> value = ConstexprEvaluator::EvaluateToValue(
&import_data_, ti_, &warning_collector_,
table_.GetParametricEnv(parametric_context_), initializer);
if (!value.ok()) {
return TypeInferenceErrorStatus(
initializer->span(), &type_,
absl::Substitute("Initializer for member `$0` of proc `$1` must be "
"possible to evaluate at compile time.",
member_name,
type.AsProc().struct_def_base().identifier()),
file_table_);
}

state_init_values.push_back(std::move(*value));
}

InterpValue value = InterpValue::MakeTuple(std::move(state_init_values));
ti_->NoteConstExpr(node, value);
VLOG(6) << "Storing value " << value.ToHumanString()
<< " for proc initializer " << node->ToString();

// Propagate the proc "value" through the statement and/or statement block
// that yields it. This way IR conversion can easily say "give me the
// constant value for the body of the proc constructor."
for (AstNode* parent = node->parent();
parent != nullptr && (parent->kind() == AstNodeKind::kStatement ||
parent->kind() == AstNodeKind::kStatementBlock);
parent = parent->parent()) {
VLOG(6) << "Propagating proc initializer value for expr `"
<< node->ToString() << "` to ancestor of kind "
<< AstNodeKindToString(parent->kind());
ti_->NoteConstExpr(parent, value);
}

return absl::OkStatus();
}

private:
InferenceTable& table_;
Module& module_;
Expand Down
38 changes: 23 additions & 15 deletions xls/dslx/type_system_v2/flatten_in_type_order.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,11 @@ class Flattener : public AstNodeVisitorWithDefault {
}

absl::Status HandleStructDef(const StructDef* node) override {
if (node->IsParametric() && node != root_) {
return absl::OkStatus();
}
// StructDefBase::GetChildren does not return StructMemberNodes, this is
// blocked by https://github.com/google/xls/issues/1756.
nodes_.push_back(node->name_def());
for (const ParametricBinding* parametric_binding :
node->parametric_bindings()) {
XLS_RETURN_IF_ERROR(parametric_binding->Accept(this));
}
for (const StructMemberNode* member : node->members()) {
XLS_RETURN_IF_ERROR(member->Accept(this));
}
nodes_.push_back(node);
return absl::OkStatus();
return HandleStructDefBaseInternal(node);
}

absl::Status HandleProcDef(const ProcDef* node) override {
return HandleStructDefBaseInternal(node);
}

absl::Status HandleTypeRef(const TypeRef* node) override {
Expand Down Expand Up @@ -257,6 +247,24 @@ class Flattener : public AstNodeVisitorWithDefault {
const std::vector<const AstNode*>& nodes() const { return nodes_; }

private:
absl::Status HandleStructDefBaseInternal(const StructDefBase* node) {
if (node->IsParametric() && node != root_) {
return absl::OkStatus();
}
// StructDefBase::GetChildren does not return StructMemberNodes, this is
// blocked by https://github.com/google/xls/issues/1756.
nodes_.push_back(node->name_def());
for (const ParametricBinding* parametric_binding :
node->parametric_bindings()) {
XLS_RETURN_IF_ERROR(parametric_binding->Accept(this));
}
for (const StructMemberNode* member : node->members()) {
XLS_RETURN_IF_ERROR(member->Accept(this));
}
nodes_.push_back(node);
return absl::OkStatus();
}

const ImportData& import_data_;
const AstNode* const root_;
const bool include_parametric_entities_;
Expand Down
29 changes: 29 additions & 0 deletions xls/dslx/type_system_v2/import_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include "xls/dslx/frontend/module.h"
#include "xls/dslx/frontend/pos.h"
#include "xls/dslx/import_data.h"
#include "xls/dslx/type_system/type.h"
#include "xls/dslx/type_system/type_info.h"
#include "xls/dslx/type_system_v2/type_annotation_utils.h"

namespace xls::dslx {
Expand Down Expand Up @@ -247,4 +249,31 @@ absl::StatusOr<bool> IsProcDefNextFunction(const Function* f,
return (*def)->kind() == AstNodeKind::kProcDef;
}

bool IsProcDefStateType(const Type& type, const ImportData& import_data) {
// A state element in a `ProcDef` always uses explicit state access, so should
// have been made into a State<T> by semantics analysis, i.e. it is always a
// struct in type inference.
if (!type.IsStruct()) {
return false;
}

Module* builtin_stubs = *import_data.GetBuiltinStubsModule();
const StructDef* state_struct_def =
*builtin_stubs->GetMember<StructDef>("State");
return &type.AsStruct().struct_def_base() == state_struct_def;
}

absl::StatusOr<std::vector<StructMemberNode*>> GetProcDefStateMembers(
const ProcDef* proc_def, const ImportData& import_data,
const TypeInfo& type_info) {
std::vector<StructMemberNode*> result;
for (StructMemberNode* member : proc_def->members()) {
XLS_ASSIGN_OR_RETURN(const Type* type, type_info.GetItemOrError(member));
if (IsProcDefStateType(*type, import_data)) {
result.push_back(member);
}
}
return result;
}

} // namespace xls::dslx
14 changes: 14 additions & 0 deletions xls/dslx/type_system_v2/import_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
#define XLS_DSLX_TYPE_SYSTEM_V2_IMPORT_UTILS_H_

#include <optional>
#include <vector>

#include "absl/status/statusor.h"
#include "xls/dslx/frontend/ast.h"
#include "xls/dslx/frontend/module.h"
#include "xls/dslx/frontend/pos.h"
#include "xls/dslx/import_data.h"
#include "xls/dslx/type_system/type.h"
#include "xls/dslx/type_system/type_info.h"
#include "xls/dslx/type_system_v2/type_annotation_utils.h"

namespace xls::dslx {
Expand Down Expand Up @@ -74,6 +77,17 @@ absl::StatusOr<bool> IsProcDefNextFunction(const Function* f,
// Returns whether `colon_ref` is imported from a different module.
bool IsImport(const ColonRef* colon_ref);

// Returns whether the given `type` of a member of a `ProcDef` indicates that
// the member is a state element. All state members of a `ProcDef` have the
// inferred type `State<T>` where `T` is the type written by the programmer, so
// this amounts to checking if `type` is a "State-wrapped" type.
bool IsProcDefStateType(const Type& type, const ImportData& import_data);

// Returns all the members of `proc_def` that are state elements.
absl::StatusOr<std::vector<StructMemberNode*>> GetProcDefStateMembers(
const ProcDef* proc_def, const ImportData& import_data,
const TypeInfo& type_info);

} // namespace xls::dslx

#endif // XLS_DSLX_TYPE_SYSTEM_V2_IMPORT_UTILS_H_
28 changes: 19 additions & 9 deletions xls/dslx/type_system_v2/populate_table_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "absl/strings/str_join.h"
#include "absl/strings/substitute.h"
#include "absl/types/variant.h"
#include "xls/common/attribute_data.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/common/visitor.h"
Expand Down Expand Up @@ -1413,15 +1414,11 @@ class PopulateInferenceTableVisitor : public PopulateTableVisitor,
// annotation for each member and cache it for use by type unification of
// instances of this struct.
absl::Status HandleStructDef(const StructDef* node) override {
if (!node->IsParametric()) {
for (const StructMemberNode* member : node->members()) {
XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation(member, member->type()));
}
XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation(
node, CreateStructOrProcAnnotation(
module_, const_cast<StructDef*>(node), {}, std::nullopt)));
}
return DefaultHandler(node);
return HandleStructDefBaseInternal(node);
}

absl::Status HandleProcDef(const ProcDef* node) override {
return HandleStructDefBaseInternal(node);
}

absl::Status HandleTupleIndex(const TupleIndex* node) override {
Expand Down Expand Up @@ -2042,6 +2039,19 @@ class PopulateInferenceTableVisitor : public PopulateTableVisitor,
return true;
}

absl::Status HandleStructDefBaseInternal(const StructDefBase* node) {
if (!node->IsParametric()) {
for (const StructMemberNode* member : node->members()) {
XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation(member, member->type()));
}
XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation(
node,
CreateStructOrProcAnnotation(
module_, const_cast<StructDefBase*>(node), {}, std::nullopt)));
}
return DefaultHandler(node);
}

// Helper that creates an internal type variable for a `ConstantDef`, `Param`,
// or similar type of node that contains a `NameDef` and optional
// `TypeAnnotation`.
Expand Down
45 changes: 33 additions & 12 deletions xls/dslx/type_system_v2/typecheck_module_v2_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7789,8 +7789,8 @@ proc P {
}

impl P {
fn new(input: chan<u32> in, output: chan<u32> out, init_val: u32) -> Self {
P { input, output, state: init_val }
fn new(input: chan<u32> in, output: chan<u32> out) -> Self {
P { input, output, state: 5 }
}

fn next(self) {
Expand All @@ -7805,10 +7805,9 @@ impl P {
TypecheckSucceeds(AllOf(
HasNodeWithType("P", absl::Substitute("typeof($0)", kProcType)),
HasNodeWithType(
"new",
absl::Substitute("(chan(uN[32], dir=in), chan(uN[32], dir=out), "
"uN[32]) -> $0",
kProcType)),
"new", absl::Substitute(
"(chan(uN[32], dir=in), chan(uN[32], dir=out)) -> $0",
kProcType)),
HasNodeWithType("next", absl::Substitute("($0) -> ()", kProcType)))));
}

Expand All @@ -7822,8 +7821,8 @@ proc P {
}

impl P {
fn new(init_val: u32) -> Self {
P { state: init_val }
fn new() -> Self {
P { state: 5 }
}

fn next(self, a: u32) {}
Expand All @@ -7843,8 +7842,8 @@ proc P {
}

impl P {
fn new(init_val: u32) -> Self {
P { state: init_val }
fn new() -> Self {
P { state: 5 }
}

fn next() {}
Expand All @@ -7864,8 +7863,8 @@ proc P {
}

impl P {
fn new(init_val: u32) -> Self {
P { state: init_val }
fn new() -> Self {
P { state: 5 }
}

fn next() -> u32 { 0 }
Expand All @@ -7875,6 +7874,28 @@ impl P {
"must not return anything")));
}

TEST(TypecheckV2Test, ProcWithImplIntegerParamInTopProcNewFails) {
EXPECT_THAT(
R"(
#![feature(explicit_state_access)]

proc P {
state: u32,
}

impl P {
fn new(init_val: u32) -> Self {
P { state: init_val }
}

fn next() -> u32 { 0 }
}
)",
TypecheckFails(
HasSubstr("Initializer for member `state` of proc `P` must be "
"possible to evaluate at compile time.")));
}

TEST(TypecheckV2Test, SpawnProcWithImpl) {
std::string_view kProgram = R"(
#![feature(explicit_state_access)]
Expand Down
Loading
Loading