Theory Watched_Literals_Watch_List_Restart

theory Watched_Literals_Watch_List_Restart
  imports Watched_Literals_Watch_List
    Watched_Literals_List_Simp
begin

(*TODO Move*)
lemma cdcl_twl_restart_get_all_init_clss:
  assumes cdcl_twl_restart S T
  shows get_all_init_clss T = get_all_init_clss S
  using assms by (induction rule: cdcl_twl_restart.induct) auto

lemma rtranclp_cdcl_twl_restart_get_all_init_clss:
  assumes cdcl_twl_restart** S T
  shows get_all_init_clss T = get_all_init_clss S
  using assms by (induction rule: rtranclp_induct) (auto simp: cdcl_twl_restart_get_all_init_clss)
(*END Move*)

text As we have a specialised version of termcorrect_watching, we defined a special version for
the inclusion of the domain:

definition all_init_lits :: (nat, 'v literal list × bool) fmap  'v literal multiset multiset 
   'v literal multiset where
  all_init_lits S NUE = all_lits_of_mm ((λC. mset C) `# init_clss_lf S + NUE)

lemma all_init_lits_alt_def:
  all_init_lits S (NUE + NUS + N0S) = all_lits_of_mm ((λC. mset C) `# init_clss_lf S + NUE + NUS + N0S)
  all_init_lits b (d + f + g) = all_lits_of_mm ({#mset (fst x). x ∈# init_clss_l b#} + d + f + g)
  by (auto simp: all_init_lits_def ac_simps)

(* abbreviation all_init_lits_of_wl :: ‹'v twl_st_wl ⇒ 'v literal multiset› where
 *   ‹all_init_lits_of_wl S ≡ all_init_lits (get_clauses_wl S)
 *     (get_unit_init_clss_wl S + get_subsumed_init_clauses_wl S + get_init_clauses0_wl S)› *)

definition all_init_atms :: _  _  'v multiset where
  all_init_atms N NUE = atm_of `# all_init_lits N NUE

declare all_init_atms_def[symmetric, simp]

lemma all_init_atms_alt_def:
  set_mset (all_init_atms N NE) = atms_of_mm (mset `# init_clss_lf N)  atms_of_mm NE
  unfolding all_init_atms_def all_init_lits_def
  by (auto simp: in_all_lits_of_mm_ain_atms_of_iff
      all_lits_of_mm_def atms_of_ms_def image_UN
      atms_of_def
    dest!: multi_member_split[of (_, _) ran_m N]
    dest: multi_member_split atm_of_lit_in_atms_of
    simp del: set_image_mset)

lemma in_set_all_init_atms_iff:
  y ∈# all_init_atms bu bw 
    y  atms_of_mm (mset `# init_clss_lf bu)  y  atms_of_mm bw
  by (auto simp: all_lits_def atm_of_all_lits_of_mm all_init_atms_alt_def atms_of_def
        in_all_lits_of_mm_ain_atms_of_iff all_lits_of_mm_def atms_of_ms_def image_UN
    dest!: multi_member_split[of (_, _) ran_m N]
    dest: multi_member_split atm_of_lit_in_atms_of
    simp del: set_image_mset)

lemma all_init_atms_fmdrop_add_mset_unit:
  C ∈# dom_m baa  irred baa C 
    all_init_atms (fmdrop C baa) (add_mset (mset (baa  C)) da) =
   all_init_atms baa da
  C ∈# dom_m baa  ¬irred baa C 
    all_init_atms (fmdrop C baa) da =
   all_init_atms baa da
  by (auto simp del: all_init_atms_def[symmetric]
    simp: all_init_atms_def all_init_lits_def
      init_clss_l_fmdrop_irrelev image_mset_remove1_mset_if)

lemma all_init_lits_of_wl_simps[simp]:
  C ∈# dom_m N  ¬irred N C 
  all_init_lits_of_wl (M, fmdrop C N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W) =
    all_init_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W)
  NO_MATCH {#} US 
  all_init_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W) =
    all_init_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, {#}, N0, U0, Q, W)
  NO_MATCH [] M 
  all_init_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W) =
    all_init_lits_of_wl ([], N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W)
  C ∈# dom_m N  irred N C 
   all_init_lits_of_wl (M, fmdrop C N, D, add_mset (mset (N  C)) NE, UE, NEk, UEk, NS, US, N0, U0, Q, W) =
  all_init_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W)
  all_init_lits_of_wl (M, N, D, NE, add_mset E UE, NEk, UEk, NS, US, N0, U0, Q, W) =
    all_init_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W)
  NO_MATCH {#} UEk 
  all_init_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W) =
    all_init_lits_of_wl (M, N, D, NE, UE, NEk, {#}, NS, US, N0, U0, Q, W)
  NO_MATCH {#} U0 
  all_init_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W) =
    all_init_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, {#}, Q, W)
  NO_MATCH {#} UE 
  all_init_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W) =
    all_init_lits_of_wl (M, N, D, NE, {#}, NEk, UEk, NS, US, N0, U0, Q, W)
  by (auto simp: all_init_lits_of_wl_def all_lits_of_mm_add_mset
    image_mset_remove1_mset_if)

lemma all_learned_lits_of_wl_simps[simp]:
  C ∈# dom_m N  irred N C 
  all_learned_lits_of_wl (M, fmdrop C N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W) =
    all_learned_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W)
  (* ‹NO_MATCH {#} NS ⟹
   * all_learned_lits_of_wl (M, N, D, NE, UE, NS, US, N0, U0, Q, W) =
   *   all_learned_lits_of_wl (M, N, D, NE, UE, {#}, US, N0, U0, Q, W)› *)
  NO_MATCH [] M 
  all_learned_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W) =
    all_learned_lits_of_wl ([], N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W)
  all_learned_lits_of_wl (M, N, D, add_mset E NE, UE, NEk, UEk, NS, US, N0, U0, Q, W) =
    all_learned_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W)
  all_learned_lits_of_wl (M, N, D, NE, UE, NEk, UEk, add_mset E NS, US, N0, U0, Q, W) =
    all_learned_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W)
  C ∈# dom_m N  ¬irred N C 
  all_learned_lits_of_wl (M, fmdrop C N, D, NE, add_mset (mset (N  C)) UE, NEk, UEk, NS, US, N0, U0, Q, W) =
  all_learned_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W)
  by (auto simp: all_learned_lits_of_wl_def all_lits_of_mm_add_mset
    image_mset_remove1_mset_if)
  
text To ease the proof, we introduce the following ``alternative'' definitions, that only considers
  variables that are present in the initial clauses (which are never deleted from the set of
  clauses, but only moved to another component).

fun correct_watching' :: 'v twl_st_wl  bool where
  correct_watching' (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W) 
    (L ∈# all_init_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W).
       distinct_watched (W L) 
       ((i, K, b)∈#mset (W L).
             i ∈# dom_m N  K  set (N  i)  K  L  correctly_marked_as_binary N (i, K, b)) 
       ((i, K, b)∈#mset (W L).
             b  i ∈# dom_m N) 
        filter_mset (λi. i ∈# dom_m N) (fst `# mset (W L)) =
          clause_to_update L (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, {#}, {#}))

(*TODO duplicate of leaking bin*)
fun correct_watching'_nobin :: 'v twl_st_wl  bool where
  correct_watching'_nobin (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W) 
    (L ∈# all_init_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W).
       distinct_watched (W L) 
       ((i, K, b)∈#mset (W L).
             i ∈# dom_m N  K  set (N  i)  K  L  correctly_marked_as_binary N (i, K, b)) 
        filter_mset (λi. i ∈# dom_m N) (fst `# mset (W L)) =
          clause_to_update L (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, {#}, {#}))

lemma correct_watching'_correct_watching': correct_watching' S  correct_watching' S
  by (cases S) auto

declare correct_watching'_nobin.simps[simp del] correct_watching'.simps[simp del]

text Now comes a weaker version of the invariants on watch lists: instead of knowing that
  the watch lists are correct, we only know that the clauses appear somewhere in the watch lists.
  From a conceptual point of view, this is sufficient to specify all operations, but this is not
  sufficient to derive bounds on the length. Hence, we also add the invariants that each watch list
  does not contain duplicates.

definition no_lost_clause_in_WL :: 'v twl_st_wl  bool where
  no_lost_clause_in_WL S 
  set_mset (dom_m (get_clauses_wl S))
     clauses_pointed_to (set_mset (all_init_lits_of_wl S)) (get_watched_wl S) 
  (L∈# all_init_lits_of_wl S. distinct_watched (watched_by S L))

definition no_lost_clause_in_WL0 :: 'v twl_st_wl  bool where
  no_lost_clause_in_WL0 S 
  set_mset (dom_m (get_clauses_wl S))
   clauses_pointed_to (set_mset (all_init_lits_of_wl S)) (get_watched_wl S)


definition blits_in_ℒin' :: 'v twl_st_wl  bool where
  blits_in_ℒin' S 
    (L ∈# all_init_lits_of_wl S.
      (i, K, b)  set (watched_by S L). K ∈# all_init_lits_of_wl S)

definition literals_are_ℒin' :: 'v twl_st_wl  bool where
  literals_are_ℒin' S 
    set_mset (all_learned_lits_of_wl S)  set_mset (all_init_lits_of_wl S) 
     blits_in_ℒin' S

definition all_init_atms_st :: 'v twl_st_wl  'v multiset where
  all_init_atms_st S  all_init_atms (get_clauses_wl S)
    (get_unit_init_clss_wl S + get_subsumed_init_clauses_wl S + get_init_clauses0_wl S)

lemma all_init_atms_st_alt_def: all_init_atms_st S = atm_of `# all_init_lits_of_wl S
  by (auto simp: all_atms_def all_lits_st_def all_init_atms_st_def all_init_lits_of_wl_def
    atm_of_all_lits_of_mm all_init_atms_def all_init_lits_def ac_simps
    simp del: all_init_atms_def[symmetric])

lemma all_all_init_atms:
  set_mset (all (all_init_atms N NU)) = set_mset (all_init_lits N NU)
  set_mset (all (all_init_atms_st S)) = set_mset (all_init_lits_of_wl S)
  by (simp_all add: all_atm_of_all_lits_of_mm all_init_atms_def all_init_lits_def
    all_init_lits_of_wl_def ac_simps all_init_atms_st_def)

lemma literals_are_ℒin_cong:
  set_mset 𝒜 = set_mset   literals_are_ℒin 𝒜 S = literals_are_ℒin  S
  using all_cong[of 𝒜 ]
  unfolding literals_are_ℒin_def blits_in_ℒin_def is_ℒall_def
  by auto

lemma all_learned_lits_of_wl_all_lits_st:
  set_mset (all_learned_lits_of_wl S)  set_mset (all_lits_st S)
  unfolding all_learned_lits_of_wl_def all_lits_st_def all_lits_def
  apply (subst (2) all_clss_l_ran_m[symmetric])
  unfolding image_mset_union
  by (cases S) (auto simp: all_lits_of_mm_union)

lemma all_lits_st_init_learned:
  set_mset (all_lits_st S) = set_mset (all_init_lits_of_wl S)  set_mset (all_learned_lits_of_wl S)
  unfolding all_learned_lits_of_wl_def all_lits_st_def all_lits_def all_init_lits_of_wl_def
  apply (subst (1) all_clss_l_ran_m[symmetric])
  unfolding image_mset_union
  by (cases S) (auto simp: all_lits_of_mm_union)

lemma all_all_atms:
  set_mset (all (all_atms_st S)) = set_mset (all_lits_st S)
  by (metis all_atm_of_all_lits_of_mm all_atms_st_alt_def_sym all_lits_def all_lits_st_def)

lemma literals_are_ℒin'_literals_are_ℒin_iff:
  assumes
    Sx: (S, x)  state_wl_l None and
    x_xa: (x, xa)  twl_st_l None and
    struct_invs: twl_struct_invs xa
  shows
    literals_are_ℒin' S  literals_are_ℒin (all_atms_st S) S (is ?A)
    literals_are_ℒin' S  literals_are_ℒin (all_atms_st S) S (is ?B)
    set_mset (all_init_atms_st S) = set_mset (all_atms_st S) (is ?C) and
    set_mset (all_init_lits_of_wl S) = set_mset (all_lits_st S) (is ?D)
proof -
  have cdclW_restart_mset.no_strange_atm (stateW_of xa)
    using struct_invs unfolding twl_struct_invs_def cdclW_restart_mset.cdclW_all_struct_inv_def
      pcdcl_all_struct_invs_def stateW_of_def
    by fast+
  then have L. L  atm_of ` lits_of_l (get_trail_wl S)  L  atms_of_ms
      ((λx. mset (fst x)) ` {a. a ∈# ran_m (get_clauses_wl S)  snd a}) 
      atms_of_mm (get_unit_init_clss_wl S) 
      atms_of_mm (get_subsumed_init_clauses_wl S) 
      atms_of_mm (get_init_clauses0_wl S) and
    alien_learned: atms_of_mm (learned_clss (stateW_of xa))
       atms_of_mm (init_clss (stateW_of xa))
    using Sx x_xa unfolding cdclW_restart_mset.no_strange_atm_def
    by (auto 5 2 simp add: twl_st twl_st_l twl_st_wl)

  show 1: set_mset (all_init_lits_of_wl S) = set_mset (all_lits_st S)
    unfolding all_lits_st_def all_lits_def all_init_lits_of_wl_def
    apply (subst (2) all_clss_l_ran_m[symmetric])
    using alien_learned Sx x_xa
    unfolding image_mset_union all_lits_of_mm_union
    by (auto simp : in_all_lits_of_mm_ain_atms_of_iff get_unit_clauses_wl_alt_def
      twl_st twl_st_l twl_st_wl get_learned_clss_wl_def)
  have set_mset (all_init_lits_of_wl S) = set_mset (all (all_init_atms_st S))
    unfolding all_all_init_atms(2) ..
 show A: literals_are_ℒin' S  literals_are_ℒin (all_atms_st S) S for 𝒜
  proof -
    have sub: set_mset (all_lits_st S)  set_mset (all_init_lits_of_wl S) 
      is_ℒall (all_init_atms_st S) (all_lits_st S)
     using all_init_lits_of_wl_all_lits_st[of S]
     unfolding is_ℒall_def all_all_init_atms(2) by auto
   have set_mset (all_learned_lits_of_wl S)  set_mset (all_lits_st S) 
     is_ℒall (all_atms_st S) (all_lits_st S)
     using all_init_lits_of_wl_all_lits_st[of S]
     unfolding all_lits_st_init_learned is_ℒall_def all_all_atms
     by auto
   then show ?thesis
      unfolding literals_are_ℒin'_def
	literals_are_ℒin_def blits_in_ℒin_def blits_in_ℒin'_def sub
	all_init_lits_def[symmetric] all_lits_alt_def2[symmetric]
        all_lits_alt_def[symmetric] all_init_lits_alt_def[symmetric]
        is_ℒall_def[symmetric] all_init_atms_def[symmetric] 1
      by simp
   qed

   show C: ?C
     using 1 unfolding all_atms_st_alt_def all_init_atms_st_alt_def
     apply (simp add: 1 del: all_init_atms_def[symmetric])
     by (metis all_atms_st_alt_def set_image_mset)

  show ?B
    apply (subst A)
    ..
qed

lemma correct_watching'_nobin_clauses_pointed_to0:
  assumes
    xa_xb: (xa, xb)  state_wl_l None and
    corr: correct_watching'_nobin xa and
    L: literals_are_ℒin' xa and
    xb_x: (xb, x)  twl_st_l None and
    struct_invs: twl_struct_invs x

  shows set_mset (dom_m (get_clauses_wl xa))
     clauses_pointed_to
    (Neg ` set_mset (all_init_atms_st xa) 
    Pos ` set_mset (all_init_atms_st xa))
    (get_watched_wl xa)
    (is ?G1 is _  ?A) and
    no_lost_clause_in_WL xa (is ?G2)
proof -
  let ?𝒜 = all_init_atms (get_clauses_wl xa) (get_unit_init_clss_wl xa)
  show ?G1
  proof
    fix C
    assume C: C ∈# dom_m (get_clauses_wl xa)
    obtain M N D NE UE NEk UEk NS US N0 U0 Q W where
      xa: xa = (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W)
      by (cases xa)
    have twl_st_inv x
      using xb_x C struct_invs
      by (auto simp: twl_struct_invs_def
        cdclW_restart_mset.cdclW_all_struct_inv_def)
    then have le0: get_clauses_wl xa  C  []
      using xb_x C xa_xb
      by (cases x; cases irred N C)
        (auto simp: twl_struct_invs_def twl_st_inv.simps
          twl_st_l_def state_wl_l_def xa ran_m_def conj_disj_distribR
          Collect_disj_eq Collect_conv_if
        dest!: multi_member_split)
    then have le: N  C ! 0  set (watched_l (N  C))
      by (cases N  C) (auto simp: xa)
    have eq: set_mset (all (all_init_atms N NE)) =
          set_mset (all_lits_of_mm (mset `# init_clss_lf N + NE))
       by (auto simp del: all_init_atms_def[symmetric]
          simp: all_init_atms_def xa all_atm_of_all_lits_of_mm[symmetric]
            all_init_lits_def)

    have H: get_clauses_wl xa  C ! 0 ∈# all_init_lits_of_wl xa
      using L C le0 apply -
      unfolding all_init_atms_def[symmetric] all_init_lits_def[symmetric]
      apply (subst literals_are_ℒin'_literals_are_ℒin_iff(4)[OF xa_xb xb_x struct_invs])
      apply (cases N  C; auto simp: literals_are_ℒin_def all_lits_def ran_m_def eq
            all_lits_of_mm_add_mset is_ℒall_def xa all_lits_of_m_add_mset
            all_all_atms_all_lits all_lits_st_def
          dest!: multi_member_split)
      done
    moreover {
      have {#i ∈# fst `# mset (W (N  C ! 0)). i ∈# dom_m N#} =
             add_mset C {#Ca ∈# remove1_mset C (dom_m N). N  C ! 0  set (watched_l (N  Ca))#}
        using corr H C le unfolding xa
        by (auto simp: clauses_pointed_to_def correct_watching'_nobin.simps xa
          simp flip: all_init_atms_def all_init_lits_def all_init_atms_alt_def
            all_init_lits_alt_def
          simp: clause_to_update_def
          simp del: all_init_atms_def[symmetric]
          dest!: multi_member_split)
      from arg_cong[OF this, of set_mset] have C  fst ` set (W (N  C ! 0))
        using corr H C le unfolding xa
        by (auto simp: clauses_pointed_to_def correct_watching'.simps xa
          simp: all_init_atms_def all_init_lits_def clause_to_update_def
          simp del: all_init_atms_def[symmetric]
          dest!: multi_member_split) }
    ultimately show C  ?A
      by (cases N  C ! 0)
        (auto simp: clauses_pointed_to_def correct_watching'.simps xa
          simp flip: all_init_lits_def all_init_atms_alt_def
            all_init_lits_alt_def 
          simp: clause_to_update_def all_init_atms_st_def all_init_lits_of_wl_def all_init_atms_def
          simp del: all_init_atms_def[symmetric]
        dest!: multi_member_split)
  qed
  moreover have set_mset (all_init_lits_of_wl xa) =
    Neg ` set_mset (all_init_atms_st xa)  Pos ` set_mset (all_init_atms_st xa)
    unfolding all_init_lits_of_wl_def
      all_lits_of_mm_def all_init_atms_st_def all_init_atms_def
    by (auto simp: all_init_atms_def all_init_lits_def all_lits_of_mm_def image_image
      image_Un
      simp del: all_init_atms_def[symmetric]) 
  moreover have distinct_watched (watched_by xa (Pos L))
    distinct_watched (watched_by xa (Neg L))
    if L ∈# all_init_atms_st xa for L
    using that corr
    by (cases xa;
        auto simp: correct_watching'_nobin.simps all_init_lits_of_wl_def all_init_atms_def
        all_lits_of_mm_union all_init_lits_def all_init_atms_st_def literal.atm_of_def
        in_all_lits_of_mm_uminus_iff[symmetric, of Pos _]
        simp del: all_init_atms_def[symmetric]
        split: literal.splits; fail)+
  ultimately show ?G2
    unfolding no_lost_clause_in_WL_def
    by (auto simp del: all_init_atms_def[symmetric]) 
qed

lemma correct_watching'_clauses_pointed_to2:
  assumes
    xa_xb: (xa, xb)  state_wl_l None and
    corr: correct_watching'_nobin xa and
    pre: mark_to_delete_clauses_l_GC_pre xb and
    L: literals_are_ℒin' xa
  shows set_mset (dom_m (get_clauses_wl xa))
          clauses_pointed_to
            (Neg ` set_mset (all_init_atms_st xa) 
             Pos ` set_mset (all_init_atms_st xa))
            (get_watched_wl xa)
        (is ?G1 is _  ?A) and
    no_lost_clause_in_WL xa (is ?G2)
  using correct_watching'_nobin_clauses_pointed_to0[OF xa_xb corr L] pre
  unfolding mark_to_delete_clauses_l_GC_pre_def
  by fast+


definition (in -) restart_abs_wl_pre :: 'v twl_st_wl  nat  nat  bool  bool where
  restart_abs_wl_pre S last_GC last_Restart brk 
    (S'. (S, S')  state_wl_l None  restart_abs_l_pre S' last_GC last_Restart brk
       correct_watching S  blits_in_ℒin S)

definition (in -) cdcl_twl_local_restart_wl_spec :: 'v twl_st_wl  'v twl_st_wl nres where
  cdcl_twl_local_restart_wl_spec = (λ(M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W). do {
      ASSERT(last_GC last_Restart. restart_abs_wl_pre (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W) last_GC last_Restart False);
      (M, Q)  SPEC(λ(M', Q'). (K M2. (Decided K # M', M2)  set (get_all_ann_decomposition M) 
            Q' = {#})  (M' = M  Q' = Q));
      RETURN (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W)
   })

lemma cdcl_twl_local_restart_wl_spec_cdcl_twl_local_restart_l_spec:
  (cdcl_twl_local_restart_wl_spec, cdcl_twl_local_restart_l_spec)
     {(S, T). (S, T)  state_wl_l None  correct_watching S  blits_in_ℒin S} f
      {(S, T). (S, T)  state_wl_l None  correct_watching S  blits_in_ℒin S}nres_rel
proof -
  have [simp]:
    all_lits N (NE + UE + (NS + US) + (N0 + U0)) = all_lits N (NE + UE + NS + US + N0 + U0)
    all_lits N ((NE + UE) + (NS + US) + (N0 + U0)) = all_lits N (NE + UE + NS + US + N0 + U0)
    for NE UE NS US N N0 U0
    by (auto simp: ac_simps)
  have [refine0]:
    x y x1 x2 x1a x2a x1b x2b x1c x2c x1d x2d x1e x2e x1f x2f x1g x2g x1h x2h x1i x2i x1j x2j x1k x2k x1l x2l x1m x2m x1n
    x2n x1o x2o x1p x2p x1q x2q x1r x2r x1s x2s x1t x2t x1u x2u x1v x2v x1w x2w xa x1x x2x.
    (x, y)  {(S, T). (S, T)  state_wl_l None  correct_watching S  blits_in_ℒin S} 
    x2j = (x1k, x2k) 
    x2i = (x1j, x2j) 
    x2h = (x1i, x2i) 
    x2g = (x1h, x2h) 
    x2f = (x1g, x2g) 
    x2e = (x1f, x2f) 
    x2d = (x1e, x2e) 
    x2c = (x1d, x2d) 
    x2b = (x1c, x2c) 
    x2a = (x1b, x2b) 
    x2 = (x1a, x2a) 
    y = (x1, x2) 
    x2v = (x1w, x2w) 
    x2u = (x1v, x2v) 
    x2t = (x1u, x2u) 
    x2s = (x1t, x2t) 
    x2r = (x1s, x2s) 
    x2q = (x1r, x2r) 
    x2p = (x1q, x2q) 
    x2o = (x1p, x2p) 
    x2n = (x1o, x2o) 
    x2m = (x1n, x2n) 
    x2l = (x1m, x2m) 
    x = (x1l, x2l) 
    case xa of
    (M', Q')  (K M2. (Decided K # M', M2)  set (get_all_ann_decomposition x1l)  Q' = {#})  M' = x1l  Q' = x1w 
    xa = (x1x, x2x) 
    (K M2. (Decided K # x1x, M2)  set (get_all_ann_decomposition x1)  x2x = {#})  x1x = x1  x2x = x2k
    by (auto 5 3 simp: state_wl_l_def)
  show ?thesis
    unfolding cdcl_twl_local_restart_wl_spec_def cdcl_twl_local_restart_l_spec_def
    apply (intro frefI nres_relI)
    apply (refine_vcg)
    subgoal unfolding restart_abs_wl_pre_def by fast
    apply assumption+
    subgoal
      by (fastforce simp: state_wl_l_def correct_watching.simps clause_to_update_def
          blits_in_ℒin_def
        simp flip: all_lits_alt_def2)
    done
qed

definition cdcl_twl_restart_wl_prog where
cdcl_twl_restart_wl_prog S = do {
   cdcl_twl_local_restart_wl_spec S
  }

lemma cdcl_twl_restart_wl_prog_cdcl_twl_restart_l_prog:
  (cdcl_twl_restart_wl_prog, cdcl_twl_restart_l_prog)
     {(S, T). (S, T)  state_wl_l None  correct_watching S  blits_in_ℒin S} f
      {(S, T). (S, T)  state_wl_l None  correct_watching S  blits_in_ℒin S}nres_rel
  unfolding cdcl_twl_restart_wl_prog_def cdcl_twl_restart_l_prog_def
  apply (intro frefI nres_relI)
  apply (refine_vcg cdcl_twl_local_restart_wl_spec_cdcl_twl_local_restart_l_spec[THEN fref_to_Down])
  done

definition cdcl_twl_full_restart_wl_GC_prog_post :: 'v twl_st_wl  'v twl_st_wl  bool where
cdcl_twl_full_restart_wl_GC_prog_post S T 
  (S' T'. (S, S')  state_wl_l None  (T, T')  state_wl_l None 
    cdcl_twl_full_restart_l_GC_prog_pre S' 
    cdcl_twl_restart_l_inp** S' T'  correct_watching' T 
    set_mset (all_init_lits_of_wl T) =
    set_mset (all_lits_st T) 
    get_unkept_learned_clss_wl T = {#} 
    get_subsumed_learned_clauses_wl T = {#} 
    get_learned_clauses0_wl T = {#}
)

definition cdcl_twl_full_restart_wl_GC_prog_post_confl :: 'v twl_st_wl  'v twl_st_wl  bool where
cdcl_twl_full_restart_wl_GC_prog_post_confl  S T 
  (S' T'. (S, S')  state_wl_l None  (T, T')  state_wl_l None 
    cdcl_twl_full_restart_l_GC_prog_pre S' 
    cdcl_twl_restart_l_inp** S' T' 
    set_mset (all_init_lits_of_wl T) =
    set_mset (all_lits_st T))

definition (in -) restart_abs_wl_pre2 :: 'v twl_st_wl  bool  bool where
  restart_abs_wl_pre2 S brk 
    (S' last_GC last_Restart. (S, S')  state_wl_l None  restart_abs_l_pre S' last_GC last_Restart brk
       correct_watching' S  literals_are_ℒin' S)

definition (in -) cdcl_twl_local_restart_wl_spec0 :: 'v twl_st_wl  'v twl_st_wl nres where
  cdcl_twl_local_restart_wl_spec0 = (λ(M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W). do {
      ASSERT(restart_abs_wl_pre2 (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W) False);
      (M, Q)  SPEC(λ(M', Q'). (K M2. (Decided K # M', M2)  set (get_all_ann_decomposition M) 
            Q' = {#}  count_decided M' = 0)  (M' = M  Q' = Q  count_decided M' = 0));
      RETURN (M, N, D, NE, UE, NEk, UEk, NS, {#}, N0, {#}, Q, W)
   })

definition cdcl_twl_full_restart_wl_GC_prog_pre
  :: 'v twl_st_wl  bool
where
  cdcl_twl_full_restart_wl_GC_prog_pre S 
   (T. (S, T)  state_wl_l None  correct_watching' S  literals_are_ℒin' S  cdcl_twl_full_restart_l_GC_prog_pre T)

lemma blits_in_ℒin'_restart_wl_spec0:
  NO_MATCH {#} f' 
  literals_are_ℒin' (a, b, c, d, e, NEk, UEk, NS, US, N0, U0, f', g) 
      literals_are_ℒin' (ah, b, c, d, e, NEk, UEk, NS, US, N0, U0, {#}, g)
  by (auto simp: blits_in_ℒin'_def literals_are_ℒin'_def
         all_init_lits_def all_init_lits_of_wl_def all_learned_lits_of_wl_def)

lemma all_init_lits_of_wl_keepUSD:
  L ∈# all_init_lits_of_wl ([], x1k, x1l, x1m, x1n, NEk, UEk, x1o, {#}, x1q, x1r, {#}, x2s) 
  L ∈# all_init_lits_of_wl ([], x1k, x1l, x1m, x1n, NEk, UEk, x1o, {#}, x1q, x1r, Q, x2s)
  by (auto simp: all_init_lits_of_wl_def all_lits_of_mm_def)

lemma (in -)[twl_st,simp]: learned_clss (stateW_of S) = get_all_learned_clss S
  by (cases S) auto

lemma (in -)[twl_st,simp]: init_clss (stateW_of S) = get_all_init_clss S
  by (cases S) auto
 
lemma literals_are_ℒin'_empty:
  NO_MATCH {#} x2m  literals_are_ℒin' (x1h, x1p, x1j, x1k, b, NEk, UEk, x', x2l, N0, U0, x2m, Q) 
     literals_are_ℒin' (x1h, x1p, x1j, x1k, b, NEk, UEk, x', x2l, N0, U0, {#}, Q)
   NO_MATCH {#} x2l  correct_watching' (x1h, x1i, x1j, x1k, b, NEk, UEk, x', x2l, N0, U0, x2m, Q) 
    correct_watching' (x1h, x1i, x1j, x1k, b, NEk, UEk, x', {#}, N0, U0, x2m, Q)
   NO_MATCH {#} x2m  correct_watching' (x1h, x1i, x1j, x1k, b, NEk, UEk, x', x2l, N0, U0, x2m, Q) 
    correct_watching' (x1h, x1i, x1j, x1k, b, NEk, UEk, x', x2l, N0, U0, {#}, Q)
   NO_MATCH {#} U0  correct_watching' (x1h, x1i, x1j, x1k, b, NEk, UEk, x', x2l, N0, U0, x2m, Q) 
    correct_watching' (x1h, x1i, x1j, x1k, b, NEk, UEk, x', x2l, N0, {#}, x2m, Q)
   NO_MATCH {#} b  correct_watching' (x1h, x1i, x1j, x1k, b, NEk, UEk, x', x2l, N0, U0, x2m, Q) 
    correct_watching' (x1h, x1i, x1j, x1k, {#}, NEk, UEk, x', x2l, N0, U0, x2m, Q)
  literals_are_ℒin' (x1h, x1p, x1j, x1k, b, NEk, UEk, x', x2l, N0, U0, x2m, Q) 
     literals_are_ℒin' (x1h, x1p, x1j, x1k, b, NEk, UEk, x', {#}, N0, U0, x2m, Q)
  literals_are_ℒin' (x1h, x1p, x1j, x1k, b, NEk, UEk, x', x2l, N0, U0, x2m, Q) 
     literals_are_ℒin' (x1h, x1p, x1j, x1k, b, NEk, UEk, x', x2l, N0, {#}, x2m, Q)
  literals_are_ℒin' (x1h, x1p, x1j, x1k, b, NEk, UEk, x', x2l, N0, U0, x2m, Q) 
     literals_are_ℒin' (x1h, x1p, x1j, x1k, {#}, NEk, UEk, x', x2l, N0, U0, x2m, Q)
   by (auto 5 3 simp: literals_are_ℒin'_def blits_in_ℒin'_def all_lits_of_mm_union
     correct_watching'.simps correct_watching'.simps clause_to_update_def all_init_lits_of_wl_def
     all_learned_lits_of_wl_def)

lemma literals_are_ℒin'_decompD:
  (K # x1h', M2)  set (get_all_ann_decomposition x1h) 
  literals_are_ℒin' (x1h, x1p, x1j', x1k, b, NEk, UEk, x', x2l, N0, U0, x2m, Q) 
     literals_are_ℒin' (x1h', x1p, x1j, x1k, b, NEk, UEk, x', x2l, N0, U0, x2m, Q)
  by (auto 5 3 simp: literals_are_ℒin'_def blits_in_ℒin'_def all_lits_of_mm_union
     correct_watching'.simps correct_watching'.simps clause_to_update_def all_init_lits_of_wl_def
     all_learned_lits_of_wl_def
     dest!: get_all_ann_decomposition_exists_prepend)


lemma all_init_learned_lits_simps_Q:
  NO_MATCH {#} Q  all_init_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W) =
    all_init_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, {#}, W)
  NO_MATCH {#} U0  all_init_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W) =
    all_init_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, {#}, Q, W)
  NO_MATCH {#} Q  all_learned_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, Q, W) =
    all_learned_lits_of_wl (M, N, D, NE, UE, NEk, UEk, NS, US, N0, U0, {#}, W)
  by (auto simp: all_init_lits_of_wl_def all_learned_lits_of_wl_def all_lits_of_mm_def)

lemma in_all_learned_lits_of_wl_addUS:
  x  set_mset (all_learned_lits_of_wl (M, x1k, x1l, x1m, x1n, NEk, UEk, x1o,  {#}, x1q, x1r, x1s, x2s)) 
  x  set_mset (all_learned_lits_of_wl (M, x1k, x1l, x1m, x1n, NEk, UEk, x1o, x1p, x1q, x1r, x1s, x2s))
  x  set_mset (all_learned_lits_of_wl (M, x1k, x1l, x1m, x1n, NEk, UEk, x1o,  x1p, x1q, {#}, x1s, x2s)) 
  x  set_mset (all_learned_lits_of_wl (M, x1k, x1l, x1m, x1n, NEk, UEk, x1o, x1p, x1q, x1r, x1s, x2s))
  x  set_mset (all_learned_lits_of_wl (M, x1k, x1l, x1m, {#}, NEk, UEk, x1o, x1p, x1q, x1r, x1s, x2s)) 
  x  set_mset (all_learned_lits_of_wl (M, x1k, x1l, x1m, x1n, NEk, UEk, x1o, x1p, x1q, x1r, x1s, x2s))
  by (auto simp: all_learned_lits_of_wl_def all_lits_of_mm_union)

lemma cdcl_twl_local_restart_wl_spec0_cdcl_twl_local_restart_l_spec0:
  (x, y)  {(S, S'). (S, S')  state_wl_l None  correct_watching' S  literals_are_ℒin' S} 
          cdcl_twl_local_restart_wl_spec0 x
            {(S, S'). (S, S')  state_wl_l None  correct_watching' S  literals_are_ℒin' S}
	    (cdcl_twl_local_restart_l_spec0 y)
  unfolding cdcl_twl_local_restart_wl_spec0_def cdcl_twl_local_restart_l_spec0_def curry_def
  apply refine_vcg
  subgoal unfolding restart_abs_wl_pre2_def by (rule exI[of _ y]) fast
  subgoal
    by (auto simp add: literals_are_ℒin'_empty
        state_wl_l_def image_iff correct_watching'.simps clause_to_update_def
      conc_fun_RES RES_RETURN_RES2 blits_in_ℒin'_restart_wl_spec0
      intro: literals_are_ℒin'_decompD literals_are_ℒin'_empty(4))
  subgoal
    by (auto 4 3 simp add: literals_are_ℒin'_empty
        state_wl_l_def image_iff correct_watching'.simps clause_to_update_def
      conc_fun_RES RES_RETURN_RES2 blits_in_ℒin'_restart_wl_spec0 
      literals_are_ℒin'_def all_init_learned_lits_simps_Q blits_in_ℒin'_def
      dest: all_init_lits_of_wl_keepUSD
      in_all_learned_lits_of_wl_addUS)
  done

lemma cdcl_twl_full_restart_wl_GC_prog_post_correct_watching:
  assumes
    pre: cdcl_twl_full_restart_l_GC_prog_pre y and
    y_Va: cdcl_twl_restart_l_inp** y Va and
    (V, Va)  {(S, S'). (S, S')  state_wl_l None  correct_watching' S  literals_are_ℒin' S}
  shows (V, Va)  {(S, S'). (S, S')  state_wl_l None  correct_watching S  blits_in_ℒin S} and
    set_mset (all_init_lits_of_wl V) = set_mset (all_lits_st V)
proof -
  obtain x where
    y_x: (y, x)  twl_st_l None and
    struct_invs: twl_struct_invs x and
    list_invs: twl_list_invs y and
    ent: cdclW_restart_mset.cdclW_learned_clauses_entailed_by_init (stateW_of x)
    using pre unfolding cdcl_twl_full_restart_l_GC_prog_pre_def by blast
  obtain V' where cdcl_twl_inp** x V' and Va_V': (Va, V')  twl_st_l None
    using rtranclp_cdcl_twl_restart_l_inp_cdcl_twl_restart_inp[OF y_Va y_x list_invs struct_invs ent]
    by blast
  then have twl_struct_invs V'
    using struct_invs ent rtranclp_cdcl_twl_inp_twl_struct_invs by blast
  then show eq: set_mset (all_init_lits_of_wl V) = set_mset (all_lits_st V)
    using assms(3) Va_V'  twl_struct_invs V' literals_are_ℒin'_literals_are_ℒin_iff(4) by blast
  then have correct_watching' V   correct_watching V
    by (cases V) (auto simp: correct_watching.simps correct_watching'.simps)
  moreover
    have literals_are_ℒin' V  blits_in_ℒin V
    using eq by (cases V)
      (clarsimp simp: blits_in_ℒin_def blits_in_ℒin'_def all_lits_def literals_are_ℒin'_def
         all_init_lits_def ac_simps)
  ultimately show (V, Va)  {(S, S'). (S, S')  state_wl_l None  correct_watching S  blits_in_ℒin S}
    using assms by (auto simp: cdcl_twl_full_restart_wl_GC_prog_post_def)
qed

(*TODO move within this file, seems to be Watched_Literals_Watch_List_Restart.ℒall_all_init_atms(1)*)
lemma all_all_init_atms_all_init_lits:
  set_mset (all (all_init_atms N NE)) = set_mset (all_init_lits N NE)
  unfolding all_all_init_atms ..
end