#!/usr/bin/python

import sys
import string
import math

# Constants (as close as we get in Python anyway)
SOLUTION_PENDING = 0
SOLUTION_FOUND = 1
NO_SOLUTION_EXISTS = 2

def get_user_integer (msg):
	"""Custom input handler, allows input of non-zero positive integers."""
	userstring = raw_input(msg)
	try:
		number = int(userstring)
	except TypeError:
		number = 1
	except:
		number = 1
	if number <= 0:
		number = 1
	
	return number

class Position (object):
	""" Representation of a 2-dimensional location."""
	
	__slots__ = ("x", "y")

	def __init__ (self, x = 0, y = 0):
		self.x = x
		self.y = y

	def __eq__ (self, cmp_to):
		"""Method called for equality testing '=='."""
		return self.x == cmp_to.x and self.y == cmp_to.y

	def __str__ (self):
		"""'Stringification' method."""
		return "Position: (x, y) = [%d, %d]" % (self.x, self.y)
	
	def __add__ (self, rhs):
		"""Method called for the '+' operator."""
		return Position(self.x + rhs.x, self.y + rhs.y)

	def __sub__ (self, rhs):
		"""Method called for the '-' operator."""
		return Position(self.x - rhs.x, self.y - rhs.y)
	
	def __mul__ (self, rhs):
		"""Method called forMartin Persson System och Design the '*' operator."""
		return Position(self.x * rhs.x, self.y * rhs.y)

	def __hash__ (self):
		"""Method for hashing a Position object (used by dictionaries)."""
		return hash((self.x, self.y))

class Graphnode (object):
	"""Comparable, sortable based on f(), has a parent and children nodes."""

	__slots__ = ("parent", "f", "g", "h", "location")

	def __init__ (self):
		self.parent = self
		self.f = 0
		self.g = 0
		self.h = 0
		self.location = None

	def __str__ (self):
		return "Graphnode: (f, g, h) = [%f, %f, %f]\n\t%s" % (self.f, self.g, self.h, str(self.location))
	
	def __cmp__ (self, cmp_to):
		"""This method is called by lists in order to sort their elements.
		Since we want our agent to sort the nodes by cost, in a raising
		manner, we implement comparision in terms of f()."""
		if self.f < cmp_to.f:
			return -1
		elif self.f > cmp_to.f:
			return 1
		else:
			return 0

class Worldmap (object):
	"""Represents a square world of tiles of varying sort.
	Worldmap is loaded from a textual file by the load_map()-method."""

	def __init__ (self):
		self.filename = ""
		self.width = 0
		self.height = 0
		self.case_sensitive = False
		self.map_data = []
		self.map_type = None

	def load_map (self, filename):
		"""This is a simple map parser (hack is a more proper name).
		Based on a read-by-line approach. It can handle both DOS, UNIX and MAC-
		formatted files, and tries to work out if the user is supplying us
		with malformed data. It returns a dictionary with the loaded map's start
		and goal positions."""

		self.filename = filename
		inputfile = open(self.filename, "r")
		
		# Create the map object (just a list). We set it to 'None' if no
		# map has been loaded (in the constructur, '__init__').
		self.map_data = []
		stuff = { }

		# These tables contains the valid values and substitutions for maps.
		table_binary = { " ":" ", "X":"X", "S":" ", "E":" " }
		table_weighted = { "X":"X", "S":" ", "E":" ", "1":"1", "2":"2", "3":"3", "4":"4", "5":"5", "6":"6", "7":"7", "8":"8", "9":"9" }
		table = table_binary
			
		for i in range (0, 2):
			# Read a line, split by '=', clean up and interpret.
			line = inputfile.readline()
			parts = line.split('=')
			clean_parts = []
			for j in parts:
				clean_parts.append(j.strip())
			
			if clean_parts[0] == "height":
				self.height = int(clean_parts[1])
			elif clean_parts[0] == "width":
				self.width = int(clean_parts[1])
		
		num_ignored = 0 # Use this to keep track of ignored symbols.
		for i in range(self.height):
			line = inputfile.readline()
			newline = []

			for j in range(len(line)):
				
				# 'i' indicates the current row, 'j' indicates the current
				# column. We'll assign the current symbol (or "token") to
				# 'mapsymbol' for convenience.
				mapsymbol = line[j]
				
				# Strip out any newline symbols, but keep a note of it.
				if mapsymbol == '\n' or mapsymbol == '\r':
					num_ignored = 1 + num_ignored
					continue

				# If we are not sensitive about case in the input stream,
				# just transform all tokens to uppercase. The internal map
				# is always in uppercase.
				if not self.case_sensitive:
					mapsymbol = mapsymbol.upper()
				
				# When we get the first map symbol which isn't an blocked
				# square, start or goal square, we can determine map type.
				if self.map_type == None:
					if mapsymbol != "X" and mapsymbol != "S" and mapsymbol != "E":
						if mapsymbol in table_binary:
							self.map_type = "binary"
							table = table_binary
						elif mapsymbol in table_weighted:
							self.map_type = "weighted"
							table = table_weighted
						else:
							raise ValueError("Map type could not be detected.")
				
				# Extract the appropiate substitution symbol, and also make
				# sure that the symbol is valid.
				try:
					newsymbol = table[mapsymbol]
				except KeyError:
					raise ValueError("Invalid symbol \"%s\" found in map!" % mapsymbol)

				if mapsymbol == "S":
					stuff["start"] = Position(j, i)

				elif mapsymbol == "E":
					stuff["goal"] = Position(j, i)
				
				newline.append(newsymbol)

			self.map_data.append(newline)
		
		# Whine a little, the users deserve it you know...
		print "Warning: %d symbols in the input stream were ignored (newlines?)." % num_ignored
		inputfile.close()

		if self.width != len(self.map_data[0]):
			raise ValueError("Map width does not match matrix contents.")

		if self.height != len(self.map_data):
			raise ValueError("Map height doesn not match matrix contents.")
		
		print "This appears to be a " + self.map_type + " type map."

		# Calculate the values for the start and goal nodes for a non-binary
		# map type. This must be done since we don't know the value for those.
		# The value is calculated by averaging the value of all adj. squares.
		if self.map_type == "weighted":
			sum = 0
			nodes = self.get_adjacent(stuff["start"])
			for mapsym in nodes.itervalues():
				sum = sum + int(mapsym)
			sum = sum / len(nodes)
			self.set_data(stuff["start"], str(int(sum)))
			print "Averaged the value for the start location to: %d" % sum
			
			sum = 0
			nodes = self.get_adjacent(stuff["goal"])
			for mapsym in nodes.itervalues():
				sum = sum + int(mapsym)
			sum = sum / len(nodes)
			self.set_data(stuff["goal"], str(int(sum)))
			print "Averaged the value for the goal location to: %d" % sum
		
		# Return the locations of the 'start' and 'goal' nodes.
		return stuff
	
	def valid_map_location (self, pos):
		"""Validates a given map location.
		The position is valid if within map bounds an does't contain an 'X'."""
		if pos.x >= self.width or pos.x < 0:
			return False
		if pos.y >= self.height or pos.y < 0:
			return False
		if self.map_data[pos.y][pos.x] == "X":
			return False

		return True
	
	def get_data (self, pos):
		"""Convinience map data access method.
		Less error-prone than manual access, since matrix is (y, x) ."""
		return self.map_data[pos.y][pos.x]
	
	def set_data (self, pos, val):
		"""Easy-to-use mutator method for changing map data."""
		self.map_data[pos.y][pos.x] = val;
		
	def get_adjacent (self, pos):
		"""Return all the accessible locations based on the supplied one.
		Return value is a dictionary with Position instances as keys, and
		the map symbols as values. Only valid locations are returned."""
		
		adjacent = { }

		# Evaluate what directions that constitutes valid moves.
		test = Position(pos.x, pos.y + 1)
		if self.valid_map_location(test):
			adjacent[test] = self.get_data(test)

		test = Position(pos.x, pos.y - 1)
		if self.valid_map_location(test):
			adjacent[test] = self.get_data(test)

		test = Position(pos.x + 1, pos.y)
		if self.valid_map_location(test):
			adjacent[test] = self.get_data(test)

		test = Position(pos.x - 1, pos.y)
		if self.valid_map_location(test):
			adjacent[test] = self.get_data(test)

		test = Position(pos.x + 1, pos.y + 1)
		if self.valid_map_location(test):
			adjacent[test] = self.get_data(test)

		test = Position(pos.x + 1, pos.y - 1)
		if self.valid_map_location(test):
			adjacent[test] = self.get_data(test)

		test = Position(pos.x - 1, pos.y + 1)
		if self.valid_map_location(test):
			adjacent[test] = self.get_data(test)

		test = Position(pos.x - 1, pos.y - 1)
		if self.valid_map_location(test):
			adjacent[test] = self.get_data(test)
		
		return adjacent

class Scenario (object):
	"""A problem that an agent must solve.
	It has a worldmap, a start and a goal location."""
	
	def __init__ (self, filename):
		self.world = Worldmap()
		locations = self.world.load_map(filename)
		self.start = locations["start"]
		self.goal = locations["goal"]
	
	def setup (self, filename):
		self.world = Worldmap()
		locations = self.world.load_map(filename)
		self.start = locations["start"]
		self.goal = locations["goal"]

class Agent (object):
	"""An path-finding A*-based agent."""
	
	def __init__ (self):
		self.open = []
		self.closed = []
		self.task = None
		self.location = None
		self.iterations = 0
		self.path = []
		self.cost_function = self.cost_simple
		self.heuristic_function = self.heuristic_manhattan
		self.smooth_path = False
		self.dir_penalty_cost = 30
	
	def setup (self, task):
		"""Initializes the agent to work with the supplied task."""
		self.open = []
		self.closed = []
		self.task = task
		self.location = task.start
		initial_node = Graphnode()
		initial_node.location = self.location
		initial_node.parent = initial_node
		self.open.append(initial_node)
		self.iterations = 0
		self.path = []

		# If it is a weighted map, use cost_weighted, else cost_simple
		if self.task.world.map_type == "weighted":
			self.cost_function = self.cost_weighted
		else:
			self.cost_function = self.cost_simple

		self.heuristic_function = self.heuristic_manhattan
		self.smooth_path = False

	def heuristic_manhattan (self, loc):
		return math.fabs(self.task.goal.x - loc.x) + math.fabs(self.task.goal.y - loc.y)

	def heuristic_straight_line (self, loc):
		return math.hypot(self.task.goal.x - loc.x, self.task.goal.y - loc.y)
	
	def cost_simple (self, node):
		"""Returns the length of the current path to the start (root) node."""
	
		cost = 0
		tmp = node
		while tmp.location != self.task.start:
			tmp = tmp.parent
			cost = cost + 1

		return cost

	def cost_weighted (self, node):
		"""Returns the cost of the current path, for a weighted map."""
		cost = 0
		tmp = node
		while tmp.location != self.task.start:
			cur_par = int(self.task.world.get_data(tmp.parent.location))
			cur_val = int(self.task.world.get_data(tmp.location))
			
			cost -= cur_par - cur_val
			tmp = tmp.parent

		return cost

	def direction_penalty (self, cur_node, test_node):
		"""Return the directional penalty for using 'node' as a successor.
		This method creates a vector A from the current and it's parent node,
		as well as a vector B from the current node and 'node'. The returned
		cost is the dot product of those two vector."""
		
		if test_node.location == cur_node.location:
			raise ValueError("Testing a node against itself.")
		
		if cur_node.parent == cur_node:
			return 0

		# Create the vectors	
		vec_a = cur_node.location - cur_node.parent.location
		vec_b = cur_node.location - test_node.location

		# Hacked up dot product.
		#if (vec_a.x * vec_b.x + vec_a.y * vec_b.y != 0):
		#	return 3;
		#else:
		#	return 0;

		if vec_a != vec_b:
			print "penalty!"
			return self.dir_penalty_cost
		
		print "no penalty!"
		return 0

	def get_path (self, node):
		"""Builds a list of all nodes from the root node to the referred."""

		path = []
		tmpnode = node
		while tmpnode.location != self.task.start:
			path.append(tmpnode)
			tmpnode = tmpnode.parent

		return path

	def think (self, verbose = False):
		"""The main solving iteration method.
		Call this repeatedly in order to solve the current task.
		The return value tells you wether the task is currently unsolved,
		successfuly solved, or unsolvable."""

		# If the open list is exhausted, we have failed utterly...
		if len(self.open) == 0:
			return NO_SOLUTION_EXISTS
		
		# Sort the candidate nodes depending on f()-value.	
		self.open.sort()
		if verbose == True:
			print "The open list:"
			for i in self.open:
				print i

		cur_node = self.open.pop(0)
		
		if verbose == True:
			print "The closed list:"
			for i in self.closed:
				print i
	
		# If the node is the goal node, return success after storing path.
		self.location = cur_node.location
		self.closed.append(cur_node)
		if self.location == self.task.goal:
			self.path = self.get_path(cur_node)
			return SOLUTION_FOUND

		# Get all adjacent nodes from the map.
		adjacent = self.task.world.get_adjacent(cur_node.location)

		# Encapsulate each one in a Grapnode, and add the costs and heuristics.
		nodes = []
		for adj_loc in adjacent:
			new_node = Graphnode()
			new_node.parent = cur_node
			new_node.location = adj_loc
			if self.smooth_path == True:
				new_node.g = self.cost_function(new_node) + self.direction_penalty(cur_node, new_node)
			else:
				new_node.g = self.cost_function(new_node)
			
			new_node.h = self.heuristic_function(adj_loc)
			new_node.f = new_node.g + new_node.h
			nodes.append(new_node)

		for node in nodes:
			
			in_open = False
			in_closed = False

			# If this node is already in the open list, ignore it.
			for o in self.open:
				if o.location == node.location:
					in_open = True
					break
			
			# If in 'closed' and new is cheaper, update 'parent' and 'g'.
			for c in self.closed:
				if c.location == node.location:
					in_closed = True
					break
			
			if not in_closed and not in_open:
				self.open.append(node)
	
		self.iterations = self.iterations + 1
		
		return SOLUTION_PENDING

	def show (self):
		"""Prints the current state of the agent's task solving progress."""
		
		print "Iteration #%d" % self.iterations
		
		# Print lines one by one, check if the location has any data about it
		# in any of the lists, or is a special location (start, goal, agent).
		# If it is any such location, print a special symbol.	
		for y, row in enumerate(self.task.world.map_data):

			line = ""
			replacement = ""
			for x, symbol in enumerate(row):
				replacement = symbol
				cur_pos = Position(x, y)
				if cur_pos == self.task.goal:
					replacement = "E"
				elif cur_pos == self.task.start:
					replacement = "S"
				elif cur_pos == self.location:
					replacement = "A"
				else:
				
					for i in self.open:
						if i.location == cur_pos:
							replacement = "+"
							continue
	
					for i in self.closed:
						if i.location == cur_pos:
							replacement = "-"
							continue
					
					for i in self.path:
						if i.location == cur_pos:
							replacement = "o"
							continue

				line = line + replacement

			print line
		
		# Print explanatory string to please the user.
		print "+ (open)  - (closed)  o (final path)"


### Main script start ###

try:
	print "Using map file: \"%s\"" % sys.argv[1]

except IndexError:
	print "Error: Please supply a map filename as argument to this script."

try:
	myscenario = Scenario(sys.argv[1])

except IOError:
	print "Error: Unable to open map: \"%s\"" % sys.argv[1]

myagent = Agent()
myagent.setup(myscenario)

print "\nSelect simulation speed/visualization:\n"
print "1. Interactive: Redraw and wait after each succession."
print "2. Automatic: Full speed, no drawing until solution has been found."

mode = get_user_integer("\nWhat mode do you wish to use? (default: Interactive): ")

# Allow the user to select what heuristic function the agent should use.
print "\nSelect the agent's heuristic function:\n"
print "1. Manhattan distance"
print "2. Straight line distance"
heuristics = { 1 : myagent.heuristic_manhattan, 2 : myagent.heuristic_straight_line }
heur = get_user_integer("\nWhich heuristic do you want to use? (default: Manhattan): ")
myagent.heuristic_function = heuristics[heur]

# Allow user to select optional features.
print "\nSelect extra features to use:\n"
print "1. Default (no extra features enabled)"
print "2. Straight path generation"

feat = get_user_integer("What features do you want to enable? (default: None) ")

if feat == 2:
	myagent.smooth_path = True
	# Allow custom directional weighting.
	myagent.dir_penalty_cost = get_user_integer("Enter the penalty for a nonstraight path (default = 1): ");

print 
myagent.show()
raw_input("\nPress [enter] to begin simulation...")

# Loop until we have solved the stuff, or it b0rks completely.
ret = 0
usercounter = 0
while ret == SOLUTION_PENDING:
	ret = myagent.think()
	if mode == 1 and usercounter == 0:
		myagent.show()
		usercounter = get_user_integer("\nNumber of iterations to run non-interactively (default = 1)... ")
	usercounter = usercounter - 1

myagent.show()

if ret == NO_SOLUTION_EXISTS:
	print "A solution could not be obtained for this problem."
elif ret == SOLUTION_FOUND:
	print "A solution was found after %d iterations." % myagent.iterations
	print "Final path has %d nodes." % len(myagent.path)



