Theory DPLL_W_Optimal_Model

theory DPLL_W_Optimal_Model
imports
  DPLL_W_BnB
begin

locale dpllW_state_optimal_weight =
  dpllW_state trail clauses
    tl_trail cons_trail state_eq state +
  ocdcl_weight ρ
  for
    trail :: 'st  'v  dpllW_ann_lits and
    clauses :: 'st  'v clauses and
    tl_trail :: 'st  'st and
    cons_trail :: 'v  dpllW_ann_lit  'st  'st and
    state_eq  :: 'st  'st  bool (infix  50) and
    state :: 'st  'v  dpllW_ann_lits × 'v clauses × 'v clause option × 'b and
    ρ :: 'v clause  'a :: {linorder} +
  fixes
    update_additional_info :: 'v clause option × 'b  'st  'st
  assumes
    update_additional_info:
      state S = (M, N, K)  state (update_additional_info K' S) = (M, N, K')
begin

definition update_weight_information :: ('v literal, 'v literal, unit) annotated_lits  'st  'st where
  update_weight_information M S =
    update_additional_info (Some (lit_of `# mset M), snd (additional_info S)) S

lemma [simp]:
  trail (update_weight_information M' S) = trail S
  clauses (update_weight_information M' S) = clauses S
  clauses (update_additional_info c S) = clauses S
  additional_info (update_additional_info (w, oth) S) = (w, oth)
  using update_additional_info[of S] unfolding update_weight_information_def
  by (auto simp: state)

lemma state_update_weight_information: state S = (M, N, w, oth) 
       w'. state (update_weight_information M' S) = (M, N, w', oth)
  apply (auto simp: state)
  apply (auto simp: update_weight_information_def)
  done

definition weight where
  weight S = fst (additional_info S)

lemma [simp]: (weight (update_weight_information M' S)) = Some (lit_of `# mset M')
  unfolding weight_def by (auto simp: update_weight_information_def)

text 

  We test here a slightly different decision. In the CDCL version, we renamed termadditional_info
  from the BNB version to avoid collisions. Here instead of renaming, we add the prefix
  text‹bnb.› to every name.


sublocale bnb: bnb_ops where
  trail = trail and
  clauses = clauses and
  tl_trail = tl_trail and
  cons_trail = cons_trail and
  state_eq = state_eq and
  state = state and
  weight = weight and
  conflicting_clauses = conflicting_clauses and
  is_improving_int = is_improving_int and
  update_weight_information = update_weight_information
  by unfold_locales


lemma atms_of_mm_conflicting_clss_incl_init_clauses:
  atms_of_mm (bnb.conflicting_clss S)  atms_of_mm (clauses S)
  using conflicting_clss_incl_init_clauses[of clauses S weight S]
  unfolding bnb.conflicting_clss_def
  by auto

lemma is_improving_conflicting_clss_update_weight_information: bnb.is_improving M M' S 
       bnb.conflicting_clss S ⊆# bnb.conflicting_clss (update_weight_information M' S)
  using is_improving_conflicting_clss_update_weight_information[of M M' clauses S weight S]
  unfolding bnb.conflicting_clss_def
  by (auto simp: update_weight_information_def weight_def)

lemma conflicting_clss_update_weight_information_in2:
  assumes bnb.is_improving M M' S
  shows negate_ann_lits M' ∈# bnb.conflicting_clss (update_weight_information M' S)
  using conflicting_clss_update_weight_information_in2[of M M' clauses S weight S] assms
  unfolding bnb.conflicting_clss_def
  unfolding bnb.conflicting_clss_def
  by (auto simp: update_weight_information_def weight_def)


lemma state_additional_info':
  state S = (trail S, clauses S, weight S, bnb.additional_info S)
  unfolding additional_info_def by (cases state S; auto simp: state weight_def bnb.additional_info_def)

sublocale bnb: bnb where
  trail = trail and
  clauses = clauses and
  tl_trail = tl_trail and
  cons_trail = cons_trail and
  state_eq = state_eq and
  state = state and
  weight = weight and
  conflicting_clauses = conflicting_clauses and
  is_improving_int = is_improving_int and
  update_weight_information = update_weight_information
  apply unfold_locales
  subgoal by auto
  subgoal by (rule state_eq_sym)
  subgoal by (rule state_eq_trans)
  subgoal by (auto dest!: state_eq_state)
  subgoal by (rule cons_trail)
  subgoal by (rule tl_trail)
  subgoal by (rule state_update_weight_information)
  subgoal by (rule is_improving_conflicting_clss_update_weight_information)
  subgoal by (rule conflicting_clss_update_weight_information_in2; assumption)
  subgoal by (rule atms_of_mm_conflicting_clss_incl_init_clauses)
  subgoal by (rule state_additional_info')
  done

lemma improve_model_still_model:
  assumes
    bnb.dpllW_bound S T and
    all_struct: dpllW_all_inv (bnb.abs_state S) and
    ent: set_mset I ⊨sm clauses S  set_mset I ⊨sm bnb.conflicting_clss S and
    dist: distinct_mset I and
    cons: consistent_interp (set_mset I) and
    tot: atms_of I = atms_of_mm (clauses S) and
    le: Found (ρ I) < ρ' (weight T)
  shows
    set_mset I ⊨sm clauses T  set_mset I ⊨sm bnb.conflicting_clss T
  using assms(1)
proof (cases rule: bnb.dpllW_bound.cases)
  case (update_info M M') note imp = this(1) and T = this(2)
  have atm_trail: atms_of (lit_of `# mset (trail S))  atms_of_mm (clauses S) and
       dist2: distinct_mset (lit_of `# mset (trail S)) and
      taut2: ¬ tautology (lit_of `# mset (trail S))
    using all_struct unfolding dpllW_all_inv_def by (auto simp: lits_of_def atms_of_def
      dest: no_dup_distinct no_dup_not_tautology)

  have tot2: total_over_m (set_mset I) (set_mset (clauses S))
    using tot[symmetric]
    by (auto simp: total_over_m_def total_over_set_def atm_iff_pos_or_neg_lit)
  have atm_trail: atms_of (lit_of `# mset M')  atms_of_mm (clauses S) and
    dist2: distinct_mset (lit_of `# mset M') and
    taut2: ¬ tautology (lit_of `# mset M')
    using imp by (auto simp: lits_of_def atms_of_def is_improving_int_def
      simple_clss_def)

  have tot2: total_over_m (set_mset I) (set_mset (clauses S))
    using tot[symmetric]
    by (auto simp: total_over_m_def total_over_set_def atm_iff_pos_or_neg_lit)
  have
    set_mset I ⊨m conflicting_clauses (clauses S) (weight (update_weight_information M' S))
    using entails_conflicting_clauses_if_le[of I clauses S M' M weight S]
    using T dist cons tot le imp by auto
  then have set_mset I ⊨m bnb.conflicting_clss (update_weight_information M' S)
    by (auto simp: update_weight_information_def bnb.conflicting_clss_def)
  then show ?thesis
    using ent T by (auto simp: bnb.conflicting_clss_def state)
qed

lemma cdcl_bnb_still_model:
  assumes
    bnb.dpllW_bnb S T and
    all_struct: dpllW_all_inv (bnb.abs_state S) and
    ent: set_mset I ⊨sm clauses S set_mset I ⊨sm bnb.conflicting_clss S and
    dist: distinct_mset I and
    cons: consistent_interp (set_mset I) and
    tot: atms_of I = atms_of_mm (clauses S)
  shows
    (set_mset I ⊨sm clauses T  set_mset I ⊨sm bnb.conflicting_clss T)  Found (ρ I)  ρ' (weight T)
  using assms
proof (induction rule: bnb.dpllW_bnb.induct)
  case (dpll S T)
  then show ?case using ent by (auto elim!: bnb.dpllW_coreE simp: bnb.state'_def
       dpll_decide.simps dpll_backtrack.simps bnb.backtrack_opt.simps
       dpll_propagate.simps)
next
  case (bnb S T)
  then show ?case
    using improve_model_still_model[of S T I] using assms(2-) by auto
qed

lemma cdcl_bnb_larger_still_larger:
  assumes
    bnb.dpllW_bnb S T
  shows ρ' (weight S)  ρ' (weight T)
  using assms apply (cases rule: bnb.dpllW_bnb.cases)
  by (auto simp: bnb.dpllW_bound.simps is_improving_int_def bnb.dpllW_core_same_weight)

lemma rtranclp_cdcl_bnb_still_model:
  assumes
    st: bnb.dpllW_bnb** S T and
    all_struct: dpllW_all_inv (bnb.abs_state S) and
    ent: (set_mset I ⊨sm clauses S  set_mset I ⊨sm bnb.conflicting_clss S)  Found (ρ I)  ρ' (weight S) and
    dist: distinct_mset I and
    cons: consistent_interp (set_mset I) and
    tot: atms_of I = atms_of_mm (clauses S)
  shows
    (set_mset I ⊨sm clauses T  set_mset I ⊨sm bnb.conflicting_clss T)  Found (ρ I)  ρ' (weight T)
  using st
proof (induction rule: rtranclp_induct)
  case base
  then show ?case
    using ent by auto
next
  case (step T U) note star = this(1) and st = this(2) and IH = this(3)
  have 1: dpllW_all_inv (bnb.abs_state T)
    using bnb.rtranclp_dpllW_bnb_abs_state_all_inv[OF star all_struct] .
  have 3: atms_of I = atms_of_mm (clauses T)
    using bnb.rtranclp_dpllW_bnb_clauses[OF star] tot by auto
  show ?case
    using cdcl_bnb_still_model[OF st 1 _ _ dist cons 3] IH
      cdcl_bnb_larger_still_larger[OF st]
      order.trans by blast
qed

lemma simple_clss_entailed_by_too_heavy_in_conflicting:
   C ∈# mset_set (simple_clss (atms_of_mm (clauses S))) 
    too_heavy_clauses (clauses S) (weight S) ⊨pm
     (C)  C ∈# bnb.conflicting_clss S
  by (auto simp: conflicting_clauses_def bnb.conflicting_clss_def)

lemma can_always_improve:
  assumes
    ent: trail S ⊨asm clauses S and
    total: total_over_m (lits_of_l (trail S)) (set_mset (clauses S)) and
    n_s: (C ∈# bnb.conflicting_clss S. ¬ trail S ⊨as CNot C) and
    all_struct: dpllW_all_inv (bnb.abs_state S)
   shows Ex (bnb.dpllW_bound S)
proof -
  have H: (lit_of `# mset (trail S)) ∈# mset_set (simple_clss (atms_of_mm (clauses S)))
    (lit_of `# mset (trail S))  simple_clss (atms_of_mm (clauses S))
    no_dup (trail S)
    apply (subst finite_set_mset_mset_set[OF simple_clss_finite])
    using all_struct by (auto simp: simple_clss_def
        dpllW_all_inv_def atms_of_def lits_of_def image_image clauses_def
      dest: no_dup_not_tautology no_dup_distinct)
  moreover have trail S ⊨as CNot (pNeg (lit_of `# mset (trail S)))
    by (auto simp: pNeg_def true_annots_true_cls_def_iff_negation_in_model lits_of_def)

  ultimately have le: Found (ρ (lit_of `# mset (trail S))) < ρ' (weight S)
    using n_s total not_entailed_too_heavy_clauses_ge[of lit_of `# mset (trail S) clauses S weight S]
     simple_clss_entailed_by_too_heavy_in_conflicting[of pNeg (lit_of `# mset (trail S)) S]
    by (cases ¬ too_heavy_clauses (clauses S) (weight S) ⊨pm
       pNeg (lit_of `# mset (trail S)))
     (auto simp:  lits_of_def
         conflicting_clauses_def clauses_def negate_ann_lits_pNeg_lit_of image_iff
         simple_clss_finite subset_iff
       dest: bspec[of _ _ (lit_of `# mset (trail S))] dest: total_over_m_atms_incl
          true_clss_cls_in too_heavy_clauses_contains_itself
          dest!: multi_member_split)
  have tr: trail S ⊨asm clauses S
    using ent by (auto simp: clauses_def)
  have tot': total_over_m (lits_of_l (trail S)) (set_mset (clauses S))
    using total all_struct by (auto simp: total_over_m_def total_over_set_def)
  have M': ρ (lit_of `# mset M') = ρ (lit_of `# mset (trail S))
    if total_over_m (lits_of_l M') (set_mset (clauses S)) and
      incl: mset (trail S) ⊆# mset M' and
      lit_of `# mset M'  simple_clss (atms_of_mm (clauses S))
      for M'
    proof -
      have [simp]: lits_of_l M' = set_mset (lit_of `# mset M')
        by (auto simp: lits_of_def)
      obtain A where A: mset M' = A + mset (trail S)
        using incl by (auto simp: mset_subset_eq_exists_conv)
      have M': lits_of_l M' = lit_of ` set_mset A  lits_of_l (trail S)
        unfolding lits_of_def
        by (metis A image_Un set_mset_mset set_mset_union)
      have mset M' = mset (trail S)
        using that tot' total unfolding A total_over_m_alt_def
          apply (case_tac A)
        apply (auto simp: A simple_clss_def distinct_mset_add M' image_Un
            tautology_union mset_inter_empty_set_mset atms_of_def atms_of_s_def
            atm_of_in_atm_of_set_iff_in_set_or_uminus_in_set image_image
            tautology_add_mset)
          by (metis (no_types, lifting) atm_of_in_atm_of_set_iff_in_set_or_uminus_in_set
          lits_of_def subsetCE)
      then show ?thesis
        using total by auto
    qed
  have bnb.is_improving (trail S) (trail S) S
    if Found (ρ (lit_of `# mset (trail S))) < ρ' (weight S)
    using that total H tr tot' M' unfolding is_improving_int_def lits_of_def
    by fast
  then show ?thesis
    using bnb.dpllW_bound.intros[of trail S _ S update_weight_information (trail S) S] total H le
    by fast
qed


lemma no_step_dpllW_bnb_conflict:
  assumes
    ns: no_step bnb.dpllW_bnb S and
    invs: dpllW_all_inv (bnb.abs_state S)
  shows C ∈# clauses S + bnb.conflicting_clss S. trail S ⊨as CNot C (is ?A) and
      count_decided (trail S) = 0 and
     unsatisfiable (set_mset (clauses S + bnb.conflicting_clss S))
  apply (rule bnb.no_step_dpllW_bnb_conflict[OF _ assms])
  subgoal using can_always_improve by blast
  apply (rule bnb.no_step_dpllW_bnb_conflict[OF _ assms])
  subgoal using can_always_improve by blast
  apply (rule bnb.no_step_dpllW_bnb_conflict[OF _ assms])
  subgoal using can_always_improve by blast
  done

lemma full_cdcl_bnb_stgy_larger_or_equal_weight:
  assumes
    st: full bnb.dpllW_bnb S T and
    all_struct: dpllW_all_inv (bnb.abs_state S) and
    ent: (set_mset I ⊨sm clauses S  set_mset I ⊨sm bnb.conflicting_clss S)  Found (ρ I)  ρ' (weight S) and
    dist: distinct_mset I and
    cons: consistent_interp (set_mset I) and
    tot: atms_of I = atms_of_mm (clauses S)
  shows
    Found (ρ I)  ρ' (weight T) and
    unsatisfiable (set_mset (clauses T + bnb.conflicting_clss T))
proof -
  have ns: no_step bnb.dpllW_bnb T and
    st: bnb.dpllW_bnb** S T
    using st unfolding full_def by (auto intro: )
  have struct_T: dpllW_all_inv (bnb.abs_state T)
    using bnb.rtranclp_dpllW_bnb_abs_state_all_inv[OF st all_struct]  .

  have atms_eq: atms_of I  atms_of_mm (bnb.conflicting_clss T) = atms_of_mm (clauses T)
    using atms_of_mm_conflicting_clss_incl_init_clauses[of T]
      bnb.rtranclp_dpllW_bnb_clauses[OF st] tot
    by auto

  show unsatisfiable (set_mset (clauses T + bnb.conflicting_clss T))
    using no_step_dpllW_bnb_conflict[of T] ns struct_T
    by fast
  then have ¬set_mset I ⊨sm clauses T + bnb.conflicting_clss T
    using dist cons by auto
  then have False if Found (ρ I) < ρ' (weight T)
    using ent that rtranclp_cdcl_bnb_still_model[OF st assms(2-)]
      bnb.rtranclp_dpllW_bnb_clauses[OF st]
    apply simp
    using leD by blast

  then show Found (ρ I)  ρ' (weight T)
    by force
qed


(*TODO:
full_cdcl_bnb_stgy_larger_or_equal_weight
full_cdcl_bnb_stgy_no_conflicting_clause_from_init_state
*)

end

end