"""
021508 TCT.
Really Use bayesian estimation of CC from several scoring approaches.
For each scoring method, get P(score|cc_perfect). We can get this in
bins of 5%, but smoothing by collecting with grid of 10% or 15%
Then combine with bayes rule to get P(cc|scores)
"""

from phenix.autosol.read_scoring_hist import get_bin,relative_cc_est,write_histograms,read_histograms,get_standard_hist,hist_object

from scitbx import lbfgs
from cctbx.array_family import flex
import math,sys,os

all_list=  ['CC_DENMOD','RFACTOR','SKEW','NCS_OVERLAP','NCS_COPIES','NCS_IN_GROUP','TRUNCATE','REGIONS','CONTRAST','FOM','FLATNESS','CORR_RMS']

#skip_list=[]
#skip_list=  ['NCS_OVERLAP','NCS_COPIES','NCS_IN_GROUP','TRUNCATE','REGIONS','SD']

file1=sys.argv[1]
file2=sys.argv[2]
print "FILES: ",file1,file2
args=sys.argv[3:]
skip_list=[]
ok_res_range=[0.,100.]
if len(args)>=1:

  # see if there is a resolution range
  low=None
  high=None
  have_res_range=False
  for arg in args:
    print "trying: ",arg,
    aa=None
    try: aa=float(arg)
    except: pass
    print aa
    if aa is not None:
      if low is None or aa < low: low=aa
      if high is None or aa > high: high=aa
      have_res_range=True
      ok_res_range=[low,high]

  if have_res_range and len(args)>=3 or (not have_res_range and len(args)>=1):
    print args
    for i in all_list:
      if not i in args:
        skip_list.append(i)
print "Resolution range: ",ok_res_range
print "SKIP LIST: ",skip_list

def get_data(file,min_cc=0.0001): 
  target_list=flex.double()
  target_list_sqrt=flex.double()
  id_list=[]
  first=True
  # figure out which datasets we are going to keep if min_cc>0.
  if min_cc:
   print "NOTE: MINIMUM CC: ",min_cc
   all_datasets=[]
   good_datasets=[]
   for line in open(file).readlines():
    if not line: continue
    if first:
      first=False
    else:
      spl=line.split()
      dataset=spl[0] 
      if not dataset in all_datasets: all_datasets.append(dataset)
      target=float(spl[2])
      if float(target)>min_cc and dataset not in good_datasets:
        good_datasets.append(dataset)
   bad_datasets=[]
   for dataset in all_datasets:
     if not dataset in good_datasets: bad_datasets.append(dataset)
   print "bad datasets:",len(bad_datasets),bad_datasets
   print "good datasets:",len(good_datasets),good_datasets
  first=True
  for line in open(file).readlines():
    if not line: continue
    if first:
      first=False
      # run solution other perfect_delta 
      variable_list=line.split()[3:]
      vv=[]
      for var in variable_list:
           if not var in vv: vv.append(var)
      variable_list=vv
      print "Variable list: ",variable_list
      variable_value_list=[]
      for var in variable_list:
        variable_value_list.append(flex.double())
    else:
      spl=line.split()
      target=float(spl[2])
      vars=spl[3:3+len(variable_list)]
      if min_cc:
        dataset=spl[0]
        if not dataset in good_datasets:
          continue
      # decide if we are in resolution range..
      res_range=spl[3+len(variable_list):]
      print "res range: ",dataset,res_range
      if len(res_range)==2:
        if float(res_range[0])<ok_res_range[0]  or \
           float(res_range[0])>ok_res_range[1]:
         continue
      
      for var,variable_value in zip(vars,variable_value_list):
       variable_value.append(float(var))
      target_list.append(target)
      id_list.append(dataset)

  # get rid of anything in skip list
  new_variable_value_list=[]
  new_variable_list=[]
  for value_list,variable in zip (variable_value_list,variable_list):
    if not variable in skip_list:
      new_variable_value_list.append(value_list)
      new_variable_list.append(variable)
  variable_value_list=new_variable_value_list
  variable_list=new_variable_list
  print "Variables: ",len(variable_value_list[0]),"Values",len(target_list)
  return id_list,variable_value_list,target_list,variable_list

id_list,variable_value_list,target_list,variable_list=get_data(file1)
id_list_test,variable_value_list_test,target_list_test,variable_list_test=get_data(file2)

# get p(score|cc) for each scoring type
# NOTE Special case: any value that is 0.0 exactly...is to be ignored
n_range_score=30
n_range_target=30
range_low_score=-0.1
range_high_score=1.1
range_low_target=-0.1
range_high_target=0.8


gauss_smooth=True
gauss_d=3.  # smooth with 1/e at 3 points out in any direction
gauss_n=int(gauss_d*2+0.5)  # how many points out to include anything
print "Using Gaussian smoothing with r=",gauss_d," and ",gauss_n," points"

hist_dict={} 
if gauss_smooth:
 offset_list=[]
 for i in xrange(-gauss_n,gauss_n+1):
   offset_list.append(i)
else:
   offset_list=[0]

range_of_obs_dict={}
for var,value_list in zip(variable_list,variable_value_list):
  print "Length of ",var,":",len(value_list)
  two_d_hist=[]
  for i in xrange(n_range_score):
    hist=n_range_target*[0.0]  # minimum in each bin...
    two_d_hist.append(hist)
    # two_d_hist[score][target]
  hist_dict[var]=two_d_hist

  low_obs=None
  high_obs=None
  for value,target in zip(value_list,target_list):
    if low_obs is None or value<low_obs: low_obs=value
    if high_obs is None or value>high_obs: high_obs=value
    i_score_bin,a_score_bin=get_bin(
         value,n_range_score,range_low_score,range_high_score)
    i_target_bin,a_target_bin=get_bin(
         target,n_range_target,range_low_target,range_high_target)
    for i in offset_list:
      i_bin=i_score_bin+i
      if i_bin<0 or i_bin>n_range_score-1: continue
      for j in offset_list:
        j_bin=i_target_bin+j
        if j_bin<0 or j_bin>n_range_target-1: continue
        if i==0 and j==0:
          two_d_hist[i_bin][j_bin]+=1.
        else:
          dist2=(i_bin-i_score_bin)**2+(j_bin-i_target_bin)**2
          weight=math.exp(-1.*dist2/gauss_d)
          two_d_hist[i_bin][j_bin]+=weight

  range_of_obs_dict[var]=[low_obs,high_obs]

  # normalize this histogram so that sum over i_score_bin=1
  for i_target_bin in xrange(n_range_target):
    sum=0.
    for i_score_bin in xrange(n_range_score): 
      sum=sum+two_d_hist[i_score_bin][i_target_bin]
    if sum<=0.: sum=1.
    for i_score_bin in xrange(n_range_score): 
      two_d_hist[i_score_bin][i_target_bin]=\
      two_d_hist[i_score_bin][i_target_bin]/sum
   
  # print out this histogram:

  for i in xrange(n_range_score):
    print "\n",
    for j in xrange(n_range_target):
       print " %3i" %(int(100.*two_d_hist[i][j])),
  
  print "\n"

hist_object=hist_object(variable_list,
      n_range_score,n_range_target,range_low_score,
      range_low_target,range_high_score,range_high_target,hist_dict,
      range_of_obs_dict)

file='scoring_hist.dat'
write_histograms(file,hist_object)

Facts={}
#Facts['solve_lib_path']='/home/terwill/PHENIX/phenix-1.3b-rc4/phenix-1.3b-rc4/solve_resolve/ext_ref_files/'
Facts['solve_lib_path']=''
hist_object=get_standard_hist(Facts)

# check everything:
for var in variable_list:
  two_d_hist=hist_dict[var]
  new_two_d_hist=hist_object.hist_dict[var]
  print "VAR: ",var,len(two_d_hist),len(new_two_d_hist)
  for vec,new_vec in zip(two_d_hist,new_two_d_hist):
    print "VEC: ",len(vec),len(new_vec)
    for item,new_item in zip(vec,new_vec):
      if item <new_item-0.001 or item> new_item+0.001:
         print "NO: ",item, new_item,type(item),type(new_item)
   


# Now go through and use bayes to get cc for each one. Use:
# relative p(cc|{scores})=product_over_scores[p(score|cc)]
est_list=flex.double()
true_list=flex.double()
ff=open('bayes_5_est.dat','w')
sum_sd=0.
sum_sd_n=0.
for i in xrange(len(target_list_test)):
  cc_perfect=target_list_test[i]
  dataset=id_list_test[i]
  values=[]
  for value_list in variable_value_list_test:
    values.append(value_list[i])
 
  prob_list,est,sd=relative_cc_est(variable_list,values,hist_object)
  print '\nVALUES: ',values
  print 'PROB: ',
  for pp in prob_list:
     print " %6.2f " %(pp,) ,
  print '\nEST: ',est,' CC: ',cc_perfect,dataset
  print '\nSD_est: ',sd,' delta_obs:',cc_perfect-est
  sum_sd+=sd
  sum_sd_n+=1.
  print >>ff, dataset,cc_perfect,est,sd
  est_list.append(est)
  true_list.append(cc_perfect)
c=flex.linear_correlation(est_list,true_list)
cc=c.coefficient()
n=float(len(est_list))
msqr=math.sqrt(flex.sum_sq(est_list-true_list)/n)
print  "OVERALL CC: ",cc," RMSDdev: ",msqr
print "mean estimated SD: ",sum_sd/sum_sd_n

