from phylip import parse
from chronogram import draw,reset,savePostscript
from inspect import signature

#--------------------------------------------------------
# Summary of data structures (created later in this script).
# All indexing is ZERO-based (i.e. names[0] is the first leaf)
#
# names[k] is the original name of the k-th leaf of the input
#
# clades[k] is a list of integers that ID the leaves of the clade
#
# trees[k] is a tuple representation of the subtree for clades[k]
#
# active is a list of integer IDs for active clades during the
# clustering process
#
# dist is list of dictionaries such that dist[a][b] is distance
# between clades a and b
#
#--------------------------------------------------------


#--------------------------------------------------------
# Utility functions (to be completed by student)
#--------------------------------------------------------
def find_nearest_active():
    """Return IDs for the two nearest active clusters."""
    pass


def combine_trees(a,b):
    """Return tuple that represents tree for newly merged clusters a and b."""
    pass


def single_linkage_cost(a,b):
    """Return distance between clusters a and b.

    For single linkage, distance should be based on the
    NEAREST neighbors across the clusters.
    """
    pass


def single_linkage_cost_efficient(a, b1, b2):
    """Return distance between clusters a and a new cluster b1+b2.

    For single linkage, distance should be based on the
    NEAREST neighbors across the clusters.
    """
    pass


def complete_linkage_cost(a, b):
    """Return distance between clusters a and b.

    For complete linkage, distance should be based on the
    FURTHEST neighbors across the clusters.
    """
    pass


def complete_linkage_cost_efficient(a, b1, b2):
    """Return distance between clusters a and a new cluster b1+b2.

    For complete linkage, distance should be based on the
    FURTHEST neighbors across the clusters.
    """
    pass


def upgma_linkage_cost(a, b):
    """Return distance between clusters a and b.

    For UPGMA, distance should be the unweighted average of
    distances for all pairs across the clusters.
    """
    pass


def upgma_linkage_cost_efficient(a, b1, b2):
    """Return distance between clusters a and a new cluster b1+b2.

    For UPGMA, distance should be the unweighted average of
    distances for all pairs across the clusters.
    """
    pass


#--------------------------------------------------------
# Gather user input
#--------------------------------------------------------

# read raw data
filename = input('Distance file: ')
names,dist = parse(filename)
N = len(names)               # number of leaves

algorithms = [single_linkage_cost,
              single_linkage_cost_efficient,
              complete_linkage_cost,
              complete_linkage_cost_efficient,
              upgma_linkage_cost,
              upgma_linkage_cost_efficient]
done = False
while not done:
    for j in range(len(algorithms)):
        print('%d) %s' % (1+j, algorithms[j].__name__))
    try:
        k = int(input('Choose a linkage algorithm: '))
        if 1 <= k <= len(algorithms):
            alg = algorithms[k-1]
            done = True
        else:
            print('invalid choice')
        
    except ValueError:
        print('invalid choice')


# vary parameterization of wrapper function
if len(signature(alg).parameters) == 2:
    def compute_distance(a, b, b1, b2):
        return alg(a,b)
else:
    def compute_distance(a, b, b1, b2):
        return alg(a,b1,b2)

#--------------------------------------------------------
# Initialize Data Structures
#--------------------------------------------------------

# clades[k] is a list of original leaf IDs for those
# leaves within clade k
clades = []
for k in range(N):
    clades.append( [k] )   # clade k originally has only leaf k

# tree[k] is our tuple representation of clade k
trees = []
for k in range(N):
    trees.append( (names[k], (), ()) )

# originally, clades 0 through N-1 are the active ones
active = set( range(N) )

#--------------------------------------------------------
# Main clustering algorithm
#--------------------------------------------------------

while len(active) > 1:
    a,b = find_nearest_active()

    # remove a and b from active
    active.remove(a)
    active.remove(b)

    print()
    print('About to merge clusters at distance %.3f:' % dist[a][b])
    print('  ' + str(trees[a]))
    print('  ' + str(trees[b]))

    # new clade has all leaves of clades[a] and clades[b]
    clades.append(clades[a] + clades[b])  # combination of both subtree

    # create new tree representation for this clade
    trees.append( combine_trees(a,b) )

    print('New tree: ', trees[-1])

    # compute distances to the new clade
    newID = len(clades)-1
    dist.append( {} )    # dictionary for new clade's distances
    for c in active:
        val = compute_distance(c, newID, a, b)
        dist[newID][c] = val
        dist[c][newID] = val

    # officially add new clade to active set
    active.add(newID)


result = trees[-1]

print()
print('Final tree:',result)

print()
response = input('Would you like me to draw it? [y/n] ').strip().lower()
if response.startswith('y'):
    x = 400/result[0]
    y = max(400//(N-1), 20)
    draw(result,x,y,result[0]*1.05)
