#!/usr/bin/python
# compares sequenced data to reference (obscure & obsolete)
import sys,Bio.Seq,Bio.SeqIO,os,progressbar
import dbcnf
verbose=0
db=dbcnf.db
dbc=db.cursor()
fn_pat=sys.argv[1]
fn_ref=sys.argv[2]
fn_cov=sys.argv[3]
snp_list=['deep.ceu_sites','deep.yri_sites','deep.jptchb_sites','deep.ensvar54']
tn='mut_'+os.path.splitext(os.path.basename(fn_pat))[0]
default_widgets=[progressbar.Percentage(), ' ', progressbar.Bar(), progressbar.ETA()]

def table_exists(t):
 """Checks if a table exists in postgreSQL database."""
 dbc.execute('SELECT relname FROM pg_class WHERE relname ILIKE %s',(t,))
 return dbc.rowcount

if table_exists(tn): dbc.execute('DROP TABLE deep.%s' % tn)
dbc.execute("""CREATE TABLE deep.%s (id SERIAL PRIMARY KEY, gene TEXT, transcript TEXT, exon TEXT,
strand INTEGER, phase INTEGER, nt_seq INTEGER, nt_ref INTEGER, cov FLOAT, nt_mut INTEGER,
nt_var INTEGER, aa_mut INTEGER, aa_var INTEGER, nt1 TEXT, nt2 TEXT, 
aa1 TEXT, aa2 TEXT, covlist TEXT, dna_mutlist TEXT, dna_freq TEXT, dna_coh TEXT,
aa_mutlist TEXT, aa_freq TEXT, aa_coh TEXT, dna_coding TEXT)""" % tn)

def alignment(ref,seq,dust=False):
 import Bio.SeqIO,os,time,Bio.Seq
 fn1='/tmp/seq%d.fasta' % time.time()
 fn1d='/tmp/seq%d_dust.fasta' % time.time()
 fn2='/tmp/aln%d.txt' % time.time()
 fasta='>ref\n%s\n>seq\n%s\n' % (str(ref),str(seq))
 open(fn1,'w').write(fasta)
 if dust: os.system('mdust %s >%s -v 28' % (fn1,fn1d))
 else: fn1d=fn1
 os.system('muscle -in %s -out %s -quiet >/dev/null 2>&1' % (fn1d,fn2))
 ali_ref,ali_seq=list(Bio.SeqIO.parse(open(fn2), "fasta"))
 os.unlink(fn1)
 if dust: os.unlink(fn1d)
 os.unlink(fn2)
 return ali_ref.seq,ali_seq.seq

def exrev(exon_cov):
 import Bio.Seq
 ec=[]
 for r in exon_cov:
  nr=[]
  for b in r.split('/'):
   nb=''
   if b:
    bb=b[0]
    bbr=str(Bio.Seq.Seq(bb).reverse_complement())
    nb=bbr+b[1:]
   nr.append(nb)
  ec.append('/'.join(nr))
 ec.reverse()
 return ec

def match_seq(ref,seq):
 if str(ref)==str(seq): return ref,seq
 return alignment(ref,seq)

glc1,glc2,glc3,glc4=0,0,0,0

def check_poly(snp,chro,pos,base1,base2):
 global glc1,glc2,glc3,glc4
 sql='SELECT reference,mutated,frequency FROM %s WHERE chromo=%%s AND position=%%s' % snp
 dbc.execute(sql,(chro,pos))
 if dbc.rowcount: glc1+=1
 else: glc2+=1
 for r,m,f in dbc.fetchall():
  if r!=base1: glc4+=1; continue
  else: glc3+=1
  if not f: continue # in ensvar54, some variations are zero or Null....
  if m==base2: return True,f
 return False,False

def all_poly(chro,startpos,base1,base2):
 for snp in snp_list:
  is_poly,frequency=check_poly(snp,chro,startpos,base1,base2)
  if is_poly: return True,frequency,snp.split('.')[1].split('_')[0]
 return False,'new','-'

def analyze_transcript(gene,tid,ref,si,cov_list):
 dbc.execute('SELECT strand FROM deep.exon_pos4 WHERE transcript_id=%s',(tid,))
 strand,=dbc.fetchone()
 if strand==1: direction='ASC'
 else: direction='DESC'
 dbc.execute("""SELECT exon_id, chr, rel_start, rel_end, abs_start, phase, coding_start, 
 coding_end FROM deep.exon_pos4 WHERE transcript_id=%%s ORDER BY rel_start %s""" % direction,(tid,))
 tot_aamis,dna_pos,aa_pos=0,0,0
 for eid,chrom,start,end,abs_start,phase,cstart,cend in dbc.fetchall():
  dbc.execute('SELECT COUNT(*) FROM deep.%s WHERE exon=%%s' % tn,(eid,))
  if dbc.fetchone()[0]: continue # no need to analyze an exon more than once
  aamis,ntmis,aavar,ntvar=0,0,0,0
  dna_freq,aa_freq,dna_coh,dna_coding,aa_coh=[],[],[],[],[]
  if int(phase)<0: continue
  exon_ref = Bio.Seq.Seq(str(ref[start:end+1]))
  exon_seq = Bio.Seq.Seq(str(si[start:end+1]))
  exon_cov = cov_list[start:end+1]
  if str(exon_seq).count('N')==len(str(exon_seq)):
   dna_pos+=end-start # also count unsequenced exon
   continue
  if int(strand)==-1:
   exon_ref=exon_ref.reverse_complement()
   exon_seq=exon_seq.reverse_complement()
   exon_cov=exrev(exon_cov)
  if int(phase)>0:
   pp=3-int(phase)
   exon_ref=exon_ref[pp:]
   exon_seq=exon_seq[pp:]
   exon_cov=exon_cov[pp:]
  dna_ref,dna_seq=match_seq(exon_ref,exon_seq)
  # match coverage to alignment
  n,dna_cov=0,[] 
  for nt in str(dna_ref):
   if nt!='-' and n<len(exon_cov):
    dna_cov.append(exon_cov[n])
    n+=1
   else: dna_cov.append('')
  dna_mutlist,aa_mutlist=[],[]
  trip1,trip2,nt_poly='','',False
  aas_ref,aas_seq='',''
  # analyse DNA sequence for mutations
  for ii in range(len(dna_ref)):
   if dna_ref[ii]=='-' or dna_seq[ii]=='-': continue
   dna_pos+=1
   in_coding=dna_pos>=cstart and dna_pos<=cend
   if dna_ref[ii]!=dna_seq[ii] and dna_seq[ii]!='N' and dna_ref[ii]!='N': # DNA mutation
    if strand==1: 
     pos=abs_start+ii # position on chromosome
     if phase==2: pos+=1
     if phase==1: pos+=2
     base1,base2=dna_ref[ii],dna_seq[ii]
    else: 
     pos=abs_start+len(dna_ref)-ii-1  # position on chromosome
     base1=str(Bio.Seq.Seq(dna_ref[ii]).reverse_complement())
     base2=str(Bio.Seq.Seq(dna_seq[ii]).reverse_complement())
    nt_poly,freq,cohort=all_poly(chrom,pos,base1,base2)
    if nt_poly: ntvar+=1 # DNA variation
    else: ntmis+=1 # DNA mutation
    dna_mutlist.append('%s%d%s' % (dna_ref[ii],dna_pos,dna_seq[ii]))
    dna_freq.append(str(freq))
    dna_coh.append(cohort)
    dna_coding.append(in_coding) # store if it was coding or UTR
   if in_coding: # only translate coding sequence
    trip1,trip2=trip1+dna_ref[ii],trip2+dna_seq[ii]
    if len(trip1)==3:
     aa1,aa2=str(Bio.Seq.Seq(trip1).translate()),str(Bio.Seq.Seq(trip2).translate())
     aas_ref,aas_seq=aas_ref+aa1,aas_seq+aa2
     aa_pos+=1
     if aa1!=aa2 and aa1!='X' and aa2!='X': # AA mutation
      if nt_poly: aavar+=1 # AA variation
      else: aamis+=1 # AA missmatch mutation
      aa_mutlist.append('%s%d%s' % (aa1,aa_pos,aa2))
      aa_freq.append(dna_freq[-1])
      aa_coh.append(dna_coh[-1])
     trip1,trip2,nt_poly='','',False
     if aa1=='*' or aa2=='*': break
  if verbose and (ntmis or ntvar):
   print gene,tid,eid,chrom,abs_start,start,end,strand,phase
   print 'dna_pos = %d, aa_pos = %d' % (dna_pos,aa_pos)
   print 'refDNA (%d): %s' % (len(dna_ref),dna_ref)
   print 'seqDNA (%d): %s' % (len(dna_seq),dna_seq)
   print 'refAA (%d): %s' % (len(aas_ref),aas_ref)
   print 'seqAA (%d): %s' % (len(aas_seq),aas_seq)
   print 'ntmis = %d, aamis = %d, ntvar = %d, aavar = %d' % (ntmis,aamis,ntvar,aavar)
   print 'DNA mutlist: %s' % dna_mutlist
   print 'DNA freq: %s' % dna_freq
   print 'DNA coh: %s' % dna_coh
   print 'AA mutlist: %s' % aa_mutlist
   print 'AA freq: %s' % aa_freq
   print 'AA coh: %s' % aa_coh
   print
  # store results in table
  sql="""INSERT INTO deep.%s (gene, transcript, exon, strand, phase, nt_seq, nt_ref, cov, 
  nt_mut, nt_var, aa_mut, aa_var, nt1, nt2, aa1, aa2, covlist, dna_mutlist, dna_freq, dna_coh,
  aa_mutlist, aa_freq, aa_coh, dna_coding) VALUES  (%%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, 
  %%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s, %%s)""" % tn
  nt_seq=len(dna_seq)-dna_seq.count('N')
  nt_ref=len(exon_ref)
  if nt_ref: cov=float(nt_seq)/float(nt_ref)
  else: cov=0
  tp=( gene, tid, eid, strand, phase, nt_seq, nt_ref, cov, ntmis, ntvar, aamis, aavar, 
  str(dna_ref), str(dna_seq), str(aas_ref), str(aas_seq), ','.join(dna_cov), ','.join(dna_mutlist),
  ','.join(dna_freq), ','.join(dna_coh), ','.join(aa_mutlist), ','.join(aa_freq), 
  ','.join(aa_coh), ','.join(map(str,dna_coding)) )
  dbc.execute(sql,tp)
  tot_aamis+=aamis
 return tot_aamis

def analyze_gene(gene,ref,si,cov_list):
 aamis_list=[]
 dbc.execute('SELECT transcript_id,chr FROM deep.exon_pos4 WHERE gene=%s',(gene,))
 for tid,chrom in dbc.fetchall():
  if chrom=='X' or chrom=='Y': return 0
  tid_aamid=analyze_transcript(gene,tid,ref,si,cov_list)
  aamis_list.append(tid_aamid)
 return max(aamis_list)

def gene_coverage(readfile,gene):
 for i in open(readfile):
  isp=i.strip().split('\t')
  if isp[0]==gene: return isp[1].split(',')

ref_seqs=Bio.SeqIO.parse(open(fn_ref, "rU"), "fasta")
refd={}
for s in ref_seqs: refd[s.id]=s.seq
total=0
for i in open(fn_pat,"rU"):
 if i[0]=='>': total+=1
print 'Total Genes:',total

pat_seqs=Bio.SeqIO.parse(open(fn_pat, "rU"), "fasta")
aatotal=0
if not verbose: pbar = progressbar.ProgressBar(widgets=['Matching to reference: '] + default_widgets).start()
for n,s in enumerate(pat_seqs):
 if not verbose: pbar.update(float(n)/float(total)*100)
 if s.id=='TTN': continue # titin is too large to wait for
 ref=refd[s.id]
 si=s.seq
 cov_list=gene_coverage(fn_cov,s.id)
 aamis=analyze_gene(s.id,ref,si,cov_list)
 aatotal+=aamis
 if verbose:
  print 'found = %d, not found = %d' % (glc1,glc2)
  print 'reference match = %d, reference missmatch = %d' % (glc3,glc4)
if not verbose: pbar.finish()
print 'found = %d, not found = %d' % (glc1,glc2)
print 'reference match = %d, reference missmatch = %d' % (glc3,glc4)
print 'aatotal:',aatotal
db.commit()