from cctbx.array_family import flex
from cctbx import adptbx
from cctbx import crystal
from cctbx import xray
from cctbx import maptbx

def get_xray_structure(atom, q, b, site_frac=(0.5, 0.5, 0.5)):
  cs = crystal.symmetry((10, 10, 10, 90, 90, 90), "P 1")
  sp = crystal.special_position_settings(cs)
  scatterer = xray.scatterer(
    scattering_type = atom, 
    site            = site_frac,
    u               = adptbx.b_as_u(b),
    occupancy       = q)
  scatterers = flex.xray_scatterer()
  scatterers.append(scatterer)
  xray_structure = xray.structure(sp, scatterers)
  xray_structure.scattering_type_registry(table = "n_gaussian")
  return xray_structure
  
def get_exact_ed(xray_structure):
  assert xray_structure.scatterers().size()==1 # xrs contains one scatterer!
  rho_1d_calc = flex.double()
  sc = xray_structure.scatterers()[0]
  ed = xray_structure.scattering_type_registry().gaussian(sc.scattering_type)
  r = 0
  dist = flex.double()
  while r<5:
    dist.append(r)
    ed_value = sc.occupancy*ed.electron_density(r, adptbx.u_as_b(sc.u_iso))
    rho_1d_calc.append(ed_value)
    r+=0.01
  return rho_1d_calc 
  
def get_synthesis(xray_structure, d_min):
  fft_map = xray_structure.structure_factors(
    d_min=d_min).f_calc().fft_map(grid_step=0.2)
  fft_map.apply_volume_scaling()
  map_data = fft_map.real_map_unpadded()
  sites_cart = xray_structure.sites_cart()
  assert sites_cart.size() == 1 # one atom!
  _, rho = maptbx.map_peak_3d_as_2d(
    map_data              = map_data,
    unit_cell             = xray_structure.unit_cell(),
    center_cart           = sites_cart[0],
    radius                = 3.5,
    step                  = 0.1,
    s_angle_sampling_step = 30,
    t_angle_sampling_step = 30)
  return rho
  
def get_map(xray_structure, d_min):
  if(d_min is None): return get_exact_ed(xray_structure=xray_structure)
  else: return get_synthesis(xray_structure=xray_structure, d_min=d_min)
  
def ls_target(x,y):
  diffs = x-y
  return flex.sum(diffs * diffs)
  
def r_factor(x,y):
  scale = flex.sum(x*y)/flex.sum(y*y)
  num = flex.sum(flex.abs(x-scale*y))
  den = flex.sum(flex.abs(x))
  return num/den
  
def find_optimal_b(xrs, rho_ref, d_min):
  xrs_ = xrs.deep_copy_scatterers()
  b=1.0
  b_best = None
  ls_best = 1.e+9
  while b < 100.:
    xrs_ = xrs_.set_b_iso(value=b)
    rho_ = get_map(xray_structure=xrs_, d_min=d_min)
    ls_ = ls_target(x=rho_ref, y=rho_)
    if(ls_ < ls_best):
      ls_best = ls_
      b_best = b
    b+=1.0
  assert b_best is not None
  return b_best

def run():
  for d_min in [None, 2.0]:
    print "Resolution:", d_min, "-"*70
    for atom in ["H", "C", "S"]:
      print "  atom: %s"%atom
      for b in [10, 30, 50, 80]:
        xrs = get_xray_structure(atom=atom, q=1, b=b)
        print "    B: %2.0f"%b
        rho_ref = get_map(xray_structure=xrs, d_min=d_min)
        qs = flex.double()
        bs = flex.double()
        cc = flex.double()
        r  = flex.double()
        for q_ in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
          qs.append(q_)
          xrs_ = get_xray_structure(atom=atom, q=q_, b=b)
          b_opt = find_optimal_b(xrs=xrs_, rho_ref=rho_ref, d_min=d_min)
          xrs_opt = get_xray_structure(atom=atom, q=q_, b=b_opt)
          rho_opt = get_map(xray_structure=xrs_opt, d_min=d_min)
          cc_ = flex.linear_correlation(x=rho_ref, y=rho_opt).coefficient()
          r_ = r_factor(x=rho_ref, y=rho_opt)
          bs.append(b_opt)
          cc.append(cc_)
          r.append(r_)
        r = r*100.
        print "      trial q            :", " ".join(["%5.2f"%i for i in qs])
        print "      B_opt              :", " ".join(["%5.2f"%i for i in bs])
        print "      CC(rho_ref,rho_opt):", " ".join(["%5.2f"%i for i in cc])
        print "      R(%)               :", " ".join(["%5.2f"%i for i in r])
        print "      CC(q,B): %6.2f"%flex.linear_correlation(
          x=qs, y=bs).coefficient()
  
if (__name__ == "__main__"):
  run()
