/* Produce a vector-per-class description of the model data in a barrel */

/* Copyright (C) 1997 Andrew McCallum

   Written by:  Andrew Kachites McCallum <mccallum@cs.cmu.edu>

   This file is part of the Bag-Of-Words Library, `libbow'.

   This library is free software; you can redistribute it and/or
   modify it under the terms of the GNU Library General Public License
   as published by the Free Software Foundation, version 2.
   
   This library is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   Library General Public License for more details.

   You should have received a copy of the GNU Library General Public
   License along with this library; if not, write to the Free Software
   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111, USA */

#include <bow/libbow.h>


/* Given a barrel of documents, create and return another barrel with
   only one vector per class. The classes will be represented as
   "documents" in this new barrel. */
bow_barrel *
bow_barrel_new_vpc (bow_barrel *doc_barrel, const char **classnames)
{
  bow_barrel* vpc_barrel;	/* The vector per class barrel */
  int max_ci = -1;		/* The highest index of encountered classes */
  int wi;
  int max_wi;
  int dvi;
  int ci;
  bow_dv *dv;
  bow_dv *vpc_dv;
  int di;

  max_wi = MIN (doc_barrel->wi2dvf->size, bow_num_words ());

  /* Create an empty barrel; we fill fill it with vector-per-class
     data and return it. */
  vpc_barrel = bow_barrel_new (doc_barrel->wi2dvf->size,
			       doc_barrel->cdocs->length,
			       doc_barrel->cdocs->entry_size,
			       doc_barrel->cdocs->free_func);
  vpc_barrel->method = doc_barrel->method;

  bow_verbosify (bow_progress, "Making vector-per-class... words ::       ");

  /* Initialize the WI2DVF part of the VPC_BARREL.  Sum together the
     counts and weights for individual documents, grabbing only the
     training documents. */
  for (wi = 0; wi < max_wi; wi++)
    {
      dv = bow_wi2dvf_dv (doc_barrel->wi2dvf, wi);
      if (!dv)
	continue;
      for (dvi = 0; dvi < dv->length; dvi++)
	{
	  bow_cdoc *cdoc;
	  di = dv->entry[dvi].di;
	  cdoc = bow_array_entry_at_index (doc_barrel->cdocs, di);
	  ci = cdoc->class;
	  assert (ci >= 0);
	  if (ci > max_ci)
	    max_ci = ci;
	  if (cdoc->type == model)
	    bow_wi2dvf_add_wi_di_count_weight (&(vpc_barrel->wi2dvf), 
					       wi, ci, 
					       dv->entry[dvi].count,
					       dv->entry[dvi].weight);
	}
      /* Set the IDF of the class's wi2dvf directly from the doc's wi2dvf */
      vpc_dv = bow_wi2dvf_dv (vpc_barrel->wi2dvf, wi);
      if (vpc_dv)		/* xxx Why would this be NULL? */
	vpc_dv->idf = dv->idf;
      if (wi % 100 == 0)
	bow_verbosify (bow_progress, "\b\b\b\b\b\b%6d", max_wi - wi);
    }
  assert (max_ci > 0);

  /* Initialize the CDOCS part of the VPC_BARREL.  Create BOW_CDOC
     structures for each class, and append them to the VPC->CDOCS
     array. */
  for (ci = 0; ci <= max_ci; ci++)
    {
      bow_cdoc cdoc;
      cdoc.type = model;
      cdoc.normalizer = -1.0f;
      if (classnames)
	{
	  assert (classnames[ci]);
	  cdoc.filename = strdup (classnames[ci]);
	  if (!cdoc.filename)
	    bow_error ("Memory exhausted.");
	}
      else
	{
	  cdoc.filename = NULL;
	}
      cdoc.class = ci;
      bow_array_append (vpc_barrel->cdocs, &cdoc);
    }

  if (doc_barrel->method->vpc_set_priors)
    {
      /* Set the prior probabilities on classes, if we're doing
	 NaiveBayes or something else that needs them.  */
      (*doc_barrel->method->vpc_set_priors) (vpc_barrel, doc_barrel);
    }
  else
    {
      /* We don't need priors for the other methods.  Set them to
	 obviously bogus values, so we'll notice if they accidently
	 get used. */
      for (ci = 0; ci <= max_ci; ci++)
	{
	  bow_cdoc *cdoc;
	  cdoc = bow_array_entry_at_index (vpc_barrel->cdocs, ci);
	  cdoc->prior = -1;
	}
    }

  bow_verbosify (bow_progress, "\n");

  return vpc_barrel;
}

/* Like bow_barrel_new_vpc(), but it also sets and normalizes the
   weights appropriately by calling SET_WEIGHTS from the METHOD of
   DOC_BARREL on the `vector-per-class' barrel that will be returned. */
bow_barrel *
bow_barrel_new_vpc_merge_then_weight (bow_barrel *doc_barrel, 
				      const char **classnames)
{
  bow_barrel *vpc_barrel;

  assert (doc_barrel->method->name);
  /* Merge documents into classes, then set weights. */
  vpc_barrel = bow_barrel_new_vpc (doc_barrel, classnames);
  bow_barrel_set_weights (vpc_barrel);
  /* Scale the weights */
  bow_barrel_scale_weights (vpc_barrel, doc_barrel);
  /* Normalize the weights. */
  bow_barrel_normalize_weights (vpc_barrel);
  return vpc_barrel;
}

/* Same as above, but set the weights in the DOC_BARREL, create the
   `Vector-Per-Class' barrel, and set the weights in the VPC barrel by
   summing weights from the DOC_BARREL. */
bow_barrel *
bow_barrel_new_vpc_weight_then_merge (bow_barrel *doc_barrel, 
				      const char **classnames)
{
  bow_barrel *vpc_barrel;

  /* Set weights, then merge documents into classes. */
  bow_barrel_set_weights (doc_barrel);
  vpc_barrel = bow_barrel_new_vpc (doc_barrel, classnames);
  bow_barrel_scale_weights (vpc_barrel, doc_barrel);
  bow_barrel_normalize_weights (vpc_barrel);
  return vpc_barrel;
}

/* Set the class prior probabilities by counting the number of
   documents of each class. */
void
bow_barrel_set_vpc_priors_by_counting (bow_barrel *vpc_barrel,
				       bow_barrel *doc_barrel)
					
{
  float prior_sum = 0;
  int ci;
  int max_ci = vpc_barrel->cdocs->length - 1;
  int di;

  /* Zero them. */
  for (ci = 0; ci <= max_ci; ci++)
    {
      bow_cdoc *cdoc;
      cdoc = bow_array_entry_at_index (vpc_barrel->cdocs, ci);
      cdoc->prior = 0;
    }
  /* Add in document counts. */
  for (di = 0; di < doc_barrel->cdocs->length; di++)
    {
      bow_cdoc *doc_cdoc;
      bow_cdoc *vpc_cdoc;
      doc_cdoc = bow_array_entry_at_index (doc_barrel->cdocs, di);
      if (doc_cdoc->class >= vpc_barrel->cdocs->length)
	{
	  /* This can happen if all of the documents in a certain class
	     contain only words that are not in the vocabulary used
	     when running bow_barrel_new_vpc() above. */
	  continue;
	}
      vpc_cdoc = bow_array_entry_at_index (vpc_barrel->cdocs, 
					   doc_cdoc->class);
      vpc_cdoc->prior += doc_cdoc->prior;
    }
  /* Sum them all. */
  for (ci = 0; ci <= max_ci; ci++)
    {
      bow_cdoc *cdoc;
      cdoc = bow_array_entry_at_index (vpc_barrel->cdocs, ci);
      prior_sum += cdoc->prior;
    }
  /* Normalize to set the prior. */
  for (ci = 0; ci <= max_ci; ci++)
    {
      bow_cdoc *cdoc;
      cdoc = bow_array_entry_at_index (vpc_barrel->cdocs, ci);
      cdoc->prior /= prior_sum;
    }
}