Theory IsaSAT_Conflict_Analysis_LLVM

theory IsaSAT_Conflict_Analysis_LLVM
imports IsaSAT_Conflict_Analysis_Defs IsaSAT_VMTF_LLVM IsaSAT_Setup_LLVM IsaSAT_LBD_LLVM
begin

sepref_def maximum_level_removed_eq_count_dec_fast_code
  is uncurry (maximum_level_removed_eq_count_dec_heur)
  :: unat_lit_assnk *a isasat_bounded_assnk a bool1_assn
  unfolding maximum_level_removed_eq_count_dec_heur_def
  apply (annot_unat_const TYPE(32))
  by sepref

definition is_decided_trail where is_decided_trail = (λ(M, xs, lvls, reasons, k).
      let r = reasons ! (atm_of (last M)) in
      RETURN (r = DECISION_REASON))

sepref_def is_decided_trail_impl
  is is_decided_trail
  :: [(λS. fst S  []  last_trail_pol_pre S)]a trail_pol_fast_assnk  bool1_assn
  unfolding is_decided_trail_def trail_pol_fast_assn_def last_trail_pol_pre_def
  by sepref

definition is_decided_hd_trail_wl_fast_code :: twl_st_wll_trail_fast2  _ where
  is_decided_hd_trail_wl_fast_code = read_trail_wl_heur_code is_decided_trail_impl
global_interpretation is_decided_hd: read_trail_param_adder0 where
  f = is_decided_trail_impl and
  f' = is_decided_trail and
  x_assn = bool1_assn and
  P = (λS. fst S  []  last_trail_pol_pre S)
  rewrites read_trail_wl_heur is_decided_trail = RETURN o is_decided_hd_trail_wl_heur and
  read_trail_wl_heur_code is_decided_trail_impl = is_decided_hd_trail_wl_fast_code and
  case_isasat_int (λM _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _. fst M  []  last_trail_pol_pre M) = is_decided_hd_trail_wl_heur_pre
  apply unfold_locales
  apply (rule is_decided_trail_impl.refine)
  subgoal
    by (auto simp: read_all_st_def is_decided_hd_trail_wl_heur_def is_decided_trail_def last_trail_pol_def Let_def
      intro!: ext
      split: isasat_int_splits)
  subgoal
    by (auto simp: is_decided_hd_trail_wl_fast_code_def)
  subgoal by (auto simp: is_decided_hd_trail_wl_heur_pre_def intro!: ext split: isasat_int_splits)
  done


definition lit_and_ann_of_propagated_trail_heur
   :: _  (nat literal × nat) nres
where
  lit_and_ann_of_propagated_trail_heur = (λ(M, _, _, reasons, _) . do {
     ASSERT(M  []  atm_of (last M) < length reasons);
     RETURN (last M, reasons ! (atm_of (last M)))})

sepref_def lit_and_ann_of_propagated_trail_heur_impl
  is lit_and_ann_of_propagated_trail_heur
  :: trail_pol_fast_assnk a (unat_lit_assn ×a sint64_nat_assn)
  unfolding lit_and_ann_of_propagated_trail_heur_def trail_pol_fast_assn_def
  by sepref

definition lit_and_ann_of_propagated_st_heur_fast_code :: twl_st_wll_trail_fast2  _ where
  lit_and_ann_of_propagated_st_heur_fast_code = read_trail_wl_heur_code lit_and_ann_of_propagated_trail_heur_impl

global_interpretation lit_and_of_proped_lit: read_trail_param_adder0 where
  f = lit_and_ann_of_propagated_trail_heur_impl and
  f' = lit_and_ann_of_propagated_trail_heur and
  x_assn = unat_lit_assn ×a sint64_nat_assn and
  P = (λS. True)
  rewrites read_trail_wl_heur lit_and_ann_of_propagated_trail_heur = lit_and_ann_of_propagated_st_heur and
  read_trail_wl_heur_code lit_and_ann_of_propagated_trail_heur_impl = lit_and_ann_of_propagated_st_heur_fast_code
  apply unfold_locales
  apply (rule lit_and_ann_of_propagated_trail_heur_impl.refine)
  subgoal
    by (auto simp: read_all_st_def lit_and_ann_of_propagated_st_heur_def lit_and_ann_of_propagated_trail_heur_def last_trail_pol_def Let_def
      intro!: ext
      split: isasat_int_splits)
  subgoal
    by (auto simp: lit_and_ann_of_propagated_st_heur_fast_code_def)
  done


definition atm_is_in_conflict_confl_heur :: _  nat literal bool nres where
  atm_is_in_conflict_confl_heur = (λ(_, D) L. do {
     ASSERT (atm_in_conflict_lookup_pre (atm_of L) D); RETURN (¬atm_in_conflict_lookup (atm_of L) D) })

sepref_def atm_is_in_conflict_confl_heur_impl
  is uncurry atm_is_in_conflict_confl_heur
  :: conflict_option_rel_assnk *a unat_lit_assnk  a bool1_assn
  unfolding atm_is_in_conflict_confl_heur_def conflict_option_rel_assn_def
  by sepref

definition atm_is_in_conflict_st_heur_fast_code :: twl_st_wll_trail_fast2  _ where
  atm_is_in_conflict_st_heur_fast_code = (λN C. read_conflict_wl_heur_code (λM. atm_is_in_conflict_confl_heur_impl M C) N)


definition atm_is_in_conflict_st_heur' :: isasat  nat literal  bool nres where
  atm_is_in_conflict_st_heur' S L = (λ(_, D). do {
     ASSERT (atm_in_conflict_lookup_pre (atm_of L) D); RETURN (¬atm_in_conflict_lookup (atm_of L) D) }) (get_conflict_wl_heur S)

global_interpretation atm_in_conflict: read_conflict_param_adder where
  f =  λS L. atm_is_in_conflict_confl_heur_impl L S and
  f' = λS L. atm_is_in_conflict_confl_heur L S and
  x_assn = bool1_assn and
  P = (λ_ _. True) and
  R = unat_lit_rel
  rewrites (λN C. read_conflict_wl_heur (λM. atm_is_in_conflict_confl_heur M C) N) = atm_is_in_conflict_st_heur' and
  (λN C. read_conflict_wl_heur_code (λM. atm_is_in_conflict_confl_heur_impl M C) N) = atm_is_in_conflict_st_heur_fast_code
  apply unfold_locales
  apply (subst lambda_comp_true,
     rule atm_is_in_conflict_confl_heur_impl.refine)
  subgoal
    by (auto simp: read_all_st_def atm_is_in_conflict_st_heur'_def atm_is_in_conflict_confl_heur_def Let_def
      intro!: ext
      split: isasat_int_splits)
  subgoal
    by (auto simp: atm_is_in_conflict_st_heur_fast_code_def)
  done

lemmas [unfolded lambda_comp_true, sepref_fr_rules] = is_decided_hd.refine lit_and_of_proped_lit.refine atm_in_conflict.refine
lemmas [unfolded inline_direct_return_node_case, llvm_code] =
  is_decided_hd_trail_wl_fast_code_def[unfolded read_all_st_code_def]
  lit_and_ann_of_propagated_st_heur_fast_code_def[unfolded read_all_st_code_def]
  atm_is_in_conflict_st_heur_fast_code_def[unfolded read_all_st_code_def]

sepref_def atm_is_in_conflict_st_heur_fast2_code
  is uncurry (atm_is_in_conflict_st_heur)
  :: [λ_. True]a unat_lit_assnk *a isasat_bounded_assnk  bool1_assn
  supply [[goals_limit=1]]
  unfolding atm_is_in_conflict_st_heur_def atm_is_in_conflict_st_heur'_def[symmetric]
  by sepref

lemma tl_state_wl_heurI: tl_state_wl_heur_pre S  fst (get_trail_wl_heur S)  []
  tl_state_wl_heur_pre S  tl_trailt_tr_pre (get_trail_wl_heur S)
  by (auto simp: tl_state_wl_heur_pre_def tl_trailt_tr_pre_def Let_def isa_bump_unset_pre_def
    vmtf_unset_pre_def lit_of_last_trail_pol_def)

lemma tl_state_wl_heur_alt_def:
  tl_state_wl_heur = (λS0. do {
       ASSERT(tl_state_wl_heur_pre S0);
       let (M, S) = extract_trail_wl_heur S0; let (vm, S) = extract_vmtf_wl_heur S;
       ASSERT (M = get_trail_wl_heur S0);
       ASSERT (vm = get_vmtf_heur S0);
       let L = lit_of_last_trail_pol M;
       let S = update_trail_wl_heur (tl_trailt_tr M) S;
       ASSERT (isa_bump_unset_pre (atm_of L) vm);
       vm  isa_bump_unset (atm_of L) vm;
       let S = update_vmtf_wl_heur vm S;
       RETURN (False, S)
  })
  by (auto simp: tl_state_wl_heur_def state_extractors Let_def intro!: ext split: isasat_int_splits)

sepref_register isa_bump_unset
sepref_def tl_state_wl_heur_fast_code
  is tl_state_wl_heur
  :: [λ_. True]a isasat_bounded_assnd  bool1_assn ×a isasat_bounded_assn
  supply [[goals_limit=1]] if_splits[split] tl_state_wl_heurI[dest]
  unfolding vmtf_unset_def bind_ref_tag_def short_circuit_conv tl_state_wl_heur_alt_def
  by sepref


definition extract_values_of_lookup_conflict :: conflict_option_rel  bool where
extract_values_of_lookup_conflict = (λ(b, (_, xs)). b)


sepref_def extract_values_of_lookup_conflict_impl
  is RETURN o extract_values_of_lookup_conflict
  :: conflict_option_rel_assnk a bool1_assn
  unfolding extract_values_of_lookup_conflict_def conflict_option_rel_assn_def
    lookup_clause_rel_assn_def
  by sepref

sepref_register extract_values_of_lookup_conflict
declare extract_values_of_lookup_conflict_impl.refine[sepref_fr_rules]

sepref_register isasat_lookup_merge_eq2 update_confl_tl_wl_heur

lemma update_confl_tl_wl_heur_alt_def:
  update_confl_tl_wl_heur = (λL C S0. do {
      let (M, S) = extract_trail_wl_heur S0;
      let (N, S) = extract_arena_wl_heur S;
      let (lbd, S) = extract_lbd_wl_heur S;
      let (outl, S) = extract_outl_wl_heur S;
      let (clvls, S) = extract_clvls_wl_heur S;
      let (vm, S) = extract_vmtf_wl_heur S;
      let (bnxs, S) = extract_conflict_wl_heur S;
      (N, lbd)  calculate_LBD_heur_st M N lbd C;
      ASSERT (clvls  1);
      let L' = atm_of L;
      ASSERT(arena_is_valid_clause_idx N C);
      (bnxs, clvls, outl) 
        if arena_length N C = 2 then isasat_lookup_merge_eq2 L M N C bnxs clvls outl
        else isa_resolve_merge_conflict_gt2 M N C bnxs clvls outl;
      let b = extract_values_of_lookup_conflict bnxs;
      let nxs = the_lookup_conflict bnxs;
      ASSERT(curry lookup_conflict_remove1_pre L (nxs)  clvls  1);
      let (nxs) = lookup_conflict_remove1 L (nxs);
      ASSERT(arena_act_pre N C);
      vm  isa_vmtf_bump_to_rescore_also_reasons_cl M N C (-L) vm;
      ASSERT(isa_bump_unset_pre L' vm);
      ASSERT(tl_trailt_tr_pre M);
      vm  isa_bump_unset L' vm;
      let S = update_trail_wl_heur (tl_trailt_tr M) S;
      let S = update_conflict_wl_heur (None_lookup_conflict b nxs) S;
      let S = update_vmtf_wl_heur vm S;
      let S = update_clvls_wl_heur (clvls - 1) S;
      let S = update_outl_wl_heur outl S;
      let S = update_arena_wl_heur N S;
      let S = update_lbd_wl_heur lbd S;
      RETURN (False, S)
   })
  unfolding update_confl_tl_wl_heur_def
  by (auto intro!: ext bind_cong simp: None_lookup_conflict_def the_lookup_conflict_def
    extract_values_of_lookup_conflict_def Let_def state_extractors split: isasat_int_splits)

sepref_def update_confl_tl_wl_fast_code
  is uncurry2 update_confl_tl_wl_heur
  :: [λ((i, L), S). isasat_fast S]a
   unat_lit_assnk *a sint64_nat_assnk *aisasat_bounded_assnd  bool1_assn ×a isasat_bounded_assn
  supply [[goals_limit=1]] isasat_fast_length_leD[intro]
  unfolding update_confl_tl_wl_heur_alt_def
    PR_CONST_def
  apply (rewrite at If (_ = ) snat_const_fold[where 'a=64])
  apply (annot_unat_const TYPE (32))
  by sepref

(*TODO create mop_isa_bump_unset*)
(*TODO Move*)
sepref_register is_in_conflict_st atm_is_in_conflict_st_heur
sepref_def skip_and_resolve_loop_wl_D_fast
  is skip_and_resolve_loop_wl_D_heur
  :: [λS. isasat_fast S]a isasat_bounded_assnd  isasat_bounded_assn
  supply [[goals_limit=1]]
    skip_and_resolve_loop_wl_DI[intro]
    isasat_fast_after_skip_and_resolve_loop_wl_D_heur_inv[intro]
  unfolding skip_and_resolve_loop_wl_D_heur_def
  apply (rewrite at ¬_  ¬ _ short_circuit_conv)
  by sepref (* slow *)

experiment
begin
  export_llvm
    get_count_max_lvls_heur_impl
    maximum_level_removed_eq_count_dec_fast_code
    is_decided_hd_trail_wl_fast_code
    lit_and_ann_of_propagated_st_heur_fast_code
    is_in_option_lookup_conflict_code
    atm_is_in_conflict_st_heur_fast_code
    lit_of_last_trail_fast_code
    tl_state_wl_heur_fast_code
    None_lookup_conflict_impl
    extract_values_of_lookup_conflict_impl
    update_confl_tl_wl_fast_code
    skip_and_resolve_loop_wl_D_fast

end



end