LCOV - code coverage report
Current view: top level - src - gensvm_cross_validation.c (source / functions) Hit Total Coverage
Test: coverage.all Lines: 0 18 0.0 %
Date: 2017-02-21 18:44:20 Functions: 0 1 0.0 %

          Line data    Source code
       1             : /**
       2             :  * @file gensvm_cross_validation.c
       3             :  * @author G.J.J. van den Burg
       4             :  * @date 2016-10-24
       5             :  * @brief Function for running cross validation on GenModel
       6             :  *
       7             :  * @copyright
       8             :  Copyright 2016, G.J.J. van den Burg.
       9             : 
      10             :  This file is part of GenSVM.
      11             : 
      12             :  GenSVM is free software: you can redistribute it and/or modify
      13             :  it under the terms of the GNU General Public License as published by
      14             :  the Free Software Foundation, either version 3 of the License, or
      15             :  (at your option) any later version.
      16             : 
      17             :  GenSVM is distributed in the hope that it will be useful,
      18             :  but WITHOUT ANY WARRANTY; without even the implied warranty of
      19             :  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
      20             :  GNU General Public License for more details.
      21             : 
      22             :  You should have received a copy of the GNU General Public License
      23             :  along with GenSVM. If not, see <http://www.gnu.org/licenses/>.
      24             : 
      25             :  */
      26             : 
      27             : #include "gensvm_cross_validation.h"
      28             : 
      29             : extern FILE *GENSVM_OUTPUT_FILE;
      30             : 
      31             : /**
      32             :  * @brief Run cross validation with a given set of train/test folds
      33             :  *
      34             :  * @details
      35             :  * This cross validation function uses predefined train/test splits. Also, the
      36             :  * the optimal parameters GenModel::V of a previous fold as initial conditions
      37             :  * for GenModel::V of the next fold.
      38             :  *
      39             :  * @note
      40             :  * This function always sets the output stream defined in GENSVM_OUTPUT_FILE
      41             :  * to NULL, to ensure gensvm_optimize() doesn't print too much.
      42             :  *
      43             :  * @param[in]   model           GenModel with the configuration to train
      44             :  * @param[in]   train_folds     array of training datasets
      45             :  * @param[in]   test_folds      array of test datasets
      46             :  * @param[in]   folds           number of folds
      47             :  * @param[in]   n_total         number of objects in the union of the train
      48             :  *                              datasets
      49             :  * @return                      performance (hitrate) of the configuration on
      50             :  *                              cross validation
      51             :  */
      52           0 : double gensvm_cross_validation(struct GenModel *model,
      53             :                 struct GenData **train_folds, struct GenData **test_folds,
      54             :                 long folds, long n_total)
      55             : {
      56             :         long f;
      57           0 :         long *predy = NULL;
      58           0 :         double performance, total_perf = 0;
      59             : 
      60             :         // make sure that gensvm_optimize() is silent.
      61           0 :         FILE *fid = GENSVM_OUTPUT_FILE;
      62           0 :         GENSVM_OUTPUT_FILE = NULL;
      63             : 
      64             :         // run cross-validation
      65           0 :         for (f=0; f<folds; f++) {
      66             :                 // reallocate model in case dimensions differ with data
      67           0 :                 gensvm_reallocate_model(model, train_folds[f]->n,
      68           0 :                                 train_folds[f]->r);
      69             : 
      70             :                 // initialize object weights
      71           0 :                 gensvm_initialize_weights(train_folds[f], model);
      72             : 
      73             :                 // train the model (surpressing output)
      74           0 :                 gensvm_optimize(model, train_folds[f]);
      75             : 
      76             :                 // calculate prediction performance on test set
      77           0 :                 predy = Calloc(long, test_folds[f]->n);
      78           0 :                 gensvm_predict_labels(test_folds[f], model, predy);
      79           0 :                 performance = gensvm_prediction_perf(test_folds[f], predy);
      80           0 :                 total_perf += performance * test_folds[f]->n;
      81             : 
      82           0 :                 free(predy);
      83             :         }
      84             : 
      85           0 :         total_perf /= ((double) n_total);
      86             : 
      87             :         // reset the output stream
      88           0 :         GENSVM_OUTPUT_FILE = fid;
      89             : 
      90           0 :         return total_perf;
      91             : }

Generated by: LCOV version 1.12