use crate::ty::GenericArgsRef;
use crate::ty::{AssocItemContainer, Instance, ParamEnv, Ty, TyCtxt};
use rustc_hir::def_id::DefId;
use rustc_hir::{lang_items, LangItem};
use rustc_span::symbol::Ident;
use rustc_span::{sym, DesugaringKind, Span};
use tracing::debug;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum CallDesugaringKind {
    ForLoopIntoIter,
    QuestionBranch,
    QuestionFromResidual,
    TryBlockFromOutput,
    Await,
}
impl CallDesugaringKind {
    pub fn trait_def_id(self, tcx: TyCtxt<'_>) -> DefId {
        match self {
            Self::ForLoopIntoIter => tcx.get_diagnostic_item(sym::IntoIterator).unwrap(),
            Self::QuestionBranch | Self::TryBlockFromOutput => {
                tcx.require_lang_item(LangItem::Try, None)
            }
            Self::QuestionFromResidual => tcx.get_diagnostic_item(sym::FromResidual).unwrap(),
            Self::Await => tcx.get_diagnostic_item(sym::IntoFuture).unwrap(),
        }
    }
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum CallKind<'tcx> {
    Normal {
        self_arg: Option<Ident>,
        desugaring: Option<(CallDesugaringKind, Ty<'tcx>)>,
        method_did: DefId,
        method_args: GenericArgsRef<'tcx>,
    },
    FnCall { fn_trait_id: DefId, self_ty: Ty<'tcx> },
    Operator { self_arg: Option<Ident>, trait_id: DefId, self_ty: Ty<'tcx> },
    DerefCoercion {
        deref_target: Span,
        deref_target_ty: Ty<'tcx>,
        self_ty: Ty<'tcx>,
    },
}
pub fn call_kind<'tcx>(
    tcx: TyCtxt<'tcx>,
    param_env: ParamEnv<'tcx>,
    method_did: DefId,
    method_args: GenericArgsRef<'tcx>,
    fn_call_span: Span,
    from_hir_call: bool,
    self_arg: Option<Ident>,
) -> CallKind<'tcx> {
    let parent = tcx.opt_associated_item(method_did).and_then(|assoc| {
        let container_id = assoc.container_id(tcx);
        match assoc.container {
            AssocItemContainer::ImplContainer => tcx.trait_id_of_impl(container_id),
            AssocItemContainer::TraitContainer => Some(container_id),
        }
    });
    let fn_call = parent.and_then(|p| {
        lang_items::FN_TRAITS.iter().filter_map(|&l| tcx.lang_items().get(l)).find(|&id| id == p)
    });
    let operator = if !from_hir_call && let Some(p) = parent {
        lang_items::OPERATORS.iter().filter_map(|&l| tcx.lang_items().get(l)).find(|&id| id == p)
    } else {
        None
    };
    let is_deref = !from_hir_call && tcx.is_diagnostic_item(sym::deref_method, method_did);
    let kind = if let Some(trait_id) = fn_call {
        Some(CallKind::FnCall { fn_trait_id: trait_id, self_ty: method_args.type_at(0) })
    } else if let Some(trait_id) = operator {
        Some(CallKind::Operator { self_arg, trait_id, self_ty: method_args.type_at(0) })
    } else if is_deref {
        let deref_target = tcx.get_diagnostic_item(sym::deref_target).and_then(|deref_target| {
            Instance::try_resolve(tcx, param_env, deref_target, method_args).transpose()
        });
        if let Some(Ok(instance)) = deref_target {
            let deref_target_ty = instance.ty(tcx, param_env);
            Some(CallKind::DerefCoercion {
                deref_target: tcx.def_span(instance.def_id()),
                deref_target_ty,
                self_ty: method_args.type_at(0),
            })
        } else {
            None
        }
    } else {
        None
    };
    kind.unwrap_or_else(|| {
        debug!(?method_did, ?fn_call_span);
        let desugaring = if tcx.is_lang_item(method_did, LangItem::IntoIterIntoIter)
            && fn_call_span.desugaring_kind() == Some(DesugaringKind::ForLoop)
        {
            Some((CallDesugaringKind::ForLoopIntoIter, method_args.type_at(0)))
        } else if fn_call_span.desugaring_kind() == Some(DesugaringKind::QuestionMark) {
            if tcx.is_lang_item(method_did, LangItem::TryTraitBranch) {
                Some((CallDesugaringKind::QuestionBranch, method_args.type_at(0)))
            } else if tcx.is_lang_item(method_did, LangItem::TryTraitFromResidual) {
                Some((CallDesugaringKind::QuestionFromResidual, method_args.type_at(0)))
            } else {
                None
            }
        } else if tcx.is_lang_item(method_did, LangItem::TryTraitFromOutput)
            && fn_call_span.desugaring_kind() == Some(DesugaringKind::TryBlock)
        {
            Some((CallDesugaringKind::TryBlockFromOutput, method_args.type_at(0)))
        } else if fn_call_span.is_desugaring(DesugaringKind::Await) {
            Some((CallDesugaringKind::Await, method_args.type_at(0)))
        } else {
            None
        };
        CallKind::Normal { self_arg, desugaring, method_did, method_args }
    })
}