Theory IsaSAT_Inner_Propagation_LLVM

theory IsaSAT_Inner_Propagation_LLVM
  imports IsaSAT_Setup_LLVM
    IsaSAT_Inner_Propagation_Defs
    IsaSAT_VMTF_LLVM
    IsaSAT_LBD_LLVM
begin
hide_const (open) NEMonad.ASSERT NEMonad.RETURN
sepref_register isa_save_pos unit_propagation_update_statistics

lemma unit_propagation_update_statistics_alt_def:
  unit_propagation_update_statistics p q S = do {
  let (stats, S) = extract_stats_wl_heur S;
  let (M, S) = extract_trail_wl_heur S;
  let pq = q - p;
  let stats = incr_propagation_by pq stats;
  let stats = (if get_conflict_wl_is_None_heur S then stats else incr_conflict stats);
  let stats = (if count_decided_pol M = 0 then incr_units_since_last_GC_by pq (incr_uset_by pq stats) else stats);
  height  (if get_conflict_wl_is_None_heur S then RETURN q else do {j  trail_height_before_conflict M; RETURN (of_nat j)});
  let stats = set_no_conflict_until q stats;
  RETURN (update_stats_wl_heur stats (update_trail_wl_heur M S))
  }
  by (auto simp: unit_propagation_update_statistics_def state_extractors Let_def get_conflict_wl_is_None_heur_def
    split: isasat_int_splits intro!: ext)

sepref_def unit_propagation_update_statistics_impl
  is uncurry2 (unit_propagation_update_statistics)
  :: word64_assnk *a word64_assnk *a isasat_bounded_assnd a isasat_bounded_assn
  supply [[goals_limit=1]] of_nat_unat[sepref_import_param]
  unfolding unit_propagation_update_statistics_alt_def
  apply (annot_unat_const TYPE (32))
  apply (rewrite  at RETURN (word_of_nat ) annot_unat_unat_upcast[where 'l=64])
  by sepref

lemma isa_save_pos_alt_def:
  isa_save_pos C i = (λS0. do {
      ASSERT(arena_is_valid_clause_idx (get_clauses_wl_heur S0) C);
      if arena_length (get_clauses_wl_heur S0) C > MAX_LENGTH_SHORT_CLAUSE then do {
        let (N, S) = extract_arena_wl_heur S0;
        ASSERT (N = get_clauses_wl_heur S0);
        ASSERT(isa_update_pos_pre ((C, i), N));
        let N = arena_update_pos C i N;
        RETURN (update_arena_wl_heur N S)
      } else RETURN S0
    })
  
  by (auto simp: isa_save_pos_def state_extractors
    split: isasat_int_splits intro!: ext)

sepref_def isa_save_pos_fast_code
  is uncurry2 isa_save_pos
  :: sint64_nat_assnk *a sint64_nat_assnk *a isasat_bounded_assnd a isasat_bounded_assn
  supply
    [[goals_limit=1]]
    if_splits[split]
  unfolding isa_save_pos_alt_def PR_CONST_def access_length_heur_def[symmetric]
  by sepref


(* TODO most of the unfolding should move to the definition *)
sepref_register isa_find_unwatched_wl_st_heur isa_find_unwatched_between isa_find_unset_lit
  polarity_pol

(*TODO dup*)
sepref_register 0 1

(*lemma ‹found = None ⟷ is_None (ASSN_ANNOT (snat_option_assn' TYPE(64)) found)›*)

sepref_def isa_find_unwatched_between_fast_code
  is uncurry4 isa_find_unset_lit
  :: [λ((((M, N), _), _), _). length N  snat64_max]a
     trail_pol_fast_assnk *a arena_fast_assnk *a sint64_nat_assnk *a sint64_nat_assnk *a sint64_nat_assnk 
       snat_option_assn' TYPE(64)
  supply [[goals_limit = 3]]
  unfolding isa_find_unset_lit_def isa_find_unwatched_between_def SET_FALSE_def[symmetric]
    PR_CONST_def
  apply (rewrite in if  then _ else _ tri_bool_eq_def[symmetric])
  apply (annot_snat_const TYPE (64))
  by sepref

definition isa_find_unset_lit_st where
  isa_find_unset_lit_st S = isa_find_unset_lit (get_trail_wl_heur S) (get_clauses_wl_heur S)

definition isasat_find_unset_lit_st_impl :: twl_st_wll_trail_fast2  _ where
  isasat_find_unset_lit_st_impl = (λN C D E.
     read_all_st_code
      (λM N _ _ _ _ _ _ _ _ _ _ _ _ _ _ _. isa_find_unwatched_between_fast_code M N C D E) N)

global_interpretation find_unset_lit: read_trail_arena_param_adder2_threeargs where
  R = snat_rel' (TYPE(64)) and
  R' = snat_rel' (TYPE(64)) and
  R'' = snat_rel' (TYPE(64)) and
  f = λC C' C'' M N. isa_find_unwatched_between_fast_code M N C C' C'' and
  f' = λC C' C'' M N. isa_find_unset_lit M N C C' C'' and
  x_assn = snat_option_assn' TYPE(64) and
  P = λC C' C'' M N. length N  snat64_max
  rewrites
  (λN C D E.
  read_all_st (λM N _ _ _ _ _ _ _ _ _ _ _ _ _ _ _. isa_find_unset_lit M N C D E) N) = isa_find_unset_lit_st and
  (λN C D E.
     read_all_st_code
      (λM N _ _ _ _ _ _ _ _ _ _ _ _ _ _ _. isa_find_unwatched_between_fast_code M N C D E) N) = isasat_find_unset_lit_st_impl
  apply (unfold_locales)
  apply (subst (9) uncurry_def)+
  apply (rule isa_find_unwatched_between_fast_code.refine)
  subgoal by (auto simp: read_all_st_def isa_find_unset_lit_st_def intro!: ext split: isasat_int_splits)
  subgoal by (auto simp: isasat_find_unset_lit_st_impl_def)
  done

lemmas [sepref_fr_rules] = find_unset_lit.refine
lemmas [unfolded inline_direct_return_node_case, llvm_code] =
  isasat_find_unset_lit_st_impl_def[unfolded read_all_st_code_def]

sepref_def swap_lits_impl is uncurry3 mop_arena_swap
  :: sint64_nat_assnk *a sint64_nat_assnk *a sint64_nat_assnk *a arena_fast_assnd a arena_fast_assn
  unfolding mop_arena_swap_def swap_lits_pre_def
  unfolding gen_swap
  by sepref

sepref_register isa_find_unset_lit_st

lemma case_tri_bool_If:
  (case a of
       None  f1
     | Some v 
        (if v then f2 else f3)) =
   (let b = a in if b = UNSET
    then f1
    else if b = SET_TRUE then f2 else f3)
  by (auto split: option.splits)

sepref_def find_unwatched_wl_st_heur_fast_code
  is uncurry isa_find_unwatched_wl_st_heur
  :: [λ(S, C). length (get_clauses_wl_heur S)  snat64_max]a
         isasat_bounded_assnk *a sint64_nat_assnk  snat_option_assn' TYPE(64)
  supply [[goals_limit = 1]] isasat_fast_def[simp]
  unfolding isa_find_unwatched_wl_st_heur_def PR_CONST_def
    fmap_rll_def[symmetric]
    length_uint32_nat_def[symmetric] isa_find_unwatched_def
    case_tri_bool_If
    fmap_rll_u64_def[symmetric]
    mop_arena_length_st_def[symmetric]
    mop_arena_pos_st_def[symmetric]
  apply (subst isa_find_unset_lit_def[symmetric])+
  apply (subst isa_find_unset_lit_st_def[symmetric])+
  apply (annot_snat_const TYPE (64))
  by sepref

lemma other_watched_wl_heur_alt_def:
  other_watched_wl_heur = (λS. arena_other_watched (get_clauses_wl_heur S))
  apply (intro ext)
  unfolding other_watched_wl_heur_def
    arena_other_watched_def
    mop_access_lit_in_clauses_heur_def
  by auto argo

definition clause_not_marked_to_delete_heur_code :: twl_st_wll_trail_fast2  _  _ where
  clause_not_marked_to_delete_heur_code S C' = read_arena_wl_heur_code (λN. not_deleted_code N C') S
(*mop_arena_lit2 ⟷ mop_access_lit_in_clauses_heur*)

sepref_def other_watched_wl_heur_impl
  is uncurry3 other_watched_wl_heur
  :: isasat_bounded_assnk *a unat_lit_assnk *a sint64_nat_assnk *a sint64_nat_assnk a
    unat_lit_assn
  supply [[goals_limit=1]]
  unfolding other_watched_wl_heur_alt_def
    arena_other_watched_def
    mop_access_lit_in_clauses_heur_def[symmetric]
  apply (annot_snat_const TYPE (64))
  by sepref

sepref_register update_clause_wl_heur
setup map_theory_claset (fn ctxt => ctxt delSWrapper "split_all_tac")

lemma arena_lit_pre_le2: 
       arena_lit_pre a i  length a  snat64_max  i < max_snat 64
   using arena_lifting(7)[of a _ _] unfolding arena_lit_pre_def arena_is_valid_clause_idx_and_access_def snat64_max_def max_snat_def
  by fastforce

lemma snat64_max_le_max_snat64: a < snat64_max  Suc a < max_snat 64
  by (auto simp: max_snat_def snat64_max_def)

lemma update_clause_wl_heur_alt_def:
  update_clause_wl_heur = (λ(L::nat literal) L' C b j w i f S0. do {
     let (N, S) = extract_arena_wl_heur S0;
     ASSERT (N = get_clauses_wl_heur S0);
     let (W, S) = extract_watchlist_wl_heur S;
     ASSERT (W = get_watched_wl_heur S0);
     K'  mop_arena_lit2' (set (get_vdom S)) N C f;
     ASSERT(w < length N);
     N'  mop_arena_swap C i f N;
     ASSERT(nat_of_lit K' < length W);
     ASSERT(length (W ! (nat_of_lit K')) < length N);
     let W = W[nat_of_lit K':= W ! (nat_of_lit K') @ [(C, L, b)]];
     let S = update_arena_wl_heur N' S;
     let S = update_watchlist_wl_heur W S;
     RETURN (j, w+1, S)
   })
   by (auto intro!: ext simp: state_extractors update_clause_wl_heur_def
         split: isasat_int_splits)

sepref_def update_clause_wl_fast_code
  is uncurry8 update_clause_wl_heur
  :: [λ((((((((L, L'), C), b), j), w), i), f), S). length (get_clauses_wl_heur S)  snat64_max]a
     unat_lit_assnk *a unat_lit_assnk *a sint64_nat_assnk *a bool1_assnk *a sint64_nat_assnk *a sint64_nat_assnk *a sint64_nat_assnk *a
       sint64_nat_assnk
        *a isasat_bounded_assnd  sint64_nat_assn ×a sint64_nat_assn ×a isasat_bounded_assn
  supply [[goals_limit=1]]  arena_lit_pre_le2[intro] swap_lits_pre_def[simp]
    snat64_max_le_max_snat64[intro]
  unfolding update_clause_wl_heur_alt_def
    fmap_rll_def[symmetric] delete_index_and_swap_update_def[symmetric]
    delete_index_and_swap_ll_def[symmetric] fmap_swap_ll_def[symmetric]
    append_ll_def[symmetric] update_clause_wl_code_pre_def
    fmap_rll_u64_def[symmetric]
    fmap_swap_ll_u64_def[symmetric]
    fmap_swap_ll_def[symmetric]
    PR_CONST_def mop_arena_lit2'_def
  apply (annot_snat_const TYPE (64))
  by sepref

sepref_register mop_arena_swap

definition propagate_lit_wl_heur_inner :: _ where
  propagate_lit_wl_heur_inner L' C i =  (λM N D j W ivmtf icount ccach lbd outl stats heur aivdom clss opts arena occs. do {
      ASSERT(i  1);
      M  cons_trail_Propagated_tr L' C M;
      N'  mop_arena_swap C 0 (1 - i) N;
      heur  mop_save_phase_heur (atm_of L') (is_pos L') heur;
      RETURN (Tuple17 M N' D j W ivmtf icount ccach lbd outl stats heur aivdom clss opts arena occs)
  })

lemma propagate_lit_wl_heur_propagate_lit_wl_heur_inner:
  propagate_lit_wl_heur = (λL' C i (S0::isasat).
  case_isasat_int (propagate_lit_wl_heur_inner L' C i)
   S0)
  by (auto intro!: ext simp: state_extractors propagate_lit_wl_heur_def read_all_st_def
    propagate_lit_wl_heur_inner_def
    split: isasat_int_splits)

sepref_def propagate_lit_wl_fast_code
  is uncurry3 propagate_lit_wl_heur
  :: [λ(((L, C), i), S). length (get_clauses_wl_heur S)  snat64_max]a
      unat_lit_assnk *a sint64_nat_assnk *a sint64_nat_assnk *a isasat_bounded_assnd  isasat_bounded_assn
  unfolding PR_CONST_def propagate_lit_wl_heur_propagate_lit_wl_heur_inner
    RETURN_case_tuple16_invers comp_def propagate_lit_wl_heur_inner_def
  unfolding
    fmap_swap_ll_def[symmetric]
    fmap_swap_ll_u64_def[symmetric]
    save_phase_def
  apply (annot_snat_const TYPE (64))
  supply [[goals_limit=1]]
  by sepref
lemmas [llvm_inline] = Mreturn_comp_IsaSAT_int


sepref_register incr_uset incr_units_since_last_GC


lemma propagate_lit_wl_bin_heur_alt2:
  propagate_lit_wl_bin_heur = (λL' C (S0::isasat).
  case_isasat_int (λM N D j W ivmtf icount ccach lbd outl stats heur aivdom clss opts arena occs. do {
      M  cons_trail_Propagated_tr L' C M;
      heur  mop_save_phase_heur (atm_of L') (is_pos L') heur;
      RETURN (Tuple17 M N D j W ivmtf icount ccach lbd outl stats heur aivdom clss opts arena occs)
  })
   S0)
  by (auto intro!: ext simp: state_extractors propagate_lit_wl_bin_heur_def read_all_st_def
    propagate_lit_wl_heur_inner_def
    split: isasat_int_splits)


lemma propagate_lit_wl_bin_heur_alt_def:
  propagate_lit_wl_bin_heur = (λL' C S0. do {
      let (M, S) = extract_trail_wl_heur S0;
      ASSERT (M = get_trail_wl_heur S0);
      let (heur, S) = extract_heur_wl_heur S;
      ASSERT (heur = get_heur S0);
      M  cons_trail_Propagated_tr L' C M;
      heur  mop_save_phase_heur (atm_of L') (is_pos L') heur;
      let S = update_trail_wl_heur M S;
      let S = update_heur_wl_heur heur S;
      RETURN S
  })
  by (auto intro!: ext simp: state_extractors propagate_lit_wl_bin_heur_def
      split: isasat_int_splits)

sepref_def propagate_lit_wl_bin_fast_code
  is uncurry2 propagate_lit_wl_bin_heur
  :: [λ((L, C), S). length (get_clauses_wl_heur S)  snat64_max]a
      unat_lit_assnk *a sint64_nat_assnk *a isasat_bounded_assnd 
      isasat_bounded_assn
  unfolding PR_CONST_def propagate_lit_wl_bin_heur_alt2
    RETURN_case_tuple16_invers comp_def
  supply [[goals_limit=1]]  length_ll_def[simp]
  by sepref

lemma update_blit_wl_heur_alt_def:
  update_blit_wl_heur = (λ(L::nat literal) C b j w K S0. do {
     let (W, S) = extract_watchlist_wl_heur S0;
     ASSERT (W = get_watched_wl_heur S0);
     ASSERT(nat_of_lit L < length W);
     ASSERT(j < length (W ! nat_of_lit L));
     ASSERT(j < length (get_clauses_wl_heur S0));
     ASSERT(w < length (get_clauses_wl_heur S0));
     let W = W[nat_of_lit L := (W!nat_of_lit L)[j:= (C, K, b)]];
     RETURN (j+1, w+1, update_watchlist_wl_heur W S)
  })
  by (auto intro!: ext simp: state_extractors update_blit_wl_heur_def
    split: isasat_int_splits)

sepref_def update_blit_wl_heur_fast_code
  is uncurry6 update_blit_wl_heur
  :: [λ((((((_, _), _), _), C), i), S). length (get_clauses_wl_heur S)  snat64_max]a
      unat_lit_assnk *a sint64_nat_assnk *a bool1_assnk *a sint64_nat_assnk *a
      sint64_nat_assnk *a unat_lit_assnk *a isasat_bounded_assnd 
     sint64_nat_assn ×a sint64_nat_assn ×a isasat_bounded_assn
  supply [[goals_limit=1]] snat64_max_le_max_snat64[intro]
  unfolding update_blit_wl_heur_alt_def append_ll_def[symmetric]
    op_list_list_upd_alt_def[symmetric]
  apply (annot_snat_const TYPE (64))
  by sepref


sepref_register keep_watch_heur

lemma op_list_list_take_alt_def: op_list_list_take xss i l = xss[i := take l (xss ! i)]
  unfolding op_list_list_take_def by auto

lemma keep_watch_heur_alt_def:
  keep_watch_heur = (λL i j S. do {
     let (W, S) = extract_watchlist_wl_heur S;
     ASSERT(nat_of_lit L < length W);
     ASSERT(i < length (W ! nat_of_lit L));
     ASSERT(j < length (W ! nat_of_lit L));
     let W =  W[nat_of_lit L := (W!(nat_of_lit L))[i := W ! (nat_of_lit L) ! j]];
     RETURN (update_watchlist_wl_heur W S)
   })
  by (auto intro!: ext simp: state_extractors keep_watch_heur_def
    split: isasat_int_splits)

sepref_def keep_watch_heur_fast_code
  is uncurry3 keep_watch_heur
  :: unat_lit_assnk *a sint64_nat_assnk *a sint64_nat_assnk *a isasat_bounded_assnd a isasat_bounded_assn
  supply
    [[goals_limit=1]]
  unfolding keep_watch_heur_alt_def PR_CONST_def
  unfolding fmap_rll_def[symmetric]
  unfolding
    op_list_list_upd_alt_def[symmetric]
    nth_rll_def[symmetric]
    SET_FALSE_def[symmetric] SET_TRUE_def[symmetric]
  by sepref


sepref_register unit_propagation_inner_loop_body_wl_heur

sepref_register isa_set_lookup_conflict_aa set_conflict_wl_heur mark_conflict_to_rescore

lemma mark_conflict_to_rescore_alt_def:
  mark_conflict_to_rescore C  S0 = do {
    let (M, S) = extract_trail_wl_heur S0;
    let (N, S) = extract_arena_wl_heur S;
    let (vm, S) = extract_vmtf_wl_heur S;
    n  mop_arena_length N C;
    ASSERT (n  length (get_clauses_wl_heur S0));
    (_, vm)  WHILET (λ(i, vm). i < n)
      (λ(i, vm). do{
       ASSERT (i < n);
       L  mop_arena_lit2 N C i;
       vm  isa_vmtf_bump_to_rescore_also_reasons_cl M N C (-L) vm;
      RETURN (i+1, vm)
     })
      (0, vm);
    let (lbd, S) = extract_lbd_wl_heur S;
    (N, lbd)  calculate_LBD_heur_st M N lbd C;
    let S = update_trail_wl_heur M S;
    let S = update_arena_wl_heur N S;
    let S = update_vmtf_wl_heur vm S;
    let S = update_lbd_wl_heur lbd S;
    RETURN S }
  by (auto intro!: ext simp: state_extractors mark_conflict_to_rescore_def Let_def
    split: isasat_int_splits)

(*TODO Move*)
sepref_register isa_vmtf_bump_to_rescore_also_reasons_cl
sepref_def mark_conflict_to_rescore_impl
  is uncurry mark_conflict_to_rescore
  :: sint64_nat_assnk *a isasat_bounded_assnd a isasat_bounded_assn
  unfolding mark_conflict_to_rescore_alt_def
  apply (annot_snat_const TYPE (64))
  by sepref

lemma set_conflict_wl_heur_alt_def:
  set_conflict_wl_heur = (λC S0. do {
    let n = 0;
    let (M, S) = extract_trail_wl_heur S0;
    let (N, S) = extract_arena_wl_heur S;
    let (D, S) = extract_conflict_wl_heur S;
    let (outl, S) = extract_outl_wl_heur S;
    ASSERT(curry5 isa_set_lookup_conflict_aa_pre M N C D n outl);
    (D, clvls, outl)  isa_set_lookup_conflict_aa M N C D n outl;
    j  mop_isa_length_trail M;
    let S = update_conflict_wl_heur D S;
    let S = update_outl_wl_heur outl S;
    let S = update_clvls_wl_heur clvls S;
    let S = update_literals_to_update_wl_heur j S;
    let S = update_trail_wl_heur M S;
    let S = update_arena_wl_heur N S;
    RETURN S})
    by (auto intro!: ext bind_cong[OF refl] simp: state_extractors set_conflict_wl_heur_def Let_def
    split: isasat_int_splits)

sepref_def set_conflict_wl_heur_fast_code
  is uncurry set_conflict_wl_heur
  :: [λ(C, S). length (get_clauses_wl_heur S)  snat64_max]a
    sint64_nat_assnk *a isasat_bounded_assnd  isasat_bounded_assn
  supply [[goals_limit=1]]
  unfolding set_conflict_wl_heur_alt_def
  apply (annot_unat_const TYPE (32))
  by sepref



sepref_register update_blit_wl_heur clause_not_marked_to_delete_heur

lemma unit_propagation_inner_loop_wl_loop_D_heur_inv0D:
  unit_propagation_inner_loop_wl_loop_D_heur_inv0 L (j, w, S0) 
    j  length (get_clauses_wl_heur S0) - MIN_HEADER_SIZE 
    w  length (get_clauses_wl_heur S0) - MIN_HEADER_SIZE
  unfolding unit_propagation_inner_loop_wl_loop_D_heur_inv0_def prod.case
    unit_propagation_inner_loop_wl_loop_inv_def unit_propagation_inner_loop_l_inv_def
  apply normalize_goal+
  by (simp only: twl_st_l twl_st twl_st_wl
     all_all_atms_all_lits) linarith


sepref_def pos_of_watched_heur_impl
  is uncurry2 pos_of_watched_heur
  :: isasat_bounded_assnk *a sint64_nat_assnk *a unat_lit_assnk a sint64_nat_assn
  supply [[goals_limit=1]]
  unfolding pos_of_watched_heur_def
  apply (annot_snat_const TYPE (64))
  by sepref

sepref_def unit_propagation_inner_loop_body_wl_fast_heur_code
  is uncurry3 unit_propagation_inner_loop_body_wl_heur
  :: [λ((L, w), S). length (get_clauses_wl_heur S)  snat64_max]a
      unat_lit_assnk *a sint64_nat_assnk  *a sint64_nat_assnk *a isasat_bounded_assnd 
       sint64_nat_assn ×a sint64_nat_assn ×a isasat_bounded_assn
  supply [[goals_limit=1]]
    if_splits[split] snat64_max_le_max_snat64[intro] unit_propagation_inner_loop_wl_loop_D_heur_inv0D[dest!]
  unfolding unit_propagation_inner_loop_body_wl_heur_def PR_CONST_def
  unfolding fmap_rll_def[symmetric]
  unfolding option.case_eq_if is_None_alt[symmetric]
    SET_FALSE_def[symmetric] SET_TRUE_def[symmetric] tri_bool_eq_def[symmetric]
  apply (annot_snat_const TYPE (64))
  by sepref

lemmas [llvm_inline] =
  other_watched_wl_heur_impl_def
  pos_of_watched_heur_impl_def
  propagate_lit_wl_heur_def
  keep_watch_heur_fast_code_def
  nat_of_lit_rel_impl_def

experiment begin

export_llvm
  isa_save_pos_fast_code
  watched_by_app_heur_fast_code
  isa_find_unwatched_between_fast_code
  find_unwatched_wl_st_heur_fast_code
  update_clause_wl_fast_code
  propagate_lit_wl_fast_code
  propagate_lit_wl_bin_fast_code
  status_neq_impl
  update_blit_wl_heur_fast_code
  keep_watch_heur_fast_code
  set_conflict_wl_heur_fast_code
  unit_propagation_inner_loop_body_wl_fast_heur_code

end

end