Theory CDCL.DPLL_CDCL_W_Implementation

theory DPLL_CDCL_W_Implementation
imports
  Entailment_Definition.Partial_Annotated_Herbrand_Interpretation
  CDCL_W_Level
begin
chapter List-based Implementation of DPLL and CDCL

text We can now reuse all the theorems to go towards an implementation using 2-watched literals:
   @{file CDCL_W_Abstract_State.thy} defines a better-suited state: the operation operating on it
  are more constrained, allowing simpler proofs and less edge cases later.

section Simple List-Based Implementation of the DPLL and CDCL

text The idea of the list-based implementation is to test the stack: the theories about the
  calculi, adapting the theorems to a simple implementation and the code exportation. The
  implementation are very simple ans simply iterate over-and-over on lists.

subsection Common Rules

subsubsection Propagation

text The following theorem holds:

lemma lits_of_l_unfold:
  "(c  set C. -c  lits_of_l Ms)  Ms ⊨as CNot (mset C)"
  unfolding true_annots_def Ball_def true_annot_def CNot_def by auto
text The right-hand version is written at a high-level, but only the left-hand side is executable.

definition is_unit_clause :: "'a literal list  ('a, 'b) ann_lits  'a literal option"
 where
 "is_unit_clause l M =
   (case List.filter (λa. atm_of a  atm_of ` lits_of_l M) l of
     a # []  if M ⊨as CNot (mset l - {#a#}) then Some a else None
   | _  None)"

definition is_unit_clause_code :: "'a literal list  ('a, 'b) ann_lits
   'a literal option" where
 "is_unit_clause_code l M =
   (case List.filter (λa. atm_of a  atm_of ` lits_of_l M) l of
     a # []  if (c set (remove1 a l). -c  lits_of_l M) then Some a else None
   | _  None)"

lemma is_unit_clause_is_unit_clause_code[code]:
  "is_unit_clause l M = is_unit_clause_code l M"
proof -
  have 1: "a. (cset (remove1 a l). - c  lits_of_l M)  M ⊨as CNot (mset l - {#a#})"
    using lits_of_l_unfold[of "remove1 _ l", of _ M] by simp
  then show ?thesis
    unfolding is_unit_clause_code_def is_unit_clause_def 1 by blast
qed

lemma is_unit_clause_some_undef:
  assumes "is_unit_clause l M = Some a"
  shows "undefined_lit M a"
proof -
  have "(case [al . atm_of a  atm_of ` lits_of_l M] of []  None
          | [a]  if M ⊨as CNot (mset l - {#a#}) then Some a else None
          | a # ab # xa  Map.empty xa) = Some a"
    using assms unfolding is_unit_clause_def .
  then have "a  set [al . atm_of a  atm_of ` lits_of_l M]"
    apply (cases "[al . atm_of a  atm_of ` lits_of_l M]")
      apply simp
    apply (rename_tac aa list; case_tac list) by (auto split: if_split_asm)
  then have "atm_of a  atm_of ` lits_of_l M" by auto
  then show ?thesis
    by (simp add: Decided_Propagated_in_iff_in_lits_of_l
      atm_of_in_atm_of_set_iff_in_set_or_uminus_in_set )
qed

lemma is_unit_clause_some_CNot: "is_unit_clause l M = Some a  M ⊨as CNot (mset l - {#a#})"
  unfolding is_unit_clause_def
proof -
  assume "(case [al . atm_of a  atm_of ` lits_of_l M] of []  None
          | [a]  if M ⊨as CNot (mset l - {#a#}) then Some a else None
          | a # ab # xa  Map.empty xa) = Some a"
  then show ?thesis
    apply (cases "[al . atm_of a  atm_of ` lits_of_l M]", simp)
      apply simp
    apply (rename_tac aa list, case_tac list) by (auto split: if_split_asm)
qed

lemma is_unit_clause_some_in: "is_unit_clause l M = Some a  a  set l"
  unfolding is_unit_clause_def
proof -
  assume "(case [al . atm_of a  atm_of ` lits_of_l M] of []  None
         | [a]  if M ⊨as CNot (mset l - {#a#}) then Some a else None
         | a # ab # xa  Map.empty xa) = Some a"
  then show "a  set l"
    by (cases "[al . atm_of a  atm_of ` lits_of_l M]")
       (fastforce dest: filter_eq_ConsD split: if_split_asm split: list.splits)+
qed

lemma is_unit_clause_Nil[simp]: "is_unit_clause [] M = None"
  unfolding is_unit_clause_def by auto


subsubsection Unit propagation for all clauses

text Finding the first clause to propagate
fun find_first_unit_clause :: "'a literal list list  ('a, 'b) ann_lits
   ('a literal × 'a literal list) option" where
"find_first_unit_clause (a # l) M =
  (case is_unit_clause a M of
    None  find_first_unit_clause l M
  | Some L  Some (L, a))" |
"find_first_unit_clause [] _ = None"

lemma find_first_unit_clause_some:
  "find_first_unit_clause l M = Some (a, c)
   c  set l   M ⊨as CNot (mset c - {#a#})  undefined_lit M a  a  set c"
  apply (induction l)
    apply simp
  by (auto split: option.splits dest: is_unit_clause_some_in is_unit_clause_some_CNot
         is_unit_clause_some_undef)

lemma propagate_is_unit_clause_not_None:
  assumes
  M: "M ⊨as CNot (mset c - {#a#})" and
  undef: "undefined_lit M a" and
  ac: "a  set c"
  shows "is_unit_clause c M  None"
proof -
  have "[ac . atm_of a  atm_of ` lits_of_l M] = [a]"
    using assms
    proof (induction c)
      case Nil then show ?case by simp
    next
      case (Cons ac c)
      show ?case
        proof (cases "a = ac")
          case True
          then show ?thesis using Cons
            by (auto simp del: lits_of_l_unfold
                 simp add: lits_of_l_unfold[symmetric] Decided_Propagated_in_iff_in_lits_of_l
                   atm_of_eq_atm_of atm_of_in_atm_of_set_iff_in_set_or_uminus_in_set)
        next
          case False
          then have T: "mset c + {#ac#} - {#a#} = mset c - {#a#} + {#ac#}"
            by (auto simp add: multiset_eq_iff)
          show ?thesis using False Cons
            by (auto simp add: T atm_of_in_atm_of_set_iff_in_set_or_uminus_in_set)
        qed
    qed
  then show ?thesis
    using M unfolding is_unit_clause_def by auto
qed

lemma find_first_unit_clause_none:
  "c  set l   M ⊨as CNot (mset c - {#a#})  undefined_lit M a  a  set c
   find_first_unit_clause l M  None"
  by (induction l)
     (auto split: option.split simp add: propagate_is_unit_clause_not_None)

subsubsection Decide
fun find_first_unused_var :: "'a literal list list  'a literal set  'a literal option" where
"find_first_unused_var (a # l) M =
  (case List.find (λlit. lit  M  -lit  M) a of
    None  find_first_unused_var l M
  | Some a  Some a)" |
"find_first_unused_var [] _ = None"

lemma find_none[iff]:
  "List.find (λlit. lit  M  -lit  M) a = None   atm_of ` set a  atm_of `  M"
  apply (induct a)
  using atm_of_in_atm_of_set_iff_in_set_or_uminus_in_set
    by (force simp add:  atm_of_in_atm_of_set_iff_in_set_or_uminus_in_set)+

lemma find_some: "List.find (λlit. lit  M  -lit  M) a = Some b  b  set a  b  M  -b  M"
  unfolding find_Some_iff by (metis nth_mem)

lemma find_first_unused_var_None[iff]:
  "find_first_unused_var l M = None  (a  set l. atm_of ` set a  atm_of `  M)"
  by (induct l)
     (auto split: option.splits dest!: find_some
       simp add: image_subset_iff atm_of_in_atm_of_set_iff_in_set_or_uminus_in_set)

lemma find_first_unused_var_Some_not_all_incl:
  assumes "find_first_unused_var l M = Some c"
  shows " ¬(a  set l. atm_of ` set a  atm_of `  M)"
proof -
  have "find_first_unused_var l M  None"
    using assms by (cases "find_first_unused_var l M") auto
  then show "¬(a  set l. atm_of ` set a  atm_of `  M)" by auto
qed

lemma find_first_unused_var_Some:
  "find_first_unused_var l M = Some a  (m  set l. a  set m  a  M  -a  M)"
  by (induct l) (auto split: option.splits dest: find_some)

lemma find_first_unused_var_undefined:
  "find_first_unused_var l (lits_of_l Ms) = Some a  undefined_lit Ms a"
  using find_first_unused_var_Some[of l "lits_of_l Ms" a] Decided_Propagated_in_iff_in_lits_of_l
  by blast


subsection CDCL specific functions

subsubsection Level

fun maximum_level_code:: "'a literal list  ('a, 'b) ann_lits  nat"
  where
"maximum_level_code [] _ = 0" |
"maximum_level_code (L # Ls) M = max (get_level M L) (maximum_level_code Ls M)"

lemma maximum_level_code_eq_get_maximum_level[simp]:
  "maximum_level_code D M = get_maximum_level M (mset D)"
  by (induction D) (auto simp add: get_maximum_level_add_mset)

lemma [code]:
  fixes M :: "('a, 'b) ann_lits"
  shows "get_maximum_level M (mset D) = maximum_level_code D M"
  by simp

subsubsection Backjumping
fun find_level_decomp where
"find_level_decomp M [] D k = None" |
"find_level_decomp M (L # Ls) D k =
  (case (get_level M L, maximum_level_code (D @ Ls) M) of
    (i, j)  if i = k  j < i then Some (L, j) else find_level_decomp M Ls (L#D) k
  )"

lemma find_level_decomp_some:
  assumes "find_level_decomp M Ls D k = Some (L, j)"
  shows "L  set Ls  get_maximum_level M (mset (remove1 L (Ls @ D))) = j  get_level M L = k"
  using assms
proof (induction Ls arbitrary: D)
  case Nil
  then show ?case by simp
next
  case (Cons L' Ls) note IH = this(1) and H = this(2)
  (* heavily modified sledgehammer proof *)
  define find where "find  (if get_level M L'  k  ¬ get_maximum_level M (mset D + mset Ls) < get_level M L'
    then find_level_decomp M Ls (L' # D) k
    else Some (L', get_maximum_level M (mset D + mset Ls)))"
  have a1: "D. find_level_decomp M Ls D k = Some (L, j) 
     L  set Ls  get_maximum_level M (mset Ls + mset D - {#L#}) = j  get_level M L = k"
    using IH by simp
  have a2: "find = Some (L, j)"
    using H unfolding find_def by (auto split: if_split_asm)
  { assume "Some (L', get_maximum_level M (mset D + mset Ls))  find"
    then have f3: "L  set Ls" and "get_maximum_level M (mset Ls + mset (L' # D) - {#L#}) = j"
      using a1 IH a2 unfolding find_def by meson+
    moreover then have "mset Ls + mset D - {#L#} + {#L'#} = {#L'#} + mset D + (mset Ls - {#L#})"
      by (auto simp: ac_simps multiset_eq_iff Suc_leI)
    ultimately have f4: "get_maximum_level M (mset Ls + mset D - {#L#} + {#L'#}) = j"
      by auto
  } note f4 = this
  have "{#L'#} + (mset Ls + mset D) = mset Ls + (mset D + {#L'#})"
      by (auto simp: ac_simps)
  then have
    "L = L'  get_maximum_level M (mset Ls + mset D) = j  get_level M L' = k" and
    "L  L'  L  set Ls  get_maximum_level M (mset Ls + mset D - {#L#} + {#L'#}) = j 
      get_level M L = k"
     using a2 a1[of "L' # D"] unfolding find_def
     apply (metis add.commute add_diff_cancel_left' add_mset_add_single mset.simps(2)
         option.inject prod.inject)
    using f4 a2 a1[of "L' # D"] unfolding find_def by (metis option.inject prod.inject)
  then show ?case by simp
qed

lemma find_level_decomp_none:
  assumes "find_level_decomp M Ls E k = None" and "mset (L#D) = mset (Ls @ E)"
  shows "¬(L  set Ls  get_maximum_level M (mset D) < k  k = get_level M L)"
  using assms
proof (induction Ls arbitrary: E L D)
  case Nil
  then show ?case by simp
next
  case (Cons L' Ls) note IH = this(1) and find_none = this(2) and LD = this(3)
  have "mset D + {#L'#} = mset E + (mset Ls + {#L'#})   mset D = mset E + mset Ls"
    by (metis add_right_imp_eq union_assoc)
  then show ?case
    using find_none IH[of "L' # E" L D] LD by (auto simp add: ac_simps split: if_split_asm)
qed

fun bt_cut where
"bt_cut i (Propagated _ _ # Ls) = bt_cut i Ls" |
"bt_cut i (Decided K # Ls) = (if count_decided Ls = i then Some (Decided K # Ls) else bt_cut i Ls)" |
"bt_cut i [] = None"

lemma bt_cut_some_decomp:
  assumes "no_dup M" and "bt_cut i M = Some M'"
  shows "K M2 M1. M = M2 @ M'  M' = Decided K # M1  get_level M K = (i+1)"
  using assms by (induction i M rule: bt_cut.induct) (auto simp: no_dup_def split: if_split_asm)

lemma bt_cut_not_none:
  assumes "no_dup M" and "M = M2 @ Decided K # M'" and "get_level M K = (i+1)"
  shows "bt_cut i M  None"
  using assms by (induction M2 arbitrary: M rule: ann_lit_list_induct)
  (auto simp: no_dup_def atm_lit_of_set_lits_of_l)

lemma get_all_ann_decomposition_ex:
  "N. (Decided K # M', N)  set (get_all_ann_decomposition (M2@Decided K # M'))"
  apply (induction M2 rule: ann_lit_list_induct)
    apply auto[2]
  by (rename_tac L m xs,  case_tac "get_all_ann_decomposition (xs @ Decided K # M')")
  auto

lemma bt_cut_in_get_all_ann_decomposition:
  assumes "no_dup M" and "bt_cut i M = Some M'"
  shows "M2. (M', M2)  set (get_all_ann_decomposition M)"
  using bt_cut_some_decomp[OF assms] by (auto simp add: get_all_ann_decomposition_ex)

fun do_backtrack_step where
"do_backtrack_step (M, N, U, Some D) =
  (case find_level_decomp M D [] (count_decided M) of
    None  (M, N, U, Some D)
  | Some (L, j) 
    (case bt_cut j M of
      Some (Decided _ # Ls)  (Propagated L D # Ls, N, D # U, None)
    | _  (M, N, U, Some D))
  )" |
"do_backtrack_step S = S"

end