# Copyright 2009 Ben Escoto
#
# This file is part of Explicans.

# Explicans is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# Explicans is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with Explicans.  If not, see <http://www.gnu.org/licenses/>.

import lazyarray, program, objects, funcall, extable, relref

def lookup(las, name, relref_info):
	"""Return the value of the given name (ExString) in current stack"""
	py_name = name.obj
	result = relref_info.get_name_binding(py_name)
	if result is not None: return result
	return las[name]

def eval(las, ast, relref_info):
	"""Return evaluation of abstract syntax tree ast under names las
	
	ast is an abstract syntax tree made of tuples of the kind returned by the
	parser. las is a LazyArrayStack that binds names to values.
	
	relref_info is a RelRefInfo object which is used to interpret relative
	references like 'me', 'next', and 'prev'
	"""
	def subval(n):
		"""Evaluate the nth element of the current ast and return result"""
		return eval(las, ast[n], relref_info)

	top = ast[0]
	if top == 'NUM': return objects.ExNum(float(''.join(ast[1:])))
	elif top == '+':
		return funcall.call(op_plus, (subval(1), subval(2)))
	elif top == '-':
		return funcall.call(op_minus, (subval(1), subval(2)))
	elif top == '*':
		return funcall.call(op_mult, (subval(1), subval(2)))
	elif top == '/':
		return funcall.call(op_div, (subval(1), subval(2)))
	elif top == '^':
		return funcall.call(op_pow, (subval(1), subval(2)))
	elif top == 'NEG':
		return funcall.call(op_neg, (subval(1),))
	elif top == 'NAME':
		return lookup(las, objects.ExString(ast[1]), relref_info)
	elif top == 'STRING': return objects.ExString(ast[1])
	elif top == ':':
		assert ast[2][0] == 'NAME', ast
		return subval(1)[objects.ExString(ast[2][1])]
	elif top == 'FUNCALL':
		return funcall.call(subval(1), [subval(i) for i in range(2, len(ast))])
	elif top == '.':
		assert ast[2][0] == 'NAME', ast
		return subval(1).get_method(ast[2][1])
	elif top == 'INDEX':
		return funcall.call(index_func, (subval(1), subval(2)))
	elif top == '<':
		return funcall.call(comp_lt, (subval(1), subval(2)))
	elif top == '>':
		return funcall.call(comp_gt, (subval(1), subval(2)))
	elif top == '<=':
		return funcall.call(comp_lte, (subval(1), subval(2)))
	elif top == '>=':
		return funcall.call(comp_gte, (subval(1), subval(2)))
	elif top == '=':
		return funcall.call(comp_eq, (subval(1), subval(2)))
	elif top == '!=':
		return funcall.call(comp_neq, (subval(1), subval(2)))
	elif top == 'and':
		return funcall.call(op_and, (subval(1), subval(2)))
	elif top == 'or':
		return funcall.call(op_or, (subval(1), subval(2)))
	elif top == 'where':
		return op_where(subval(1), subval(2))
	elif top == 'ARITHSEQ2':
		return arith_seq(subval(1), subval(2))
	elif top == 'ARITHSEQ3':
		return arith_seq(subval(1), subval(3), subval(2))
	else: assert False, "Unknown operation in %s" % (ast,)

def eval_absref(las, prog, ar):
	"""Evalute the program starting at given absolute reference"""
	if len(ar) > 0: relref_info = relref.RelRefCell(las.get_root_la(), ar)
	else: relref_info = relref.RelRefInfo()
	result = eval(las, prog[ar].ast, relref_info)
	me_pointer = [result]
	if isinstance(result, objects.ExArray):
		if len(ar) > 0:
			relref_col = relref.RelRefColumn(las.get_root_la(), ar,
											 result.len(), me_pointer)
			result = apply_col_formula(las, prog, ar, result, relref_col)
		result = label(las, result, prog, ar)
	elif isinstance(result, extable.ExTable):
		result = label_table(las, result, prog, ar, me_pointer)
	if not ar.isroot(): result.set_relref(relref_info)
	me_pointer[0] = result
	return result

def apply_col_formula(las, prog, ar, orig_array, relref_col):
	"""Check if there is a column formula, and return updated result array"""
	try: propset = prog.get_col_ps(ar)
	except KeyError: return orig_array
	if propset.ast is None: return orig_array

	result = eval(las, propset.ast, relref_col)
	if result.t == 'array': return result
	la = lazyarray.LazyArray(orig_array.len())
	la.set_values_precomputed((result,)*orig_array.len())
	return objects.ExArray(la)

def thunk_maker(las, prog, ar):
	"""Return a function that runs eval_absref on the given arguments"""
	def helper(): return eval_absref(las, prog, ar)
	return helper

def thunk_maker_recursive_label(las, thunk, prog, ar):
	"""Return a function that continues labeling of what thunk produces
	
	This is used to make sure that a lazy array is labeled recursively even if
	no functions are specified at the top level.
	"""
	def helper():
		result = thunk()
		if isinstance(result, objects.ExArray):
			return label(las, result, prog, ar)
		return result
	return helper

def label(las, exa, prog, ar):
	"""Return new labeled ExArray at node ar based on ExArray exa
	
	As part of program evaluation, if an ExArray is produced at a node, then
	names and expressions the user entered in the program override the default
	values and names. This function returns a modified ExArray that is like
	la but incorporates the program information.
	"""
	assert isinstance(exa, objects.ExArray), exa
	new_la = exa.obj.copy() # needed b/c prog may sets one array eq to another
	new_las = las+new_la
	for index in range(len(new_la)):
		new_ar = ar.append(index)
		try: propset = prog[new_ar]
		except KeyError:
			newval = False
		else:
			if propset.name is not None:
				new_la.set_key(index, propset.get_exobj_name())
			newval = (propset.ast is not None)

		if newval: new_la.set_value(index, thunk_maker(new_las, prog, new_ar))
		else:
			labeled_thunk = thunk_maker_recursive_label(
				new_las, new_la.get_thunk(index), prog, new_ar)
			new_la.set_value(index, labeled_thunk)
	return objects.ExArray(new_la)

def label_table(las, orig_table, prog, ar, me_pointer):
	"""Return new labeled ExTable at node ar based on ExTable ext
	
	This is like label but for tables, so there are a few more steps. (Both
	axis and column/row formulas need to be evaluated.)
	"""
	assert isinstance(orig_table, extable.ExTable), orig_table
	row_axis_name, col_axis_name = table_get_axis_names(prog, ar, orig_table)
	temp_row_axis = relref.TempTableAxisValue()
	temp_col_axis = relref.TempTableAxisValue()
	row_names = table_rowcol_names(las, orig_table, prog, ar,
								   row_axis_name, temp_row_axis, True)
	col_names = table_rowcol_names(las, orig_table, prog, ar,
								   col_axis_name, temp_col_axis, False)
	table = extable.ExTable(row_axis_name, col_axis_name, row_names, col_names)
	table_apply_column_formulas(las, table, prog, ar)
	table_apply_value_formulas(las, table, prog, ar)
	return table

def table_get_axis_names(prog, ar, orig_table):
	"""Return (row name, col name) for the table at AR.  Default to old name"""
	row_propset, col_propset = prog.get_table_axis_propsets(ar)

	if row_propset is None or row_propset.name is None:
		row_axis_name = orig_table.row_axis_name
	else: row_axis_name = row_propset.name
	if col_propset is None or col_propset.name is None:
		col_axis_name = orig_table.row_axis_name
	else: col_axis_name = col_propset.name

	return (row_axis_name, col_axis_name)

def table_rowcol_names(las, orig_table, prog, ar, axis_name, axis, isrow):
	"""Return python list of row or column names for table
	
	if isrow is true, return the row formulas. If isrow is false, return column
	formulas.
	"""
	def get_initial_la():
		"""Return initial lazyarray from original table names"""
		if isrow: pyname_list = orig_table.rownames
		else: pyname_list = orig_table.colnames
		la = lazyarray.LazyArray(len(pyname_list))
		la.set_values_precomputed(pyname_list)
		return la

	def apply_col_formula(initial_la, result_pointer):
		"""Return new la that is result of applying the column formula"""
		axis_ps_index = 0 if isrow else 1
		rowcol_ps = prog.get_table_axis_propsets(ar)[axis_ps_index]
		if not rowcol_ps or not rowcol_ps.ast: return initial_la
		relrefinfo = relref.TableAxisNames(las.get_root_la(), ar,
						len(initial_la), axis_name, axis, result_pointer)
		row_eval = eval(las, rowcol_ps.ast, relrefinfo)
		assert isinstance(row_eval, objects.ExArray)
		return row_eval.obj

	def apply_specific_names(la):
		"""Overwrite certain values of la with explicitly given names"""
		if isrow: lookup_func = prog.get_table_rowname_ps
		else: lookup_func = prog.get_table_colname_ps
			
		for i in range(len(la)):
			ps = lookup_func(ar, i+1)
			if ps and ps.name is not None:
				la.set_value(i, ps.get_exobj_name)

	result_pointer = [None]
	la_result = apply_col_formula(get_initial_la(), result_pointer)
	apply_specific_names(la_result)
	result_pointer[0] = objects.ExArray(la_result)
	pylist = [thunk() for thunk in la_result.values]
	return pylist

def table_apply_column_formulas(las, table, prog, ar):
	"""Apply column formulas for each column in table"""
	def to_thunklist(exobj):
		"""Convert ExObject to list of thunks"""
		if exobj.isscalar():
			def helper(): return exobj
			return [helper]*table.get_num_rows()
		assert isinstance(exobj, objects.ExArray), exobj
		return exobj.obj.values[:table.get_num_rows()]

	for j in range(table.get_num_cols()):
		ps = prog.get_table_colname_ps(ar, j+1) # prog index starts at 1
		if ps and ps.ast:
			relref_info = relref.TableColumns(las.get_root_la(), ar, table, j)
			result = to_thunklist(eval(las, ps.ast, relref_info))
			for i in range(len(result)):
				table.set_value_thunk(i, j, result[i])

def table_apply_value_formulas(las, table, prog, table_ar):
	"""Apply individual cell formulas to a table's values"""
	def thunk_maker(ps, relref_info, old_thunk):
		def thunk():
			if ps and ps.ast: result = eval(las, ps.ast, relref_info)
			else: result = old_thunk()
			result.set_relref(relref_info)
			return result
		return thunk

	for i in range(table.get_num_rows()):
		for j in range(table.get_num_cols()):
			ps = prog.get_table_value_ps(table_ar, i+1, j+1) # index start at 1
			relref_info = relref.TableCellValue(las.get_root_la(), table_ar,
												table, i, j)
			new_thunk = thunk_maker(ps, relref_info, table.get_value_thunk(i,j))
			table.set_value_thunk(i, j, new_thunk)

def eval_prog(prog):
	"""Evaluate the program and return results"""
	las = default_las
	return eval_absref(las, prog, program.AbsoluteReference(()))

def op_plus_py(arg1, arg2):
	"""Add the two ExObjects and return an ExObject"""
	assert (arg1.t, arg2.t) in (('num', 'num'), ('string', 'string')), \
		(arg1.t, arg2.t)
	if arg1.t == 'num': return objects.ExNum(arg1.obj + arg2.obj)
	return objects.ExString(arg1.obj + arg2.obj)
op_plus = objects.ExFunc(op_plus_py, (True, True))

def op_minus_py(arg1, arg2):
	"""Subtract ExObject arg2 from ExObject arg1, return ExObject result"""
	assert arg1.t == 'num' and arg2.t == 'num', (arg1.t, arg2.t)
	return objects.ExNum(arg1.obj - arg2.obj)
op_minus = objects.ExFunc(op_minus_py, (True, True))

def op_mult_py(arg1, arg2):
	"""Multiply ExObject arg2 and ExObject arg1, return ExObject result"""
	assert arg1.t == 'num' and arg2.t == 'num', (arg1.t, arg2.t)
	return objects.ExNum(arg1.obj * arg2.obj)
op_mult = objects.ExFunc(op_mult_py, (True, True))

def op_div_py(arg1, arg2):
	"""Divide ExObject arg1 by ExObject arg2, return ExObject result"""
	assert arg1.t == 'num' and arg2.t == 'num', (arg1.t, arg2.t)
	return objects.ExNum(arg1.obj / arg2.obj)
op_div = objects.ExFunc(op_div_py, (True, True))

def op_pow_py(arg1, arg2):
	"""Take ExObject arg1 to the power of arg1, return ExObject result"""
	assert arg1.t == 'num' and arg2.t == 'num', (arg1.t, arg2.t)
	return objects.ExNum(arg1.obj ** arg2.obj)
op_pow = objects.ExFunc(op_pow_py, (True, True))

def op_neg_py(arg):
	"""Return 0 - arg, where arg is an ExNum"""
	assert arg.t == 'num', arg.t
	return objects.ExNum(-arg.obj)
op_neg = objects.ExFunc(op_neg_py, (True,))

def array_py(length):
	"""Array function--take an ExNum and return an ExArray"""
	assert isinstance(length, objects.ExNum)
	return objects.ExArray(lazyarray.LazyArray(int(length.obj)))
array = objects.ExFunc(array_py, (True,))

ex_len = objects.ExFunc(len, (False,))

make_table = objects.ExFunc(extable.make_blank_table, (True, True))

def index_func_py(exarray, index):
	"""Absolute index function: return element at index in array"""
	assert (isinstance(exarray, objects.ExArray) and
			isinstance(index, objects.ExNum)), (exarray, index)
	la = exarray.obj
	n = int(index.obj)
	length = len(la)
	assert -length <= n <= length and n != 0, (n, length)
	if n > 0: la_index = n - 1 # start counting at 1
	else: la_index = length + n
	return exarray.obj.get_thunk(la_index)()
index_func = objects.ExFunc(index_func_py, (False, True))

def comp_lt_py(arg1, arg2):
	"""True if arg1 is less than arg2"""
	return objects.ExBool(arg1.obj < arg2.obj)
comp_lt = objects.ExFunc(comp_lt_py, (True, True))

def comp_gt_py(arg1, arg2):
	"""True if arg1 is greater than arg2"""
	return objects.ExBool(arg1.obj > arg2.obj)
comp_gt = objects.ExFunc(comp_gt_py, (True, True))

def comp_lte_py(arg1, arg2):
	"""True if arg1 is less than or equal to arg2"""
	return objects.ExBool(arg1.obj <= arg2.obj)
comp_lte = objects.ExFunc(comp_lte_py, (True, True))

def comp_gte_py(arg1, arg2):
	"""True if arg1 is greater than or equal to arg2"""
	return objects.ExBool(arg1.obj >= arg2.obj)
comp_gte = objects.ExFunc(comp_gte_py, (True, True))

def comp_eq_py(arg1, arg2):
	"""True if arg1 is equal to arg2.  This is only for scalars"""
	return objects.ExBool(arg1.obj == arg2.obj)
comp_eq = objects.ExFunc(comp_eq_py, (True, True))

def comp_neq_py(arg1, arg2):
	"""True if arg1 is not equal to arg2.  This is only for scalars"""
	return objects.ExBool(arg1.obj != arg2.obj)
comp_neq = objects.ExFunc(comp_neq_py, (True, True))

def op_or_py(bool1, bool2):
	"""True if either of bool1 or bool2 is true"""
	assert bool1.t == 'bool' and bool2.t == 'bool', (bool1, bool2)
	return objects.ExBool(bool1.obj or bool2.obj)
op_or = objects.ExFunc(op_or_py, (True, True))

def op_and_py(bool1, bool2):
	"""True if both of bool1 and bool2 are true"""
	assert bool1.t == 'bool' and bool2.t == 'bool', (bool1, bool2)
	return objects.ExBool(bool1.obj and bool2.obj)
op_and = objects.ExFunc(op_and_py, (True, True))

def op_where(array1, array2):
	"""Return the elements of the first array if True is in the second"""
	assert (isinstance(array1, objects.ExArray) and
			isinstance(array2, objects.ExArray)), (array1, array2)
	la1, la2 = array1.obj, array2.obj
	assert len(la1) == len(la2), (len(la1), len(la2))
	valid_indicies = []
	for i in range(len(la2)):
		include_bool = la2.get_thunk(i)()
		assert include_bool.t == 'bool', (array2, include_bool)
		if include_bool.obj: valid_indicies.append(i)
	result_la = lazyarray.LazyArray(len(valid_indicies))
	for j in range(len(valid_indicies)):
		valid_index = valid_indicies[j]
		result_la.set_key(j, la1.get_key(valid_index))
		result_la.set_value(j, la1.get_thunk(valid_index))
	return objects.ExArray(result_la)


def arith_seq(fromex, toex, interex = None):
	"""Return an ExArray sequence starting from fromex and ending at toex"""
	assert isinstance(fromex, objects.ExNum), fromval
	assert isinstance(toex, objects.ExNum), toval
	if interex:
		assert isinstance(interex, objects.ExNum), toval
		list_seq = arith_seq_py(fromex.obj, toex.obj, interex.obj)
	else: list_seq = arith_seq_py(fromex.obj, toex.obj)

	la = lazyarray.LazyArray(len(list_seq))
	la.set_values_precomputed([objects.ExNum(i) for i in list_seq])
	return objects.ExArray(la)

def arith_seq_py(start, end, mid = None):
	"""Return a sequence of numbers from start to end with mid as second
	
	This is meant to replicate Haskell arithmetic sequences. For example, if
	start, end, and mid are 1, 3, and 10 respectively, result is [1,3,5,7,9].
	"""
	if start < end: reversed = False
	elif start == end: assert False, (start, end)
	else: reversed = True
	
	if mid is None: increment = -1 if reversed else 1
	else:
		assert start < mid < end or start > mid > end
		increment = mid - start

	result_list, cur_elem = [], start
	while ((not reversed and cur_elem <= end) or
		   (reversed and cur_elem >= end)):
		result_list.append(cur_elem)
		cur_elem += increment
	return result_list

def get_globals_la():
	"""Return a LazyArray holding the default global bindings"""
	pairs = ((objects.ExString('array'), array),
			 (objects.ExString('len'), ex_len),
			 (objects.ExString('table'), make_table))
	la = lazyarray.LazyArray(len(pairs))
	la.set_items_precomputed(pairs)
	return la
default_las = lazyarray.LazyArrayStack()+get_globals_la()
