Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xls/dslx/type_system_v2/inference_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ class InferenceTableImpl : public InferenceTable {
absl::StatusOr<StructContextResult> GetOrCreateParametricStructContext(
const StructDefBase* struct_def, const AstNode* node,
ParametricEnv parametric_env, const TypeAnnotation* self_type,
std::optional<const ParametricContext*> parent_context,
absl::FunctionRef<absl::StatusOr<TypeInfo*>()> type_info_factory)
override {
std::optional<StructContextResult> cached_result =
Expand All @@ -478,7 +479,7 @@ class InferenceTableImpl : public InferenceTable {
auto context = std::make_unique<ParametricContext>(
parametric_contexts_.size(), node,
ParametricStructDetails{struct_def, parametric_env}, type_info,
/*parent_context=*/std::nullopt, self_type);
parent_context, self_type);
const ParametricContext* result = context.get();
parametric_contexts_.push_back(std::move(context));
auto& contexts = parametric_struct_contexts_[struct_def];
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/type_system_v2/inference_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ class InferenceTable {
GetOrCreateParametricStructContext(
const StructDefBase* struct_def, const AstNode* node,
ParametricEnv parametric_env, const TypeAnnotation* self_type,
std::optional<const ParametricContext*> parent_context,
absl::FunctionRef<absl::StatusOr<TypeInfo*>()> type_info_factory) = 0;

// Returns the expression for the value of the given parametric in the given
Expand Down
21 changes: 8 additions & 13 deletions xls/dslx/type_system_v2/inference_table_converter_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -846,27 +846,22 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
// use of proc-level parametrics. Unlike with impl-style member functions,
// we don't have a target struct context for the proc.
if (!function->IsInProc()) {
XLS_RETURN_IF_ERROR(GenerateTypeInfo(caller_or_target_struct_context,
invocation->callee()));
XLS_RETURN_IF_ERROR(
GenerateTypeInfo(caller_context, invocation->callee()));
}

for (int i = 0; i < parametric_free_function_type->param_types().size();
i++) {
const TypeAnnotation* formal_type =
parametric_free_function_type->param_types()[i];
const Expr* actual_param = actual_args[i];
const bool is_self_param =
i == 0 && function_and_target_object.target_object.has_value();
XLS_RETURN_IF_ERROR(
table_.AddTypeAnnotationToVariableForParametricContext(
caller_context, *table_.GetTypeVariable(actual_param),
formal_type));
TypeSystemTrace arg_trace =
tracer_->TraceConvertActualArgument(actual_param);
XLS_RETURN_IF_ERROR(ConvertSubtree(
actual_param, caller,
is_self_param ? function_and_target_object.target_struct_context
: caller_context));
XLS_RETURN_IF_ERROR(ConvertSubtree(actual_param, caller, caller_context));
}

// Convert the actual parametric function in the context of this invocation,
Expand Down Expand Up @@ -1190,11 +1185,11 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
ref.def->owner(), absl::StrCat("struct_", ref.def->identifier()),
struct_base_ti);
};
XLS_ASSIGN_OR_RETURN(
InferenceTable::StructContextResult lookup_result,
table_.GetOrCreateParametricStructContext(
ref.def, node, parametric_env,
CreateStructOrProcAnnotation(module_, ref), type_info_factory));
XLS_ASSIGN_OR_RETURN(InferenceTable::StructContextResult lookup_result,
table_.GetOrCreateParametricStructContext(
ref.def, node, parametric_env,
CreateStructOrProcAnnotation(module_, ref),
parent_context, type_info_factory));
const ParametricContext* struct_context = lookup_result.context;
if (!lookup_result.created_new) {
return struct_context;
Expand Down
62 changes: 62 additions & 0 deletions xls/dslx/type_system_v2/typecheck_module_v2_generics_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,68 @@ const_assert!(D == 4);
)"));
}

TEST(TypecheckV2GenericsTest,
ResolveAnnotationFromSeparateImplInvocationWithOtherStruct) {
EXPECT_THAT(
R"(
#![feature(generics)]

struct S2<T: type> {
val: T,
}

impl S2<T> {
fn call<U: type>(self, other: S2<U>) -> U {
other.val
}
}

fn structs<N: u32>() -> s32[N+1] {
let my_val = map(0..N, |i| i as s32);
let plus_one = map(0..N+1, |i| (i + 1) as s32);
S2{val: my_val}.call(S2{val: plus_one})
}

const RES = structs<4>();
const_assert!(RES == [s32:1, 2, 3, 4, 5]);
)",
TypecheckSucceeds(HasNodeWithType("RES", "sN[32][5]")));
}

TEST(TypecheckV2GenericsTest, ResolveAnnotationFromSeparateImplInvocation) {
EXPECT_THAT(
R"(
#![feature(generics)]

fn is_odd(i: u32) -> bool {
i % 2 == 1
}

struct S2<odd_map_type: type> {
odd_map: odd_map_type,
}

impl S2 {
fn call(self, i: u32) -> u32 {
if self.odd_map[i] {
i + 2
} else {
i
}
}
}

fn add_two<N: u32>() -> u32 {
let odd_map = map(0..N, is_odd);
S2{odd_map: odd_map}.call(N - 1)
}

const RES = add_two<5>();
const_assert!(RES == u32:4);
)",
TypecheckSucceeds(HasNodeWithType("RES", "uN[32]")));
}

TEST(TypecheckV2GenericsTest, GenericConstantAccess) {
EXPECT_THAT(
R"(
Expand Down
29 changes: 29 additions & 0 deletions xls/dslx/type_system_v2/typecheck_module_v2_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10059,6 +10059,35 @@ const_assert!(main() == 6);
TypecheckSucceeds(HasNodeWithType("ARR", "uN[32][5]")));
}

TEST(TypecheckV2Test, LambdaUsesMapResult) {
EXPECT_THAT(R"(
fn is_odd(i: u32) -> bool {
i % 2 == 1
}

fn add_two<N: u32>() -> u32[N] {
let odd_map = map(0..N, is_odd);
map(
0..N,
|i| -> u32 {
if odd_map[i] {
i + 2
} else {
i
}
}
)
}

const RES = add_two<5>();
const_assert!(RES == [u32:0, 3, 2, 5, 4]);
const RES2 = add_two<3>();
const_assert!(RES2 == [u32:0, 3, 2]);
)",
TypecheckSucceeds(AllOf(HasNodeWithType("RES", "uN[32][5]"),
HasNodeWithType("RES2", "uN[32][3]"))));
}

TEST(TypecheckV2Test, LambdaCallsImportedFunction) {
constexpr std::string_view kImported = R"(
pub fn add_two(x: u32) -> u32 {
Expand Down
Loading