/* Copyright (c) 1993 by The Johns Hopkins University */
 
/* 

 PEBLS:  Parallel Exemplar-Based Learning System

 For more information, contact:				

 Steven Salzberg  (salzberg@cs.jhu.edu)
 Dept. of Computer Science
 Johns Hopkins University
 Baltimore, MD  21210


 Code written by:

 John N. Rachlin
 Dept. of Computer Science
 Johns Hopkins University
 Baltimore, MD  21210

*/



/* PEBLS.C:  Core routines for invoking various system modules */
/*           Provides routines for training and testing.       */



#include <stdio.h>
#include "config.h"
#include "pebls.h"



config_type   	CONFIG;   			     
instance_type 	data[INSTANCES_MAX];		     
int	 	count[CLASSES_MAX+1][FEATURES_MAX][VALUES_MAX];
float	 	dtable[FEATURES_MAX][VALUES_MAX][VALUES_MAX];  
output_type 	output[CLASSES_MAX+1][TRIALS_MAX];   



/* ------------------------------------------------------------ */
/* NEAREST_THRESHOLD_VOTE					*/


int nearest_threshold_vote(int k, int nearest[])
{

    int class_count[CLASSES_MAX];
    int i, p, classes = CONFIG.classes;
    int class_nearest;

    for (i=0; i<classes; i++)
	class_count[i] = 0;

    for (i=0; i<k; i++)
	class_count[data[nearest[i]].class_true]++;
    

    for (i=0; i<classes; i++)
    {
	p = CONFIG.precedence[i]; 
        if (class_count[p] >= CONFIG.threshold[i]) class_nearest = p; 
    }

    return (class_nearest);
}








/* ------------------------------------------------------------ */
/* NEAREST_WEIGHTED_DISTANCE_VOTE:  Give each neighbor a vote   */
/* inversely proportional to its distance.			*/

int nearest_weighted_distance_vote(int k, int nearest[], float distances[])
{
    float class_distance[CLASSES_MAX];
    int i, classes = CONFIG.classes;
    int vote_class;
    float vote_size;

    for (i=0; i<classes; i++)
	class_distance[i] = 0.0;

    for (i=0; i<classes; i++)
	class_distance[data[nearest[i]].class_true] += 1/distances[i];

    vote_size = -1.0;

    for (i=0; i<classes; i++)
    {
	if ((class_distance[i] > vote_size) &&
	    (class_distance[i] != 0.0))
	{
	    vote_size = class_distance[i];
	    vote_class = i;
	}
    }

    return (vote_class);

}


/* ------------------------------------------------------------ */
/* NEAREST_MAJORITY_VOTE: Find the majority class in the set    */
/*    set of nearest neighbors.					*/
/* INPUTS: k = # neighbors					*/
/*	   nearest = set of nearest neighbors			*/
/* OUTPUT: The majority class of the k nearest neighbors.	*/

int nearest_majority_vote(int k, int nearest[])
{
    int class_count[CLASSES_MAX];
    int i, classes = CONFIG.classes;
    int majority_class, majority_size;

    for (i=0; i<classes; i++)
	class_count[i] = 0;

    for (i=0; i<k; i++)
	class_count[data[nearest[i]].class_true]++;

    majority_size = class_count[0];
    majority_class = 0;

    for (i=1; i<classes; i++)
    {
	if ((class_count[i] > majority_size) ||
	    ((class_count[i] == majority_size) &&
	     (f_random(1.0) > 0.500)))
	{
	    majority_size = class_count[i];
	    majority_class = i;
	}
    }

    return (majority_class);
}


/* ------------------------------------------------------------ */
/* NEAREST_VOTE:  Vote on the nearest neighbor.			*/

int nearest_vote(int k, int nearest_list[], float distances[])
{
    int nearest_class;

    switch (CONFIG.nearest_voting)
    {
	case MAJORITY:   
	    nearest_class = nearest_majority_vote(k, nearest_list);
	    break;

	case WEIGHTED_DISTANCE:
	    nearest_class = nearest_weighted_distance_vote(k, nearest_list, distances);
	    break;

	case THRESHOLD:
	    nearest_class = nearest_threshold_vote(k, nearest_list);
	    break;
    }

    return (nearest_class);
}







/* ------------------------------------------------------------ */
int f_minimum(float A[], int n)
{

    int i;
    float min = INFINITY;
    int   min_idx = -1;

    for (i=0; i<n; i++)
    {
	if (A[i] < min) 
	{
	    min = A[i];
	    min_idx = i;
	}
    }

    return(min_idx);
}


/* ------------------------------------------------------------ */
int partition(float A[], int p, int r)
{
    int i,j;
    float x, temp;

    x = A[p];
    i = p - 1;
    j = r + 1;

    while (TRUE)
    {
	do j--; while (A[j] > x);
	do i++; while (A[i] < x);

	if (i < j)
	{
	    temp = A[i];
	    A[i] = A[j];
	    A[j] = temp;
	}
        else return (j);
    }
}

	

/* ------------------------------------------------------------ */
int randomized_partition(float A[], int p, int r)
{
    int i;
    float temp;

    i = p + i_random(r - p + 1);

    temp = A[p];
    A[p] = A[i];
    A[i] = temp;
    return(partition(A,p,r));
}




/* ------------------------------------------------------------ */
float randomized_select(float A[], int p, int r, int i)
{
    int k, q;

    if (p==r) return A[p];
    q = randomized_partition(A,p,r);
    k = q - p + 1;

    if (i <= k) 
        return (randomized_select(A,p,q,i));
    else        
        return (randomized_select(A,q+1,r,i-k));
}
	




/* ------------------------------------------------------------ */
void array_copy(float A[], float B[], int size)
{
    int i;
    for (i=0; i<size; i++)
	B[i] = A[i];
}


/* ------------------------------------------------------------ */
void find_nearest(int k, float distance[], int nearest_idx[], 
			 float nearest_dist[])
{
    float d;
    float distance_copy[INSTANCES_MAX];
    int i, j, kth;
    int instances = CONFIG.instances;

    if (k==1)
    {
  	kth = f_minimum(distance, instances);
	nearest_idx[0] = kth;
	nearest_dist[0] = distance[kth];
    }

    else 

    {
	array_copy(distance, distance_copy, CONFIG.instances);
        d = randomized_select(distance_copy, 0, CONFIG.instances-1, k);

        for (i=0,j=0; i<instances; i++)
        {
	    if (distance[i] <= d)
	    {
	        nearest_idx[j] = i;
	        nearest_dist[j] = distance[i];
	        j++;
	    }
        }
    }
}



/* ------------------------------------------------------------ */
/* NEAREST_NEIGHBOR:  Find set of nearest neighbors		*/
/* INPUTS:	current    = current instance			*/
/*		k	   = # of nearest neighbors		*/
/*		nearest_idx = index of nearest neighbors	    	*/
/*		nearest_dist = distances of nearest neighbors   */
/*		weighting  = weighting is ON or OFF		*/

void nearest_neighbor(int current, int k, int nearest_idx[], 
		      float nearest_dist[], int weighting)
{
    float distance[INSTANCES_MAX];
    int nearest_class;
    int i, instances = CONFIG.instances;


    for (i=0; i<instances; i++)
    {
	if ((data[i].trained) && (i != current) &&
	    ((CONFIG.exemplar_weighting != USED_CORRECT) || 
	     (data[i].weighted == TRUE)))
	  distance[i] = MVDM(i, current, weighting);
	else
	  distance[i] = INFINITY;
    }

    find_nearest(k, distance, nearest_idx, nearest_dist);
}





/* ------------------------------------------------------------ */
/* TRAIN_INSTANCE:  Add instance "i" to exemplar data		*/
/* INPUTS:	i = current instance				*/
/* OUTPUT:	None.  Global COUNT data updated		*/

void train_instance(int i)
{
    int c, f, v;
    int classes = CONFIG.classes;
    int features = CONFIG.features;

    c = data[i].class_true;
    for (f=0; f<features; f++)		/* Update value counts */
    {
	v = data[i].value[f];
	count[c][f][v]++;
	count[classes][f][v]++;
    }

    data[i].trained = TRUE;
}







/* ------------------------------------------------------------ */
/*  LEAVE_ONE_OUT:  A training and test method by which    	*/
/*   each instance is tested by first training on all other     */
/*   instances in the data set.					*/
/*  INPUTS:  None.						*/
/*  OUTPUT:  None.  (Results printed to screen.)		*/

void leave_one_out(void)
{
    int i,j,t;
    int trials = CONFIG.trials;
    int instances = CONFIG.instances;
    int weighting = CONFIG.exemplar_weighting;
    int nearest_class, nearest_list[K_NEIGHBOR_MAX];
    int k = CONFIG.nearest_neighbor;
    float distances[K_NEIGHBOR_MAX];

    if (CONFIG.output_mode == COMPLETE)
    {
	printf("\n\n%10s %10s %10s %10s\n", "", "", "TRUE", "PEBLS");
	printf("%10s %10s %10s %10s\n", "TRIAL", "ID", "CLASS", "CLASS");
	printf("%10s %10s %10s %10s\n", "-----", "-----", "-----", "-----");
    }

    for (i=0; i<instances; i++)  		/* For ea. instance */
    {
	initialize_training();
	for (j=0; j<instances; j++)		/* Train on all others */
	    if (j != i) train_instance(j);
	build_distance_tables();

        for (t=0; t<trials; t++)
        {
	    set_exemplar_weights(); 
	    nearest_neighbor(i, k, nearest_list, distances, weighting);
	    nearest_class = nearest_vote(k, nearest_list, distances);
	    update_single_output(i, nearest_class, t);
	    if (CONFIG.output_mode == COMPLETE)
    	        printf("%10d %10s %10s %10s \n", 
			t+1,
		       	data[i].id,
		       	CONFIG.class_name[data[i].class_true], 
		       	CONFIG.class_name[nearest_class]);
	}
    }

    print_output();
}



/* ------------------------------------------------------------ */
/* TRAIN_SUBSET:  Train on a random subset of the instance data */
/* INPUT:  None.						*/
/* OUTPUT: None.						*/


void train_subset(void)
{
    int i, size;
    int select, select_order[INSTANCES_MAX];
    int training_instances;
    float training_size = CONFIG.training_size;

    if (CONFIG.training_mode == SUBSET)
	training_instances = CONFIG.instances;  	 /* Choose among ALL instances */
    else training_instances = CONFIG.training_instances; /* Choose among specified group */

    shuffle(select_order, training_instances);
    size = round (training_instances * training_size);  /* Actual number to train */

    for (i=0; i<size; i++)
    {
	select = select_order[i];
	train_instance(select);
    }
}




/* ------------------------------------------------------------ */
/* TRAIN_SPECIFIED_GROUP:  Train a specified group of instances */
/* INPUTS: None.						*/
/* OUTPUT: None.  						*/

void train_specified_group(void)
{
    int i, training_instances = CONFIG.training_instances;
    float training_size = CONFIG.training_size;

    if (training_instances == 0)
	error(NO_TRAIN_ERR, NULL);
    
    if (training_size == 1.00)
    {
        for (i=0; i<training_instances; i++)
	    train_instance(i);
    }
    else train_subset();
}




/* ------------------------------------------------------------ */
/* TRAIN:  Train instances. Invoke the proper training routines */
/* INPUTS: None.						*/
/* OUTPUT: None. (Global distance tables are produced		*/

void train(void)
{
    initialize_training();

    switch (CONFIG.training_mode)
    {
        case SUBSET: 
          train_subset(); 
	  break;

	case SPECIFIED_GROUP: 
	  train_specified_group(); 
	  break;
    }

    build_distance_tables();
}




/* ------------------------------------------------------------ */
/* TEST:  Test all untrained instances				*/
/* INPUTS:  trial = # of the current trial;			*/

void test(void)
{
    int i;
    int instances = CONFIG.instances;
    int k = CONFIG.nearest_neighbor;
    int nearest_list[K_NEIGHBOR_MAX];
    float distances[K_NEIGHBOR_MAX];

    for (i=0; i<instances; i++)
    {
       	if (data[i].trained == FALSE) 
	{
 	    nearest_neighbor(i, k, nearest_list, distances, CONFIG.exemplar_weighting);
	    data[i].class_nearest = nearest_vote(k, nearest_list, distances);
	}
    }

    post_process();
}








/* ============================================================ */

main(int argc, char *argv[])
{

    int t, trials, train_once=FALSE, weight_once=FALSE;

    if (argc != 2)  error(USAGE_ERR, NULL);
    else 
    {
      	initialize(argv[1]);
	if (CONFIG.training_mode == LEAVE_ONE_OUT)
	    leave_one_out();
	else
	{
	    trials = CONFIG.trials;

				/* IF TRAINING SET NEVER CHANGES, TRAIN ONCE */

	    if ((CONFIG.training_mode == SPECIFIED_GROUP) &&
		(CONFIG.training_size == 1.00)) 
	    {
	        train();
		train_once = TRUE;
	    }
				/* IF EXEMPLAR WEIGHTS NEVER CHANGE, */
				/* SET WEIGHTS ONCE		     */

	    if (CONFIG.exemplar_weighting == ONE_PASS)
	    {
		set_exemplar_weights();
		weight_once = TRUE;
	    }


	    for (t=0; t<trials; t++)
	    {
		if (train_once == FALSE) train();
		if (weight_once == FALSE) set_exemplar_weights();
	        test();
		update_output(t);
	    }

	    print_output();
	}
    }
}



