use rustc_data_structures::graph::dominators::Dominators;
use rustc_index::bit_set::BitSet;
use rustc_index::{IndexSlice, IndexVec};
use rustc_middle::bug;
use rustc_middle::middle::resolve_bound_vars::Set1;
use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
use rustc_middle::ty::{ParamEnv, TyCtxt};
use tracing::{debug, instrument, trace};
pub(super) struct SsaLocals {
    assignments: IndexVec<Local, Set1<DefLocation>>,
    assignment_order: Vec<Local>,
    copy_classes: IndexVec<Local, Local>,
    direct_uses: IndexVec<Local, u32>,
    borrowed_locals: BitSet<Local>,
}
pub(super) enum AssignedValue<'a, 'tcx> {
    Arg,
    Rvalue(&'a mut Rvalue<'tcx>),
    Terminator,
}
impl SsaLocals {
    pub(super) fn new<'tcx>(
        tcx: TyCtxt<'tcx>,
        body: &Body<'tcx>,
        param_env: ParamEnv<'tcx>,
    ) -> SsaLocals {
        let assignment_order = Vec::with_capacity(body.local_decls.len());
        let assignments = IndexVec::from_elem(Set1::Empty, &body.local_decls);
        let dominators = body.basic_blocks.dominators();
        let direct_uses = IndexVec::from_elem(0, &body.local_decls);
        let borrowed_locals = BitSet::new_empty(body.local_decls.len());
        let mut visitor = SsaVisitor {
            body,
            assignments,
            assignment_order,
            dominators,
            direct_uses,
            borrowed_locals,
        };
        for local in body.args_iter() {
            visitor.assignments[local] = Set1::One(DefLocation::Argument);
            visitor.assignment_order.push(local);
        }
        for (bb, data) in traversal::reverse_postorder(body) {
            visitor.visit_basic_block_data(bb, data);
        }
        for var_debug_info in &body.var_debug_info {
            visitor.visit_var_debug_info(var_debug_info);
        }
        debug!(?visitor.borrowed_locals);
        for local in visitor.borrowed_locals.iter() {
            if !body.local_decls[local].ty.is_freeze(tcx, param_env) {
                visitor.assignments[local] = Set1::Many;
            }
        }
        debug!(?visitor.assignments);
        debug!(?visitor.direct_uses);
        visitor
            .assignment_order
            .retain(|&local| matches!(visitor.assignments[local], Set1::One(_)));
        debug!(?visitor.assignment_order);
        let mut ssa = SsaLocals {
            assignments: visitor.assignments,
            assignment_order: visitor.assignment_order,
            direct_uses: visitor.direct_uses,
            borrowed_locals: visitor.borrowed_locals,
            copy_classes: IndexVec::default(),
        };
        compute_copy_classes(&mut ssa, body);
        ssa
    }
    pub(super) fn num_locals(&self) -> usize {
        self.assignments.len()
    }
    pub(super) fn locals(&self) -> impl Iterator<Item = Local> {
        self.assignments.indices()
    }
    pub(super) fn is_ssa(&self, local: Local) -> bool {
        matches!(self.assignments[local], Set1::One(_))
    }
    pub(super) fn num_direct_uses(&self, local: Local) -> u32 {
        self.direct_uses[local]
    }
    #[inline]
    pub(super) fn assignment_dominates(
        &self,
        dominators: &Dominators<BasicBlock>,
        local: Local,
        location: Location,
    ) -> bool {
        match self.assignments[local] {
            Set1::One(def) => def.dominates(location, dominators),
            _ => false,
        }
    }
    pub(super) fn assignments<'a, 'tcx>(
        &'a self,
        body: &'a Body<'tcx>,
    ) -> impl Iterator<Item = (Local, &'a Rvalue<'tcx>, Location)> + 'a {
        self.assignment_order.iter().filter_map(|&local| {
            if let Set1::One(DefLocation::Assignment(loc)) = self.assignments[local] {
                let stmt = body.stmt_at(loc).left()?;
                let Some((target, rvalue)) = stmt.kind.as_assign() else { bug!() };
                assert_eq!(target.as_local(), Some(local));
                Some((local, rvalue, loc))
            } else {
                None
            }
        })
    }
    pub(super) fn for_each_assignment_mut<'tcx>(
        &self,
        basic_blocks: &mut IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
        mut f: impl FnMut(Local, AssignedValue<'_, 'tcx>, Location),
    ) {
        for &local in &self.assignment_order {
            match self.assignments[local] {
                Set1::One(DefLocation::Argument) => f(local, AssignedValue::Arg, Location {
                    block: START_BLOCK,
                    statement_index: 0,
                }),
                Set1::One(DefLocation::Assignment(loc)) => {
                    let bb = &mut basic_blocks[loc.block];
                    let stmt = &mut bb.statements[loc.statement_index];
                    let StatementKind::Assign(box (target, ref mut rvalue)) = stmt.kind else {
                        bug!()
                    };
                    assert_eq!(target.as_local(), Some(local));
                    f(local, AssignedValue::Rvalue(rvalue), loc)
                }
                Set1::One(DefLocation::CallReturn { call, .. }) => {
                    let bb = &mut basic_blocks[call];
                    let loc = Location { block: call, statement_index: bb.statements.len() };
                    f(local, AssignedValue::Terminator, loc)
                }
                _ => {}
            }
        }
    }
    pub(super) fn copy_classes(&self) -> &IndexSlice<Local, Local> {
        &self.copy_classes
    }
    pub(super) fn borrowed_locals(&self) -> &BitSet<Local> {
        &self.borrowed_locals
    }
    pub(super) fn meet_copy_equivalence(&self, property: &mut BitSet<Local>) {
        for (local, &head) in self.copy_classes.iter_enumerated() {
            if !property.contains(local) {
                property.remove(head);
            }
        }
        for (local, &head) in self.copy_classes.iter_enumerated() {
            if !property.contains(head) {
                property.remove(local);
            }
        }
        #[cfg(debug_assertions)]
        for (local, &head) in self.copy_classes.iter_enumerated() {
            assert_eq!(property.contains(local), property.contains(head));
        }
    }
}
struct SsaVisitor<'a, 'tcx> {
    body: &'a Body<'tcx>,
    dominators: &'a Dominators<BasicBlock>,
    assignments: IndexVec<Local, Set1<DefLocation>>,
    assignment_order: Vec<Local>,
    direct_uses: IndexVec<Local, u32>,
    borrowed_locals: BitSet<Local>,
}
impl SsaVisitor<'_, '_> {
    fn check_dominates(&mut self, local: Local, loc: Location) {
        let set = &mut self.assignments[local];
        let assign_dominates = match *set {
            Set1::Empty | Set1::Many => false,
            Set1::One(def) => def.dominates(loc, self.dominators),
        };
        if !assign_dominates {
            *set = Set1::Many;
        }
    }
}
impl<'tcx> Visitor<'tcx> for SsaVisitor<'_, 'tcx> {
    fn visit_local(&mut self, local: Local, ctxt: PlaceContext, loc: Location) {
        match ctxt {
            PlaceContext::MutatingUse(MutatingUseContext::Projection)
            | PlaceContext::NonMutatingUse(NonMutatingUseContext::Projection) => bug!(),
            PlaceContext::NonMutatingUse(NonMutatingUseContext::RawBorrow)
            | PlaceContext::MutatingUse(_) => {
                self.assignments[local] = Set1::Many;
            }
            PlaceContext::NonMutatingUse(
                NonMutatingUseContext::SharedBorrow | NonMutatingUseContext::FakeBorrow,
            ) => {
                self.borrowed_locals.insert(local);
                self.check_dominates(local, loc);
                self.direct_uses[local] += 1;
            }
            PlaceContext::NonMutatingUse(_) => {
                self.check_dominates(local, loc);
                self.direct_uses[local] += 1;
            }
            PlaceContext::NonUse(_) => {}
        }
    }
    fn visit_place(&mut self, place: &Place<'tcx>, ctxt: PlaceContext, loc: Location) {
        let location = match ctxt {
            PlaceContext::MutatingUse(MutatingUseContext::Store) => {
                Some(DefLocation::Assignment(loc))
            }
            PlaceContext::MutatingUse(MutatingUseContext::Call) => {
                let call = loc.block;
                let TerminatorKind::Call { target, .. } =
                    self.body.basic_blocks[call].terminator().kind
                else {
                    bug!()
                };
                Some(DefLocation::CallReturn { call, target })
            }
            _ => None,
        };
        if let Some(location) = location
            && let Some(local) = place.as_local()
        {
            self.assignments[local].insert(location);
            if let Set1::One(_) = self.assignments[local] {
                self.assignment_order.push(local);
            }
        } else if place.projection.first() == Some(&PlaceElem::Deref) {
            if ctxt.is_use() {
                let new_ctxt = PlaceContext::NonMutatingUse(NonMutatingUseContext::Copy);
                self.visit_projection(place.as_ref(), new_ctxt, loc);
                self.check_dominates(place.local, loc);
            }
        } else {
            self.visit_projection(place.as_ref(), ctxt, loc);
            self.visit_local(place.local, ctxt, loc);
        }
    }
}
#[instrument(level = "trace", skip(ssa, body))]
fn compute_copy_classes(ssa: &mut SsaLocals, body: &Body<'_>) {
    let mut direct_uses = std::mem::take(&mut ssa.direct_uses);
    let mut copies = IndexVec::from_fn_n(|l| l, body.local_decls.len());
    for (local, rvalue, _) in ssa.assignments(body) {
        let (Rvalue::Use(Operand::Copy(place) | Operand::Move(place))
        | Rvalue::CopyForDeref(place)) = rvalue
        else {
            continue;
        };
        let Some(rhs) = place.as_local() else { continue };
        let local_ty = body.local_decls()[local].ty;
        let rhs_ty = body.local_decls()[rhs].ty;
        if local_ty != rhs_ty {
            trace!("skipped `{local:?} = {rhs:?}` due to subtyping: {local_ty} != {rhs_ty}");
            continue;
        }
        if !ssa.is_ssa(rhs) {
            continue;
        }
        let head = copies[rhs];
        if local == RETURN_PLACE {
            if body.local_kind(head) != LocalKind::Temp {
                continue;
            }
            for h in copies.iter_mut() {
                if *h == head {
                    *h = RETURN_PLACE;
                }
            }
        } else {
            copies[local] = head;
        }
        direct_uses[rhs] -= 1;
    }
    debug!(?copies);
    debug!(?direct_uses);
    #[cfg(debug_assertions)]
    for &head in copies.iter() {
        assert_eq!(copies[head], head);
    }
    debug_assert_eq!(copies[RETURN_PLACE], RETURN_PLACE);
    ssa.direct_uses = direct_uses;
    ssa.copy_classes = copies;
}
#[derive(Debug)]
pub(crate) struct StorageLiveLocals {
    storage_live: IndexVec<Local, Set1<DefLocation>>,
}
impl StorageLiveLocals {
    pub(crate) fn new(
        body: &Body<'_>,
        always_storage_live_locals: &BitSet<Local>,
    ) -> StorageLiveLocals {
        let mut storage_live = IndexVec::from_elem(Set1::Empty, &body.local_decls);
        for local in always_storage_live_locals.iter() {
            storage_live[local] = Set1::One(DefLocation::Argument);
        }
        for (block, bbdata) in body.basic_blocks.iter_enumerated() {
            for (statement_index, statement) in bbdata.statements.iter().enumerate() {
                if let StatementKind::StorageLive(local) = statement.kind {
                    storage_live[local]
                        .insert(DefLocation::Assignment(Location { block, statement_index }));
                }
            }
        }
        debug!(?storage_live);
        StorageLiveLocals { storage_live }
    }
    #[inline]
    pub(crate) fn has_single_storage(&self, local: Local) -> bool {
        matches!(self.storage_live[local], Set1::One(_))
    }
}