import sys
from scitbx.math import dihedral_angle
import iotbx.pdb
import os.path

def basics():
  print "Basics:"
  angle = dihedral_angle(sites=[(1,0,0), (0,0,0), (0,1,0), (1,1,0)], deg=True)
  print "Angle:", angle
  angle = dihedral_angle(sites=[(2,1,1), (0,0,0), (0,1,0), (1,1,0)], deg=True)
  print "Angle in degrees:", angle
  angle = dihedral_angle(sites=[(2,1,1), (0,0,0), (0,1,0), (1,1,0)], deg=False)
  print "Angle in radians:", angle
  print "="*50

def advanced(args):
  print "Advanced phi-psi:"
  if len(args) != 1 or not os.path.isfile(args[0]):
    print "Bad argument"
    return

  # phi-psi angles
  from mmtbx.conformation_dependent_library import generate_protein_threes
  pdb_h = iotbx.pdb.hierarchy.input(file_name=args[0]).hierarchy
  sites_cart = pdb_h.atoms().extract_xyz()
  n_proxies = 0
  # Note that this is a generator
  for three in generate_protein_threes(
          hierarchy=pdb_h,
          geometry=None):
    proxies = three.get_dummy_dihedral_proxies()
    for p in proxies:
      for i_seq in p.i_seqs:
        print pdb_h.atoms()[i_seq].id_str(),
      print "angle=", dihedral_angle(
          sites=[sites_cart[x] for x in p.i_seqs], 
          deg=True)
      n_proxies += 1
  print "Total angles:", n_proxies


def all_dihedral_proxies(args):
  print "Advanced all:"
  if len(args) != 1 or not os.path.isfile(args[0]):
    print "Bad argument"
    return
  # all angles:
  from mmtbx.geometry_restraints.torsion_restraints.utils import \
      get_complete_dihedral_proxies

  pdb_h = iotbx.pdb.hierarchy.input(file_name=args[0]).hierarchy
  sites_cart = pdb_h.atoms().extract_xyz()
  dih_proxies = get_complete_dihedral_proxies(
      pdb_hierarchy=pdb_h)
  for p in dih_proxies:
    for i_seq in p.i_seqs:
      print pdb_h.atoms()[i_seq].id_str(),
    print "angle=", dihedral_angle(
        sites=[sites_cart[x] for x in p.i_seqs], 
        deg=True)
  print "Total angles:", dih_proxies.size()

def run(args):
  basics()
  advanced(args)
  all_dihedral_proxies(args)

if (__name__ == "__main__"):
  run(sys.argv[1:])