Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 115 additions & 27 deletions src/analyze/basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,12 +508,58 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
self.type_rvalue(Rvalue::Use(operand), expected);
}

fn type_return(&mut self, expected: &rty::RefinedType<Var>) {
self.type_operand(Operand::Move(mir::RETURN_PLACE.into()), expected);
fn type_return(
&mut self,
expected: &rty::RefinedType<Var>,
expected_fn: &rty::FunctionType,
outer_fn_param_vars: &HashMap<rty::FunctionParamIdx, Var>,
) {
//self.type_operand(Operand::Move(mir::RETURN_PLACE.into()), expected);

let mut builder = self.env.build_clause();
// env.build_clause() で env の Var は既に mapped_var として登録済み

let mut clauses = Vec::new();

// 1. 各 FunctionParamIdx を mapped_var として登録し、
// TermVarIdx 空間で snapshot Var と等値の atom を body に積む
for (&param_idx, &snapshot_var) in outer_fn_param_vars {
let sort = expected_fn.params[param_idx].ty.to_sort();
if sort.is_singleton() {
continue;
}
builder.add_mapped_var(param_idx, sort);
let tv_param = builder.mapped_var(param_idx);
let tv_snapshot = builder.mapped_var(snapshot_var);
builder.add_body(chc::Term::var(tv_param).equal_to(chc::Term::var(tv_snapshot)));
}

// 2. _0 の env-side view を Refinement<Var> として取得
let ret_rty: rty::RefinedType<Var> =
self.operand_refined_type(Operand::Move(mir::RETURN_PLACE.into()));

// 3. RefinementClauseBuilder で subtyping CHC を1本生成
// add_body: Refinement<Var> → Free(env_var) を env-側 mapped_var で TermVarIdx 解決
// head: Refinement<FunctionParamIdx> → Free(param_idx) を step 1 の mapped_var で TermVarIdx 解決
// 両者は同じ builder の TermVarIdx 空間で結ばれる
let cs = builder
.with_value_var(&expected_fn.ret.ty)
.add_body(ret_rty.refinement)
.head(expected_fn.ret.refinement.clone());
clauses.extend(cs);

// 4. 型構造の subtyping (relate_fn_sub_type 末尾と同じ)
clauses.extend(builder.relate_sub_type(&ret_rty.ty, &expected_fn.ret.ty));

self.ctx.extend_clauses(clauses);
}

fn type_goto(&mut self, bb: BasicBlock, expected_ret: &rty::RefinedType<Var>) {
tracing::debug!(bb = ?bb, "type_goto");
fn type_goto(
&mut self,
bb: BasicBlock,
expected_ret: &rty::RefinedType<Var>,
outer_fn_param_vars: &HashMap<rty::FunctionParamIdx, Var>,
) {
let bty = self.basic_block_ty(bb);
let expected_args: IndexVec<_, _> = bty
.as_ref()
Expand All @@ -528,7 +574,23 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
} else {
rty::RefinedType::unrefined(arg_local_ty.ty)
}
} else if param_idx.index() >= bty.locals.len() {
// snapshot slot: 現 BB の outer_fn_param_vars から該当 Var を取る
let outer_idx =
rty::FunctionParamIdx::from(param_idx.index() - bty.locals.len());
if let Some(snapshot_var) = outer_fn_param_vars.get(&outer_idx) {
let pty = PlaceType::with_ty_and_term(
rty.ty.clone().assert_closed().vacuous(),
chc::Term::var(*snapshot_var),
);
pty.into()
} else {
// some test case fail this assertion. what?
// assert!(param_idx.index() == 0);
rty::RefinedType::unrefined(rty.ty.clone().assert_closed().vacuous())
}
} else {
// 既存: それ以外(ほぼ singleton 系?)
rty::RefinedType::unrefined(rty.ty.clone().assert_closed().vacuous())
}
})
Expand Down Expand Up @@ -565,6 +627,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
discr: Operand<'tcx>,
targets: mir::SwitchTargets,
expected_ret: &rty::RefinedType<Var>,
outer_fn_param_vars: &HashMap<rty::FunctionParamIdx, Var>,
mut callback: F,
) where
F: FnMut(&mut Self, BasicBlock),
Expand All @@ -588,7 +651,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
};
self.with_assumption(pos_assumption, |ecx| {
callback(ecx, bb);
ecx.type_goto(bb, expected_ret);
ecx.type_goto(bb, expected_ret, outer_fn_param_vars);
});
let neg_assumption = {
let mut builder = PlaceTypeBuilder::default();
Expand All @@ -600,7 +663,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
}
self.with_assumptions(negations, |ecx| {
callback(ecx, targets.otherwise());
ecx.type_goto(targets.otherwise(), expected_ret);
ecx.type_goto(targets.otherwise(), expected_ret, outer_fn_param_vars);
});
}

Expand Down Expand Up @@ -877,42 +940,50 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
}
}

#[tracing::instrument(skip(self, expected_ret), fields(term = ?term.kind))]
#[tracing::instrument(skip(self, expected_ret, expected_fn, outer_fn_param_vars), fields(term = ?term.kind))]
fn analyze_terminator_goto(
&mut self,
term: &mir::Terminator<'tcx>,
expected_ret: &rty::RefinedType<Var>,
expected_fn: &rty::FunctionType,
outer_fn_param_vars: &HashMap<rty::FunctionParamIdx, Var>,
) {
match &term.kind {
TerminatorKind::Return => {
self.type_return(expected_ret);
self.type_return(expected_ret, expected_fn, outer_fn_param_vars);
}
TerminatorKind::Goto { target } => {
self.type_goto(*target, expected_ret);
self.type_goto(*target, expected_ret, outer_fn_param_vars);
}
TerminatorKind::SwitchInt { discr, targets } => {
self.type_switch_int(discr.clone(), targets.clone(), expected_ret, |a, target| {
for local in a.drop_points.after_terminator(&target).iter() {
tracing::info!(?local, ?target, "implicitly dropped for target");
a.drop_local(local);
}
});
self.type_switch_int(
discr.clone(),
targets.clone(),
expected_ret,
outer_fn_param_vars,
|a, target| {
for local in a.drop_points.after_terminator(&target).iter() {
tracing::info!(?local, ?target, "implicitly dropped for target");
a.drop_local(local);
}
},
);
}
TerminatorKind::Call { target, .. } => {
if let Some(target) = target {
for local in self.drop_points.after_terminator(target).iter() {
tracing::info!(?local, "implicitly dropped after call");
self.drop_local(local);
}
self.type_goto(*target, expected_ret);
self.type_goto(*target, expected_ret, outer_fn_param_vars);
}
}
TerminatorKind::Drop { target, .. } => {
for local in self.drop_points.after_terminator(target).iter() {
tracing::info!(?local, "dropped");
self.drop_local(local);
}
self.type_goto(*target, expected_ret);
self.type_goto(*target, expected_ret, outer_fn_param_vars);
}
TerminatorKind::Assert {
cond,
Expand All @@ -931,7 +1002,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
chc::Term::bool(*expected),
),
);
self.type_goto(*target, expected_ret);
self.type_goto(*target, expected_ret, outer_fn_param_vars);
}
TerminatorKind::UnwindResume {} => {}
TerminatorKind::Unreachable {} => {}
Expand Down Expand Up @@ -1085,27 +1156,32 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
fn bind_locals(
&mut self,
expected_params: &IndexVec<rty::FunctionParamIdx, rty::RefinedType<rty::FunctionParamIdx>>,
) {
) -> HashMap<rty::FunctionParamIdx, Var> {
let mut param_terms = HashMap::<rty::FunctionParamIdx, chc::Term<PlaceTypeVar>>::new();
let mut assumption = Assumption::default();

let mut outer_fn_param_vars = HashMap::new();

let bb_ty = self.basic_block_ty(self.basic_block).clone();
let params = &bb_ty.as_ref().params;
assert!(!params.is_empty());
for (param_idx, param_rty) in params.iter_enumerated() {
let param_ty = &param_rty.ty;
if let Some(local) = bb_ty.local_of_param(param_idx) {
let rty = rty::RefinedType::unrefined(param_ty.clone().subst_var(|v| {
let param_unrefined_rty =
rty::RefinedType::unrefined(param_ty.clone().subst_var(|v| {
param_terms[&v].clone().map_var(|v| match v {
PlaceTypeVar::Var(v) => v,
// TODO
_ => unimplemented!(),
})
}));
if bb_ty.mutbl_of_param(param_idx).unwrap().is_mut() || rty.ty.is_mut() {
self.env.mut_bind(local, rty);
if let Some(local) = bb_ty.local_of_param(param_idx) {
if bb_ty.mutbl_of_param(param_idx).unwrap().is_mut()
|| param_unrefined_rty.ty.is_mut()
{
self.env.mut_bind(local, param_unrefined_rty);
} else {
self.env.immut_bind(local, rty);
self.env.immut_bind(local, param_unrefined_rty);
}
let param_sort = param_ty.to_sort();
if param_sort.is_singleton() {
Expand All @@ -1123,6 +1199,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
.map_var(|v| v.shift_existential(assumption.existentials.len()));
assumption.existentials.extend(local_ty.existentials);
param_terms.insert(param_idx, term);
} else {
if param_idx >= bb_ty.locals.next_index() {
if !param_ty.to_sort().is_singleton() {
let var = self.env.immut_bind_tmp(param_unrefined_rty);
param_terms.insert(param_idx, chc::Term::var(var.into()));
let outer_fn_param_idx =
rty::FunctionParamIdx::from(param_idx.index() - bb_ty.locals.len());
outer_fn_param_vars.insert(outer_fn_param_idx, var);
}
}
}
}

Expand All @@ -1139,6 +1225,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
}

self.env.assume(assumption);

outer_fn_param_vars
}

fn unbind_atoms(&self) -> UnbindAtoms<rty::FunctionParamIdx> {
Expand Down Expand Up @@ -1204,21 +1292,21 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
self
}

pub fn run(&mut self, expected: &BasicBlockType) {
pub fn run(&mut self, expected: &BasicBlockType, expected_fn: &rty::FunctionType) {
let span = tracing::info_span!("bb", bb = ?self.basic_block);
let _guard = span.enter();
self.register_enum_defs();

let params = expected.as_ref().params.clone();
self.bind_locals(&params);
let outer_fn_param_vars = self.bind_locals(&params);
let unbind_atoms = self.unbind_atoms();
self.alloc_prophecies();
self.analyze_statements();

let term = self.elaborated_terminator();
self.analyze_terminator_binds(&term);
let ret_template = self.ret_template();
self.analyze_terminator_goto(&term, &ret_template);
self.analyze_terminator_goto(&term, &ret_template, expected_fn, &outer_fn_param_vars);

let got_ret_ty = unbind_atoms.unbind(&self.env, ret_template);
let got_ty = rty::FunctionType::new(params, got_ret_ty).into_closed_ty();
Expand Down
80 changes: 74 additions & 6 deletions src/analyze/local_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -855,20 +855,21 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
let rty = self
.type_builder
.for_template(&mut self.ctx)
.build_basic_block(live_locals, ret_ty);
.build_basic_block(&self.body, live_locals, ret_ty);
self.ctx.register_basic_block_ty(self.local_def_id, bb, rty);
}
}

fn analyze_basic_blocks(&mut self) {
fn analyze_basic_blocks(&mut self, expected_fn_ty: &rty::RefinedType) {
let expected_fn_ty = expected_fn_ty.ty.as_function().unwrap();
for bb in self.body.basic_blocks.indices() {
let rty = self.ctx.basic_block_ty(self.local_def_id, bb).clone();
let drop_points = self.drop_points[&bb].clone();
self.ctx
.basic_block_analyzer(self.local_def_id, bb)
.body(self.body.clone())
.drop_points(drop_points)
.run(&rty);
.run(&rty, &expected_fn_ty);
}
}

Expand Down Expand Up @@ -954,13 +955,80 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
rty::FunctionType::new(params, bb_ty.as_ref().ret.clone().map_var(|v| subst[&v]))
}

// Inner function type of BasicBlockType contains a extra parameters that carries original
// function parameter values. `trucate_entry_ty` removes these extra parameters to subtype
// entry_ty against the function type.
//
// before: (_1: int, _2: int, int, { int | p4 ν $0 $1 $2 }) → { int | p5 ν $0 $1 $2 $3 }
// after: (_1: int, _2: { int | p4 v $0 $1 $0 }) → { int | p5 ν $0 $1 _1 _2 }
fn truncate_entry_ty(&self, entry_ty: &mut BasicBlockType) {
let last_param_idx = entry_ty.as_ref().params.last_index().unwrap();
let last_param_ty = entry_ty.as_ref().params.raw.last().unwrap();

let mut mapping = HashMap::new();
for (idx, param_ty) in entry_ty.as_ref().params.iter_enumerated() {
let mapped_idx = if idx >= entry_ty.locals.next_index() {
let outer_fn_param_idx =
rty::FunctionParamIdx::from(idx.index() - entry_ty.locals.len());
let corresponding_local = analyze::local_of_function_param(outer_fn_param_idx);
entry_ty
.param_of_local(corresponding_local)
.unwrap_or_else(|| {
// XXX: if local-param is empty and there are some outer fn param,
// idx == $0, corresponding_local is _1 and param_of_local returns None
assert!(idx.index() == 0);
idx
})
} else {
idx
};
mapping.insert(idx, mapped_idx);

// to be sure
if idx != last_param_idx {
assert!(param_ty.refinement.is_top());
}
}

let last_param_refinement = last_param_ty.refinement.clone().map_refine_var(|v| {
let idx = match v {
rty::RefinedTypeVar::Free(idx) => idx,
rty::RefinedTypeVar::Value => last_param_idx,
v => return v,
};
let mapped_idx = mapping[&idx];
if Some(mapped_idx) == entry_ty.locals.last_index() {
rty::RefinedTypeVar::Value
} else {
rty::RefinedTypeVar::Free(mapped_idx)
}
});

if entry_ty.locals.len() != 0 {
entry_ty.ty.params.truncate(entry_ty.locals.len());
}

entry_ty.ty.params.raw.last_mut().unwrap().refinement = last_param_refinement;
entry_ty.ty.ret.refinement = entry_ty
.ty
.ret
.refinement
.clone()
.map_var(|idx| mapping[&idx]);
}

fn assert_entry(&mut self, expected: &rty::RefinedType) {
let entry_ty = self.ctx.basic_block_ty(self.local_def_id, mir::START_BLOCK);
let mut entry_ty = self
.ctx
.basic_block_ty(self.local_def_id, mir::START_BLOCK)
.clone();
tracing::debug!(expected = %expected.display(), entry = %entry_ty.display(), "assert_entry before");
let mut expected = expected.ty.as_function().cloned().unwrap();
self.elaborate_mut_params(&mut expected);

let entry_ty = self.elaborate_unused_args(entry_ty, &expected);
self.truncate_entry_ty(&mut entry_ty);
let entry_ty = self.elaborate_unused_args(&entry_ty, &expected);
expected.ret.refinement = rty::Refinement::top();

tracing::debug!(expected = %expected.display(), entry = %entry_ty.display(), "assert_entry after");
let clauses = rty::relate_sub_closed_type(&entry_ty.into(), &expected.into());
Expand Down Expand Up @@ -1000,7 +1068,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
self.unelaborate_derefs();
self.reassign_local_mutabilities();
self.refine_basic_blocks();
self.analyze_basic_blocks();
self.analyze_basic_blocks(expected);
self.assert_entry(expected);
}
}
Loading
Loading