%------------------------------------------------------------------------------%
% Copyright (C) 2001 INRIA/IFSIC.
% This file may only be copied under the terms of the GNU Library General
% Public License - see the file License in the Morphine distribution.
% 
% Author : Erwan Jahier <jahier@irisa.fr>
%
% Computes the proof tree.
%
% Note that it only works if the internal events have been generated in 
% the trace. Also, the Morphine parameter collect_arg should be set to `yes'.
%
% Bug: not/1 is not handled (I should do something at `neg_*' events to fix 
% that).
%
 
:- import_module list, string, require, map, stack, term, std_util.

:- type accumulator_type --->  
	ct(
		   % We need that stack to be able to get the 
		   % call number of the caller.
		   %
	   stack(call_number),

		   % We need the goal path and 2 stacks to be able 
		   % to get the goal path at internal events.
		   %
	   goal_path_string,
	   stack(goal_path_string),
	   stack(goal_path_string),

		   % Each goal is indexed by its call number. We use
		   % that call number to build a table that
		   % contains the list of the goal direct successors
		   % that succeed. Each call number is associated with
		   % a goal path so that we know in which branch each
		   % success node was produced; this is necessary to be
		   % able to remove success nodes that do not
		   % contribute to the final solution.
		   %
	   goal_successors_table, 

		   % list of call numbers corresponding to all solution 
		   % predicates. Indeed, such predicates need to be handled
		   % differently because the goal they search the solutions of
		   % ends up to fail although it has produced solutions.
	   list(call_number),

		   % Table of proof trees indexed by call numbers.
	   proof_tree_table
	   ).

:- type collected_type ---> pb(string) ; ct(proof_tree).

:- type call_path == pair(call_number, goal_path_string).

:- type goal_successors_table == map(call_number, list(call_path)).

:- type predicate ---> p(proc_name, arity, list(univ)).

:- type proof_tree ---> 
	leaf(call_number, predicate)
    ;   tree(call_number, predicate, list(proof_tree)).

:- type proof_tree_table == map(call_number, list(proof_tree)).

%-----------------------------------------------------------------------%

initialize(ct(CallStack, "", GpDetStack, GpNonDetStack, GoalSuccTable, [],
    ProofTable)) :-
	stack__init(CallStack0),
	stack__push(CallStack0, 0, CallStack),
	stack__init(GpDetStack),
	stack__init(GpNonDetStack),
	map__init(GoalSuccTable0),
	map__det_insert(GoalSuccTable0, 0, [], GoalSuccTable),
	map__init(ProofTable).

filter(Event, Acc0, Acc) :-
	Acc0 = ct(CallStack0, GoalPath0, GpDetStack0, GpNonDetStack0, 
		GoalSuccT0, AllSolList0, ProofT0),
	Port = port(Event),
	CurrentCN = call(Event),
	Arg = arguments(Event),
	CurrentPred = p(proc_name(Event), proc_arity(Event), Arg),
	update_goal_path(Event, GoalPath0, GoalPath, GpDetStack0, GpDetStack, 
		GpNonDetStack0, GpNonDetStack),
	update_call_stack(Port, CurrentCN, CallStack0, CallStack),
	( 
		Port = call
	->
		map__det_insert(GoalSuccT0, CurrentCN, [], GoalSuccT),
		( is_a_all_solutions_predicate(proc_name(Event)) ->
		      AllSolList = [CurrentCN | AllSolList0]
		;
		      AllSolList = AllSolList0
		),
		ProofT = ProofT0
	;
		Port = exit 
	->
		stack__top_det(CallStack, FatherCN),
		map__det_remove(GoalSuccT0, FatherCN, CNListF0, GoalSuccT1),
		list__merge_and_remove_dups([CurrentCN - GoalPath], CNListF0,
			CNListF),
		map__det_insert(GoalSuccT1, FatherCN, CNListF, GoalSuccT),
		%
		% Update the proof tree Table. To do that, we retrieve the
		% list of successors of the current goal, we retrieve their
		% proof trees, and then we build the proof tree of the 
		% current goal.
		%
		map__lookup(GoalSuccT, CurrentCN, ListCP),
		%
		% we remove the goal path from that list of call_path
		%
		list__map(remove_goal_path, ListCP, ListCN),

		(  ListCN = [] ->
			% The current goal is a success leaf 
			% (i.e., it has no child)
			ProofTreeCGoal = leaf(CurrentCN, CurrentPred)
		;
			list__map(map__lookup(ProofT0), ListCN, ListListTrees),
			% If the goal is an all sol pred, ListListTrees should
			% contain the list of the solution of the goal it
			% calls; otherwise, it should contain the list of its
			% children solutions
			list__condense(ListListTrees, ListTrees),
			ProofTreeCGoal = 
				tree(CurrentCN, CurrentPred, ListTrees)
		),
		( map__search(ProofT0, CurrentCN, ProofTL0) ->
			( member(FatherCN, AllSolList0) ->
				append([ProofTreeCGoal], ProofTL0, ProofTL),
				map__set(ProofT0, CurrentCN, ProofTL, ProofT)
			;
				% If a goal is not called from an all solution
				% predicate, we only keep the current solution
				map__set(ProofT0, CurrentCN, 
					[ProofTreeCGoal], ProofT)
			)
		;
			map__set(ProofT0, CurrentCN, [ProofTreeCGoal], ProofT)
		),
		AllSolList = AllSolList0
	;
		Port = fail
	->
		stack__top_det(CallStack0, FatherCN),
		( member(FatherCN, AllSolList0) -> 
			% for all sol procs (such as std_util__solutions/2)
			% we must not remove the previous solutions of 
			% the goal. Indeed, all solutions predicates 
			% call the goal they search all the solution of 
			% until it fails; therefore that would mean that we 
			% would always remove those solutions.
			GoalSuccT = GoalSuccT0,
			ProofT = ProofT0
		;
			map__lookup(GoalSuccT0, CurrentCN, SuccList),
			remove_proof_trees(ProofT0, SuccList, ProofT),
			map__lookup(GoalSuccT0, FatherCN, SuccListFather0),
			list__filter(remove_current_cn(CurrentCN), 
				SuccListFather0, SuccListFather),
			map__det_update(GoalSuccT0, FatherCN, SuccListFather, 
				GoalSuccT)
		),
		AllSolList = AllSolList0
	;
		( Port = disj ; Port = ite_else )
	->
		%
		% We need to remove the solutions of the previous (failing
		% branch) disjunction, i.e., we remove form the successors
		% tables all the branches from the father goal of the 
		% previous disjunction. We also have to remove the proof 
		% trees of those goals.
		%
		% We also need to remove the solutions produced in the if 
		% branch for the same reasons as above (actually, it is 
		% natural since `if A then B else C' <=> `A, !, B ; C'.).
		%
		stack__top_det(CallStack, FatherCN),
		( member(FatherCN, AllSolList0) ->
		    GoalSuccT = GoalSuccT0,
		    ProofT = ProofT0
		;
		    ( Port = disj ->
			remove_sol_disj_branch(GoalPath, CurrentCN,
				GoalSuccT0, GoalSuccT, ProofT0, ProofT)
		    ;
			remove_sol_then_branch(GoalPath, CurrentCN,
				GoalSuccT0, GoalSuccT, ProofT0, ProofT)
		    )
		),
		AllSolList = AllSolList0
	;
		Port = redo
	->
		% Deterministic proc won't produce any fail event on 
		% backtraking; thefore we need to remove solutions here too.
		stack__top_det(CallStack0, FatherCN),
		( member(FatherCN, AllSolList0) ->
		    GoalSuccT = GoalSuccT0,
		    ProofT = ProofT0
		;
		    remove_sol_at_redo_port(CurrentCN, FatherCN, 
			    GoalSuccT0, GoalSuccT, ProofT0, ProofT)
		),
		AllSolList = AllSolList0
	;
		% Bug XXX I probably should do something at ports  
		% neg_enter, neg_success, and neg_failure.
		%
 		GoalSuccT = GoalSuccT0,
		ProofT = ProofT0,
		AllSolList = AllSolList0
	),
	Acc = ct(CallStack, GoalPath, GpDetStack, GpNonDetStack, 
		GoalSuccT, AllSolList, ProofT).


:- pred is_a_all_solutions_predicate(string::in) is semidet.
is_a_all_solutions_predicate("solutions").
is_a_all_solutions_predicate("std_util__solutions").
is_a_all_solutions_predicate("solutions_set").
is_a_all_solutions_predicate("std_util__solutions_set").
is_a_all_solutions_predicate("unsorted_solutions").
is_a_all_solutions_predicate("std_util__unsorted_solutions").
is_a_all_solutions_predicate("aggregate").
is_a_all_solutions_predicate("std_util__aggregate").
is_a_all_solutions_predicate("aggregate2").
is_a_all_solutions_predicate("std_util__aggregate2").
is_a_all_solutions_predicate("unsorted_aggregate").
is_a_all_solutions_predicate("std_util__unsorted_aggregate").
is_a_all_solutions_predicate("do_while").
is_a_all_solutions_predicate("std_util__do_while").


:- pred remove_goal_path(call_path, call_number).
:- mode remove_goal_path(in, out) is det.
remove_goal_path(CallNumber - _GoalPath, CallNumber).
	

:- pred remove_sol_at_redo_port(call_number, call_number,
	goal_successors_table, goal_successors_table,
	proof_tree_table, proof_tree_table).
:- mode remove_sol_at_redo_port(in, in, in, out, in, out) is det.
remove_sol_at_redo_port(CurrentCN, FatherCN, GoalSuccT0, GoalSuccT, 
    ProofT0, ProofT) :-
	% Get the successors list of the father of the current goal.
	map__lookup(GoalSuccT0, FatherCN, SuccList),
	list__filter(call_number_is_bigger(CurrentCN), SuccList, NewSuccList),
	map__det_update(GoalSuccT0, FatherCN, NewSuccList, GoalSuccT),
	% We also remove the corresponding proof trees
	list__delete_elems(SuccList, NewSuccList, FailList),
	remove_proof_trees(ProofT0, FailList, ProofT).

:- pred call_number_is_bigger(call_number, call_path).
:- mode call_number_is_bigger(in, in) is semidet.
call_number_is_bigger(CN1, CN2 - _GoalPath2) :-
	CN1 > CN2.


:- pred remove_sol_disj_branch(goal_path_string, call_number, 
	goal_successors_table, goal_successors_table,
	proof_tree_table, proof_tree_table).
:- mode remove_sol_disj_branch(in, in, in, out, in, out) is det.
remove_sol_disj_branch(GoalPath, CurrentCN, GoalSuccT0, GoalSuccT, 
	ProofT0, ProofT) :-
	( if get_goal_path_previous_disj(GoalPath, GoalPathPrev) then 
		% Get the successors list of the father of the current goal.
		map__lookup(GoalSuccT0, CurrentCN, SuccList),
		% From this list of successors, we want to remove all the 
		% call_path that were in the previous disjunction, as well as 
		% the solutions that were produced after; to do 
		% that, we use the goal path information.
		list__filter(goal_path_is_after(GoalPathPrev), 
			SuccList, NewSuccList),
		map__det_update(GoalSuccT0, CurrentCN, NewSuccList, GoalSuccT),
		% We also remove the corresponding proof trees
		list__delete_elems(SuccList, NewSuccList, FailList),
		remove_proof_trees(ProofT0, FailList, ProofT)
	  else
		GoalSuccT = GoalSuccT0,
		ProofT = ProofT0
	).

:- pred remove_sol_then_branch(goal_path_string, call_number, 
	goal_successors_table, goal_successors_table,
	proof_tree_table, proof_tree_table).
:- mode remove_sol_then_branch(in, in, in, out, in, out) is det.
remove_sol_then_branch(GoalPath, CurrentCN, GoalSuccT0, GoalSuccT, 
	ProofT0, ProofT) :-
	( if string__remove_suffix(GoalPath, "e;", GoalPathIfBranch) then
		append(GoalPathIfBranch, "t;", GoalPathThenBranch),
		% Get the succesors list of the current of the current goal.
		map__lookup(GoalSuccT0, CurrentCN, SuccList),
		% From this list of successors, we want to remove all the 
		% call_path that were in the then branch, as well as 
		% the solutions that were produced after; to do 
		% that, we use the goal path information.
		list__filter(goal_path_is_after(GoalPathThenBranch), 
			SuccList, NewSuccList),
		map__det_update(GoalSuccT0, CurrentCN, NewSuccList, GoalSuccT),
		% We also remove the corresponding proof trees
		list__delete_elems(SuccList, NewSuccList, FailList),
		remove_proof_trees(ProofT0, FailList, ProofT)
	  else
		error("Not an ite_else event")
	).


:- pred remove_proof_trees(proof_tree_table, list(call_path), 
	proof_tree_table).
:- mode remove_proof_trees(in, in, out) is det.
remove_proof_trees(ProofT0, [], ProofT0).
remove_proof_trees(ProofT0, [CN-_|Tail], ProofT) :-
	map__det_remove(ProofT0, CN, _, ProofT1),
	remove_proof_trees(ProofT1, Tail, ProofT).


% goal_path_is_after(Path1, _ - Path2) is true iff Path1
% occurs <<after>> Path2 in the goal. 
%
% The idea is that when we enter the branch of a disjunction,
% we need to remove all the solutions that were produced after
% as well as the solutions that were produced in the previous disjunction.
%
% Ex: goal_path_is_after("d3;c4;d3;d3", _ - "d3;c4;d2;t")
%     goal_path_is_after("d3;e;c6;d9", _ - "d4;e;")
:- pred goal_path_is_after(goal_path_string, call_path).
:- mode goal_path_is_after(in, in) is semidet.
goal_path_is_after(GoalPath1, _CN2 - GoalPath2) :-
	string__words(is_semi_column, GoalPath1) = ListPath1,
	string__words(is_semi_column, GoalPath2) = ListPath2,
	goal_path_is_after2(ListPath1, ListPath2).
	       
:- pred is_semi_column(char::in) is semidet.
is_semi_column(';').

:- pred goal_path_is_after2(list(string)::in, list(string)::in) is semidet.
goal_path_is_after2([Path1|Tail1], [Path2|Tail2]) :-
	( Path1 = Path2 ->
		goal_path_is_after2(Tail1, Tail2)
	    ;
		atomic_path_is_after(Path1, Path2)
	).

:- pred atomic_path_is_after(string::in, string::in) is semidet.
% atomic_path_is_after(Path1, Path2) succeeds iff Path1 < Path2
atomic_path_is_after(Path1, Path2) :- 
	(
		Path2 = "" 
	->
		% "s2;d2;d4;" is after "s2;d2;"
		true
	;
		append(L, Int1Str, Path1),
		append(L, Int2Str, Path2),
		( L = "c" ; L = "d"),
		string__det_to_int(Int1Str) = Int1,
		string__det_to_int(Int2Str) = Int2
	->
		Int1 > Int2
	;
		( Path1 = "e" ; Path1 = "t"), Path2 = "?"
	).


:- pred remove_current_cn(call_number,  call_path).
:- mode remove_current_cn(in, in) is semidet.
remove_current_cn(N1, N2 - _) :-
	not (N1 = N2).


:- pred get_goal_path_previous_disj(goal_path_string, goal_path_string).
:- mode get_goal_path_previous_disj(in, out) is semidet.
% Takes a goal path of the form "<path>;di;" (where i>0), and outputs the 
% goal path "<path>;di-1;" if i>1, fails otherwise
get_goal_path_previous_disj(GoalPath, GoalPathPrev) :-
	reverse_string(GoalPath, GoalPathRev),
	( if 
		string__sub_string_search(GoalPathRev, "d", Index1)
	  then
		Index = Index1 
	  else 
		error("No disjunction is found in this string")
	),
	string__split(GoalPathRev, Index, NRevStr0, GoalPathTailRev),
	% remove the ";" 
	string__append(";", NRevStr, NRevStr0),
	% check that it is not the first disjunction
	not (NRevStr = "1"),
	reverse_string(GoalPathTailRev, GoalPathTail),
	reverse_string(NRevStr, NStr),
	( if string__to_int(NStr, N0) then
		NPrev is N0 - 1
	  else
		error("Should be an integer")
	),
	string__int_to_string(NPrev, NPrevStr),
	string__append_list([GoalPathTail, NPrevStr, ";"], GoalPathPrev).

:- pred reverse_string(string, string).
:- mode reverse_string(in, out) is det.
reverse_string(String, StringRev) :-
	% Is this the most efficient way to do that?
	string__to_char_list(String, Char),
	string__from_rev_char_list(Char, StringRev). 
	

:- pred update_goal_path(event::in, 
	goal_path_string::in, goal_path_string::out, 
						% To get rigth goal path
						% at call ports.
	stack(goal_path_string)::in, stack(goal_path_string)::out,
						% To get rigth goal path
						% at exit and fail ports.
	stack(goal_path_string)::in, stack(goal_path_string)::out
						% To get rigth goal path
						% at redo ports.
    ) is det.
	% For efficiency, this should be interwined inside filter. I did not
	% do that for the sake of clarity and reusability.
	% Hopefully, the deforestration optimisation should be able to
	% work here.
update_goal_path(Event, GoalPath0, GoalPath, DetStack0, DetStack, 
    NonDetStack0, NonDetStack) :-
	Port = port(Event),
	Det = determinism(Event),
	(
		Port = call
	->
		stack__push(DetStack0, GoalPath0, DetStack),
		NonDetStack = NonDetStack0,
		GoalPath = GoalPath0
	;
		Port = exit
	->
		stack__pop_det(DetStack0, GoalPath, DetStack),
		( nondet_or_multi(Det) ->
			% Only nondet and multi proc have redo ports
			stack__push(NonDetStack0, GoalPath, NonDetStack)
		;
			NonDetStack = NonDetStack0
		)
	;
		Port = redo
	->
		stack__pop_det(NonDetStack0, GoalPath, NonDetStack),
		stack__push(DetStack0, GoalPath, DetStack)
	;
		Port = fail
	->
		stack__pop_det(DetStack0, GoalPath, DetStack),
		NonDetStack = NonDetStack0
	;
		% At internal events, the goal path is available
		GoalPath = goal_path(Event),
		NonDetStack = NonDetStack0,
		DetStack = DetStack0
	).

:- pred nondet_or_multi(determinism::in) is semidet.
nondet_or_multi(3). % nondet
nondet_or_multi(7). % multi


:- pred update_call_stack(trace_port_type, call_number, stack(call_number), 
	stack(call_number)).
:- mode update_call_stack(in, in, in, out) is det.

update_call_stack(Port, Current, CallStack0, CallStack) :-
	( 
		( Port = call ; Port = redo )
	->
		stack__push(CallStack0, Current, CallStack)
	;
		( Port = exit ; Port = fail )
	->
		stack__pop_det(CallStack0, _, CallStack)
	;
		CallStack = CallStack0
	).


post_process(ct(_, _, _, _, _, _, ProofTreeTable), ProofTree) :-
	( map__to_assoc_list(ProofTreeTable, [1 - [ProofTree0] | _]) ->
		ProofTree = ct(ProofTree0)
	    ;
		% Should not occur
		ProofTree = pb("map__to_assoc_list failed \n") 
	).
