use crate::traits::query::evaluate_obligation::InferCtxtExt as _;
use crate::traits::{self, DefiningAnchor, ObligationCtxt, SelectionContext};
use crate::traits::TraitEngineExt as _;
use rustc_hir::def_id::DefId;
use rustc_hir::lang_items::LangItem;
use rustc_infer::traits::{Obligation, TraitEngine, TraitEngineExt as _};
use rustc_middle::arena::ArenaAllocatable;
use rustc_middle::infer::canonical::{Canonical, CanonicalQueryResponse, QueryResponse};
use rustc_middle::traits::query::NoSolution;
use rustc_middle::traits::ObligationCause;
use rustc_middle::ty::{self, Ty, TyCtxt, TypeFoldable, TypeVisitableExt};
use rustc_middle::ty::{GenericArg, ToPredicate};
use rustc_span::DUMMY_SP;
use std::fmt::Debug;
pub use rustc_infer::infer::*;
#[extension(pub trait InferCtxtExt<'tcx>)]
impl<'tcx> InferCtxt<'tcx> {
    fn type_is_copy_modulo_regions(&self, param_env: ty::ParamEnv<'tcx>, ty: Ty<'tcx>) -> bool {
        let ty = self.resolve_vars_if_possible(ty);
        if !(param_env, ty).has_infer() {
            return ty.is_copy_modulo_regions(self.tcx, param_env);
        }
        let copy_def_id = self.tcx.require_lang_item(LangItem::Copy, None);
        traits::type_known_to_meet_bound_modulo_regions(self, param_env, ty, copy_def_id)
    }
    fn type_is_sized_modulo_regions(&self, param_env: ty::ParamEnv<'tcx>, ty: Ty<'tcx>) -> bool {
        let lang_item = self.tcx.require_lang_item(LangItem::Sized, None);
        traits::type_known_to_meet_bound_modulo_regions(self, param_env, ty, lang_item)
    }
    #[instrument(level = "debug", skip(self, params), ret)]
    fn type_implements_trait(
        &self,
        trait_def_id: DefId,
        params: impl IntoIterator<Item: Into<GenericArg<'tcx>>>,
        param_env: ty::ParamEnv<'tcx>,
    ) -> traits::EvaluationResult {
        let trait_ref = ty::TraitRef::new(self.tcx, trait_def_id, params);
        let obligation = traits::Obligation {
            cause: traits::ObligationCause::dummy(),
            param_env,
            recursion_depth: 0,
            predicate: ty::Binder::dummy(trait_ref).to_predicate(self.tcx),
        };
        self.evaluate_obligation(&obligation).unwrap_or(traits::EvaluationResult::EvaluatedToErr)
    }
    fn type_implements_trait_shallow(
        &self,
        trait_def_id: DefId,
        ty: Ty<'tcx>,
        param_env: ty::ParamEnv<'tcx>,
    ) -> Option<Vec<traits::FulfillmentError<'tcx>>> {
        self.probe(|_snapshot| {
            let mut selcx = SelectionContext::new(self);
            match selcx.select(&Obligation::new(
                self.tcx,
                ObligationCause::dummy(),
                param_env,
                ty::TraitRef::new(self.tcx, trait_def_id, [ty]),
            )) {
                Ok(Some(selection)) => {
                    let mut fulfill_cx = <dyn TraitEngine<'tcx>>::new(self);
                    fulfill_cx.register_predicate_obligations(self, selection.nested_obligations());
                    Some(fulfill_cx.select_all_or_error(self))
                }
                Ok(None) | Err(_) => None,
            }
        })
    }
}
#[extension(pub trait InferCtxtBuilderExt<'tcx>)]
impl<'tcx> InferCtxtBuilder<'tcx> {
    fn enter_canonical_trait_query<K, R>(
        self,
        canonical_key: &Canonical<'tcx, K>,
        operation: impl FnOnce(&ObligationCtxt<'_, 'tcx>, K) -> Result<R, NoSolution>,
    ) -> Result<CanonicalQueryResponse<'tcx, R>, NoSolution>
    where
        K: TypeFoldable<TyCtxt<'tcx>>,
        R: Debug + TypeFoldable<TyCtxt<'tcx>>,
        Canonical<'tcx, QueryResponse<'tcx, R>>: ArenaAllocatable<'tcx>,
    {
        let (infcx, key, canonical_inference_vars) = self
            .with_opaque_type_inference(DefiningAnchor::Bubble)
            .build_with_canonical(DUMMY_SP, canonical_key);
        let ocx = ObligationCtxt::new(&infcx);
        let value = operation(&ocx, key)?;
        ocx.make_canonicalized_query_response(canonical_inference_vars, value)
    }
}