Theory IsaSAT_Backtrack_LLVM

theory IsaSAT_Backtrack_LLVM
  imports IsaSAT_Backtrack_Defs IsaSAT_VMTF_State_LLVM IsaSAT_Lookup_Conflict_LLVM
    IsaSAT_Rephase_State_LLVM IsaSAT_LBD_LLVM IsaSAT_Proofs_LLVM
    IsaSAT_Stats_LLVM
begin

hide_const (open) NEMonad.ASSERT NEMonad.RETURN

lemma isa_empty_conflict_and_extract_clause_heur_alt_def:
    isa_empty_conflict_and_extract_clause_heur M D outl = do {
     let C = replicate (length outl) (outl!0);
     (D, C, _)  WHILET
         (λ(D, C, i). i < length_uint32_nat outl)
         (λ(D, C, i). do {
           ASSERT(i < length outl);
           ASSERT(i < length C);
           ASSERT(lookup_conflict_remove1_pre (outl ! i, D));
           let D = lookup_conflict_remove1 (outl ! i) D;
           let C = C[i := outl ! i];
	   ASSERT(get_level_pol_pre (M, C!i));
	   ASSERT(get_level_pol_pre (M, C!1));
	   ASSERT(1 < length C);
           let L1 = C!i;
           let L2 = C!1;
           let C = (if get_level_pol M L1 > get_level_pol M L2 then swap C 1 i else C);
           ASSERT(i+1  unat32_max);
           RETURN (D, C, i+1)
         })
        (D, C, 1);
     ASSERT(length outl  1  length C > 1);
     ASSERT(length outl  1   get_level_pol_pre (M, C!1));
     RETURN ((True, D), C, if length outl = 1 then 0 else get_level_pol M (C!1))
  }
  unfolding isa_empty_conflict_and_extract_clause_heur_def
  by auto

sepref_def empty_conflict_and_extract_clause_heur_fast_code
  is uncurry2 (isa_empty_conflict_and_extract_clause_heur)
  :: [λ((M, D), outl). outl  []  length outl  unat32_max]a
      trail_pol_fast_assnk *a lookup_clause_rel_assnd *a out_learned_assnk 
       (conflict_option_rel_assn) ×a clause_ll_assn ×a uint32_nat_assn
  supply [[goals_limit=1]] image_image[simp]
  supply [simp] = max_snat_def unat32_max_def
  unfolding isa_empty_conflict_and_extract_clause_heur_alt_def
    larray_fold_custom_replicate length_uint32_nat_def conflict_option_rel_assn_def
  apply (rewrite at  in _ !1 snat_const_fold[where 'a=64])+
  apply (rewrite at  in _ !0 snat_const_fold[where 'a=64])
  apply (rewrite at swap _  _ snat_const_fold[where 'a=64])
  apply (rewrite at  in (_, _, _ + 1) snat_const_fold[where 'a=64])
  apply (rewrite at  in (_, _, 1) snat_const_fold[where 'a=64])
  apply (rewrite at  in If (length _ = ) snat_const_fold[where 'a=64])
  apply (annot_unat_const TYPE(32))
  unfolding gen_swap convert_swap
  by sepref


lemma emptied_list_alt_def: emptied_list xs = take 0 xs
  by (auto simp: emptied_list_def)

sepref_def empty_cach_code
  is empty_cach_ref_set
  :: cach_refinement_l_assnd a cach_refinement_l_assn
  supply [[goals_limit=1]]
  unfolding empty_cach_ref_set_def comp_def cach_refinement_l_assn_def emptied_list_alt_def
  apply (annot_snat_const TYPE(64))
  apply (rewrite at _[ := SEEN_UNKNOWN] value_of_atm_def[symmetric])
  apply (rewrite at _[ := SEEN_UNKNOWN] index_of_atm_def[symmetric])
  by sepref



theorem empty_cach_code_empty_cach_ref[sepref_fr_rules]:
  (empty_cach_code, RETURN  empty_cach_ref)
     [empty_cach_ref_pre]a
    cach_refinement_l_assnd  cach_refinement_l_assn
  (is ?c  [?pre]a ?im  ?f)
proof -
  have H: ?c
    [comp_PRE Id
     (λ(cach, supp).
         (Lset supp. L < length cach) 
         length supp  Suc (unat32_max div 2) 
         (L<length cach. cach ! L  SEEN_UNKNOWN  L  set supp))
     (λx y. True)
     (λx. nofail ((RETURN  empty_cach_ref) x))]a
      hrp_comp (cach_refinement_l_assnd)
                     Id  hr_comp cach_refinement_l_assn Id
    (is _  [?pre']a ?im'  ?f')
    using hfref_compI_PRE[OF empty_cach_code.refine[unfolded PR_CONST_def]
        empty_cach_ref_set_empty_cach_ref] by simp
  have pre: ?pre' h x if ?pre x for x h
    using that by (auto simp: comp_PRE_def trail_pol_def
        ann_lits_split_reasons_def empty_cach_ref_pre_def)
  have im: ?im' = ?im
    by simp
  have f: ?f' = ?f
    by auto
  show ?thesis
    apply (rule hfref_weaken_pre[OF ])
     defer
    using H unfolding im f apply assumption
    using pre ..
qed

sepref_register fm_add_new_fast

lemma isasat_fast_length_leD: isasat_fast S  Suc (length (get_clauses_wl_heur S)) < max_snat 64
  by (cases S) (auto simp: isasat_fast_def max_snat_def snat64_max_def)

sepref_register update_propagation_heuristics
sepref_def update_heuristics_stats_impl
  is uncurry (RETURN oo update_propagation_heuristics_stats)
  :: uint32_nat_assnk *a heuristic_int_assnd a heuristic_int_assn
  unfolding update_propagation_heuristics_stats_def heuristic_int_assn_def
  by sepref

sepref_def update_heuristics_impl
  is uncurry (RETURN oo update_propagation_heuristics)
  :: uint32_nat_assnk *a heuristic_assnd a heuristic_assn
  unfolding update_propagation_heuristics_def
  by sepref

(*TODO Move to isasat_fast_countD*)
lemma isasat_fast_countD_tmp:
  isasat_fast S  clss_size_lcountUEk (get_learned_count S) < unat64_max
  by (auto simp: isasat_fast_def learned_clss_count_def)

lemma propagate_unit_bt_wl_D_int_alt_def:
    propagate_unit_bt_wl_D_int = (λL S0. do {
      let (M, S) = extract_trail_wl_heur S0;
      let (N, S) = extract_arena_wl_heur S;
      ASSERT (N = get_clauses_wl_heur S0);
      let (lcount, S) = extract_lcount_wl_heur S;
      ASSERT (lcount = get_learned_count S0);
      let (heur, S) = extract_heur_wl_heur S;
      let (stats, S) = extract_stats_wl_heur S;
      let (lbd, S) = extract_lbd_wl_heur S;
      let (vm0, S) = extract_vmtf_wl_heur S;
      vm  isa_bump_heur_flush M vm0;
      glue  get_LBD lbd;
      lbd  lbd_empty lbd;
      j  mop_isa_length_trail M;
      ASSERT(0  DECISION_REASON);
      ASSERT(cons_trail_Propagated_tr_pre ((- L, 0::nat), M));
      M  cons_trail_Propagated_tr (- L) 0 M;
      let stats = incr_units_since_last_GC (incr_uset stats);
      let S = update_stats_wl_heur stats S;
      let S = update_trail_wl_heur M S;
      let S = update_lbd_wl_heur lbd S;
      let S = update_literals_to_update_wl_heur j S;
      let S = update_heur_wl_heur (unset_fully_propagated_heur (heuristic_reluctant_tick (update_propagation_heuristics glue heur))) S;
      let S = update_lcount_wl_heur (clss_size_incr_lcountUEk lcount) S;
      let S = update_arena_wl_heur N S;
      let S = update_vmtf_wl_heur vm S;
      let _ = log_unit_clause (-L);
        RETURN S})
  by (auto simp: propagate_unit_bt_wl_D_int_def state_extractors log_unit_clause_def intro!: ext split: isasat_int_splits)

sepref_register cons_trail_Propagated_tr update_heur_wl_heur
sepref_def propagate_unit_bt_wl_D_fast_code
  is uncurry propagate_unit_bt_wl_D_int
  :: [λ(L, S). isasat_fast S]a unat_lit_assnk *a isasat_bounded_assnd  isasat_bounded_assn
  supply [[goals_limit = 1]]
    isasat_fast_countD[dest] isasat_fast_countD_tmp[dest]
  unfolding propagate_unit_bt_wl_D_int_alt_def
    PR_CONST_def
  apply (annot_snat_const TYPE(64))
  by sepref


definition propagate_bt_wl_D_heur_extract where
  propagate_bt_wl_D_heur_extract = (λS0. do {
      let (M,S) = extract_trail_wl_heur S0;
      let (vdom, S) = extract_vdom_wl_heur S;
      let (N0, S) = extract_arena_wl_heur S;
      let (W0, S) = extract_watchlist_wl_heur S;
      let (lcount, S) = extract_lcount_wl_heur S;
      let (heur, S) = extract_heur_wl_heur S;
      let (stats, S) = extract_stats_wl_heur S;
      let (lbd, S) = extract_lbd_wl_heur S;
      let (vm0, S) = extract_vmtf_wl_heur S;
      RETURN (M, vdom, N0, W0, lcount, heur, stats, lbd, vm0, S)})

sepref_def propagate_bt_wl_D_heur_extract_impl
  is propagate_bt_wl_D_heur_extract
  :: isasat_bounded_assnd a trail_pol_fast_assn ×a aivdom_assn ×a arena_fast_assn ×a
  watchlist_fast_assn ×a lcount_assn ×a heuristic_assn ×a isasat_stats_assn ×a lbd_assn ×a
  heuristic_bump_assn ×a isasat_bounded_assn
  unfolding propagate_bt_wl_D_heur_extract_def
  by sepref

definition propagate_bt_wl_D_heur_update where
  propagate_bt_wl_D_heur_update = (λS0 M vdom N0 W0 lcount heur stats lbd vm0 j. do {
      let (S) = update_trail_wl_heur M S0;
      let (S) = update_vdom_wl_heur vdom S;
      let (S) = update_arena_wl_heur N0 S;
      let (S) = update_watchlist_wl_heur W0 S;
      let (S) = update_lcount_wl_heur lcount S;
      let (S) = update_heur_wl_heur heur S;
      let (S) = update_stats_wl_heur stats S;
      let S = update_lbd_wl_heur lbd S;
      let S = update_vmtf_wl_heur vm0 S;
      let S = update_clvls_wl_heur 0 S;
      let S = update_literals_to_update_wl_heur j S;
      RETURN (S)})

sepref_def propagate_bt_wl_D_heur_update_impl
  is uncurry10 propagate_bt_wl_D_heur_update
  :: isasat_bounded_assnd *a trail_pol_fast_assnd *a aivdom_assnd *a arena_fast_assnd *a
  watchlist_fast_assnd *a lcount_assnd *a heuristic_assnd *a isasat_stats_assnd *a lbd_assnd *a
  heuristic_bump_assnd *a sint64_nat_assnk  a  isasat_bounded_assn
  supply [[goals_limit = 1]]
  unfolding propagate_bt_wl_D_heur_update_def
  apply (rewrite at update_clvls_wl_heur  _ unat_const_fold[where 'a=32])
  by sepref

lemma propagate_bt_wl_D_heur_alt_def:
  propagate_bt_wl_D_heur = (λL C S0. do {
      (M, vdom, N0, W0, lcount, heur, stats, lbd, vm0, S)  propagate_bt_wl_D_heur_extract S0;
      ASSERT (N0 = get_clauses_wl_heur S0);
      ASSERT (vdom = get_aivdom S0);
      ASSERT(length (get_vdom_aivdom vdom)  length N0);
      ASSERT(length (get_avdom_aivdom vdom)  length N0);
      ASSERT(nat_of_lit (C!1) < length W0  nat_of_lit (-L) < length W0);
      ASSERT(length C > 1);
      let L' = C!1;
      ASSERT(length C  unat32_max div 2 + 1);
      vm  isa_bump_rescore C M vm0;
      glue  get_LBD lbd;
      let b = False;
      let l = 2;
      let b' = (length C = l);
      ASSERT(isasat_fast S0  append_and_length_fast_code_pre ((b, C), N0));
      ASSERT(isasat_fast S0  clss_size_lcount lcount < snat64_max);
      (N, i)  fm_add_new b C N0;
      ASSERT(update_lbd_pre ((i, glue), N));
      let N = update_lbd_and_mark_used i glue N;
      ASSERT(isasat_fast S0  length_ll W0 (nat_of_lit (-L)) < snat64_max);
      let W = W0[nat_of_lit (- L) := W0 ! nat_of_lit (- L) @ [(i, L', b')]];
      ASSERT(isasat_fast S0  length_ll W (nat_of_lit L') < snat64_max);
      let W = W[nat_of_lit L' := W!nat_of_lit L' @ [(i, -L, b')]];
      lbd  lbd_empty lbd;
      j  mop_isa_length_trail M;
      ASSERT(i  DECISION_REASON);
      ASSERT(cons_trail_Propagated_tr_pre ((-L, i), M));
      M  cons_trail_Propagated_tr (- L) i M;
      vm  isa_bump_heur_flush M vm;
      heur  mop_save_phase_heur (atm_of L') (is_neg L') heur;
      S  propagate_bt_wl_D_heur_update S M (add_learned_clause_aivdom i vdom) N
          W (clss_size_incr_lcount lcount) (unset_fully_propagated_heur (heuristic_reluctant_tick (update_propagation_heuristics glue heur))) (add_lbd (of_nat glue) stats) lbd vm j;
        _  log_new_clause_heur S i;
      S  maybe_mark_added_clause_heur2 S i;
      RETURN (S)
  })
  unfolding propagate_bt_wl_D_heur_def Let_def propagate_bt_wl_D_heur_update_def
          propagate_bt_wl_D_heur_extract_def nres_monad3
  by (auto simp: propagate_bt_wl_D_heur_def Let_def state_extractors propagate_bt_wl_D_heur_update_def
          propagate_bt_wl_D_heur_extract_def intro!: ext bind_cong[OF refl]
          split: isasat_int_splits)

lemmas [sepref_bounds_simps] =
  max_snat_def[of 64, simplified]
  max_unat_def[of 64, simplified]

definition two_sint64 :: nat where [simp]: two_sint64 = 2
lemma [sepref_fr_rules]:
   (uncurry0 (Mreturn 2), uncurry0 (RETURN two_sint64))  unit_assnk a sint64_nat_assn
  apply sepref_to_hoare
  apply (vcg, auto simp: snat_rel_def snat.rel_def br_def snat_invar_def ENTAILS_def
      snat_numeral max_snat_def exists_eq_star_conv Exists_eq_simp
      sep_conj_commute pure_true_conv)
  done

section Backtrack with direct extraction of literal if highest level

lemma le_unat32_max_div_2_le_unat32_max: a  unat32_max div 2 + 1  a  unat32_max
  by (auto simp: unat32_max_def snat64_max_def)

lemma propagate_bt_wl_D_fast_code_isasat_fastI2: isasat_fast b 
       a < length (get_clauses_wl_heur b)  a  snat64_max
  by (cases b) (auto simp: isasat_fast_def)

lemma propagate_bt_wl_D_fast_code_isasat_fastI3: isasat_fast b 
       a  length (get_clauses_wl_heur b)  a < snat64_max
  by (cases b) (auto simp: isasat_fast_def snat64_max_def unat32_max_def)

sepref_register propagate_bt_wl_D_heur_update propagate_bt_wl_D_heur_extract two_sint64
sepref_def propagate_bt_wl_D_fast_codeXX
  is uncurry2 propagate_bt_wl_D_heur
  :: [λ((L, C), S). isasat_fast S]a
      unat_lit_assnk *a clause_ll_assnk *a isasat_bounded_assnd  isasat_bounded_assn
  supply [[goals_limit = 1]] append_ll_def[simp] isasat_fast_length_leD[dest]
     propagate_bt_wl_D_fast_code_isasat_fastI2[intro] length_ll_def[simp]
     propagate_bt_wl_D_fast_code_isasat_fastI3[intro]
     isasat_fast_countD[dest]
  unfolding propagate_bt_wl_D_heur_alt_def
    fm_add_new_fast_def[symmetric]
  unfolding delete_index_and_swap_update_def[symmetric] append_update_def[symmetric]
    append_ll_def[symmetric] append_ll_def[symmetric] two_sint64_def[symmetric]
    PR_CONST_def save_phase_def fold_tuple_optimizations
  apply (annot_snat_const TYPE(64))
  by sepref

lemma extract_shorter_conflict_list_heur_st_alt_def:
    extract_shorter_conflict_list_heur_st = (λS0. do {
     let (M,S) = extract_trail_wl_heur S0;
     let (N, S) = extract_arena_wl_heur S;
     ASSERT (N = get_clauses_wl_heur S0);
     let (lbd, S) = extract_lbd_wl_heur S;
     let (vm0, S) = extract_vmtf_wl_heur S;
     let (outl, S) = extract_outl_wl_heur S;
     let (bD, S) = extract_conflict_wl_heur S;
     let (ccach, S) = extract_ccach_wl_heur S;
     lbd  mark_lbd_from_list_heur M outl lbd;
     let D =  the_lookup_conflict bD;
     ASSERT(fst M  []);
     let K = lit_of_last_trail_pol M;
     ASSERT(0 < length outl);
     ASSERT(lookup_conflict_remove1_pre (-K, D));
     let D = lookup_conflict_remove1 (-K) D;
     let outl = outl[0 := -K];
     vm  isa_vmtf_mark_to_rescore_also_reasons M N outl (-K) vm0;
     (D, ccach, outl)  isa_minimize_and_extract_highest_lookup_conflict M N D ccach lbd outl;
     ASSERT(empty_cach_ref_pre ccach);
     let ccach = empty_cach_ref ccach;
     ASSERT(outl  []  length outl  unat32_max);
     (D, C, n)  isa_empty_conflict_and_extract_clause_heur M D outl;
      let S = update_trail_wl_heur M S;
      let S = update_arena_wl_heur N S;
      let S = update_vmtf_wl_heur vm S;
      let S = update_lbd_wl_heur lbd S;
      let S = update_outl_wl_heur (take 1 outl) S;
      let S = update_ccach_wl_heur ccach S;
      let S = update_conflict_wl_heur D S;
     RETURN (S, n, C)
  })
  unfolding extract_shorter_conflict_list_heur_st_def
  by (auto simp: the_lookup_conflict_def Let_def state_extractors intro!: ext bind_cong[OF refl]
    split: isasat_int_splits)

sepref_register isa_minimize_and_extract_highest_lookup_conflict
    isa_vmtf_mark_to_rescore_also_reasons

sepref_def extract_shorter_conflict_list_heur_st_fast
  is extract_shorter_conflict_list_heur_st
  :: [λS. length (get_clauses_wl_heur S)  snat64_max]a
        isasat_bounded_assnd  isasat_bounded_assn ×a uint32_nat_assn ×a clause_ll_assn
  supply [[goals_limit=1]]
  unfolding extract_shorter_conflict_list_heur_st_alt_def PR_CONST_def
  unfolding delete_index_and_swap_update_def[symmetric] append_update_def[symmetric]
  apply (annot_snat_const TYPE(64))
  by sepref

sepref_register find_lit_of_max_level_wl
  extract_shorter_conflict_list_heur_st lit_of_hd_trail_st_heur propagate_bt_wl_D_heur
  propagate_unit_bt_wl_D_int
sepref_register backtrack_wl

lemma get_learned_count_learned_clss_countD2:
  get_learned_count S = (get_learned_count T) 
       learned_clss_count S  learned_clss_count T
  by (cases S; cases T) (auto simp: learned_clss_count_def)

lemma backtrack_wl_D_nlit_heurI:
  isasat_fast x 
       get_clauses_wl_heur xc = get_clauses_wl_heur x 
       get_learned_count xc = get_learned_count x  isasat_fast xc
  by (auto simp: isasat_fast_def dest: get_learned_count_learned_clss_countD2)

sepref_register save_phase_st
sepref_def backtrack_wl_D_fast_code
  is backtrack_wl_D_nlit_heur
  :: [isasat_fast]a isasat_bounded_assnd  isasat_bounded_assn
  supply [[goals_limit=1]]
    size_conflict_wl_def[simp] isasat_fast_length_leD[intro] backtrack_wl_D_nlit_heurI[intro]
    isasat_fast_countD[dest] IsaSAT_Setup.isasat_fast_length_leD[dest]
  unfolding backtrack_wl_D_nlit_heur_def PR_CONST_def
  unfolding delete_index_and_swap_update_def[symmetric] append_update_def[symmetric]
    append_ll_def[symmetric]
    size_conflict_wl_def[symmetric]
  apply (annot_snat_const TYPE(64))
  by sepref

(* TODO: Move *)
lemmas [llvm_inline] = add_lbd_def

experiment
begin
  export_llvm
     empty_conflict_and_extract_clause_heur_fast_code
     empty_cach_code
     update_heuristics_impl
     update_heuristics_impl
     isa_vmtf_flush_fast_code
     get_LBD_code
     mop_isa_length_trail_fast_code
     cons_trail_Propagated_tr_fast_code
     update_heuristics_impl
     append_and_length_fast_code
     update_lbd_impl
     reluctant_tick_impl
     propagate_bt_wl_D_fast_codeXX
end


end