Theory CDCL_W_MaxSAT

theory CDCL_W_MaxSAT
  imports CDCL_W_Optimal_Model
begin


subsection Partial MAX-SAT

definition weight_on_clauses where
  weight_on_clauses NS ρ I = (C ∈# (filter_mset (λC. I  C) NS). ρ C)

definition atms_exactly_m :: 'v partial_interp  'v clauses  bool where
  atms_exactly_m I N 
  total_over_m I (set_mset N) 
  atms_of_s I  atms_of_mm N

text Partial in the name refers to the fact that not all clauses are soft clauses, not to the fact
  that we consider partial models.
inductive partial_max_sat :: 'v clauses  'v clauses  ('v clause  nat) 
  'v partial_interp option  bool where
  partial_max_sat:
  partial_max_sat NH NS ρ (Some I)
if
  I ⊨sm NH and
  atms_exactly_m I ((NH + NS)) and
  consistent_interp I and
  I'. consistent_interp I'  atms_exactly_m I' (NH + NS)  I' ⊨sm NH 
      weight_on_clauses NS ρ I'  weight_on_clauses NS ρ I |
  partial_max_unsat:
  partial_max_sat NH NS ρ None
if
  unsatisfiable (set_mset NH)

inductive partial_min_sat :: 'v clauses  'v clauses  ('v clause  nat) 
  'v partial_interp option  bool where
  partial_min_sat:
  partial_min_sat NH NS ρ (Some I)
if
  I ⊨sm NH and
  atms_exactly_m I (NH + NS) and
  consistent_interp I and
  I'. consistent_interp I'  atms_exactly_m I' (NH + NS)  I' ⊨sm NH 
      weight_on_clauses NS ρ I'  weight_on_clauses NS ρ I |
  partial_min_unsat:
  partial_min_sat NH NS ρ None
if
  unsatisfiable (set_mset NH)

lemma atms_exactly_m_finite:
  assumes atms_exactly_m I N
  shows finite I
proof -
  have I  Pos ` (atms_of_mm N)  Neg ` atms_of_mm N
    using assms by (force simp: total_over_m_def  atms_exactly_m_def lit_in_set_iff_atm
        atms_of_s_def)
  from finite_subset[OF this] show ?thesis by auto
qed


lemma
  fixes NH :: 'v clauses
  assumes satisfiable (set_mset NH)
  shows sat_partial_max_sat: I. partial_max_sat NH NS ρ (Some I) and
    sat_partial_min_sat: I. partial_min_sat NH NS ρ (Some I)
proof -
  let ?Is = {I. atms_exactly_m I ((NH + NS))   consistent_interp I 
     I ⊨sm NH}
  let ?Is'= {I. atms_exactly_m I ((NH + NS))  consistent_interp I 
    I ⊨sm NH  finite I}
  have Is: ?Is = ?Is'
    by (auto simp: atms_of_s_def atms_exactly_m_finite)
  have ?Is'  set_mset ` simple_clss (atms_of_mm (NH + NS))
    apply rule
    unfolding image_iff
    by (rule_tac x= mset_set x in bexI)
      (auto simp: simple_clss_def atms_exactly_m_def image_iff
        atms_of_s_def atms_of_def distinct_mset_mset_set consistent_interp_tuatology_mset_set)
  from finite_subset[OF this] have fin: finite ?Is unfolding Is
    by (auto simp: simple_clss_finite)
  then have fin': finite (weight_on_clauses NS ρ ` ?Is)
    by auto
  define ρI where
    ρI = Min (weight_on_clauses NS ρ ` ?Is)
  have nempty: ?Is  {}
  proof -
    obtain I where I:
      total_over_m I (set_mset NH)
      I ⊨sm NH
      consistent_interp I
      atms_of_s I  atms_of_mm NH
      using assms unfolding satisfiable_def_min atms_exactly_m_def
      by (auto simp: atms_of_s_def atm_of_def total_over_m_def)
    let ?I = I  Pos ` {x  atms_of_mm NS. x  atm_of ` I}
    have ?I  ?Is
      using I
      by (auto simp: atms_exactly_m_def total_over_m_alt_def image_iff
          lit_in_set_iff_atm)
        (auto simp: consistent_interp_def uminus_lit_swap)
    then show ?thesis
      by blast
  qed
  have ρI  weight_on_clauses NS ρ ` ?Is
    unfolding ρI_def
    by (rule Min_in[OF fin']) (use nempty in auto)
  then obtain I :: 'v partial_interp where
    weight_on_clauses NS ρ I = ρI and
    I  ?Is
    by blast
  then have H: consistent_interp I'  atms_exactly_m I' (NH + NS)  I' ⊨sm NH 
      weight_on_clauses NS ρ I'  weight_on_clauses NS ρ I for I'
    using Min_le[OF fin', of weight_on_clauses NS ρ I']
    unfolding ρI_def[symmetric]
    by auto
  then have partial_min_sat NH NS ρ (Some I)
    apply -
    by (rule partial_min_sat)
      (use fin I  ?Is in auto simp: atms_exactly_m_finite)
  then show I. partial_min_sat NH NS ρ (Some I)
    by fast

  define ρI where
    ρI = Max (weight_on_clauses NS ρ ` ?Is)
  have ρI  weight_on_clauses NS ρ ` ?Is
    unfolding ρI_def
    by (rule Max_in[OF fin']) (use nempty in auto)
  then obtain I :: 'v partial_interp where
    weight_on_clauses NS ρ I = ρI and
    I  ?Is
    by blast
  then have H: consistent_interp I'  atms_exactly_m I' (NH + NS)  I' ⊨m NH 
      weight_on_clauses NS ρ I'  weight_on_clauses NS ρ I for I'
    using Max_ge[OF fin', of weight_on_clauses NS ρ I']
    unfolding ρI_def[symmetric]
    by auto
  then have partial_max_sat NH NS ρ (Some I)
    apply -
    by (rule partial_max_sat)
      (use fin I  ?Is in auto simp: atms_exactly_m_finite
        consistent_interp_tuatology_mset_set)
  then show I. partial_max_sat NH NS ρ (Some I)
    by fast
qed

inductive weight_sat
  :: 'v clauses  ('v literal multiset  'a :: linorder) 
    'v literal multiset option  bool
where
  weight_sat:
  weight_sat N ρ (Some I)
if
  set_mset I ⊨sm N and
  atms_exactly_m (set_mset I) N and
  consistent_interp (set_mset I) and
  distinct_mset I
  I'. consistent_interp (set_mset I')  atms_exactly_m (set_mset I') N  distinct_mset I' 
      set_mset I' ⊨sm N  ρ I'  ρ I |
  partial_max_unsat:
  weight_sat N ρ None
if
  unsatisfiable (set_mset N)

lemma partial_max_sat_is_weight_sat: (* \htmllink{ocdcl-maxsat}*)
  fixes additional_atm :: 'v clause  'v and
    ρ :: 'v clause  nat and
    NS :: 'v clauses
  defines
    ρ'  (λC. sum_mset
       ((λL. if L  Pos ` additional_atm ` set_mset NS
         then count NS (SOME C. L = Pos (additional_atm C)  C ∈# NS)
           * ρ (SOME C. L = Pos (additional_atm C)  C ∈# NS)
         else 0) `# C))
  assumes
    add: C. C ∈# NS  additional_atm C  atms_of_mm (NH + NS)
    C D. C ∈# NS  D ∈# NS  additional_atm C = additional_atm D  C = D and
    w: weight_sat (NH + (λC. add_mset (Pos (additional_atm C)) C) `# NS) ρ' (Some I)
  shows
    partial_max_sat NH NS ρ (Some {L  set_mset I. atm_of L  atms_of_mm (NH + NS)})
proof -
  define N where N  NH + (λC. add_mset (Pos (additional_atm C)) C) `# NS
  define cl_of where cl_of L = (SOME C. L = Pos (additional_atm C)  C ∈# NS) for L
  from w
  have
    ent: set_mset I ⊨sm N and
    bi: atms_exactly_m (set_mset I) N and
    cons: consistent_interp (set_mset I) and
    dist: distinct_mset I and
    weight: I'. consistent_interp (set_mset I')  atms_exactly_m (set_mset I') N 
      distinct_mset I'  set_mset I' ⊨sm N  ρ' I'  ρ' I
    unfolding N_def[symmetric]
    by (auto simp: weight_sat.simps)
  let ?I = {L. L ∈# I  atm_of L  atms_of_mm (NH + NS)}
  have ent': set_mset I ⊨sm NH
    using ent unfolding true_clss_restrict
    by (auto simp: N_def)
  then have ent': ?I ⊨sm NH
    apply (subst (asm) true_clss_restrict[symmetric])
    apply (rule true_clss_mono_left, assumption)
    apply auto
    done
  have [simp]: atms_of_ms ((λC. add_mset (Pos (additional_atm C)) C) ` set_mset NS) =
    additional_atm ` set_mset NS  atms_of_ms (set_mset NS)
    by (auto simp: atms_of_ms_def)
  have bi': atms_exactly_m ?I (NH + NS)
    using bi
    by (auto simp: atms_exactly_m_def total_over_m_def total_over_set_def
        atms_of_s_def N_def)
  have cons': consistent_interp ?I
    using cons by (auto simp: consistent_interp_def)
  have [simp]: cl_of (Pos (additional_atm xb)) = xb
    if xb ∈# NS for xb
    using someI[of λC. additional_atm xb = additional_atm C xb] add that
    unfolding cl_of_def
    by auto

  let ?I = {L. L ∈# I  atm_of L  atms_of_mm (NH + NS)}  Pos ` additional_atm ` {C  set_mset NS. ¬set_mset I  C}
     Neg ` additional_atm ` {C  set_mset NS. set_mset I  C}
  have consistent_interp ?I
    using cons add by (auto simp: consistent_interp_def
        atms_exactly_m_def uminus_lit_swap
        dest: add)
  moreover have atms_exactly_m ?I N
    using bi
    by (auto simp: N_def atms_exactly_m_def total_over_m_def
        total_over_set_def image_image)
  moreover have ?I ⊨sm N
    using ent by (auto simp: N_def true_clss_def image_image
          atm_of_lit_in_atms_of true_cls_def
        dest!: multi_member_split)
  moreover have set_mset (mset_set ?I) = ?I and fin: finite ?I
    by (auto simp: atms_exactly_m_finite)
  moreover have distinct_mset (mset_set ?I)
    by (auto simp: distinct_mset_mset_set)
  ultimately have ρ' (mset_set ?I)  ρ' I
    using weight[of mset_set ?I]
    by argo
  moreover have ρ' (mset_set ?I)  ρ' I
    using ent
    by (auto simp: ρ'_def sum_mset_inter_restrict[symmetric] mset_set_subset_iff N_def
        intro!: sum_image_mset_mono
        dest!: multi_member_split)
  ultimately have I_I: ρ' (mset_set ?I) = ρ' I
    by linarith

  have min: weight_on_clauses NS ρ I'
       weight_on_clauses NS ρ {L. L ∈# I  atm_of L  atms_of_mm (NH + NS)}
    if
      cons: consistent_interp I' and
      bit: atms_exactly_m I' (NH + NS) and
      I': I' ⊨sm NH
    for I'
  proof -
    let ?I' = I'  Pos ` additional_atm ` {C  set_mset NS. ¬I'  C}
       Neg ` additional_atm ` {C  set_mset NS. I'  C}
    have consistent_interp ?I'
      using cons bit add by (auto simp: consistent_interp_def
          atms_exactly_m_def uminus_lit_swap
          dest: add)
    moreover have atms_exactly_m ?I' N
      using bit
      by (auto simp: N_def atms_exactly_m_def total_over_m_def
          total_over_set_def image_image)
    moreover have ?I' ⊨sm N
      using I' by (auto simp: N_def true_clss_def image_image
          dest!: multi_member_split)
    moreover have set_mset (mset_set ?I') = ?I' and fin: finite ?I'
      using bit by (auto simp: atms_exactly_m_finite)
    moreover have distinct_mset (mset_set ?I')
      by (auto simp: distinct_mset_mset_set)
    ultimately have I'_I: ρ' (mset_set ?I')  ρ' I
      using weight[of mset_set ?I']
      by argo
    have inj: inj_on cl_of (I'  (λx. Pos (additional_atm x)) ` set_mset NS) for I'
      using add by (auto simp: inj_on_def)

    have we: weight_on_clauses NS ρ I' = sum_mset (ρ `# NS) -
      sum_mset (ρ `# filter_mset (Not  (⊨) I') NS) for I'
      unfolding weight_on_clauses_def
      apply (subst (3) multiset_partition[of _ (⊨) I'])
      unfolding image_mset_union sum_mset.union
      by (auto simp: comp_def)
    have H: sum_mset
       (ρ `#
        filter_mset (Not  (⊨) {L. L ∈# I  atm_of L  atms_of_mm (NH + NS)})
         NS) = ρ' I
            unfolding I_I[symmetric] unfolding ρ'_def cl_of_def[symmetric]
              sum_mset_sum_count if_distrib
            apply (auto simp: sum_mset_sum_count image_image simp flip: sum.inter_restrict
                cong: if_cong)
            apply (subst comm_monoid_add_class.sum.reindex_cong[symmetric, of cl_of, OF _ refl])
            apply ((use inj in auto; fail)+)[2]
            apply (rule sum.cong)
            apply auto[]
            using inj[of set_mset I] set_mset I ⊨sm N assms(2)
            apply (auto dest!: multi_member_split simp: N_def image_Int
                atm_of_lit_in_atms_of true_cls_def)[]
            using add apply (auto simp: true_cls_def)
            done
    have (x(I'  (λx. Pos (additional_atm x)) ` {C. C ∈# NS  ¬ I'  C} 
         (λx. Neg (additional_atm x)) ` {C. C ∈# NS  I'  C}) 
        (λx. Pos (additional_atm x)) ` set_mset NS.
       count NS (cl_of x) * ρ (cl_of x))
     (A{a. a ∈# NS  ¬ I'  a}. count NS A * ρ A)
      apply (subst comm_monoid_add_class.sum.reindex_cong[symmetric, of cl_of, OF _ refl])
      apply ((use inj in auto; fail)+)[2]
      apply (rule ordered_comm_monoid_add_class.sum_mono2)
      using that add by (auto dest:  simp: N_def
          atms_exactly_m_def)
    then have sum_mset (ρ `# filter_mset (Not  (⊨) I') NS)  ρ' (mset_set ?I')
      using fin unfolding cl_of_def[symmetric] ρ'_def
      by (auto simp: ρ'_def
          simp add: sum_mset_sum_count image_image simp flip: sum.inter_restrict)
    then have ρ' I  sum_mset (ρ `# filter_mset (Not  (⊨) I') NS)
      using I'_I by auto
    then show ?thesis
      unfolding we H I_I apply -
      by auto
  qed

  show ?thesis
    apply (rule partial_max_sat.intros)
    subgoal using ent' by auto
    subgoal using bi' by fast
    subgoal using cons' by fast
    subgoal for I'
      by (rule min)
    done
qed

lemma sum_mset_cong:
  (a. a ∈# A  f a = g a)  (a∈#A. f a) = (a∈#A. g a)
  by (induction A) auto

lemma partial_max_sat_is_weight_sat_distinct: (* \htmllink{ocdcl-maxsat}*)
  fixes additional_atm :: 'v clause  'v and
    ρ :: 'v clause  nat and
    NS :: 'v clauses
  defines
    ρ'  (λC. sum_mset
       ((λL. if L  Pos ` additional_atm ` set_mset NS
         then ρ (SOME C. L = Pos (additional_atm C)  C ∈# NS)
         else 0) `# C))
  assumes
    distinct_mset NS and ―‹This is implicit on paper
    add: C. C ∈# NS  additional_atm C  atms_of_mm (NH + NS)
    C D. C ∈# NS  D ∈# NS  additional_atm C = additional_atm D  C = D and
    w: weight_sat (NH + (λC. add_mset (Pos (additional_atm C)) C) `# NS) ρ' (Some I)
  shows
    partial_max_sat NH NS ρ (Some {L  set_mset I. atm_of L  atms_of_mm (NH + NS)})
proof -
  define cl_of where cl_of L = (SOME C. L = Pos (additional_atm C)  C ∈# NS) for L
  have [simp]: cl_of (Pos (additional_atm xb)) = xb
    if xb ∈# NS for xb
    using someI[of λC. additional_atm xb = additional_atm C xb] add that
    unfolding cl_of_def
    by auto
  have ρ': ρ' = (λC. L∈#C. if L  Pos ` additional_atm ` set_mset NS
                 then count NS
                       (SOME C. L = Pos (additional_atm C)  C ∈# NS) *
                      ρ (SOME C. L = Pos (additional_atm C)  C ∈# NS)
                 else 0)
    unfolding cl_of_def[symmetric] ρ'_def
    using assms(2,4) by (auto intro!: ext sum_mset_cong simp: ρ'_def not_in_iff dest!: multi_member_split)
  show ?thesis
    apply (rule partial_max_sat_is_weight_sat[where additional_atm=additional_atm])
    subgoal by (rule assms(3))
    subgoal by (rule assms(4))
    subgoal unfolding ρ'[symmetric] by (rule assms(5))
    done
qed

lemma atms_exactly_m_alt_def:
  atms_exactly_m (set_mset y) N  atms_of y  atms_of_mm N 
        total_over_m (set_mset y) (set_mset N)
  by (auto simp: atms_exactly_m_def atms_of_s_def atms_of_def
      atms_of_ms_def dest!: multi_member_split)

lemma atms_exactly_m_alt_def2:
  atms_exactly_m (set_mset y) N  atms_of y = atms_of_mm N
  by (metis atms_of_def atms_of_s_def atms_exactly_m_alt_def equalityI order_refl total_over_m_def
      total_over_set_alt_def)

lemma (in conflict_driven_clause_learningW_optimal_weight) full_cdcl_bnb_stgy_weight_sat:
  full cdcl_bnb_stgy (init_state N) T  distinct_mset_mset N  weight_sat N ρ (weight T)
  using full_cdcl_bnb_stgy_no_conflicting_clause_from_init_state[of N T]
  apply (cases weight T = None)
  subgoal
    by (auto intro!: weight_sat.intros(2))
  subgoal premises p
    using p(1-4,6)
    apply (clarsimp simp only:)
    apply (rule weight_sat.intros(1))
    subgoal by auto
    subgoal by (auto simp: atms_exactly_m_alt_def)
    subgoal by auto
    subgoal by auto
    subgoal for J I'
      using p(5)[of I'] by (auto simp: atms_exactly_m_alt_def2)
    done
  done

end