diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 15d7003..f55e700 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -508,12 +508,58 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.type_rvalue(Rvalue::Use(operand), expected); } - fn type_return(&mut self, expected: &rty::RefinedType) { - self.type_operand(Operand::Move(mir::RETURN_PLACE.into()), expected); + fn type_return( + &mut self, + expected: &rty::RefinedType, + expected_fn: &rty::FunctionType, + outer_fn_param_vars: &HashMap, + ) { + //self.type_operand(Operand::Move(mir::RETURN_PLACE.into()), expected); + + let mut builder = self.env.build_clause(); + // env.build_clause() で env の Var は既に mapped_var として登録済み + + let mut clauses = Vec::new(); + + // 1. 各 FunctionParamIdx を mapped_var として登録し、 + // TermVarIdx 空間で snapshot Var と等値の atom を body に積む + for (¶m_idx, &snapshot_var) in outer_fn_param_vars { + let sort = expected_fn.params[param_idx].ty.to_sort(); + if sort.is_singleton() { + continue; + } + builder.add_mapped_var(param_idx, sort); + let tv_param = builder.mapped_var(param_idx); + let tv_snapshot = builder.mapped_var(snapshot_var); + builder.add_body(chc::Term::var(tv_param).equal_to(chc::Term::var(tv_snapshot))); + } + + // 2. _0 の env-side view を Refinement として取得 + let ret_rty: rty::RefinedType = + self.operand_refined_type(Operand::Move(mir::RETURN_PLACE.into())); + + // 3. RefinementClauseBuilder で subtyping CHC を1本生成 + // add_body: Refinement → Free(env_var) を env-側 mapped_var で TermVarIdx 解決 + // head: Refinement → Free(param_idx) を step 1 の mapped_var で TermVarIdx 解決 + // 両者は同じ builder の TermVarIdx 空間で結ばれる + let cs = builder + .with_value_var(&expected_fn.ret.ty) + .add_body(ret_rty.refinement) + .head(expected_fn.ret.refinement.clone()); + clauses.extend(cs); + + // 4. 型構造の subtyping (relate_fn_sub_type 末尾と同じ) + clauses.extend(builder.relate_sub_type(&ret_rty.ty, &expected_fn.ret.ty)); + + self.ctx.extend_clauses(clauses); } - fn type_goto(&mut self, bb: BasicBlock, expected_ret: &rty::RefinedType) { - tracing::debug!(bb = ?bb, "type_goto"); + fn type_goto( + &mut self, + bb: BasicBlock, + expected_ret: &rty::RefinedType, + outer_fn_param_vars: &HashMap, + ) { let bty = self.basic_block_ty(bb); let expected_args: IndexVec<_, _> = bty .as_ref() @@ -528,7 +574,23 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } else { rty::RefinedType::unrefined(arg_local_ty.ty) } + } else if param_idx.index() >= bty.locals.len() { + // snapshot slot: 現 BB の outer_fn_param_vars から該当 Var を取る + let outer_idx = + rty::FunctionParamIdx::from(param_idx.index() - bty.locals.len()); + if let Some(snapshot_var) = outer_fn_param_vars.get(&outer_idx) { + let pty = PlaceType::with_ty_and_term( + rty.ty.clone().assert_closed().vacuous(), + chc::Term::var(*snapshot_var), + ); + pty.into() + } else { + // some test case fail this assertion. what? + // assert!(param_idx.index() == 0); + rty::RefinedType::unrefined(rty.ty.clone().assert_closed().vacuous()) + } } else { + // 既存: それ以外(ほぼ singleton 系?) rty::RefinedType::unrefined(rty.ty.clone().assert_closed().vacuous()) } }) @@ -565,6 +627,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { discr: Operand<'tcx>, targets: mir::SwitchTargets, expected_ret: &rty::RefinedType, + outer_fn_param_vars: &HashMap, mut callback: F, ) where F: FnMut(&mut Self, BasicBlock), @@ -588,7 +651,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { }; self.with_assumption(pos_assumption, |ecx| { callback(ecx, bb); - ecx.type_goto(bb, expected_ret); + ecx.type_goto(bb, expected_ret, outer_fn_param_vars); }); let neg_assumption = { let mut builder = PlaceTypeBuilder::default(); @@ -600,7 +663,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } self.with_assumptions(negations, |ecx| { callback(ecx, targets.otherwise()); - ecx.type_goto(targets.otherwise(), expected_ret); + ecx.type_goto(targets.otherwise(), expected_ret, outer_fn_param_vars); }); } @@ -877,26 +940,34 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } } - #[tracing::instrument(skip(self, expected_ret), fields(term = ?term.kind))] + #[tracing::instrument(skip(self, expected_ret, expected_fn, outer_fn_param_vars), fields(term = ?term.kind))] fn analyze_terminator_goto( &mut self, term: &mir::Terminator<'tcx>, expected_ret: &rty::RefinedType, + expected_fn: &rty::FunctionType, + outer_fn_param_vars: &HashMap, ) { match &term.kind { TerminatorKind::Return => { - self.type_return(expected_ret); + self.type_return(expected_ret, expected_fn, outer_fn_param_vars); } TerminatorKind::Goto { target } => { - self.type_goto(*target, expected_ret); + self.type_goto(*target, expected_ret, outer_fn_param_vars); } TerminatorKind::SwitchInt { discr, targets } => { - self.type_switch_int(discr.clone(), targets.clone(), expected_ret, |a, target| { - for local in a.drop_points.after_terminator(&target).iter() { - tracing::info!(?local, ?target, "implicitly dropped for target"); - a.drop_local(local); - } - }); + self.type_switch_int( + discr.clone(), + targets.clone(), + expected_ret, + outer_fn_param_vars, + |a, target| { + for local in a.drop_points.after_terminator(&target).iter() { + tracing::info!(?local, ?target, "implicitly dropped for target"); + a.drop_local(local); + } + }, + ); } TerminatorKind::Call { target, .. } => { if let Some(target) = target { @@ -904,7 +975,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { tracing::info!(?local, "implicitly dropped after call"); self.drop_local(local); } - self.type_goto(*target, expected_ret); + self.type_goto(*target, expected_ret, outer_fn_param_vars); } } TerminatorKind::Drop { target, .. } => { @@ -912,7 +983,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { tracing::info!(?local, "dropped"); self.drop_local(local); } - self.type_goto(*target, expected_ret); + self.type_goto(*target, expected_ret, outer_fn_param_vars); } TerminatorKind::Assert { cond, @@ -931,7 +1002,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { chc::Term::bool(*expected), ), ); - self.type_goto(*target, expected_ret); + self.type_goto(*target, expected_ret, outer_fn_param_vars); } TerminatorKind::UnwindResume {} => {} TerminatorKind::Unreachable {} => {} @@ -1085,27 +1156,32 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { fn bind_locals( &mut self, expected_params: &IndexVec>, - ) { + ) -> HashMap { let mut param_terms = HashMap::>::new(); let mut assumption = Assumption::default(); + let mut outer_fn_param_vars = HashMap::new(); + let bb_ty = self.basic_block_ty(self.basic_block).clone(); let params = &bb_ty.as_ref().params; assert!(!params.is_empty()); for (param_idx, param_rty) in params.iter_enumerated() { let param_ty = ¶m_rty.ty; - if let Some(local) = bb_ty.local_of_param(param_idx) { - let rty = rty::RefinedType::unrefined(param_ty.clone().subst_var(|v| { + let param_unrefined_rty = + rty::RefinedType::unrefined(param_ty.clone().subst_var(|v| { param_terms[&v].clone().map_var(|v| match v { PlaceTypeVar::Var(v) => v, // TODO _ => unimplemented!(), }) })); - if bb_ty.mutbl_of_param(param_idx).unwrap().is_mut() || rty.ty.is_mut() { - self.env.mut_bind(local, rty); + if let Some(local) = bb_ty.local_of_param(param_idx) { + if bb_ty.mutbl_of_param(param_idx).unwrap().is_mut() + || param_unrefined_rty.ty.is_mut() + { + self.env.mut_bind(local, param_unrefined_rty); } else { - self.env.immut_bind(local, rty); + self.env.immut_bind(local, param_unrefined_rty); } let param_sort = param_ty.to_sort(); if param_sort.is_singleton() { @@ -1123,6 +1199,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .map_var(|v| v.shift_existential(assumption.existentials.len())); assumption.existentials.extend(local_ty.existentials); param_terms.insert(param_idx, term); + } else { + if param_idx >= bb_ty.locals.next_index() { + if !param_ty.to_sort().is_singleton() { + let var = self.env.immut_bind_tmp(param_unrefined_rty); + param_terms.insert(param_idx, chc::Term::var(var.into())); + let outer_fn_param_idx = + rty::FunctionParamIdx::from(param_idx.index() - bb_ty.locals.len()); + outer_fn_param_vars.insert(outer_fn_param_idx, var); + } + } } } @@ -1139,6 +1225,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } self.env.assume(assumption); + + outer_fn_param_vars } fn unbind_atoms(&self) -> UnbindAtoms { @@ -1204,13 +1292,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self } - pub fn run(&mut self, expected: &BasicBlockType) { + pub fn run(&mut self, expected: &BasicBlockType, expected_fn: &rty::FunctionType) { let span = tracing::info_span!("bb", bb = ?self.basic_block); let _guard = span.enter(); self.register_enum_defs(); let params = expected.as_ref().params.clone(); - self.bind_locals(¶ms); + let outer_fn_param_vars = self.bind_locals(¶ms); let unbind_atoms = self.unbind_atoms(); self.alloc_prophecies(); self.analyze_statements(); @@ -1218,7 +1306,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let term = self.elaborated_terminator(); self.analyze_terminator_binds(&term); let ret_template = self.ret_template(); - self.analyze_terminator_goto(&term, &ret_template); + self.analyze_terminator_goto(&term, &ret_template, expected_fn, &outer_fn_param_vars); let got_ret_ty = unbind_atoms.unbind(&self.env, ret_template); let got_ty = rty::FunctionType::new(params, got_ret_ty).into_closed_ty(); diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index a29d0e3..4fe72fc 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -855,12 +855,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let rty = self .type_builder .for_template(&mut self.ctx) - .build_basic_block(live_locals, ret_ty); + .build_basic_block(&self.body, live_locals, ret_ty); self.ctx.register_basic_block_ty(self.local_def_id, bb, rty); } } - fn analyze_basic_blocks(&mut self) { + fn analyze_basic_blocks(&mut self, expected_fn_ty: &rty::RefinedType) { + let expected_fn_ty = expected_fn_ty.ty.as_function().unwrap(); for bb in self.body.basic_blocks.indices() { let rty = self.ctx.basic_block_ty(self.local_def_id, bb).clone(); let drop_points = self.drop_points[&bb].clone(); @@ -868,7 +869,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .basic_block_analyzer(self.local_def_id, bb) .body(self.body.clone()) .drop_points(drop_points) - .run(&rty); + .run(&rty, &expected_fn_ty); } } @@ -954,13 +955,80 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { rty::FunctionType::new(params, bb_ty.as_ref().ret.clone().map_var(|v| subst[&v])) } + // Inner function type of BasicBlockType contains a extra parameters that carries original + // function parameter values. `trucate_entry_ty` removes these extra parameters to subtype + // entry_ty against the function type. + // + // before: (_1: int, _2: int, int, { int | p4 ν $0 $1 $2 }) → { int | p5 ν $0 $1 $2 $3 } + // after: (_1: int, _2: { int | p4 v $0 $1 $0 }) → { int | p5 ν $0 $1 _1 _2 } + fn truncate_entry_ty(&self, entry_ty: &mut BasicBlockType) { + let last_param_idx = entry_ty.as_ref().params.last_index().unwrap(); + let last_param_ty = entry_ty.as_ref().params.raw.last().unwrap(); + + let mut mapping = HashMap::new(); + for (idx, param_ty) in entry_ty.as_ref().params.iter_enumerated() { + let mapped_idx = if idx >= entry_ty.locals.next_index() { + let outer_fn_param_idx = + rty::FunctionParamIdx::from(idx.index() - entry_ty.locals.len()); + let corresponding_local = analyze::local_of_function_param(outer_fn_param_idx); + entry_ty + .param_of_local(corresponding_local) + .unwrap_or_else(|| { + // XXX: if local-param is empty and there are some outer fn param, + // idx == $0, corresponding_local is _1 and param_of_local returns None + assert!(idx.index() == 0); + idx + }) + } else { + idx + }; + mapping.insert(idx, mapped_idx); + + // to be sure + if idx != last_param_idx { + assert!(param_ty.refinement.is_top()); + } + } + + let last_param_refinement = last_param_ty.refinement.clone().map_refine_var(|v| { + let idx = match v { + rty::RefinedTypeVar::Free(idx) => idx, + rty::RefinedTypeVar::Value => last_param_idx, + v => return v, + }; + let mapped_idx = mapping[&idx]; + if Some(mapped_idx) == entry_ty.locals.last_index() { + rty::RefinedTypeVar::Value + } else { + rty::RefinedTypeVar::Free(mapped_idx) + } + }); + + if entry_ty.locals.len() != 0 { + entry_ty.ty.params.truncate(entry_ty.locals.len()); + } + + entry_ty.ty.params.raw.last_mut().unwrap().refinement = last_param_refinement; + entry_ty.ty.ret.refinement = entry_ty + .ty + .ret + .refinement + .clone() + .map_var(|idx| mapping[&idx]); + } + fn assert_entry(&mut self, expected: &rty::RefinedType) { - let entry_ty = self.ctx.basic_block_ty(self.local_def_id, mir::START_BLOCK); + let mut entry_ty = self + .ctx + .basic_block_ty(self.local_def_id, mir::START_BLOCK) + .clone(); tracing::debug!(expected = %expected.display(), entry = %entry_ty.display(), "assert_entry before"); let mut expected = expected.ty.as_function().cloned().unwrap(); self.elaborate_mut_params(&mut expected); - let entry_ty = self.elaborate_unused_args(entry_ty, &expected); + self.truncate_entry_ty(&mut entry_ty); + let entry_ty = self.elaborate_unused_args(&entry_ty, &expected); + expected.ret.refinement = rty::Refinement::top(); tracing::debug!(expected = %expected.display(), entry = %entry_ty.display(), "assert_entry after"); let clauses = rty::relate_sub_closed_type(&entry_ty.into(), &expected.into()); @@ -1000,7 +1068,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.unelaborate_derefs(); self.reassign_local_mutabilities(); self.refine_basic_blocks(); - self.analyze_basic_blocks(); + self.analyze_basic_blocks(expected); self.assert_entry(expected); } } diff --git a/src/refine/basic_block.rs b/src/refine/basic_block.rs index 67439c9..b4fc779 100644 --- a/src/refine/basic_block.rs +++ b/src/refine/basic_block.rs @@ -14,9 +14,8 @@ use crate::rty; /// from function parameters to [`Local`]s, along with the underlying function type. #[derive(Debug, Clone)] pub struct BasicBlockType { - // TODO: make this completely private by exposing appropriate ctor - pub(super) ty: rty::FunctionType, - pub(super) locals: IndexVec, + pub ty: rty::FunctionType, + pub locals: IndexVec, } impl<'a, D> Pretty<'a, D, termcolor::ColorSpec> for &BasicBlockType @@ -26,17 +25,16 @@ where { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> { let separator = allocator.text(",").append(allocator.line()); - let params = self - .ty - .params - .iter() - .zip(&self.locals) - .map(|(ty, (local, mutbl))| { + let params = self.ty.params.iter_enumerated().map(|(idx, ty)| { + if let Some((local, mutbl)) = self.locals.get(idx) { allocator .text(format!("{}{:?}:", mutbl.prefix_str(), local)) .append(allocator.space()) .append(ty.pretty(allocator)) - }); + } else { + ty.pretty(allocator) + } + }); allocator .intersperse(params, separator) .parens() @@ -63,6 +61,12 @@ impl BasicBlockType { self.locals.get(idx).map(|(_, mutbl)| *mutbl) } + pub fn param_of_local(&self, local: Local) -> Option { + self.locals + .iter_enumerated() + .find_map(|(idx, (l, _))| if *l == local { Some(idx) } else { None }) + } + pub fn to_function_ty(&self) -> rty::FunctionType { self.ty.clone() } diff --git a/src/refine/env.rs b/src/refine/env.rs index e7762e5..b3b3a92 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -872,6 +872,12 @@ where tracing::debug!(local = ?local, rty = %rty_disp.display(), place_type = %self.local_type(local).display(), "immut_bind"); } + pub fn immut_bind_tmp(&mut self, rty: rty::RefinedType) -> Var { + let idx = self.temp_vars.push(TempVarBinding::Type(rty.clone())); + tracing::debug!(temp = ?idx, rty = %rty.display(), "immut_bind_tmp"); + Var::Temp(idx) + } + pub fn assume(&mut self, assumption: impl Into) { let assumption = assumption.into(); tracing::debug!(assumption = %assumption.display(), "assume"); diff --git a/src/refine/template.rs b/src/refine/template.rs index a29d76d..37409f7 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -426,6 +426,7 @@ where pub fn build_basic_block( &mut self, + body: &rustc_middle::mir::Body<'tcx>, live_locals: I, ret_ty: mir_ty::Ty<'tcx>, ) -> BasicBlockType @@ -442,6 +443,15 @@ where locals.push((local, ty.mutbl)); tys.push(ty); } + + for arg in body.args_iter() { + let decl = &body.local_decls[arg]; + tys.push(mir_ty::TypeAndMut { + ty: decl.ty, + mutbl: decl.mutability, + }); + } + let ty = FunctionTemplateTypeBuilder { inner: self.inner.clone(), registry: self.registry, @@ -449,7 +459,10 @@ where ret_ty, param_rtys: Default::default(), param_refinement: None, - ret_rty: None, + //ret_rty: None, + ret_rty: Some(rty::RefinedType::unrefined( + self.inner.build(ret_ty).vacuous(), + )), abi: Default::default(), } .build(); diff --git a/src/rty.rs b/src/rty.rs index 5cbe6d2..3a91e0c 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -1347,6 +1347,16 @@ impl Refinement { } } + pub fn map_refine_var(self, mut f: F) -> Refinement + where + F: FnMut(RefinedTypeVar) -> RefinedTypeVar, + { + Refinement { + existentials: self.existentials, + body: self.body.map_var(&mut f), + } + } + pub fn instantiate(self) -> Instantiator { Instantiator { value_var: None, diff --git a/src/rty/subtyping.rs b/src/rty/subtyping.rs index 20d6f34..1b5eb36 100644 --- a/src/rty/subtyping.rs +++ b/src/rty/subtyping.rs @@ -98,8 +98,9 @@ where } } } - (Type::Function(got), Type::Function(expected)) => { - // TODO: check length is equal + (Type::Function(got), Type::Function(expected)) + if got.params.len() == expected.params.len() => + { let mut builder = chc::ClauseBuilder::default(); for (param_idx, param_rty) in got.params.iter_enumerated() { let param_sort = param_rty.ty.to_sort();