Theory IsaSAT_Simplify_Binaries_LLVM

theory IsaSAT_Simplify_Binaries_LLVM
  imports
    IsaSAT_Simplify_Clause_Units_LLVM
    IsaSAT_Simplify_Binaries_Defs
    IsaSAT_Proofs_LLVM
begin


abbreviation ahm_assn :: _ where
  ahm_assn  larray64_assn (sint_assn' TYPE(64)) ×a al_assn' TYPE(64) (snat_assn' TYPE(64))

sepref_def ahm_create_code
  is ahm_create
  :: (snat_assn' TYPE(64))k a ahm_assn
  unfolding ahm_create_def larray_fold_custom_replicate al_fold_custom_empty[where 'l=64]
  apply (annot_sint_const TYPE(64))
  by sepref

definition encoded_irred_indices where
  encoded_irred_indices = {(a, b::nat × bool). a  int snat64_max  -a  int snat64_max  (snd b  a > 0)  fst b = (if a < 0 then nat (-a) else nat a)  fst b  0}

sepref_def ahm_is_marked_code
  is uncurry ahm_is_marked
  :: (ahm_assn)k *a unat_lit_assnk a bool1_assn
  unfolding ahm_is_marked_def
  apply (annot_sint_const TYPE(64))
  by sepref


sepref_def ahm_get_marked_code
  is uncurry ahm_get_marked
  :: (ahm_assn)k *a unat_lit_assnk a sint_assn' TYPE(64)
  unfolding ahm_get_marked_def
  by sepref

sepref_def ahm_empty_code
  is ahm_empty
  :: (ahm_assn)d a ahm_assn
  unfolding ahm_empty_def
  apply (annot_sint_const TYPE(64))
  apply (annot_snat_const TYPE(64))
  by sepref


definition encoded_irred_index_irred where
  encoded_irred_index_irred a = snd a

definition encoded_irred_index_irred_int where
  encoded_irred_index_irred_int a = (a > 0)

lemma encoded_irred_index_irred:
  (encoded_irred_index_irred_int, encoded_irred_index_irred)  encoded_irred_indices  bool_rel
  by (auto simp: encoded_irred_indices_def encoded_irred_index_irred_int_def
    encoded_irred_index_irred_def)

definition encoded_irred_index_get where
  encoded_irred_index_get a = fst a

definition encoded_irred_index_get_int where
  encoded_irred_index_get_int a = do {ASSERT (a  int snat64_max  -a  int snat64_max); RETURN (if a > 0 then nat a else nat (-a))}

lemma encoded_irred_index_get:
  (encoded_irred_index_get_int, RETURN o encoded_irred_index_get)  encoded_irred_indices  nat_relnres_rel
  by (auto simp: encoded_irred_indices_def encoded_irred_index_get_int_def
    encoded_irred_index_get_def intro!: nres_relI)

sepref_def encoded_irred_index_irred_int_impl
  is RETURN o encoded_irred_index_irred_int
  :: (sint_assn' TYPE(64))k a bool1_assn
  unfolding encoded_irred_index_irred_int_def
  apply (annot_sint_const TYPE(64))
  by sepref

lemma nat_sint_snat: 0  sint xi  (nat (sint xi) = snat xi)
  by (auto simp: snat_def)

lemma [sepref_fr_rules]:
  (Mreturn, RETURN o nat)  [λa. a  0]a (sint_assn' TYPE(64))k  sint64_nat_assn
  apply sepref_to_hoare
  apply vcg
  apply (auto simp: sint_rel_def ENTAILS_def snat_rel_def snat.rel_def br_def sint.rel_def
    pure_true_conv Exists_eq_simp snat_invar_def word_msb_sint nat_sint_snat)
  done
lemma [sepref_fr_rules]:
  (Mreturn o (λx. -x), RETURN o uminus)  [λa. a  int snat64_max  -a  int snat64_max]a (sint_assn' TYPE(64))k  (sint_assn' TYPE(64))
  apply sepref_to_hoare
  apply vcg
  subgoal for x xi asf s
    using sdiv_word_min'[of xi 1] sdiv_word_max'[of xi 1]
  apply (auto simp: sint_rel_def ENTAILS_def snat_rel_def snat.rel_def br_def sint.rel_def
    pure_true_conv Exists_eq_simp snat_invar_def word_msb_sint nat_sint_snat
    signed_arith_ineq_checks_to_eq word_size snat64_max_def word_size)
  apply (subst signed_arith_ineq_checks_to_eq[symmetric])
  apply (auto simp: word_size pure_true_conv)
  done
  done

lemma encoded_irred_index_get_int_alt_def:
  encoded_irred_index_get_int a = do {ASSERT (a  int snat64_max  -a  int snat64_max); RETURN (if a > 0 then nat a else nat (0-a))}
  unfolding encoded_irred_index_get_int_def by auto
sepref_def encoded_irred_index_irred_get_impl
  is encoded_irred_index_get_int
  :: (sint_assn' TYPE(64))k a sint64_nat_assn
  unfolding encoded_irred_index_get_int_alt_def
  apply (annot_sint_const TYPE(64))
  by sepref

lemmas [sepref_fr_rules] =
  encoded_irred_index_irred_get_impl.refine[FCOMP encoded_irred_index_get]
  encoded_irred_index_irred_int_impl.refine[FCOMP encoded_irred_index_irred]


definition encoded_irred_index_set where
  encoded_irred_index_set a b = (a::nat, b::bool)

definition encoded_irred_index_set_int where
  encoded_irred_index_set_int a b = do { (if b then RETURN (int a) else RETURN (- int a))}

lemma encoded_irred_index_set:
  (uncurry encoded_irred_index_set_int, uncurry (RETURN oo encoded_irred_index_set))  [λ(a,b). a  0  a  snat64_max]f nat_rel ×r bool_rel  encoded_irred_indicesnres_rel
  by (clarsimp simp: encoded_irred_indices_def encoded_irred_index_set_int_def
    encoded_irred_index_set_def  intro!: nres_relI frefI)


lemma int_snat_sint: ¬ sint xi < 0  int (snat (xi::64 word)) = sint xi
  by (auto simp: snat_def)

lemma [sepref_fr_rules]:
  (Mreturn, RETURN o int)  (snat_assn' TYPE(64))k a (sint_assn' TYPE(64))
  apply sepref_to_hoare
  apply vcg
  apply (auto simp: sint_rel_def ENTAILS_def snat_rel_def snat.rel_def br_def sint.rel_def
    pure_true_conv Exists_eq_simp snat_invar_def word_msb_sint nat_sint_snat int_snat_sint)
  done

sepref_register "uminus :: int  int" int
lemma encoded_irred_index_set_int_alt_def:
  encoded_irred_index_set_int a b = do { (if b then RETURN (int a) else RETURN (0 - int a))}
  by (auto simp: encoded_irred_index_set_int_def)

sepref_def encoded_irred_index_set_int_impl
  is uncurry encoded_irred_index_set_int
  :: sint64_nat_assnk *a bool1_assnk a (sint_assn' TYPE(64))
  unfolding encoded_irred_index_set_int_alt_def
  apply (annot_sint_const TYPE(64))
  by sepref

lemmas [sepref_fr_rules] =
  encoded_irred_index_set_int_impl.refine[FCOMP encoded_irred_index_set]

sepref_register is_marked set_marked update_marked

sepref_def ahm_set_marked_code
  is uncurry2 ahm_set_marked
  :: ahm_assnd *a unat_lit_assnk *a (sint_assn' TYPE(64))k a ahm_assn
  unfolding ahm_set_marked_def
  by sepref

sepref_def ahm_update_marked_code
  is uncurry2 ahm_update_marked
  :: ahm_assnd *a unat_lit_assnk *a (sint_assn' TYPE(64))k a ahm_assn
  unfolding ahm_update_marked_def
  by sepref

definition ahm_full_assn :: _ where
  ahm_full_assn =  hr_comp (larray64_assn (sint_assn' TYPE(64)) ×a Size_Ordering_it.arr_assn)
                 (array_hash_map_rel encoded_irred_indices)

schematic_goal ahm_full_assn_assn[sepref_frame_free_rules]: MK_FREE ahm_full_assn ?a
  unfolding ahm_full_assn_def by synthesize_free

lemma ahm_set_marked_set_marked:
 (uncurry2 ahm_set_marked, uncurry2 set_marked)
      (array_hash_map_rel encoded_irred_indices) ×f nat_lit_lit_rel ×f encoded_irred_indices  array_hash_map_rel encoded_irred_indicesnres_rel
proof -
  have H: (0, a)  encoded_irred_indices for a
    by (auto simp: encoded_irred_indices_def)
  show ?thesis
    unfolding fref_param1
    apply (rule ahm_set_marked_set_marked)
    apply (rule H)
    done
qed

lemma ahm_update_marked_update_marked:
 (uncurry2 ahm_update_marked, uncurry2 update_marked)
      (array_hash_map_rel encoded_irred_indices) ×f nat_lit_lit_rel ×f encoded_irred_indices  array_hash_map_rel encoded_irred_indicesnres_rel
proof -
  have H: (0, a)  encoded_irred_indices for a
    by (auto simp: encoded_irred_indices_def)
  show ?thesis
    unfolding fref_param1
    apply (rule ahm_update_marked_update_marked)
    apply (rule H)
    done
qed

thm 
  ahm_create_code.refine[FCOMP ahm_create_create[ where R= encoded_irred_indices]]

lemmas [unfolded ahm_full_assn_def[symmetric], sepref_fr_rules] =
  ahm_create_code.refine[FCOMP ahm_create_create[ where R= encoded_irred_indices]]
  ahm_empty_code.refine[FCOMP ahm_empty_empty[ where R = encoded_irred_indices]]
  ahm_is_marked_code.refine[FCOMP ahm_is_marked_is_marked[ where R = encoded_irred_indices]]
  ahm_get_marked_code.refine[FCOMP ahm_get_marked_get_marked[where R = encoded_irred_indices]]
  ahm_empty_code.refine[FCOMP ahm_empty_empty, where R19 = encoded_irred_indices]
  ahm_set_marked_code.refine[FCOMP ahm_set_marked_set_marked]
  ahm_update_marked_code.refine[FCOMP ahm_update_marked_update_marked]


sepref_register create encoded_irred_index_set encoded_irred_index_get
sepref_register uminus_lit:  "uminus :: nat literal  _"

lemma isa_clause_remove_duplicate_clause_wl_alt_def:
  isa_clause_remove_duplicate_clause_wl C S = (do{
    _  log_del_clause_heur S C;
    let (N', S) = extract_arena_wl_heur S;
    st  mop_arena_status N' C;
    let st = st = IRRED;
    ASSERT (mark_garbage_pre (N', C)  arena_is_valid_clause_vdom (N') C);
    let N' = extra_information_mark_to_delete (N') C;
    let (lcount, S) = extract_lcount_wl_heur S;
    ASSERT(¬st  clss_size_lcount lcount  1);
    let lcount = (if st then lcount else (clss_size_decr_lcount lcount));
    let (stats, S) = extract_stats_wl_heur S;
    let stats = incr_binary_red_removed (if st then decr_irred_clss stats else stats);
    let S = update_arena_wl_heur N' S;
    let S = update_lcount_wl_heur lcount S;
    let S = update_stats_wl_heur stats S;
    RETURN S
   })
    by (auto simp: isa_clause_remove_duplicate_clause_wl_def
        state_extractors split: isasat_int_splits)


sepref_def isa_clause_remove_duplicate_clause_wl_impl
  is uncurry isa_clause_remove_duplicate_clause_wl
  :: [λ(L, S). length (get_clauses_wl_heur S)  snat64_max  learned_clss_count S  unat64_max]a
  sint64_nat_assnk *a isasat_bounded_assnd  isasat_bounded_assn
  supply [[goals_limit=1]]
  unfolding isa_clause_remove_duplicate_clause_wl_alt_def
  by sepref

sepref_register isa_binary_clause_subres_wl

sepref_register incr_binary_unit_derived
lemma isa_binary_clause_subres_wl_alt_def:
  isa_binary_clause_subres_wl C L L' S0 = do {
      ASSERT (isa_binary_clause_subres_lits_wl_pre C L L' S0);
      let (M, S) = extract_trail_wl_heur S0;
      M  cons_trail_Propagated_tr L 0 M;
      let (lcount, S) = extract_lcount_wl_heur S;
      ASSERT (lcount = get_learned_count S0);
      let (N', S) = extract_arena_wl_heur S;
      st  mop_arena_status N' C;
      let st = st = IRRED;
      ASSERT (mark_garbage_pre (N', C)  arena_is_valid_clause_vdom (N') C);
      let N' = extra_information_mark_to_delete (N') C;
      ASSERT(¬st  (clss_size_lcount lcount  1  clss_size_lcountUEk (clss_size_decr_lcount lcount) < learned_clss_count S0));
      let lcount = (if st then lcount else (clss_size_incr_lcountUEk (clss_size_decr_lcount lcount)));
      let (stats, S) = extract_stats_wl_heur S;
      let stats = incr_binary_unit_derived (if st then decr_irred_clss stats else stats);
      let stats = incr_units_since_last_GC (incr_uset stats);
      let S = update_trail_wl_heur M S;
      let S = update_arena_wl_heur N' S;
      let S = update_lcount_wl_heur lcount S;
      let S = update_stats_wl_heur stats S;
      let _ = log_unit_clause L;
      RETURN S
  }
  by (subst Let_def[of log_unit_clause L])
   (simp add: isa_binary_clause_subres_wl_def learned_clss_count_def
        state_extractors split: isasat_int_splits)

sepref_def isa_binary_clause_subres_wl_impl
  is uncurry3 isa_binary_clause_subres_wl
  :: [λ(((C,L), L'), S). length (get_clauses_wl_heur S)  snat64_max  learned_clss_count S  unat64_max]a
  sint64_nat_assnk *a unat_lit_assnk *a unat_lit_assnk *a isasat_bounded_assnd  isasat_bounded_assn
  supply [[goals_limit=1]]
  unfolding isa_binary_clause_subres_wl_alt_def[abs_def]
  apply (annot_snat_const TYPE(64))
  by sepref

sepref_register should_eliminate_pure_st

sepref_def should_eliminate_pure_st_impl
  is RETURN o should_eliminate_pure_st
  :: isasat_bounded_assnk a bool1_assn
  unfolding should_eliminate_pure_st_def
  by sepref

sepref_def isa_deduplicate_binary_clauses_wl_code
  is uncurry2 isa_deduplicate_binary_clauses_wl
  :: [λ((L, CS), S). length (get_clauses_wl_heur S)  snat64_max  learned_clss_count S  unat64_max]a
  unat_lit_assnk *a ahm_full_assnd *a isasat_bounded_assnd 
  ahm_full_assn ×a isasat_bounded_assn
  supply [[goals_limit=1]]
  unfolding isa_deduplicate_binary_clauses_wl_def
    mop_polarity_st_heur_def[symmetric]
    mop_arena_status_st_def[symmetric]
    tri_bool_eq_def[symmetric]
    encoded_irred_index_set_def[symmetric]
    encoded_irred_index_irred_def[symmetric]
    encoded_irred_index_get_def[symmetric]
  apply (annot_snat_const TYPE(64))
  by sepref


sepref_register get_bump_heur_array_nth get_vmtf_heur_fst
  isa_deduplicate_binary_clauses_wl

lemma Massign_split: do{ x  (M :: _ nres); f x} = do{(a,b)  M; f (a,b)}
  by auto

lemma isa_mark_duplicated_binary_clauses_as_garbage_wl2_alt_def:
  isa_mark_duplicated_binary_clauses_as_garbage_wl2 S0 = (do {
     let ns = get_vmtf_heur_array S0;
    ASSERT (mark_duplicated_binary_clauses_as_garbage_pre_wl_heur S0);
    let skip = should_eliminate_pure_st  S0;
    CS  create (length (get_watched_wl_heur S0));
    (_, CS, S)  WHILETλ(n,CS, S). get_vmtf_heur_array S0 = (get_vmtf_heur_array S)(λ(n, CS, S). n  None  get_conflict_wl_is_None_heur S)
      (λ(n, CS, S). do {
        ASSERT (n  None);
        let A = the n;
        ASSERT (A < length (get_vmtf_heur_array S));
        ASSERT (A  unat32_max div 2);
        added  mop_is_marked_added_heur_st S A;
        if ¬skip then RETURN (get_next (get_vmtf_heur_array S ! A), CS, S)
        else do {
          ASSERT (length (get_clauses_wl_heur S)  length (get_clauses_wl_heur S0)  learned_clss_count S  learned_clss_count S0);
          (CS, S)  isa_deduplicate_binary_clauses_wl (Pos A) CS S;
          ASSERT (length (get_clauses_wl_heur S)  length (get_clauses_wl_heur S0)  learned_clss_count S  learned_clss_count S0);
          (CS, S)  isa_deduplicate_binary_clauses_wl (Neg A) CS S;
          ASSERT (ns = get_vmtf_heur_array S);
         RETURN (get_next (get_vmtf_heur_array S ! A), CS, S)
       }
     })
     (Some (get_vmtf_heur_fst S0), CS, S0);
    RETURN S
  })
    unfolding isa_mark_duplicated_binary_clauses_as_garbage_wl2_def bind_to_let_conv
      nres_monad3
   apply (simp add: case_prod_beta cong: if_cong)
   unfolding bind_to_let_conv Let_def prod.simps
   apply (subst Massign_split[of isa_deduplicate_binary_clauses_wl (Pos _) _ _])
   unfolding prod.simps nres_monad3
   apply (subst (2) Massign_split[of _ :: (_ × isasat) nres])
   unfolding prod.simps nres_monad3
   apply (auto intro!: bind_cong[OF refl] cong: if_cong)
   done

sepref_def isa_deduplicate_binary_clauses_code
  is isa_mark_duplicated_binary_clauses_as_garbage_wl2
  :: [λS. length (get_clauses_wl_heur S)  snat64_max  learned_clss_count S  unat64_max]a
     isasat_bounded_assnd  isasat_bounded_assn
  unfolding isa_mark_duplicated_binary_clauses_as_garbage_wl2_alt_def
    get_bump_heur_array_nth_def[symmetric] atom.fold_option nres_monad3
    length_watchlist_def[unfolded length_ll_def, symmetric]
    length_watchlist_raw_def[symmetric]
  apply (rewrite at let _ = get_vmtf_heur_array _ in _ Let_def)
  unfolding if_False nres_monad3
  supply [[goals_limit=1]]
  by sepref

end