Line data Source code
1 : /**
2 : * @file gensvm_cross_validation.c
3 : * @author G.J.J. van den Burg
4 : * @date 2016-10-24
5 : * @brief Function for running cross validation on GenModel
6 : *
7 : * @copyright
8 : Copyright 2016, G.J.J. van den Burg.
9 :
10 : This file is part of GenSVM.
11 :
12 : GenSVM is free software: you can redistribute it and/or modify
13 : it under the terms of the GNU General Public License as published by
14 : the Free Software Foundation, either version 3 of the License, or
15 : (at your option) any later version.
16 :
17 : GenSVM is distributed in the hope that it will be useful,
18 : but WITHOUT ANY WARRANTY; without even the implied warranty of
19 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20 : GNU General Public License for more details.
21 :
22 : You should have received a copy of the GNU General Public License
23 : along with GenSVM. If not, see <http://www.gnu.org/licenses/>.
24 :
25 : */
26 :
27 : #include "gensvm_cross_validation.h"
28 :
29 : extern FILE *GENSVM_OUTPUT_FILE;
30 :
31 : /**
32 : * @brief Run cross validation with a given set of train/test folds
33 : *
34 : * @details
35 : * This cross validation function uses predefined train/test splits. Also, the
36 : * the optimal parameters GenModel::V of a previous fold as initial conditions
37 : * for GenModel::V of the next fold.
38 : *
39 : * @note
40 : * This function always sets the output stream defined in GENSVM_OUTPUT_FILE
41 : * to NULL, to ensure gensvm_optimize() doesn't print too much.
42 : *
43 : * @param[in] model GenModel with the configuration to train
44 : * @param[in] train_folds array of training datasets
45 : * @param[in] test_folds array of test datasets
46 : * @param[in] folds number of folds
47 : * @param[in] n_total number of objects in the union of the train
48 : * datasets
49 : * @return performance (hitrate) of the configuration on
50 : * cross validation
51 : */
52 0 : double gensvm_cross_validation(struct GenModel *model,
53 : struct GenData **train_folds, struct GenData **test_folds,
54 : long folds, long n_total)
55 : {
56 : long f;
57 0 : long *predy = NULL;
58 0 : double performance, total_perf = 0;
59 :
60 : // make sure that gensvm_optimize() is silent.
61 0 : FILE *fid = GENSVM_OUTPUT_FILE;
62 0 : GENSVM_OUTPUT_FILE = NULL;
63 :
64 : // run cross-validation
65 0 : for (f=0; f<folds; f++) {
66 : // reallocate model in case dimensions differ with data
67 0 : gensvm_reallocate_model(model, train_folds[f]->n,
68 0 : train_folds[f]->r);
69 :
70 : // initialize object weights
71 0 : gensvm_initialize_weights(train_folds[f], model);
72 :
73 : // train the model (surpressing output)
74 0 : gensvm_optimize(model, train_folds[f]);
75 :
76 : // calculate prediction performance on test set
77 0 : predy = Calloc(long, test_folds[f]->n);
78 0 : gensvm_predict_labels(test_folds[f], model, predy);
79 0 : performance = gensvm_prediction_perf(test_folds[f], predy);
80 0 : total_perf += performance * test_folds[f]->n;
81 :
82 0 : free(predy);
83 : }
84 :
85 0 : total_perf /= ((double) n_total);
86 :
87 : // reset the output stream
88 0 : GENSVM_OUTPUT_FILE = fid;
89 :
90 0 : return total_perf;
91 : }
|