import sys
import time
import math
import random
from optparse import OptionParser, OptionGroup
from samples import samples

def train(patterns, weights):
  """
  Adjust the weights matrix based on the given sample patterns.

  Each pattern will be a sequence of length n, for some n, with
  each value being either +1 or -1.

  Weights is an n x n matrix. with weights[i][j] being the influence of
  neuron i upon neuron j in the ANN.  The weights will all be initially 0.

  """
  pass


def classify(query, weights, randgen, maxIterations=None):
  """
  Alters the query (in place), until reaching an equilibrium in the Hopfield network.
  
  The query will be a list of length n, for some n, with
  each value being either +1 or -1.

  Weights is the n x n matrix produced during training.
  weights[i][j] represents the influence of neuron i upon neuron j in the ANN.

  A call to trace.step(query) should be called for each update step (even if value remains the same).
  """
  pass

def main():
  global trace
  options,args = parseCommandLine()

  if options.seed is None:
    seed = random.randrange(1000000)
    print("Random seed: " + str(seed))
  else:
    seed = options.seed

  randgen = random.Random(seed)

  if options.training:
    numerals = [int(n) for n in options.training.split(',')]
  elif options.patterns == len(samples):
    numerals = range(0,10)
  else:
    numerals = range(1, 1+options.patterns)
  training = [samples[n] for n in numerals]

  k = len(training)

  n = len(samples[0])
  weights = [[0] * n for _ in range(n)]
  if not 1 <= k <= len(samples):
    print("Invalid number of patterns: " + str(k))
    sys.exit()

  if options.showAll:
    createSnapshot(options, training)
    sys.exit()


  trace = Tracer(options)

  # scoreboard[i][j] represents number of times query i mapped to pattern j (with n used as "other")
  scoreboard = [ [0] * (1 + len(samples)) for _ in range(len(samples)) ]
  correct = 0
  
  train(training, weights)
  for trial in range(options.numTests):
    if options.numeral is not None:
      numeral = options.numeral
    else:
      numeral = randgen.choice(numerals)
        
    query = list(samples[numeral])
    for j in range(len(query)):
      if randgen.random() < options.prob:
        query[j] = -query[j]    # bitflip

    trace.initialize(query, numeral)
    classify(query, weights, randgen, options.iterations)
    trace.final(query)

    answer = len(samples)
    query = tuple(query)
    for j in range(len(samples)):
      if query == samples[j]:
        answer = j
        break
    scoreboard[numeral][answer] += 1
    if numeral == answer:
      correct += 1

  print('')
  print("Overall success rate of %6.4f" % (1.0 * correct / options.numTests))
  print('')
  displayResults(scoreboard, numerals)


def prettyQuery(query):   # hardwired for 10x10 pixels (for now)
  out = []
  for r in range(10):
    if r != 0: out.append('\n')
    for c in range(10):
      out.append('@' if query[10*r + c] == 1 else '.')
  return ''.join(out)
  

def displayResults(scoreboard, numerals):
  n = len(numerals)
  headerFormat = '   ' + n * ' %3d'
  lineFormat = '%2d:' + (1+n) * ' %3s'
  options = tuple(numerals)

  print headerFormat % tuple(options) + ' other'
  for num in options:
    data = []
    for n in numerals:
      data.append(scoreboard[num][n])
    data.append(scoreboard[num][-1])
    data = [str(k) if k > 0 else '.' for k in data]
    print lineFormat % ((num,) + tuple(data))


def weightDisplay(weights):
  return '(' + ','.join('%5.3f'%w for w in weights) + ')'

  
def weightDisplayMatrix(weights, row, col):
  lines = []
  for r in range(row):
    line = '  '.join('%+6.4f' % weights[r*col + c] for c in range(col))
    lines.append(line)
  if len(weights) == 1 + row*col:
    lines.append("Bias weight is " + str(weights[row*col]))
  return '\n\n'.join(lines)
  

def createSnapshot(options, samples):
  try:
    import cs1graphics
    n = len(samples)
    size = options.visualSize
    scale = size / 10.0
    canvas = cs1graphics.Canvas(n * 1.2 * size, size)
    canvas.setAutoRefresh(False)
    for j in range(len(samples)):
      q = samples[j]
      for row in range(10):
        for col in range(10):
          s = cs1graphics.Square(scale)
          s.move(scale * (col + 0.5 + 12*j), scale * (row + 0.5))
          color = 'white' if q[10*row + col] == -1 else 'black'
          s.setFillColor(color)
          s.setBorderWidth(scale / 100.0)
          canvas.add(s)
    canvas.setAutoRefresh(True)
    canvas.saveToFile('numbers.ps')
  except ImportError:
    print("Unable to import cs1graphics")


class Tracer:
  def __init__(self, options):
    self._period = options.echo
    self._delay = options.visualDelay
    self._canvas = None
    self._quiet = options.quiet
    self._guess = options.guess
    if (options.visualize):
      try:
        import cs1graphics
        self._cs1graphics = cs1graphics
        size = options.visualSize
        scale = size / 10.0
        self._canvas = cs1graphics.Canvas(size,size)
        self._canvas.setAutoRefresh(False)
        self._squares = []
        for row in range(10):
          for col in range(10):
            s = cs1graphics.Square(scale)
            s.move(scale * (col + 0.5), scale * (row + 0.5))
            s.setFillColor('white')
            s.setBorderWidth(scale / 100.0)
            self._canvas.add(s)
            self._squares.append(s)
        self._canvas.setAutoRefresh(True)
      except ImportError:
        print("Unable to import cs1graphics; ignoring visualization")

  def _draw(self, query):
    if self._canvas:
      self._canvas.setAutoRefresh(False)
      for i in range(len(query)):
        color = 'white' if query[i] == -1 else 'black'
        self._squares[i].setFillColor(color)
      self._canvas.setAutoRefresh(True)

  def initialize(self, query, actual):
    self._steps = 0
    self._draw(query)
    if not self._quiet:
      if self._guess:
        print('\nQuery (any guesses?):\n' + prettyQuery(query))
        raw_input("Press return to continue...")
      else:
        print('\nQuery (spoiler: "%d"):\n' % actual + prettyQuery(query))

  def step(self, query):
    self._steps += 1
    if self._period > 0:
      self._draw(query)
      if self._steps % self._period == 0:
        if not self._quiet:
          print("\nAfter step %d)\n" % self._steps + prettyQuery(query))

      if self._delay == 0:
        raw_input("Press return to continue...")
      else:
        time.sleep(self._delay)

  def final(self, query):
    self._draw(query)
    if not self._quiet:
      print("\nTotal number of steps: %d" % self._steps)
      print("Final query:\n" + prettyQuery(query))

  
#------------------------------------------------------------------------
# Code for command line options
#------------------------------------------------------------------------
def parseCommandLine():
  parser = OptionParser(usage='usage: %prog [options]')
  
  group = OptionGroup(parser, 'Experiment Options')
  group.add_option('-n', dest='patterns', type='int', default=4,
                   help='number of patterns to use in training [default: %default]')
  group.add_option('-T', dest='training', default=None,
                   help='comma separated choice of numerals for training [default: %default]')
  group.add_option('-p', dest='prob', type='float', default=0.1,
                   help='Probability of perturbing each bit in the test pattern [default: %default]')
  group.add_option('-r', dest='numTests', type='int', default=1, metavar='REPS',
                   help='Number of independent tests to perform [default: %default]')
  group.add_option('-f', dest='numeral', type='int', default=None,
                   help='force numeral to choose as basis for test query [default: random]')
  group.add_option('-m', dest='iterations', type='int', default=10000,
                   help='maximum number of iterations to perform per query [default: %default]')
  parser.add_option_group(group)
  
  group = OptionGroup(parser, 'Display Options')
  group.add_option('-t', dest='echo', type='int', default=0, metavar='STEPS',
                   help='trace status every t steps (no trace if 0) [default: %default]')
  group.add_option('-d', dest='visualDelay', type='float', default=0.001, metavar='DELAY',
                   help='per step delay for trace; manual if 0 [default: %default]')
  group.add_option('-v', dest='visualize', default=False, action='store_true',
                   help='visualize trace [default: %default]')
  group.add_option('-w', dest='visualSize', type='int', default=200, metavar='WIDTH',
                   help='width of window for visualization [default: %default]')
  group.add_option('-q', dest='quiet', default=False, action='store_true',
                   help='no console output (other than statistics) [default: %default]')
  parser.add_option_group(group)
  
  parser.add_option('-a', dest='showAll', default=False, action='store_true',
                   help='show all test patterns and exit')
  parser.add_option('-s', dest='seed', type=int, default=None,
                   help='seed for all randomization [default: clock]')
  parser.add_option('-g', dest='guess', default=False, action='store_true',
                   help='have initial manual pause after displaying initial state so that we can guess')

  return parser.parse_args()
  

if __name__ == '__main__':
  main()
