/* Copyright (c) 1993 by The Johns Hopkins University */
 


 
/* NEIGHBOR.C:  Routines for finding the k nearest neighbors, and */
/*   for voting based on those neighbors			*/



/* ------------------------------------------------------------ */
/* NEAREST_THRESHOLD_VOTE					*/


#include "config.h"
#include "pebls.h"

extern config_type CONFIG;
extern instance_type data[INSTANCES_MAX];

int nearest_threshold_vote(int k, int nearest[])
{

    int class_count[CLASSES_MAX];
    int i, p, classes = CONFIG.classes;
    int class_nearest;

    for (i=0; i<classes; i++)
	class_count[i] = 0;

    for (i=0; i<k; i++)
	class_count[data[nearest[i]].class_true]++;
    

    for (i=0; i<classes; i++)
    {
	p = CONFIG.precedence[i]; 
        if (class_count[p] >= CONFIG.threshold[i]) class_nearest = p; 
    }

    return (class_nearest);
}








/* ------------------------------------------------------------ */
/* NEAREST_WEIGHTED_DISTANCE_VOTE:  Give each neighbor a vote   */
/* inversely proportional to its distance.			*/

int nearest_weighted_distance_vote(int nearest[], float distances[])
{
    float class_distance[CLASSES_MAX];
    int i, classes = CONFIG.classes;
    int vote_class;
    float vote_size;

    for (i=0; i<classes; i++)
	class_distance[i] = 0.0;

    for (i=0; i<classes; i++)
	class_distance[data[nearest[i]].class_true] += 1/distances[i];

    vote_size = -1.0;

    for (i=0; i<classes; i++)
    {
	if ((class_distance[i] > vote_size) &&
	    (class_distance[i] != 0.0))
	{
	    vote_size = class_distance[i];
	    vote_class = i;
	}
    }

    return (vote_class);

}


/* ------------------------------------------------------------ */
/* NEAREST_MAJORITY_VOTE: Find the majority class in the set    */
/*    set of nearest neighbors.					*/
/* INPUTS: k = # neighbors					*/
/*	   nearest = set of nearest neighbors			*/
/* OUTPUT: The majority class of the k nearest neighbors.	*/

int nearest_majority_vote(int k, int nearest[])
{
    int class_count[CLASSES_MAX];
    int i, classes = CONFIG.classes;
    int majority_class, majority_size;

    for (i=0; i<classes; i++)
	class_count[i] = 0;

    for (i=0; i<k; i++)
	class_count[data[nearest[i]].class_true]++;

    majority_size = class_count[0];
    majority_class = 0;

    for (i=1; i<classes; i++)
    {
	if ((class_count[i] > majority_size) ||
	    ((class_count[i] == majority_size) &&
	     (f_random(1.0) > 0.500)))
	{
	    majority_size = class_count[i];
	    majority_class = i;
	}
    }

    return (majority_class);
}


/* ------------------------------------------------------------ */
/* NEAREST_VOTE:  Vote on the nearest neighbor.			*/

int nearest_vote(int k, int nearest_list[], float distances[])
{
    int nearest_class;

    switch (CONFIG.nearest_voting)
    {
	case MAJORITY:   
	    nearest_class = nearest_majority_vote(k, nearest_list);
	    break;

	case WEIGHTED_DISTANCE:                            /* k-neighbors? */
	    nearest_class = nearest_weighted_distance_vote(nearest_list, distances);
	    break;

	case THRESHOLD:
	    nearest_class = nearest_threshold_vote(k, nearest_list);
	    break;
    }

    return (nearest_class);
}






/* ------------------------------------------------------------ */
int partition(float A[], int p, int r)
{
    int i,j;
    float x, temp;

    x = A[p];
    i = p - 1;
    j = r + 1;

    while (TRUE)
    {
	do j--; while (A[j] > x);
	do i++; while (A[i] < x);

	if (i < j)
	{
	    temp = A[i];
	    A[i] = A[j];
	    A[j] = temp;
	}
        else return (j);
    }
}

	

/* ------------------------------------------------------------ */
int randomized_partition(float A[], int p, int r)
{
    int i;
    float temp;

    i = p + i_random(r - p + 1);

    temp = A[p];
    A[p] = A[i];
    A[i] = temp;
    return(partition(A,p,r));
}




/* ------------------------------------------------------------ */
float randomized_select(float A[], int p, int r, int i)
{
    int k, q;

    if (p==r) return A[p];
    q = randomized_partition(A,p,r);
    k = q - p + 1;

    if (i <= k) 
        return (randomized_select(A,p,q,i));
    else        
        return (randomized_select(A,q+1,r,i-k));
}
	


/* ------------------------------------------------------------ */
void find_nearest(int k, float distance[], int nearest_idx[], 
			 float nearest_dist[])
{
    float d;
    float distance_copy[INSTANCES_MAX];
    int i, j, kth;
    int instances = CONFIG.instances;

    if (k==1)
    {
  	kth = fminimum(distance, instances);
	nearest_idx[0] = kth;
	nearest_dist[0] = distance[kth];
    }

    else 

    {
	array_copy(distance, distance_copy, CONFIG.instances);
        d = randomized_select(distance_copy, 0, CONFIG.instances-1, k);

        for (i=0,j=0; i<instances; i++)
        {
	    if (distance[i] <= d)
	    {
	        nearest_idx[j] = i;
	        nearest_dist[j] = distance[i];
	        j++;
	    }
        }
    }
}



/* ------------------------------------------------------------ */
/* NEAREST_NEIGHBOR:  Find set of nearest neighbors		*/
/* INPUTS:	current    = current instance			*/
/*		k	   = # of nearest neighbors		*/
/*		nearest_idx = index of nearest neighbors	    	*/
/*		nearest_dist = distances of nearest neighbors   */
/*		weighting  = weighting is ON or OFF		*/

void nearest_neighbor(int current, int k, int nearest_idx[], 
		      float nearest_dist[], int weighting)
{
    float distance[INSTANCES_MAX];
    int i, instances = CONFIG.instances;


    for (i=0; i<instances; i++)
    {
	if ((data[i].trained) && (i != current) &&
	    ((CONFIG.exemplar_weighting != USED_CORRECT) || 
	     (data[i].weighted == TRUE)))
	  {
	    switch (CONFIG.classify_mode)
	      {
	      case PEBLS: distance[i] = MVDM(i, current, weighting); break;
	      case OVERLAP: distance[i] = overlap(i, current); break;
	      case EUCLIDEAN: distance[i] = euclidean(i, current); break;
	      case MANHATTAN: distance[i] = manhattan(i, current); break;
	      }
	  }
	else
	  distance[i] = INFINITY;
    }

    find_nearest(k, distance, nearest_idx, nearest_dist);
}




