#Later Rounds
from math import exp
from math import factorial
from math import floor

rap = 0.9

#We consider we are at the beginning of round r
#A function which initiates Y (Y[i] = number of bins with load i at the beginning of round r) and nr (number of balls left at the beginning of round r)
def init(v,nb):
    global Y
    global nr
    Y = v
    nr = nb    

#Computes the proportion of bins with m messages if we send Mr messages
def X(m,Mr):
    global Y
    global nr
    alpha = rap*Mr*nr
    return(alpha**m/(exp(alpha)*factorial(m)))


#unranked messages
#compute ps according to the formula if we send Mr messages and if the maximum bin load is Lr
def ps(Mr,Lr):
    global Y
    global nr
    alpha = rap*Mr*nr
    result = 0.
    #we sum over all the Y[i]
    for l in range(len(Y)):
        #resultl is the part of the sum when l is fixed
        resultl = float(Lr-l)/alpha*exp(alpha)
        for m in range(Lr-l):
            if m == Lr-l-1:
                resultl = resultl - float(Lr-l)/alpha*alpha**m/factorial(m)
            else:
                resultl = resultl + alpha**m/factorial(m) - float(Lr-l)/alpha*alpha**m/factorial(m)
        result = result + Y[l]*resultl            
    return(result/exp(alpha))

#probability that a ball commits
def p(Mr,Lr):
    return(1-(1-ps(Mr,Lr))**Mr)

#Bin Loads
def binom(k,n):
    return(factorial(n)/(factorial(n-k)*factorial(k)))

#compute pc according to its formula
def pc(Mr,Lr):
    return((1-(1-ps(Mr,Lr))**Mr)/(ps(Mr,Lr)*Mr))

#computes p^{k}(Mr,Lr)
def pk(k,Mr,Lr):
    global Y
    global nr
    result = 0
    #l is less than k and L_{r-1}
    for l in range(min(k+1,len(Y))):
        #part of the sum when l is fixed
        resultl = binom(k-l,Lr-l)*pc(Mr,Lr)**(k-l)*(1-pc(Mr,Lr))**(Lr-k+l)
        for m in range(Lr-l):
            if m >=k-l:
                resultl = resultl + X(m,Mr)*binom(k-l,m)*pc(Mr,Lr)**(k-l)*(1-pc(Mr,Lr))**(m-k+l)
            resultl = resultl - X(m,Mr)*binom(k-l,Lr-l)*pc(Mr,Lr)**(k-l)*(1-pc(Mr,Lr))**(Lr-k+l)
        result = result + Y[l]*resultl
    return(result)


# Ranked messages ######################################################################################
#Computes ps for the first round
def first_ps(M1,L1):
    M1 = rap*M1
    result = float(L1)/M1*exp(M1)
    for m in range(L1):
        if m == L1-1:
            result = result - float(L1)/M1*M1**m/factorial(m)
        else:
            result = result + M1**m/factorial(m) - float(L1)/M1*M1**m/factorial(m)
    return(result/exp(M1))

#We need to notice that first_ps appears in the formula for pi in section "ranked messages".
#Computes pi using the previous remark.
def pi(i,Lr):
    global Y
    global nr
    result = 0.
    for l in range (len(Y)):
        resultl = 0.
        for m in range(Lr-l):
            #each remaining ball send 1 message of rank i and there are nr such balls.
            resultl = resultl + X(m,i-1)*first_ps(nr,Lr-l-m)
        result = result + Y[l]*resultl
    return(result)

def pranked(Mr,Lr):
    result = 1.
    for i in range(Mr):
        result = result*(1-pi(i+1,Lr))
    return(1-result)

# Bin loads ############################################################################################
def pcrank(i,Lr):
    result = 1.
    for j in range(i-1):
        result = result*(1-pi(j+1,Lr))
    return(result)

#Find the last i such as v[i] <> 0
def last_not_null(v):
    imax = 0
    for i in range(len(v)):
        if v[i] <> 0:
            imax = i
    return(imax)

#Rank begins at 1 and ends at M1
#f computes a part of the big sum if (ki) are fixed given by vk and l is fixed too
#vm is built during the run of f
def f(rank,vk,vm,Mr,Lr,l):
    result = 0.
    imax = last_not_null(vk)
    if rank == Mr:
        #acc is the product before the term for Mr
        acc = 1.
        #sum_mi  = sum_{j<i}{mj}
        sum_mrank = 0.
        for i in range(Mr-1):
            sum_mi = 0.
            sum_mrank = sum_mrank + vm[i]
            for j in range(i-1):
                sum_mi = sum_mi + vm[j]
            ki = vk[i]
            ri = min(vm[i],Lr-sum_mi-l)
            acc = acc*X(vm[i],1)*binom(ki,ri)*pcrank(i+1,Lr)**ki*(1-pcrank(i+1,Lr))**(ri-ki)
        #If mMr > born1 then rMr is fixed
        born1 = Lr-sum_mrank-l
        kMr = vk[Mr-1]
        for mMr in range(int(kMr),int(born1)):
            result = result + X(mMr,1)*binom(kMr,mMr)*pcrank(Mr,Lr)**kMr*(1-pcrank(Mr,Lr))**(mMr-kMr)
        #acc2 is the product which is always the same before E(X(mM1)) for mM > born1
        acc2 = binom(kMr,born1)*pcrank(Mr,Lr)**kMr*(1-pcrank(Mr,Lr))**(born1-kMr)
        for mMr in range(int(born1)):
            result = result - acc2*X(mMr,1)
        result = result + acc2
        result = result*acc
        return(result)
    elif rank > imax:
        #acc is the product before the term for m_rank
        acc = 1.
        sum_mrank = 0.
        for i in range(rank-1):
            sum_mi = 0.
            sum_mrank = sum_mrank+vm[i]
            for j in range(i-1):
                sum_mi = sum_mi+vm[j]
            ki = vk[i]
            ri = min(vm[i],Lr-sum_mi-l)
            acc = acc*X(vm[i],1)*binom(ki,ri)*pcrank(i+1,Lr)**ki*(1-pcrank(i+1,Lr))**(ri-ki)
        born1 = Lr-sum_mrank-l
        krank = vk[rank-1]
        #acc2 is the product which is always the same before E(X(mrank)) for mrank > born1
        acc2 = binom(krank,born1)*pcrank(rank,Lr)**krank*(1-pcrank(rank,Lr))**(born1-krank)
        for mrank in range(int(born1)):
            result = result -acc2*X(mrank,1)
        result = result + acc2
        result = result*acc
        #if mrank is between krank and born1 then we need to call f again with a different rank and a vm completed with mrank
        for mrank in range(int(krank),int(born1)):
            result = result + f(rank+1,vk,vm+[int(mrank)],Mr,Lr,l)
        return(result)    
    else:
        #if rank <= imax then we need to have both krank <= mrank and Lr-l-sum_m(rank+1) > k(rank+1) hence mrank < Lr-l-krank - sum_m(rank)
        sum_mrank = 0.
        for j in range(rank-1):
            sum_mrank = sum_mrank + vm[j]
        for mrank in range(int(vk[rank-1]),int(Lr-sum_mrank-vk[rank]-l+1)):
            result = result + f(rank+1,vk,vm+[int(mrank)],Mr,Lr,l)
        return(result)


#g compute all the vk admissible for f
def g(rank,vk,k,Mr,l):
    sum_krank = 0
    vk_result = []
    for i in range(rank-1):
        sum_krank = sum_krank + vk[i]
    if rank == Mr:
        vk_result = vk_result + [vk+[k-sum_krank-l]]
        return(vk_result)
    else:
        for krank in range(k+1):
            if sum_krank+krank <= k-l:
                vk_result = vk_result + g(rank+1,vk+[krank],k,Mr,l)
        return(vk_result)

#We use g and f to get pkranked
def pkranked(k,Mr,Lr):
    global Y
    prob = 0.
    for l in range(min(k+1,len(Y))):
        listvk = g(1,[],k,Mr,l)
        probl = 0.
        for vk  in listvk:
            probl = probl + f(1,vk,[],Mr,Lr,l)
        prob = prob + Y[l]*probl
    return(prob)

#Print_n displays n1,n2,n3 and Y3 for Mr and Lr given as arguments.
def print_n(LM1,LM2,LM3,L1,L2,L3):
    for M1 in LM1:
        init([1],1.)
        n1 = (1-pranked(M1+1,L1))
        Y1 = []
        for k1 in range(L1+1):
            Y1 = Y1 + [pkranked(k1,M1+1,L1)]
        for M2 in LM2:
            init(Y1,n1)
            Y2 = []
            n2 = n1*(1-pranked(M2+1,L2))
            for k2 in range(L2+1):
                Y2 = Y2 + [pkranked(k2,M2+1,L2)]
            for M3 in LM3:
                init(Y2,n2)
                Y3 = []
                n3 = n2*(1-pranked(M3+1,L3))
                for k3 in range(L3+1):
                    Y3 = Y3 + [pkranked(k3,M3+1,L3)]
            print(M1+1,M2+1,M3+1)
            print(n1)
            print(n2)
            print(n3)
            print(Y3)

#rap=10
#print_n([0],[1],[4],20,20,30)

#Print_n displays n1,n2 and Y2 for Mr and Lr given as arguments.
def print_n(LM1,LM2,L1,L2):
    #strRap = ''
    for M1 in LM1:
        init([1],1.)
        n1 = (1-pranked(M1+1,L1))
        Y1 = []
        for k1 in range(L1+1):
            Y1 = Y1 + [pkranked(k1,M1+1,L1)]
        for M2 in LM2:
            init(Y1,n1)
            Y2 = []
            n2 = n1*(1-pranked(M2+1,L2))
            for k2 in range(L2+1):
                Y2 = Y2 + [pkranked(k2,M2+1,L2)]
            print(M1+1,M2+1)
            print(n1)
            print(n2)
            print(Y2)
            #strRap = strRap + ' & ' + str(n2)
            #print(strRap)
            


rap = 10
print_n([0],[2],10,15)





