diff --git a/xls/dslx/type_system_v2/inference_table.cc b/xls/dslx/type_system_v2/inference_table.cc index 8d3aa8cdf8..5cd885c427 100644 --- a/xls/dslx/type_system_v2/inference_table.cc +++ b/xls/dslx/type_system_v2/inference_table.cc @@ -467,6 +467,7 @@ class InferenceTableImpl : public InferenceTable { absl::StatusOr GetOrCreateParametricStructContext( const StructDefBase* struct_def, const AstNode* node, ParametricEnv parametric_env, const TypeAnnotation* self_type, + std::optional parent_context, absl::FunctionRef()> type_info_factory) override { std::optional cached_result = @@ -478,7 +479,7 @@ class InferenceTableImpl : public InferenceTable { auto context = std::make_unique( parametric_contexts_.size(), node, ParametricStructDetails{struct_def, parametric_env}, type_info, - /*parent_context=*/std::nullopt, self_type); + parent_context, self_type); const ParametricContext* result = context.get(); parametric_contexts_.push_back(std::move(context)); auto& contexts = parametric_struct_contexts_[struct_def]; diff --git a/xls/dslx/type_system_v2/inference_table.h b/xls/dslx/type_system_v2/inference_table.h index e70bf0ddbc..3b417c44d2 100644 --- a/xls/dslx/type_system_v2/inference_table.h +++ b/xls/dslx/type_system_v2/inference_table.h @@ -475,6 +475,7 @@ class InferenceTable { GetOrCreateParametricStructContext( const StructDefBase* struct_def, const AstNode* node, ParametricEnv parametric_env, const TypeAnnotation* self_type, + std::optional parent_context, absl::FunctionRef()> type_info_factory) = 0; // Returns the expression for the value of the given parametric in the given diff --git a/xls/dslx/type_system_v2/inference_table_converter_impl.cc b/xls/dslx/type_system_v2/inference_table_converter_impl.cc index 395b413aac..7844057d1d 100644 --- a/xls/dslx/type_system_v2/inference_table_converter_impl.cc +++ b/xls/dslx/type_system_v2/inference_table_converter_impl.cc @@ -846,8 +846,8 @@ class InferenceTableConverterImpl : public InferenceTableConverter, // use of proc-level parametrics. Unlike with impl-style member functions, // we don't have a target struct context for the proc. if (!function->IsInProc()) { - XLS_RETURN_IF_ERROR(GenerateTypeInfo(caller_or_target_struct_context, - invocation->callee())); + XLS_RETURN_IF_ERROR( + GenerateTypeInfo(caller_context, invocation->callee())); } for (int i = 0; i < parametric_free_function_type->param_types().size(); @@ -855,18 +855,13 @@ class InferenceTableConverterImpl : public InferenceTableConverter, const TypeAnnotation* formal_type = parametric_free_function_type->param_types()[i]; const Expr* actual_param = actual_args[i]; - const bool is_self_param = - i == 0 && function_and_target_object.target_object.has_value(); XLS_RETURN_IF_ERROR( table_.AddTypeAnnotationToVariableForParametricContext( caller_context, *table_.GetTypeVariable(actual_param), formal_type)); TypeSystemTrace arg_trace = tracer_->TraceConvertActualArgument(actual_param); - XLS_RETURN_IF_ERROR(ConvertSubtree( - actual_param, caller, - is_self_param ? function_and_target_object.target_struct_context - : caller_context)); + XLS_RETURN_IF_ERROR(ConvertSubtree(actual_param, caller, caller_context)); } // Convert the actual parametric function in the context of this invocation, @@ -1190,11 +1185,11 @@ class InferenceTableConverterImpl : public InferenceTableConverter, ref.def->owner(), absl::StrCat("struct_", ref.def->identifier()), struct_base_ti); }; - XLS_ASSIGN_OR_RETURN( - InferenceTable::StructContextResult lookup_result, - table_.GetOrCreateParametricStructContext( - ref.def, node, parametric_env, - CreateStructOrProcAnnotation(module_, ref), type_info_factory)); + XLS_ASSIGN_OR_RETURN(InferenceTable::StructContextResult lookup_result, + table_.GetOrCreateParametricStructContext( + ref.def, node, parametric_env, + CreateStructOrProcAnnotation(module_, ref), + parent_context, type_info_factory)); const ParametricContext* struct_context = lookup_result.context; if (!lookup_result.created_new) { return struct_context; diff --git a/xls/dslx/type_system_v2/typecheck_module_v2_generics_test.cc b/xls/dslx/type_system_v2/typecheck_module_v2_generics_test.cc index 35b77bdb0a..a6162644b6 100644 --- a/xls/dslx/type_system_v2/typecheck_module_v2_generics_test.cc +++ b/xls/dslx/type_system_v2/typecheck_module_v2_generics_test.cc @@ -172,6 +172,68 @@ const_assert!(D == 4); )")); } +TEST(TypecheckV2GenericsTest, + ResolveAnnotationFromSeparateImplInvocationWithOtherStruct) { + EXPECT_THAT( + R"( +#![feature(generics)] + +struct S2 { + val: T, +} + +impl S2 { + fn call(self, other: S2) -> U { + other.val + } +} + +fn structs() -> s32[N+1] { + let my_val = map(0..N, |i| i as s32); + let plus_one = map(0..N+1, |i| (i + 1) as s32); + S2{val: my_val}.call(S2{val: plus_one}) +} + +const RES = structs<4>(); +const_assert!(RES == [s32:1, 2, 3, 4, 5]); +)", + TypecheckSucceeds(HasNodeWithType("RES", "sN[32][5]"))); +} + +TEST(TypecheckV2GenericsTest, ResolveAnnotationFromSeparateImplInvocation) { + EXPECT_THAT( + R"( +#![feature(generics)] + +fn is_odd(i: u32) -> bool { + i % 2 == 1 +} + +struct S2 { + odd_map: odd_map_type, +} + +impl S2 { + fn call(self, i: u32) -> u32 { + if self.odd_map[i] { + i + 2 + } else { + i + } + } +} + +fn add_two() -> u32 { + let odd_map = map(0..N, is_odd); + S2{odd_map: odd_map}.call(N - 1) +} + +const RES = add_two<5>(); +const_assert!(RES == u32:4); +)", + TypecheckSucceeds(HasNodeWithType("RES", "uN[32]"))); +} + TEST(TypecheckV2GenericsTest, GenericConstantAccess) { EXPECT_THAT( R"( diff --git a/xls/dslx/type_system_v2/typecheck_module_v2_test.cc b/xls/dslx/type_system_v2/typecheck_module_v2_test.cc index 9505866a0d..31188d41b5 100644 --- a/xls/dslx/type_system_v2/typecheck_module_v2_test.cc +++ b/xls/dslx/type_system_v2/typecheck_module_v2_test.cc @@ -10059,6 +10059,35 @@ const_assert!(main() == 6); TypecheckSucceeds(HasNodeWithType("ARR", "uN[32][5]"))); } +TEST(TypecheckV2Test, LambdaUsesMapResult) { + EXPECT_THAT(R"( +fn is_odd(i: u32) -> bool { + i % 2 == 1 +} + +fn add_two() -> u32[N] { + let odd_map = map(0..N, is_odd); + map( + 0..N, + |i| -> u32 { + if odd_map[i] { + i + 2 + } else { + i + } + } + ) +} + +const RES = add_two<5>(); +const_assert!(RES == [u32:0, 3, 2, 5, 4]); +const RES2 = add_two<3>(); +const_assert!(RES2 == [u32:0, 3, 2]); +)", + TypecheckSucceeds(AllOf(HasNodeWithType("RES", "uN[32][5]"), + HasNodeWithType("RES2", "uN[32][3]")))); +} + TEST(TypecheckV2Test, LambdaCallsImportedFunction) { constexpr std::string_view kImported = R"( pub fn add_two(x: u32) -> u32 {