Theory IsaSAT_Literals_LLVM

theory IsaSAT_Literals_LLVM
  imports WB_More_Word IsaSAT_Literals Watched_Literals.WB_More_IICF_LLVM
    More_Sepref.WB_More_Sepref_LLVM
begin
(*TODO: should move out to More_Sepref.WB_More_Sepref_LLVM*)

hide_const (open) NEMonad.RETURN

lemma aal_assn_boundsD':
  assumes A: rdomp (aal_assn' TYPE('l::len2) TYPE('ll::len2) A) xss and i < length xss
  shows length (xss ! i) < max_snat LENGTH('ll)
  using aal_assn_boundsD_aux1[OF A] assms
  by auto
(**)
abbreviation word32_rel  word_rel :: (32 word × _) set
abbreviation word64_rel  word_rel :: (64 word × _) set
abbreviation word32_assn  word_assn :: 32 word  _
abbreviation word64_assn  word_assn :: 64 word  _

abbreviation snat64_assn :: nat  64 word  _ where snat64_assn  snat_assn
abbreviation snat32_assn :: nat  32 word  _ where snat32_assn  snat_assn
abbreviation unat64_assn :: nat  64 word  _ where unat64_assn  unat_assn

abbreviation unat32_assn :: nat  32 word  _ where unat32_assn  unat_assn

(* TODO: Move
  TODO:  Write generic postprocessing for that!
  Maybe just beta contraction of form (λx. f x)$x = f$x
*)
lemma RETURN_comp_5_10_hnr_post[to_hnr_post]:
  (RETURN ooooo f5)$a$b$c$d$e = RETURN$(f5$a$b$c$d$e)
  (RETURN oooooo f6)$a$b$c$d$e$f = RETURN$(f6$a$b$c$d$e$f)
  (RETURN ooooooo f7)$a$b$c$d$e$f$g = RETURN$(f7$a$b$c$d$e$f$g)
  (RETURN oooooooo f8)$a$b$c$d$e$f$g$h = RETURN$(f8$a$b$c$d$e$f$g$h)
  (RETURN ooooooooo f9)$a$b$c$d$e$f$g$h$i = RETURN$(f9$a$b$c$d$e$f$g$h$i)
  (RETURN oooooooooo f10)$a$b$c$d$e$f$g$h$i$j = RETURN$(f10$a$b$c$d$e$f$g$h$i$j)
  (RETURN o11 f11)$a$b$c$d$e$f$g$h$i$j$k = RETURN$(f11$a$b$c$d$e$f$g$h$i$j$k)
  (RETURN o12 f12)$a$b$c$d$e$f$g$h$i$j$k$l = RETURN$(f12$a$b$c$d$e$f$g$h$i$j$k$l)
  (RETURN o13 f13)$a$b$c$d$e$f$g$h$i$j$k$l$m = RETURN$(f13$a$b$c$d$e$f$g$h$i$j$k$l$m)
  (RETURN o14 f14)$a$b$c$d$e$f$g$h$i$j$k$l$m$n = RETURN$(f14$a$b$c$d$e$f$g$h$i$j$k$l$m$n)
  (RETURN o15 f15)$a$b$c$d$e$f$g$h$i$j$k$l$m$n$xo = RETURN$(f15$a$b$c$d$e$f$g$h$i$j$k$l$m$n$xo)
  (RETURN o16 f16)$a$b$c$d$e$f$g$h$i$j$k$l$m$n$xo$p = RETURN$(f16$a$b$c$d$e$f$g$h$i$j$k$l$m$n$xo$p)
  (RETURN o17 f17)$a$b$c$d$e$f$g$h$i$j$k$l$m$n$xo$p$q = RETURN$(f17$a$b$c$d$e$f$g$h$i$j$k$l$m$n$xo$p$q)
  (RETURN o18 f18)$a$b$c$d$e$f$g$h$i$j$k$l$m$n$xo$p$q$r = RETURN$(f18$a$b$c$d$e$f$g$h$i$j$k$l$m$n$xo$p$q$r)
  (RETURN o19 f19)$a$b$c$d$e$f$g$h$i$j$k$l$m$n$xo$p$q$r$s = RETURN$(f19$a$b$c$d$e$f$g$h$i$j$k$l$m$n$xo$p$q$r$s)
  by simp_all

method synthesize_free =
  (rule free_thms sepref_frame_free_rules)+

(*  TODO/FIXME: Ad-hoc optimizations for large tuples *)
definition [simp,llvm_inline]: case_prod_open  case_prod
lemmas fold_case_prod_open = case_prod_open_def[symmetric]

lemma case_prod_open_arity[sepref_monadify_arity]:
  case_prod_open  λ2fp p. SP case_prod_open$(λ2a b. fp$a$b)$p
  by (simp_all only: SP_def APP_def PROTECT2_def RCALL_def)

lemma case_prod_open_comb[sepref_monadify_comb]:
  fp p. case_prod_open$fp$p  Refine_Basic.bind$(EVAL$p)$(λ2p. (SP case_prod_open$fp$p))
  by (simp_all)

lemma case_prod_open_plain_comb[sepref_monadify_comb]:
  "EVAL$(case_prod_open$(λ2a b. fp a b)$p) 
    Refine_Basic.bind$(EVAL$p)$(λ2p. case_prod_open$(λ2a b. EVAL$(fp a b))$p)"
  apply (rule eq_reflection, simp split: list.split prod.split option.split)+
  done

lemma hn_case_prod_open'[sepref_comb_rules]:
  assumes FR: Γ  hn_ctxt (prod_assn P1 P2) p' p ** Γ1
  assumes Pair: "a1 a2 a1' a2'. p'=(a1',a2')
     hn_refine (hn_ctxt P1 a1' a1 ∧* hn_ctxt P2 a2' a2 ∧* Γ1) (f a1 a2)
          (Γ2 a1 a2 a1' a2') R (CP a1 a2) (f' a1' a2')"
  assumes FR2: a1 a2 a1' a2'. Γ2 a1 a2 a1' a2'  hn_ctxt P1' a1' a1 ** hn_ctxt P2' a2' a2 ** Γ1'
  shows hn_refine Γ (case_prod_open f p) (hn_ctxt (prod_assn P1' P2') p' p ** Γ1')
                   R (CP_SPLIT CP p) (case_prod_open$(λ2a b. f' a b)$p') (is ?G Γ)
  unfolding case_prod_open_def
  unfolding autoref_tag_defs PROTECT2_def
  apply1 (rule hn_refine_cons_pre[OF FR])
  apply1 (cases p; cases p'; simp add: prod_assn_pair_conv[THEN prod_assn_ctxt])
  unfolding CP_SPLIT_def prod.simps
  apply (rule hn_refine_cons[OF _ Pair _ entails_refl])
  applyS (simp add: hn_ctxt_def)
  applyS simp using FR2
  by (simp add: hn_ctxt_def)


lemma ho_prod_open_move[sepref_preproc]: case_prod_open (λa b x. f x a b) = (λp x. case_prod_open (f x) p)
  by (auto)

lemma op_list_list_upd_alt_def: op_list_list_upd xss i j x = xss[i := (xss ! i)[j := x]]
  unfolding op_list_list_upd_def by auto

definition tuple4 a b c d  (a,b,c,d)
definition tuple7 a b c d e f g  tuple4 a b c (tuple4 d e f g)
definition tuple13 a b c d e f g h i j k l m  (tuple7 a b c d e f (tuple7 g h i j k l m))

lemmas fold_tuples = tuple4_def[symmetric] tuple7_def[symmetric] tuple13_def[symmetric]

sepref_register tuple4 tuple7 tuple13

sepref_def tuple4_impl [llvm_inline] is uncurry3 (RETURN oooo tuple4) ::
  A1d *a A2d *a A3d *a A4d a A1 ×a A2 ×a A3 ×a A4
  unfolding tuple4_def by sepref

sepref_def tuple7_impl [llvm_inline] is uncurry6 (RETURN ooooooo tuple7) ::
  A1d *a A2d *a A3d *a A4d *a A5d *a A6d *a A7d a A1 ×a A2 ×a A3 ×a A4 ×a A5 ×a A6 ×a A7
  unfolding tuple7_def by sepref

sepref_def tuple13_impl [llvm_inline] is uncurry12 (RETURN o13 tuple13) ::
  "A1d *a A2d *a A3d *a A4d *a A5d *a A6d *a A7d *a A8d *a A9d *a A10d *a A11d *a A12d *a A13d
  a A1 ×a A2 ×a A3 ×a A4 ×a A5 ×a A6 ×a A7 ×a A8 ×a A9 ×a A10 ×a A11 ×a A12 ×a A13"
  unfolding tuple13_def by sepref

lemmas fold_tuple_optimizations = fold_tuples fold_case_prod_open



(* TODO: Move!
  TODO: General max functions!
*)
lemma snat64_max_refine[sepref_import_param]: (0x7FFFFFFFFFFFFFFF, snat64_max)snat_rel' TYPE(64)
  apply (auto simp: snat_rel_def snat.rel_def in_br_conv snat64_max_def snat_invar_def)
  apply (auto simp: snat_def)
  done

lemma snat32_max_refine[sepref_import_param]: (0x7FFFFFFF, snat32_max)snat_rel' TYPE(32)
  apply (auto simp: snat_rel_def snat.rel_def in_br_conv snat32_max_def snat_invar_def)
  apply (auto simp: snat_def)
  done

lemma unat64_max_refine[sepref_import_param]: (0xFFFFFFFFFFFFFFFF, unat64_max)unat_rel' TYPE(64)
  apply (auto simp: unat_rel_def unat.rel_def in_br_conv unat64_max_def)
  done

lemma unat32_max_refine[sepref_import_param]: (0xFFFFFFFF, unat32_max)unat_rel' TYPE(32)
  apply (auto simp: unat_rel_def unat.rel_def in_br_conv unat32_max_def)
  done


(* TODO: Move *)

abbreviation uint32_nat_assn  unat_assn' TYPE(32)
abbreviation uint64_nat_assn  unat_assn' TYPE(64)

abbreviation sint32_nat_assn  snat_assn' TYPE(32)
abbreviation sint64_nat_assn  snat_assn' TYPE(64)

text It is critical for performance of auto to calculate the power instead of letting auto do it
every time.
lemmas [simplified, sepref_bounds_simps] =
  unat32_max_def snat32_max_def
  unat64_max_def snat64_max_def


lemma is_up'_32_64[simp,intro!]: is_up' UCAST(32  64) by (simp add: is_up')
lemma is_down'_64_32[simp,intro!]: is_down' UCAST(64  32)  by (simp add: is_down')

lemma ins_idx_upcast64:
  l[i:=y] = op_list_set l (op_unat_snat_upcast TYPE(64) i) y
  l!i = op_list_get l (op_unat_snat_upcast TYPE(64) i)
  by simp_all

type_synonym 'a array_list32 = ('a,32)array_list
type_synonym 'a array_list64 = ('a,64)array_list

abbreviation arl32_assn  al_assn' TYPE(32)
abbreviation arl64_assn  al_assn' TYPE(64)


type_synonym 'a larray32 = ('a,32) larray
type_synonym 'a larray64 = ('a,64) larray

abbreviation larray32_assn  larray_assn' TYPE(32)
abbreviation larray64_assn  larray_assn' TYPE(64)



definition unat_lit_rel == unat_rel' TYPE(32) O nat_lit_rel
lemmas [fcomp_norm_unfold] = unat_lit_rel_def[symmetric]

abbreviation unat_lit_assn :: nat literal  32 word  assn where
  unat_lit_assn  pure unat_lit_rel

subsection Atom-Of

type_synonym atom_assn = 32 word

definition atom_rel  b_rel (unat_rel' TYPE(32)) (λx. x<2^31)
abbreviation atom_assn  pure atom_rel

lemma atom_rel_alt: atom_rel = unat_rel' TYPE(32) O nbn_rel (2^31)
  by (auto simp: atom_rel_def)

interpretation atom: dflt_pure_option_private 2^32-1 atom_assn ll_icmp_eq (2^32-1)
  apply unfold_locales
  subgoal
    unfolding atom_rel_def
    apply (simp add: pure_def fun_eq_iff pred_lift_extract_simps)
    apply (auto simp: unat_rel_def unat.rel_def in_br_conv unat_minus_one_word)
    done
  subgoal proof goal_cases
    case 1
      interpret llvm_prim_arith_setup .
      show ?case unfolding bool.assn_def by vcg'
    qed
  subgoal by simp
  done


lemma atm_of_refine: (λx. x div 2 , atm_of)  nat_lit_rel  nat_rel
  by (auto simp: nat_lit_rel_def in_br_conv)


sepref_def atm_of_impl is [] RETURN o (λx::nat. x div 2)
  :: uint32_nat_assnk a atom_assn
  unfolding atom_rel_def b_assn_pure_conv[symmetric]
  apply (rule hfref_bassn_resI)
  subgoal by sepref_bounds
  apply (annot_unat_const TYPE(32))
  by sepref

lemmas [sepref_fr_rules] = atm_of_impl.refine[FCOMP atm_of_refine]


definition Pos_rel :: nat  nat where
 [simp]: Pos_rel n = 2 * n

lemma Pos_refine_aux: (Pos_rel,Pos)nat_rel  nat_lit_rel
  by (auto simp: nat_lit_rel_def in_br_conv split: if_splits)

lemma Neg_refine_aux: (λx. 2*x + 1,Neg)nat_rel  nat_lit_rel
  by (auto simp: nat_lit_rel_def in_br_conv split: if_splits)

sepref_def Pos_impl is [] RETURN o Pos_rel :: atom_assnd a uint32_nat_assn
  unfolding atom_rel_def Pos_rel_def
  apply (annot_unat_const TYPE(32))
  by sepref

sepref_def Neg_impl is [] RETURN o (λx. 2*x+1) :: atom_assnd a uint32_nat_assn
  unfolding atom_rel_def
  apply (annot_unat_const TYPE(32))
  by sepref

lemmas [sepref_fr_rules] =
  Pos_impl.refine[FCOMP Pos_refine_aux]
  Neg_impl.refine[FCOMP Neg_refine_aux]

sepref_def atom_eq_impl is uncurry (RETURN oo (=)) :: atom_assnd *a atom_assnd a bool1_assn
  unfolding atom_rel_def
  by sepref


definition value_of_atm :: nat  nat where
[simp]: value_of_atm A = A

lemma value_of_atm_rel: (λx. x, value_of_atm)  nat_rel  nat_rel
  by (auto)

sepref_def value_of_atm_impl
  is [] RETURN o (λx. x)
  :: atom_assnd a unat_assn' TYPE(32)
  unfolding value_of_atm_def atom_rel_def
  by sepref

lemmas [sepref_fr_rules] = value_of_atm_impl.refine[FCOMP value_of_atm_rel]

definition index_of_atm :: nat  nat where
[simp]: index_of_atm A = value_of_atm A

lemma index_of_atm_rel: (λx. value_of_atm x, index_of_atm)  nat_rel  nat_rel
  by (auto)


sepref_def index_of_atm_impl
  is [] RETURN o (λx. value_of_atm x)
  :: atom_assnd a snat_assn' TYPE(64)
  unfolding index_of_atm_def
  apply (rewrite at _ eta_expand)
  apply (subst annot_unat_snat_upcast[where 'l=64])
  by sepref

lemmas [sepref_fr_rules] = index_of_atm_impl.refine[FCOMP index_of_atm_rel]

lemma annot_index_of_atm: xs ! x = xs ! index_of_atm x
   xs [x := a] = xs [index_of_atm x := a]
  by auto

definition index_atm_of where
[simp]: index_atm_of L = index_of_atm (atm_of L)

(* TODO: Use at more places! *)
context fixes x y :: nat assumes NO_MATCH (index_of_atm y) x begin
  lemmas annot_index_of_atm' = annot_index_of_atm[where x=x]
end

method_setup annot_all_atm_idxs = Scan.succeed (fn ctxt => SIMPLE_METHOD'
    let
      val ctxt = put_simpset HOL_basic_ss ctxt
      val ctxt = ctxt addsimps @{thms annot_index_of_atm'}
      val ctxt = ctxt addsimprocs [@{simproc NO_MATCH}]
    in
      simp_tac ctxt
    end
  )

lemma annot_index_atm_of[def_pat_rules]:
  nth$xs$(atm_of$x)  nth$xs$(index_atm_of$x)
  list_update$xs$(atm_of$x)$a  list_update$xs$(index_atm_of$x)$a
  by auto


sepref_def index_atm_of_impl
  is RETURN o index_atm_of
  :: unat_lit_assnd a snat_assn' TYPE(64)
  unfolding index_atm_of_def
  by sepref




lemma nat_of_lit_refine_aux: ((λx. x), nat_of_lit)  nat_lit_rel  nat_rel
  by (auto simp: nat_lit_rel_def in_br_conv)

sepref_def nat_of_lit_rel_impl is [] RETURN o (λx::nat. x) :: uint32_nat_assnk a sint64_nat_assn
  apply (rewrite annot_unat_snat_upcast[where 'l=64])
  by sepref
lemmas [sepref_fr_rules] = nat_of_lit_rel_impl.refine[FCOMP nat_of_lit_refine_aux]

lemma uminus_refine_aux: (λx. x XOR 1, uminus)  nat_lit_rel  nat_lit_rel
  apply (auto simp: nat_lit_rel_def in_br_conv bitXOR_1_if_mod_2[simplified])
  subgoal by linarith
  subgoal by (metis dvd_minus_mod even_Suc_div_two odd_Suc_minus_one)
  done

sepref_def uminus_impl is [] RETURN o (λx::nat. x XOR 1) :: uint32_nat_assnk a uint32_nat_assn
  apply (annot_unat_const TYPE(32))
  by sepref

lemmas [sepref_fr_rules] = uminus_impl.refine[FCOMP uminus_refine_aux]

lemma lit_eq_refine_aux: ( (=), (=) )  nat_lit_rel  nat_lit_rel  bool_rel
  by (auto simp: nat_lit_rel_def in_br_conv split: if_splits; auto?; presburger)

sepref_def lit_eq_impl is [] uncurry (RETURN oo (=)) :: uint32_nat_assnk *a uint32_nat_assnk a bool1_assn
  by sepref

lemmas [sepref_fr_rules] = lit_eq_impl.refine[FCOMP lit_eq_refine_aux]

lemma is_pos_refine_aux: (λx. x AND 1 = 0, is_pos)  nat_lit_rel  bool_rel
  by (auto simp: nat_lit_rel_def in_br_conv bitAND_1_mod_2[simplified] split: if_splits)

sepref_def is_pos_impl is [] RETURN o (λx. x AND 1 = 0) :: uint32_nat_assnk a bool1_assn
  apply (annot_unat_const TYPE(32))
  by sepref

lemmas [sepref_fr_rules] = is_pos_impl.refine[FCOMP is_pos_refine_aux]

sepref_decl_op nat_lit_eq: (=) :: nat literal  _  _ ::
  (Id :: (nat literal × _) set)  (Id :: (nat literal × _) set)  bool_rel .

sepref_def nat_lit_eq_impl
  is [] uncurry (RETURN oo (λx y. x = y))
  :: uint32_nat_assnk *a uint32_nat_assnk a bool1_assn
  by sepref

lemma nat_lit_rel: ((=), op_nat_lit_eq)  nat_lit_rel  nat_lit_rel  bool_rel
  by (auto simp: nat_lit_rel_def br_def split: if_splits; presburger)

sepref_register (=) :: nat literal  _  _
declare nat_lit_eq_impl.refine[FCOMP nat_lit_rel, sepref_fr_rules]

context
  fixes l_dummy :: 'l::len2 itself
  fixes ll_dummy :: 'll::len2 itself
  fixes L LL AA
  defines [simp]: L  (LENGTH ('l))
  defines [simp]: LL  (LENGTH ('ll))
  defines [simp]: AA  raw_aal_assn TYPE('l::len2) TYPE('ll::len2)
begin
  private lemma n_unf: hr_comp AA (the_pure Alist_rellist_rel) = aal_assn A unfolding aal_assn_def AA_def ..

context
  notes [fcomp_norm_unfold] = n_unf
begin

lemma aal_assn_free[sepref_frame_free_rules]: MK_FREE AA aal_free
  apply rule by vcg
  sepref_decl_op list_list_free: λ_::_ list list. () :: Alist_rellist_rel  unit_rel .

lemma hn_aal_free_raw: (aal_free,RETURN o op_list_list_free)  AAd a unit_assn
    by sepref_to_hoare vcg

  sepref_decl_impl aal_free: hn_aal_free_raw
     .

  lemmas array_mk_free[sepref_frame_free_rules] = hn_MK_FREEI[OF aal_free_hnr]
end
end


lemma of_nat_snat:
  (id,of_nat)  snat_rel' TYPE('a::len2)  word_rel
  by (auto simp: snat_rel_def snat.rel_def in_br_conv snat_eq_unat)

lemma of_nat_unat:
  (id,of_nat)  unat_rel' TYPE('a::len2)  word_rel
  by (auto simp: unat_rel_def unat.rel_def in_br_conv snat_eq_unat)


type_synonym tri_bool_assn = 8 word

definition tri_bool_rel_aux  { (0::nat,None), (2,Some True), (3,Some False) }
definition tri_bool_rel  unat_rel' TYPE(8) O tri_bool_rel_aux
abbreviation tri_bool_assn  pure tri_bool_rel
lemmas [fcomp_norm_unfold] = tri_bool_rel_def[symmetric]

lemma tri_bool_UNSET_refine_aux: (0,UNSET)tri_bool_rel_aux
  and tri_bool_SET_TRUE_refine_aux: (2,SET_TRUE)tri_bool_rel_aux
  and tri_bool_SET_FALSE_refine_aux: (3,SET_FALSE)tri_bool_rel_aux
  and tri_bool_eq_refine_aux: ((=),tri_bool_eq)  tri_bool_rel_auxtri_bool_rel_auxbool_rel
  by (auto simp: tri_bool_rel_aux_def tri_bool_eq_def)

sepref_def tri_bool_UNSET_impl is [] uncurry0 (RETURN 0) :: unit_assnk a unat_assn' TYPE(8)
  apply (annot_unat_const TYPE(8))
  by sepref

sepref_def tri_bool_SET_TRUE_impl is [] uncurry0 (RETURN 2) :: unit_assnk a unat_assn' TYPE(8)
  apply (annot_unat_const TYPE(8))
  by sepref

sepref_def tri_bool_SET_FALSE_impl is [] uncurry0 (RETURN 3) :: unit_assnk a unat_assn' TYPE(8)
  apply (annot_unat_const TYPE(8))
  by sepref

sepref_def tri_bool_eq_impl [llvm_inline] is [] uncurry (RETURN oo (=)) :: (unat_assn' TYPE(8))k *a (unat_assn' TYPE(8))k a bool1_assn
  by sepref

lemmas [sepref_fr_rules] =
  tri_bool_UNSET_impl.refine[FCOMP tri_bool_UNSET_refine_aux]
  tri_bool_SET_TRUE_impl.refine[FCOMP tri_bool_SET_TRUE_refine_aux]
  tri_bool_SET_FALSE_impl.refine[FCOMP tri_bool_SET_FALSE_refine_aux]
  tri_bool_eq_impl.refine[FCOMP tri_bool_eq_refine_aux]
hide_const (open) tuple4 tuple7
end