from __future__ import division
from libtbx import easy_run
import time
from libtbx.test_utils import approx_equal
import iotbx.pdb
from iotbx import reflection_file_reader
from cctbx import miller
from cctbx import maptbx
from scitbx.array_family import flex
import mmtbx.maps.correlation

def run_polder(pdb_code, selection):
  cmd = " ".join([
    "phenix.polder",
    "%s.pdb" % pdb_code,
    "%s.mtz" % pdb_code,
    "sphere_radius=5",
    'output_file_name_prefix="%s" ' %pdb_code,
    'solvent_exclusion_mask_selection="%s" ' % selection,
    "> %s_polder.log" % pdb_code
  ])
  print cmd
  easy_run.call(cmd)

def get_map(cg, mc):
  fft_map = miller.fft_map(
    crystal_gridding     = cg,
    fourier_coefficients = mc)
  fft_map.apply_sigma_scaling()
  return fft_map.real_map_unpadded()

def get_map_stats(map, sites_frac):
  map_values = flex.double()
  for sf in sites_frac:
    map_values.append(map.eight_point_interpolation(sf))
  return map_values

# atom radii are inspired by
# modules/cctbx_project/mmtbx/real_space_correlation.py
def set_atom_radius(d_min):
  if(d_min < 1.0):                    atom_radius = 1.0
  elif(d_min >= 1.0 and d_min < 2.0): atom_radius = 1.5
  elif(d_min >= 2.0 and d_min < 4.0): atom_radius = 2.0
  else:                               atom_radius = 2.5
  return atom_radius

def compute_cc_and_map_values(pdb_code, selection):
  print '*'*79
  #
#  run_polder(
#    pdb_code = pdb_code,
#    selection = selection)
  file_name = pdb_code + '_polder_map_coeffs.mtz'
  # open polder map file and get miller arrays
  miller_arrays = reflection_file_reader.any_reflection_file(file_name =
    file_name).as_miller_arrays()
  mc_polder, mc_omit = [None,]*2
  for ma in miller_arrays:
    lbl = ma.info().label_string()
    if(lbl == "mFo-DFc_polder,PHImFo-DFc_polder"):
      mc_polder = ma.deep_copy()
    if(lbl == "mFo-DFc_omit,PHImFo-DFc_omit"):
      mc_omit = ma.deep_copy()
  assert [mc_polder, mc_omit].count(None)==0
  cg = maptbx.crystal_gridding(
    unit_cell         = mc_polder.unit_cell(),
    d_min             = mc_polder.d_min(),
    resolution_factor = 0.25,
    space_group_info  = mc_polder.space_group_info())
  cs = mc_polder.crystal_symmetry()
  d_min = mc_polder.d_min()
  map_polder = get_map(
  	cg = cg,
  	mc = mc_polder)
  map_omit = get_map(
  	cg = cg,
  	mc = mc_omit)
  pdb_file_name = pdb_code + '.pdb'
  pdb_hierarchy = iotbx.pdb.input(
    file_name = pdb_file_name).construct_hierarchy()
  xrs = pdb_hierarchy.extract_xray_structure(crystal_symmetry = cs)
  # selection
  sel = pdb_hierarchy.atom_selection_cache().selection(string = selection)
  xrs_ligand_no_H = xrs.select(sel)
  sites_cart_lig = xrs_ligand_no_H.sites_cart()
  sites_frac_lig = xrs_ligand_no_H.sites_frac()
  #
  mp  = get_map_stats(
    map        = map_polder,
    sites_frac = sites_frac_lig)
  mo  = get_map_stats(
    map        = map_omit,
    sites_frac = sites_frac_lig)
  mmm_mp = mp.min_max_mean().as_tuple()
  mmm_o = mo.min_max_mean().as_tuple()
  print "PDB code %s:" % pdb_code
  print "Map min/max/mean"
  print "Polder map : %7.3f %7.3f %7.3f" % mmm_mp
  print "Omit       : %7.3f %7.3f %7.3f" % mmm_o
  #
  f_calc = mc_polder.structure_factors_from_scatterers(
    xray_structure = xrs_ligand_no_H).f_calc()
  mc_polder, f_calc = mc_polder.common_sets(f_calc)
  map_calc = get_map(
  	cg = cg,
  	mc = f_calc)

  atom_radius = set_atom_radius(d_min)
  print 'resolution: %7.2f' % d_min
  print 'Atom radius: %s' % atom_radius
  atom_radii = flex.double(sites_cart_lig.size(), atom_radius)

  ligand_sel = maptbx.grid_indices_around_sites(
    unit_cell  = mc_polder.unit_cell(),
    fft_n_real = map_polder.focus(),
    fft_m_real = map_polder.all(),
    sites_cart = sites_cart_lig,
    site_radii = atom_radii)
  cc_polder = flex.linear_correlation(
    x = map_polder.select(ligand_sel),
    y = map_calc.select(ligand_sel)).coefficient()

  ligand_sel = maptbx.grid_indices_around_sites(
    unit_cell  = mc_polder.unit_cell(),
    fft_n_real = map_omit.focus(),
    fft_m_real = map_omit.all(),
    sites_cart = sites_cart_lig,
    site_radii = atom_radii)
  cc_omit = flex.linear_correlation(
    x = map_omit.select(ligand_sel),
    y = map_calc.select(ligand_sel)).coefficient()

  print "CC polder map ligand: %5.3f" % cc_polder
  print "CC omit map ligand %5.3f" % cc_omit


def exercise():
  compute_cc_and_map_values(
	  pdb_code  = '1aba',
	  selection = 'chain A and resseq 88 and not element H')
  compute_cc_and_map_values(
    pdb_code  = '1c2k',
    selection = 'chain A and resseq 246 and not element H')
  compute_cc_and_map_values(
    pdb_code  = '4opi',
    selection = 'chain A and resseq 502 and not element H')
  compute_cc_and_map_values(
    pdb_code  = '1f8t',
    selection = 'chain H and resseq 105 and not element H')

if (__name__ == "__main__"):
  t0 = time.time()
  exercise()
  print '*'*79
  print "OK. Time: %8.3f"%(time.time()-t0)
