/* Copyright (c) 1993 by The Johns Hopkins University */


/* METRIC.C:  Routines for defining the distance metric	*/
/*   		between two instances			*/





#include <stdio.h>
#include <math.h>

#include "config.h"
#include "pebls.h"

extern config_type CONFIG;				    
extern int count[CLASSES_MAX+1][FEATURES_MAX][VALUES_MAX];  
extern float dtable[FEATURES_MAX][VALUES_MAX][VALUES_MAX];
extern instance_type data[INSTANCES_MAX];



/* ------------------------------------------------------------ */
/* INITIALIZE_TRAINING: Clear Value Counts Statistics		*/

void initialize_training(void)
{
    int c,f,v,i;
    int instances = CONFIG.instances;

    for (c=0; c<CLASSES_MAX+1; c++)
	for (f=0; f<FEATURES_MAX; f++)
	    for (v=0; v<VALUES_MAX; v++)
		count[c][f][v] = 0;

    for (i=0; i<instances; i++)
    {
	data[i].trained = FALSE;
	data[i].weighted = FALSE;
	data[i].used = 1;
	data[i].correct = 1;
    }
}




/* ------------------------------------------------------------ */
/* PRINT_DISTANCE_TABLES (Prints the contents of all distance   */
/*   tables (For Debugging Purposes Only)			*/

void print_distance_tables(void)
{

    int f, v1, v2;
    int features, values;

    features = CONFIG.features;


    for (f=0; f<features; f++)
    {
	values = CONFIG.nvalues[f];
	printf("FEATURE: %d\n", f);
	for (v1=0; v1<values; v1++)
	{
	    for (v2=0; v2<values; v2++)
		printf("%.3f\t",dtable[f][v1][v2]);
	    printf("\n");
	}
	printf("\n\n");
    }
}




/* ------------------------------------------------------------ */
/* DTABLE_ENTRY: Computes distance table entry for a		*/
/*	specified feature and specified values			*/
/* INPUTS:  f  = feature #					*/
/*          v1 = value 1					*/
/*          v2 = value 2					*/


float dtable_entry(int f, int v1, int v2)
{
    int c, classes;
    float C1, C1i, C2, C2i, diff;
    float vdm = 0.0;
    float K;


    classes = CONFIG.classes;
    K = CONFIG.K;

    C1 = (float) count[classes][f][v1];
    C2 = (float) count[classes][f][v2];
    
    for (c=0; c<classes; c++)
    {
	C1i = (float) count[c][f][v1];
	C2i = (float) count[c][f][v2];

	if ((C1 == 0) || (C2 == 0)) 
	    diff = 1.0;
	else diff = (float) fabs(C1i/C1 - C2i/C2);

	if (K == 1) vdm += diff;
	else vdm += (float) pow(diff, K); 
    }

    return (vdm);
}







/* ------------------------------------------------------------ */
/* MVDM:  Modified Value Distance Metric			*/
/* INPUTS:  x = instance x					*/
/* 	    y = instance y					*/
/*          weighting = flag if weighting is ON/OFF		*/

float MVDM(int x, int y, int weighting)
{
    int i, xi, yi;
    float Wx, Wy, sum, R;
    int features;

    features = CONFIG.features;
    R = CONFIG.R;

    sum = 0.0;

    if (CONFIG.feature_weighting == OFF)
    {
        for (i=0; i<features; i++)
        {
 	  xi = data[x].value[i];
	  yi = data[y].value[i];
	  if (R==1) 
	   sum = sum + dtable[i][xi][yi];
	  else 
           sum = sum + (float) pow(dtable[i][xi][yi], R);
	}	
    }
    else
    {
        for (i=0; i<features; i++)
        {
 	  xi = data[x].value[i];
	  yi = data[y].value[i];
	  if (R==1) 
	   sum=sum+CONFIG.feature_weights[i] * dtable[i][xi][yi];
	  else 
           sum=sum+CONFIG.feature_weights[i]*(float) pow(dtable[i][xi][yi], R);
	}
    }

   if (weighting != OFF) 
   {
	Wx = data[x].weight;
	Wy = data[y].weight;
	if (Wy == 1.0) return (Wx * sum);
	else return (Wx * Wy * sum);
    }
   else return (sum);
}
    






/* ------------------------------------------------------------ */
/* BUILD_DISTANCE_TABLES:  Construct distance tables based on	*/
/*   statistical information accumulated during feature value   */
/*   counting.							*/


void build_distance_tables(void)
{

    int f,v1,v2;
    int values;
    int features = CONFIG.features;

    features = CONFIG.features;

    for (f=0; f<features; f++)
    {
	values = CONFIG.nvalues[f];
	for (v1=0; v1<values; v1++)
	    for (v2=0; v2<values; v2++)
	    {
	    	if (v1 == v2)  dtable[f][v1][v2] = 0.0;
    	    	else if (v1 > v2) dtable[f][v1][v2] = dtable[f][v2][v1];
	    	else dtable[f][v1][v2] = dtable_entry(f,v1,v2);
	    }
    }

}