use crate::errors::{
    DuplicateValuesFor, PathMustEndInFilename, RequiresAnArgument, UnknownFormatter,
};
use crate::framework::BitSetExt;
use std::ffi::OsString;
use std::path::PathBuf;
use rustc_ast as ast;
use rustc_data_structures::work_queue::WorkQueue;
use rustc_graphviz as dot;
use rustc_hir::def_id::DefId;
use rustc_index::{Idx, IndexVec};
use rustc_middle::mir::{self, traversal, BasicBlock};
use rustc_middle::mir::{create_dump_file, dump_enabled};
use rustc_middle::ty::print::with_no_trimmed_paths;
use rustc_middle::ty::TyCtxt;
use rustc_span::symbol::{sym, Symbol};
use super::fmt::DebugWithContext;
use super::graphviz;
use super::{
    visit_results, Analysis, AnalysisDomain, Direction, GenKill, GenKillAnalysis, GenKillSet,
    JoinSemiLattice, ResultsCursor, ResultsVisitor,
};
pub type EntrySets<'tcx, A> = IndexVec<BasicBlock, <A as AnalysisDomain<'tcx>>::Domain>;
#[derive(Clone)]
pub struct Results<'tcx, A>
where
    A: Analysis<'tcx>,
{
    pub analysis: A,
    pub(super) entry_sets: EntrySets<'tcx, A>,
}
impl<'tcx, A> Results<'tcx, A>
where
    A: Analysis<'tcx>,
{
    pub fn into_results_cursor<'mir>(
        self,
        body: &'mir mir::Body<'tcx>,
    ) -> ResultsCursor<'mir, 'tcx, A> {
        ResultsCursor::new(body, self)
    }
    pub fn entry_set_for_block(&self, block: BasicBlock) -> &A::Domain {
        &self.entry_sets[block]
    }
    pub fn visit_with<'mir>(
        &mut self,
        body: &'mir mir::Body<'tcx>,
        blocks: impl IntoIterator<Item = BasicBlock>,
        vis: &mut impl ResultsVisitor<'mir, 'tcx, Self, FlowState = A::Domain>,
    ) {
        visit_results(body, blocks, self, vis)
    }
    pub fn visit_reachable_with<'mir>(
        &mut self,
        body: &'mir mir::Body<'tcx>,
        vis: &mut impl ResultsVisitor<'mir, 'tcx, Self, FlowState = A::Domain>,
    ) {
        let blocks = mir::traversal::reachable(body);
        visit_results(body, blocks.map(|(bb, _)| bb), self, vis)
    }
}
pub struct Engine<'mir, 'tcx, A>
where
    A: Analysis<'tcx>,
{
    tcx: TyCtxt<'tcx>,
    body: &'mir mir::Body<'tcx>,
    entry_sets: IndexVec<BasicBlock, A::Domain>,
    pass_name: Option<&'static str>,
    analysis: A,
    apply_statement_trans_for_block: Option<Box<dyn Fn(BasicBlock, &mut A::Domain)>>,
}
impl<'mir, 'tcx, A, D, T> Engine<'mir, 'tcx, A>
where
    A: GenKillAnalysis<'tcx, Idx = T, Domain = D>,
    D: Clone + JoinSemiLattice + GenKill<T> + BitSetExt<T>,
    T: Idx,
{
    pub fn new_gen_kill(tcx: TyCtxt<'tcx>, body: &'mir mir::Body<'tcx>, mut analysis: A) -> Self {
        if !body.basic_blocks.is_cfg_cyclic() {
            return Self::new(tcx, body, analysis, None);
        }
        let identity = GenKillSet::identity(analysis.domain_size(body));
        let mut trans_for_block = IndexVec::from_elem(identity, &body.basic_blocks);
        for (block, block_data) in body.basic_blocks.iter_enumerated() {
            let trans = &mut trans_for_block[block];
            A::Direction::gen_kill_statement_effects_in_block(
                &mut analysis,
                trans,
                block,
                block_data,
            );
        }
        let apply_trans = Box::new(move |bb: BasicBlock, state: &mut A::Domain| {
            trans_for_block[bb].apply(state);
        });
        Self::new(tcx, body, analysis, Some(apply_trans as Box<_>))
    }
}
impl<'mir, 'tcx, A, D> Engine<'mir, 'tcx, A>
where
    A: Analysis<'tcx, Domain = D>,
    D: Clone + JoinSemiLattice,
{
    pub fn new_generic(tcx: TyCtxt<'tcx>, body: &'mir mir::Body<'tcx>, analysis: A) -> Self {
        Self::new(tcx, body, analysis, None)
    }
    fn new(
        tcx: TyCtxt<'tcx>,
        body: &'mir mir::Body<'tcx>,
        analysis: A,
        apply_statement_trans_for_block: Option<Box<dyn Fn(BasicBlock, &mut A::Domain)>>,
    ) -> Self {
        let mut entry_sets =
            IndexVec::from_fn_n(|_| analysis.bottom_value(body), body.basic_blocks.len());
        analysis.initialize_start_block(body, &mut entry_sets[mir::START_BLOCK]);
        if A::Direction::IS_BACKWARD && entry_sets[mir::START_BLOCK] != analysis.bottom_value(body)
        {
            bug!("`initialize_start_block` is not yet supported for backward dataflow analyses");
        }
        Engine { analysis, tcx, body, pass_name: None, entry_sets, apply_statement_trans_for_block }
    }
    pub fn pass_name(mut self, name: &'static str) -> Self {
        self.pass_name = Some(name);
        self
    }
    pub fn iterate_to_fixpoint(self) -> Results<'tcx, A>
    where
        A::Domain: DebugWithContext<A>,
    {
        let Engine {
            mut analysis,
            body,
            mut entry_sets,
            tcx,
            apply_statement_trans_for_block,
            pass_name,
        } = self;
        let mut dirty_queue: WorkQueue<BasicBlock> = WorkQueue::with_none(body.basic_blocks.len());
        if A::Direction::IS_FORWARD {
            for (bb, _) in traversal::reverse_postorder(body) {
                dirty_queue.insert(bb);
            }
        } else {
            for (bb, _) in traversal::postorder(body) {
                dirty_queue.insert(bb);
            }
        }
        let mut state = analysis.bottom_value(body);
        while let Some(bb) = dirty_queue.pop() {
            let bb_data = &body[bb];
            state.clone_from(&entry_sets[bb]);
            let edges = A::Direction::apply_effects_in_block(
                &mut analysis,
                &mut state,
                bb,
                bb_data,
                apply_statement_trans_for_block.as_deref(),
            );
            A::Direction::join_state_into_successors_of(
                &mut analysis,
                body,
                &mut state,
                bb,
                edges,
                |target: BasicBlock, state: &A::Domain| {
                    let set_changed = entry_sets[target].join(state);
                    if set_changed {
                        dirty_queue.insert(target);
                    }
                },
            );
        }
        let results = Results { analysis, entry_sets };
        if tcx.sess.opts.unstable_opts.dump_mir_dataflow {
            let (res, results) = write_graphviz_results(tcx, body, results, pass_name);
            if let Err(e) = res {
                error!("Failed to write graphviz dataflow results: {}", e);
            }
            results
        } else {
            results
        }
    }
}
fn write_graphviz_results<'tcx, A>(
    tcx: TyCtxt<'tcx>,
    body: &mir::Body<'tcx>,
    results: Results<'tcx, A>,
    pass_name: Option<&'static str>,
) -> (std::io::Result<()>, Results<'tcx, A>)
where
    A: Analysis<'tcx>,
    A::Domain: DebugWithContext<A>,
{
    use std::fs;
    use std::io::{self, Write};
    let def_id = body.source.def_id();
    let Ok(attrs) = RustcMirAttrs::parse(tcx, def_id) else {
        return (Ok(()), results);
    };
    let file = try {
        match attrs.output_path(A::NAME) {
            Some(path) => {
                debug!("printing dataflow results for {:?} to {}", def_id, path.display());
                if let Some(parent) = path.parent() {
                    fs::create_dir_all(parent)?;
                }
                let f = fs::File::create(&path)?;
                io::BufWriter::new(f)
            }
            None if dump_enabled(tcx, A::NAME, def_id) => {
                create_dump_file(tcx, ".dot", false, A::NAME, &pass_name.unwrap_or("-----"), body)?
            }
            _ => return (Ok(()), results),
        }
    };
    let mut file = match file {
        Ok(f) => f,
        Err(e) => return (Err(e), results),
    };
    let style = match attrs.formatter {
        Some(sym::two_phase) => graphviz::OutputStyle::BeforeAndAfter,
        _ => graphviz::OutputStyle::AfterOnly,
    };
    let mut buf = Vec::new();
    let graphviz = graphviz::Formatter::new(body, results, style);
    let mut render_opts =
        vec![dot::RenderOption::Fontname(tcx.sess.opts.unstable_opts.graphviz_font.clone())];
    if tcx.sess.opts.unstable_opts.graphviz_dark_mode {
        render_opts.push(dot::RenderOption::DarkTheme);
    }
    let r = with_no_trimmed_paths!(dot::render_opts(&graphviz, &mut buf, &render_opts));
    let lhs = try {
        r?;
        file.write_all(&buf)?;
    };
    (lhs, graphviz.into_results())
}
#[derive(Default)]
struct RustcMirAttrs {
    basename_and_suffix: Option<PathBuf>,
    formatter: Option<Symbol>,
}
impl RustcMirAttrs {
    fn parse(tcx: TyCtxt<'_>, def_id: DefId) -> Result<Self, ()> {
        let mut result = Ok(());
        let mut ret = RustcMirAttrs::default();
        let rustc_mir_attrs = tcx
            .get_attrs(def_id, sym::rustc_mir)
            .flat_map(|attr| attr.meta_item_list().into_iter().flat_map(|v| v.into_iter()));
        for attr in rustc_mir_attrs {
            let attr_result = if attr.has_name(sym::borrowck_graphviz_postflow) {
                Self::set_field(&mut ret.basename_and_suffix, tcx, &attr, |s| {
                    let path = PathBuf::from(s.to_string());
                    match path.file_name() {
                        Some(_) => Ok(path),
                        None => {
                            tcx.dcx().emit_err(PathMustEndInFilename { span: attr.span() });
                            Err(())
                        }
                    }
                })
            } else if attr.has_name(sym::borrowck_graphviz_format) {
                Self::set_field(&mut ret.formatter, tcx, &attr, |s| match s {
                    sym::gen_kill | sym::two_phase => Ok(s),
                    _ => {
                        tcx.dcx().emit_err(UnknownFormatter { span: attr.span() });
                        Err(())
                    }
                })
            } else {
                Ok(())
            };
            result = result.and(attr_result);
        }
        result.map(|()| ret)
    }
    fn set_field<T>(
        field: &mut Option<T>,
        tcx: TyCtxt<'_>,
        attr: &ast::NestedMetaItem,
        mapper: impl FnOnce(Symbol) -> Result<T, ()>,
    ) -> Result<(), ()> {
        if field.is_some() {
            tcx.dcx()
                .emit_err(DuplicateValuesFor { span: attr.span(), name: attr.name_or_empty() });
            return Err(());
        }
        if let Some(s) = attr.value_str() {
            *field = Some(mapper(s)?);
            Ok(())
        } else {
            tcx.dcx()
                .emit_err(RequiresAnArgument { span: attr.span(), name: attr.name_or_empty() });
            Err(())
        }
    }
    fn output_path(&self, analysis_name: &str) -> Option<PathBuf> {
        let mut ret = self.basename_and_suffix.as_ref().cloned()?;
        let suffix = ret.file_name().unwrap(); let mut file_name: OsString = analysis_name.into();
        file_name.push("_");
        file_name.push(suffix);
        ret.set_file_name(file_name);
        Some(ret)
    }
}