
###############################################################
#                                                             #           
#      Pivotal Allocation based Relabeling (PAR) algorithm    #
#       for univariate data, optimal Z = MAP of Z             #
#                    Han Li                                   #
#         last revision in Oct 8, 2017                        #               
#                                                             # 
###############################################################



library(lpSolve)   #use its "lp.assign" function

#input the observation data and the MCMC samples of parameters

modelName="M1"     #the name of the observation data
fileName=paste("data/",modelName,".txt",sep="")
y=as.matrix(read.table(fileName))[,1]          #input the observation data



fileStr=paste("MCMC_result/",modelName,"/",modelName,"_",sep="")     # input the MCMC samples of parameters

#!!!FORMAT: each row denotes the samples in one MCMC iteration, and the samples are stored in componentwise manner, for both the input and output files!!!

fileName=paste(fileStr,"weight.txt",sep="")     #component weight
postWeight=as.matrix(read.table(fileName))
fileName=paste(fileStr,"mu.txt",sep="")        #component mean
postMu=as.matrix(read.table(fileName))
fileName=paste(fileStr,"sigma2.txt",sep="")   #component variance 
postSigma2=as.matrix(read.table(fileName))
fileName=paste(fileStr,"z.txt",sep="")         #cluster indicator
postZ=as.matrix(read.table(fileName))


n=length(y)           #the number of observation data
K=ncol(postWeight)       #the number of clusters
IT=nrow(postWeight)      #the number of MCMC samples


Obj=postZ
OptPerm=matrix(0,IT,K)     #optimal permutaion of cluster labels for each sample
OptZ=rep(0,n)        #the cluster indicator of the sample that has the maximum posterior log-likelihood
mu0=mean(y)
epsilon=0.01
maxIter=200       #the maximum number of optimization steps
tempPermObj=Obj
transTable=matrix(0,K,K)
tempMatch=rep(0,K)
phi=matrix(0,K,K)      #transition matrix
cost=matrix(0,K,K)
matchMat=array(0,dim=c(IT,K,K))


#the prior parameters

kappa=10
a=2
b=1
delta=1 


#find OptZ

postLogL=rep(0,IT)

for(it in 1:IT){
  temp1=(y-postMu[it,postZ[it,]])^2/postSigma2[it,postZ[it,]]
  temp2=(postMu[it,]-mu0)^2/(kappa*postSigma2[it,])    
  temp3=(a+1)*log(postSigma2[it,])+b/postSigma2[it,]
  postLogL[it]=-sum(temp1)/2-sum(log(postSigma2[it,postZ[it,]]))/2-sum(temp2)/2-sum(log(postSigma2[it,]))/2
  postLogL[it]=postLogL[it]-sum(temp3)    
  postLogL[it]=postLogL[it]+sum(log(postWeight[it,postZ[it,]]))
}


MAPIt=which.max(postLogL)

for(it in 1:IT){
  muOrder=order(postMu[it,])
  for(k in 1:K){
    tempPermObj[it,Obj[it,]==muOrder[k]]=k
  }
}

OptZ=tempPermObj[MAPIt,]


for(it in 1:IT){
  for(i in 1:K){
    for(j in 1:K){
      matchMat[it,i,j]=sum(OptZ==i & Obj[it,]==j)
    }
  }
}


for(itr in 1:maxIter){
  
  #M step for transition matrix
  
  tempPhi=phi
  for(i in 1:n){
    for(k in 1:K){
      tempMatch[k]=sum(tempPermObj[,i]==k)
    }
    transTable[OptZ[i],]=transTable[OptZ[i],]+tempMatch
  }
  
  phi=(transTable+1)/(apply(transTable,1,sum)+K)
  if(max(abs(tempPhi-phi))<epsilon){break}
  logPhi=log(phi+0.000001)
  
  
  #M step for permutation
  
  
  for(it in 1:IT){
    
    for(i in 1:K){
      cost[i,]=apply(logPhi[,i]*matchMat[it,,],2,sum)
    }
    
    cost=-cost
    cost=round(cost-min(cost))
    OptPerm[it,]=apply(lp.assign(cost)$solution,1,which.max)
    for(k in 1:K){
      tempPermObj[it,Obj[it,]==OptPerm[it,k]]=k
    }
  }
  
}



######################################################

#output the relabeled results


relabeledMu=relabeledSigma2=relabeledWeight=matrix(0,IT,K)
relabeledZ=matrix(0,IT,n)
temp=rep(0,n)


for(it in 1:IT){
  relabeledMu[it,]=postMu[it,OptPerm[it,]]
  relabeledSigma2[it,]=postSigma2[it,OptPerm[it,]]
  relabeledWeight[it,]=postWeight[it,OptPerm[it,]]    
  for(k in 1:K){
    temp[postZ[it,]==OptPerm[it,k]]=k
  }
  relabeledZ[it,]=temp 
}



fileStr=paste("relabel_result/",modelName,"/",modelName,"_relabeled_",sep="")
fileName=paste(fileStr,"weight.txt",sep="")
write.table(relabeledWeight,fileName,col.names=F,row.names=F,quote=F,sep="\t")

fileName=paste(fileStr,"mu.txt",sep="")
write.table(relabeledMu,fileName,col.names=F,row.names=F,quote=F,sep="\t")

fileName=paste(fileStr,"sigma2.txt",sep="")
write.table(relabeledSigma2,fileName,col.names=F,row.names=F,quote=F,sep="\t")

fileName=paste(fileStr,"z.txt",sep="")
write.table(relabeledZ,fileName,col.names=F,row.names=F,quote=F,sep="\t")



