/* Copyright (c) 1993 by The Johns Hopkins University */



/* WEIGHTS.C:  Feature and Exemplar Weighting Routines */


#include <stdio.h>
#include "config.h"
#include "pebls.h"


extern config_type CONFIG;				    
extern instance_type data[INSTANCES_MAX];
extern float dtable[FEATURES_MAX][VALUES_MAX][VALUES_MAX];


/* ------------------------------------------------------------ */
void genetic_feature_weights(int count, float adjust)
{
    float adj;
    int select, nearest, nearest_list[K_NEIGHBOR_MAX];
    int i, instances = CONFIG.training_instances;
    int f, features = CONFIG.features;
    float distances[K_NEIGHBOR_MAX];

    for (i=0; i<instances; i++)
      data[i].trained = TRUE;

    for (f=0; f<features; f++)
        CONFIG.feature_weights[f] = 1.00;

    for (i=0; i<count; i++)
    {
	adj = f_random(adjust * 2.0) - adjust;
	f = i_random(features);
	CONFIG.feature_weights[f] += adj;
	select = i_random(instances);

	nearest_neighbor(select, 1, nearest_list, distances, OFF);
	nearest = nearest_list[0];

	if (data[select].class_true != data[nearest].class_true)
	    CONFIG.feature_weights[f] -= adj;
    }

    for (i=0; i<instances; i++)
      data[i].trained = FALSE;

    for (i=0; i<CONFIG.features; i++)
	printf("Feature Weight %3d %.2f\n", i, CONFIG.feature_weights[i]);
}





/* ------------------------------------------------------------ */
/* USER_FEATURE_1:  Declared function for user-defined         */
/* feature weighting						*/

void user_feature_1(void)
{

    /* Your function here */

}


/* ------------------------------------------------------------ */
/* USER_FEATURE_2:  Declared function for user-defined         */
/* feature weighting						*/

void user_feature_2(void)
{

    /* Your function here */

}



/* ------------------------------------------------------------ */
/* USER_FEATURE_3:  Declared function for user-defined         */
/* feature weighting						*/

void user_feature_3(void)
{

    /* Your function here */

}





/* ------------------------------------------------------------ */
/* SET_FEATURE_WEIGHTS:  Set the shape of the feature weights   */
/* INPUTS:  shape = PEBLS Constant specifying shape (e.g, TRIANGLE) */

void set_feature_weights(int shape)
{
    int f, features = CONFIG.features;
    float df, dfeat;
    float sum;
    int v1, v2, values;



    dfeat = (float) features;

    switch(shape)
    {
	case TRIANGLE: 
	    for (f=0; f<features; f++)
	    {
		df = (float) f;
		if (df < dfeat / 2.0)
		  CONFIG.feature_weights[f] = 2.0/(dfeat+1) * (df + 1);
		else
		  CONFIG.feature_weights[f] = 2.0 - 2.0/(dfeat+1) * (df + 1);
	    }
	    break;


	case GENETIC:
	    if (CONFIG.training_mode != SPECIFIED_GROUP)
		error(GENETIC_ERR, NULL);
	    else genetic_feature_weights(CONFIG.genetic_count, 
					 CONFIG.genetic_adj);

	    break;


	case USER_FEATURE_1: user_feature_1();

	case USER_FEATURE_2: user_feature_2();

	case USER_FEATURE_3: user_feature_3();
    }
}




/* ------------------------------------------------------------ */
/* PRINT_EXEMPLAR_WEIGHTS					*/
/* (For Debugging Purposes Only)				*/

void print_exemplar_weights(void)
{
    int i;
    int instances = CONFIG.instances;

    for (i=0; i<instances; i++)
      if (data[i].trained) 
	printf("%3d %15s %5d %5d %5.2f\n",
	       i,data[i].id,data[i].used, data[i].correct, data[i].weight);
}





/* ------------------------------------------------------------ */
void exemplar_weights_used_correct(void)
{

    int select, select_order[INSTANCES_MAX];
    int training_instances = CONFIG.training_instances;
    int i, first = TRUE;
    int nearest, nearest_list[K_NEIGHBOR_MAX];
    float distances[K_NEIGHBOR_MAX];

    shuffle(select_order, training_instances);
    for (i=0; i<training_instances; i++)
	data[i].weighted = FALSE;

    for (i=0; i<training_instances; i++)
    {
	select = select_order[i];
	if (data[i].trained)
	{
            if (first) 
	    {
		data[select].used = 1;
		data[select].correct = 1;
	        first = FALSE;
	    }
	    else
	    {	
	        nearest_neighbor(select, 1, nearest_list, distances, ON);
		nearest = nearest_list[0];
	        data[nearest].used++;
	        if (data[select].class_true == data[nearest].class_true)
	            data[nearest].correct++;
	        data[select].used = data[nearest].used;
	        data[select].correct = data[nearest].correct;
            }
	    data[select].weighted = TRUE;
	}
    }
    
    for (i=0; i<training_instances; i++)
	data[i].weight = (float) data[i].used / (float) data[i].correct;

}



/* ------------------------------------------------------------ */
int count_matching_class(int instance, int k, int nearest[])
{
    int i, match=0;

    for (i=0; i<k; i++)
      if (data[nearest[i]].class_true == data[instance].class_true) match++;

    return(match);
}

    

/* ------------------------------------------------------------ */
void exemplar_weights_one_pass(void)
{
    int training_instances = CONFIG.training_instances;
    int i;
    int nearest_list[K_NEIGHBOR_MAX];
    float distances[K_NEIGHBOR_MAX];
    int k = CONFIG.nearest_neighbor;

    for (i=0; i<training_instances; i++)
	data[i].weighted = FALSE;

    for (i=0; i<training_instances; i++)
    {
	if (data[i].trained)
	{
	    nearest_neighbor(i, k, nearest_list, distances, OFF);
	    data[i].weight = (float) (k + 1 - count_matching_class(i, k, nearest_list));
	    data[i].weighted = TRUE;
	}
    }
}

    

/* ------------------------------------------------------------ */
void exemplar_weights_increment(void)
{
    int training_instances = CONFIG.training_instances;
    int i;
    int nearest, nearest_list[K_NEIGHBOR_MAX];
    float distances[K_NEIGHBOR_MAX];
    int k = CONFIG.nearest_neighbor;

    for (i=0; i<training_instances; i++)
    {
	if (data[i].trained)
	{
	    nearest_neighbor(i, 1, nearest_list, distances, OFF);
	    nearest = nearest_list[0];

	    if (data[i].class_true != data[nearest].class_true)
		data[nearest].weight += 1.00;
	}
    }
}


/* ------------------------------------------------------------ */
/* USER_EXEMPLAR_1:  Declared function for user-defined         */
/* exemplar weighting						*/

void user_exemplar_1(void)
{

    /* Your function here */

}



/* ------------------------------------------------------------ */
/* USER_EXEMPLAR_2:  Declared function for user-defined         */
/* exemplar weighting						*/

void user_exemplar_2(void)
{

    /* Your function here */

}



/* ------------------------------------------------------------ */
/* USER_EXEMPLAR_3:  Declared function for user-defined         */
/* exemplar weighting						*/

void user_exemplar_3(void)
{

    /* Your function here */

}



/* ------------------------------------------------------------ */
void set_exemplar_weights(void)
{
    int method = CONFIG.exemplar_weighting;

    switch (method)
    {
    	case OFF:	    	/* do nothing */ break;
	case USED_CORRECT:  	exemplar_weights_used_correct(); break;
	case ONE_PASS:      	exemplar_weights_one_pass(); break;
	case INCREMENT:	    	exemplar_weights_increment(); break;
	case USER_EXEMPLAR_1: 	user_exemplar_1();
	case USER_EXEMPLAR_2: 	user_exemplar_2();
	case USER_EXEMPLAR_3: 	user_exemplar_3();
	default:	    	error(UNK_EWEIGHT_ERR, NULL);
    }

}