GenSVM
test_gensvm_cv_util.c
Go to the documentation of this file.
1 
27 #include "minunit.h"
28 #include "gensvm_cv_util.h"
29 
31 {
32  srand(0);
33  int i, j;
34  long N = 10;
35  long folds = 4;
36  long *cv_idx = Calloc(long, N);
37 
38  // start test code //
39  gensvm_make_cv_split(N, folds, cv_idx);
40  // check if the values are between [0, folds-1]
41  for (i=0; i<N; i++)
42  mu_assert(0 <= cv_idx[i] && cv_idx[i] < folds,
43  "CV range incorrect.");
44 
45  // check there are N % folds big folds of size floor(N/folds) + 1
46  // and the remaining are of size floor(N/folds)
47  int sum;
48  int is_big = 0,
49  is_small = 0;
50  for (i=0; i<folds; i++) {
51  sum = 0;
52  for (j=0; j<N; j++) {
53  if (cv_idx[j] == i) sum += 1;
54  }
55  if (sum == floor(N/folds) + 1)
56  is_big++;
57  else
58  is_small++;
59  }
60  mu_assert(is_big == N % folds, "Incorrect number of big folds");
61  mu_assert(is_small == folds - N % folds,
62  "Incorrect number of small folds");
63 
64  // end test code //
65 
66  free(cv_idx);
67 
68  return NULL;
69 }
70 
72 {
73  srand(0);
74  int i, j;
75  long N = 101;
76  long folds = 7;
77  long *cv_idx = Calloc(long, N);
78 
79  // start test code //
80  gensvm_make_cv_split(N, folds, cv_idx);
81  // check if the values are between [0, folds-1]
82  for (i=0; i<N; i++)
83  mu_assert(0 <= cv_idx[i] && cv_idx[i] < folds,
84  "CV range incorrect.");
85 
86  // check there are N % folds big folds of size floor(N/folds) + 1
87  // and the remaining are of size floor(N/folds)
88  int sum;
89  int is_big = 0,
90  is_small = 0;
91  for (i=0; i<folds; i++) {
92  sum = 0;
93  for (j=0; j<N; j++) {
94  if (cv_idx[j] == i) sum += 1;
95  }
96  if (sum == floor(N/folds) + 1)
97  is_big++;
98  else
99  is_small++;
100  }
101  mu_assert(is_big == N % folds, "Incorrect number of big folds");
102  mu_assert(is_small == folds - N % folds,
103  "Incorrect number of small folds");
104 
105  // end test code //
106 
107  free(cv_idx);
108 
109  return NULL;
110 }
111 
113 {
114  struct GenData *full = gensvm_init_data();
115  full->K = 3;
116  full->n = 10;
117  full->m = 2;
118  full->r = 2;
119 
120  full->y = Calloc(long, full->n);
121  full->y[0] = 1;
122  full->y[1] = 2;
123  full->y[2] = 3;
124  full->y[3] = 1;
125  full->y[4] = 2;
126  full->y[5] = 3;
127  full->y[6] = 1;
128  full->y[7] = 2;
129  full->y[8] = 3;
130  full->y[9] = 1;
131 
132  full->RAW = Calloc(double, full->n * (full->m+1));
133  matrix_set(full->RAW, full->m+1, 0, 1, 1.0);
134  matrix_set(full->RAW, full->m+1, 0, 2, 1.0);
135  matrix_set(full->RAW, full->m+1, 1, 1, 2.0);
136  matrix_set(full->RAW, full->m+1, 1, 2, 2.0);
137  matrix_set(full->RAW, full->m+1, 2, 1, 3.0);
138  matrix_set(full->RAW, full->m+1, 2, 2, 3.0);
139  matrix_set(full->RAW, full->m+1, 3, 1, 4.0);
140  matrix_set(full->RAW, full->m+1, 3, 2, 4.0);
141  matrix_set(full->RAW, full->m+1, 4, 1, 5.0);
142  matrix_set(full->RAW, full->m+1, 4, 2, 5.0);
143  matrix_set(full->RAW, full->m+1, 5, 1, 6.0);
144  matrix_set(full->RAW, full->m+1, 5, 2, 6.0);
145  matrix_set(full->RAW, full->m+1, 6, 1, 7.0);
146  matrix_set(full->RAW, full->m+1, 6, 2, 7.0);
147  matrix_set(full->RAW, full->m+1, 7, 1, 8.0);
148  matrix_set(full->RAW, full->m+1, 7, 2, 8.0);
149  matrix_set(full->RAW, full->m+1, 8, 1, 9.0);
150  matrix_set(full->RAW, full->m+1, 8, 2, 9.0);
151  matrix_set(full->RAW, full->m+1, 9, 1, 10.0);
152  matrix_set(full->RAW, full->m+1, 9, 2, 10.0);
153  full->Z = full->RAW;
154 
155  long *cv_idx = Calloc(long, full->n);
156  cv_idx[0] = 1;
157  cv_idx[1] = 0;
158  cv_idx[2] = 1;
159  cv_idx[3] = 0;
160  cv_idx[4] = 1;
161  cv_idx[5] = 2;
162  cv_idx[6] = 3;
163  cv_idx[7] = 2;
164  cv_idx[8] = 3;
165  cv_idx[9] = 2;
166 
167  struct GenData *train = gensvm_init_data();
168  struct GenData *test = gensvm_init_data();
169 
170  // start test code //
171  gensvm_get_tt_split(full, train, test, cv_idx, 0);
172 
173  mu_assert(train->n == 8, "train_n incorrect.");
174  mu_assert(test->n == 2, "test_n incorrect.");
175 
176  mu_assert(train->m == 2, "train_m incorrect.");
177  mu_assert(test->m == 2, "test_m incorrect.");
178 
179  mu_assert(train->K == 3, "train_K incorrect.");
180  mu_assert(test->K == 3, "test_K incorrect.");
181 
182  mu_assert(train->y[0] == 1, "train y incorrect.");
183  mu_assert(train->y[1] == 3, "train y incorrect.");
184  mu_assert(train->y[2] == 2, "train y incorrect.");
185  mu_assert(train->y[3] == 3, "train y incorrect.");
186  mu_assert(train->y[4] == 1, "train y incorrect.");
187  mu_assert(train->y[5] == 2, "train y incorrect.");
188  mu_assert(train->y[6] == 3, "train y incorrect.");
189  mu_assert(train->y[7] == 1, "train y incorrect.");
190 
191  mu_assert(test->y[0] == 2, "test y incorrect.");
192  mu_assert(test->y[1] == 1, "test y incorrect.");
193 
194  mu_assert(matrix_get(train->RAW, train->m+1, 0, 0) == 0.0,
195  "train RAW 0, 0 incorrect.");
196  mu_assert(matrix_get(train->RAW, train->m+1, 0, 1) == 1.0,
197  "train RAW 0, 1 incorrect.");
198  mu_assert(matrix_get(train->RAW, train->m+1, 0, 2) == 1.0,
199  "train RAW 0, 2 incorrect.");
200  mu_assert(matrix_get(train->RAW, train->m+1, 1, 0) == 0.0,
201  "train RAW 1, 0 incorrect.");
202  mu_assert(matrix_get(train->RAW, train->m+1, 1, 1) == 3.0,
203  "train RAW 1, 1 incorrect.");
204  mu_assert(matrix_get(train->RAW, train->m+1, 1, 2) == 3.0,
205  "train RAW 1, 2 incorrect.");
206  mu_assert(matrix_get(train->RAW, train->m+1, 2, 0) == 0.0,
207  "train RAW 2, 0 incorrect.");
208  mu_assert(matrix_get(train->RAW, train->m+1, 2, 1) == 5.0,
209  "train RAW 2, 1 incorrect.");
210  mu_assert(matrix_get(train->RAW, train->m+1, 2, 2) == 5.0,
211  "train RAW 2, 2 incorrect.");
212  mu_assert(matrix_get(train->RAW, train->m+1, 3, 0) == 0.0,
213  "train RAW 3, 0 incorrect.");
214  mu_assert(matrix_get(train->RAW, train->m+1, 3, 1) == 6.0,
215  "train RAW 3, 1 incorrect.");
216  mu_assert(matrix_get(train->RAW, train->m+1, 3, 2) == 6.0,
217  "train RAW 3, 2 incorrect.");
218  mu_assert(matrix_get(train->RAW, train->m+1, 4, 0) == 0.0,
219  "train RAW 4, 0 incorrect.");
220  mu_assert(matrix_get(train->RAW, train->m+1, 4, 1) == 7.0,
221  "train RAW 4, 1 incorrect.");
222  mu_assert(matrix_get(train->RAW, train->m+1, 4, 2) == 7.0,
223  "train RAW 4, 2 incorrect.");
224  mu_assert(matrix_get(train->RAW, train->m+1, 5, 0) == 0.0,
225  "train RAW 5, 0 incorrect.");
226  mu_assert(matrix_get(train->RAW, train->m+1, 5, 1) == 8.0,
227  "train RAW 5, 1 incorrect.");
228  mu_assert(matrix_get(train->RAW, train->m+1, 5, 2) == 8.0,
229  "train RAW 5, 2 incorrect.");
230  mu_assert(matrix_get(train->RAW, train->m+1, 6, 0) == 0.0,
231  "train RAW 6, 0 incorrect.");
232  mu_assert(matrix_get(train->RAW, train->m+1, 6, 1) == 9.0,
233  "train RAW 6, 1 incorrect.");
234  mu_assert(matrix_get(train->RAW, train->m+1, 6, 2) == 9.0,
235  "train RAW 6, 2 incorrect.");
236  mu_assert(matrix_get(train->RAW, train->m+1, 7, 0) == 0.0,
237  "train RAW 7, 0 incorrect.");
238  mu_assert(matrix_get(train->RAW, train->m+1, 7, 1) == 10.0,
239  "train RAW 7, 1 incorrect.");
240  mu_assert(matrix_get(train->RAW, train->m+1, 7, 2) == 10.0,
241  "train RAW 7, 2 incorrect.");
242 
243  mu_assert(matrix_get(test->RAW, train->m+1, 0, 0) == 0.0,
244  "test RAW 0, 0 incorrect.");
245  mu_assert(matrix_get(test->RAW, train->m+1, 0, 1) == 2.0,
246  "test RAW 0, 1 incorrect.");
247  mu_assert(matrix_get(test->RAW, train->m+1, 0, 2) == 2.0,
248  "test RAW 0, 2 incorrect.");
249  mu_assert(matrix_get(test->RAW, train->m+1, 1, 0) == 0.0,
250  "test RAW 1, 0 incorrect.");
251  mu_assert(matrix_get(test->RAW, train->m+1, 1, 1) == 4.0,
252  "test RAW 1, 1 incorrect.");
253  mu_assert(matrix_get(test->RAW, train->m+1, 1, 2) == 4.0,
254  "test RAW 1, 2 incorrect.");
255 
256  // end test code //
257  gensvm_free_data(full);
258  gensvm_free_data(train);
259  gensvm_free_data(test);
260  free(cv_idx);
261 
262  return NULL;
263 }
264 
266 {
267  struct GenData *full = gensvm_init_data();
268  full->K = 3;
269  full->n = 10;
270  full->m = 2;
271  full->r = 2;
272 
273  full->y = Calloc(long, full->n);
274  full->y[0] = 1;
275  full->y[1] = 2;
276  full->y[2] = 3;
277  full->y[3] = 1;
278  full->y[4] = 2;
279  full->y[5] = 3;
280  full->y[6] = 1;
281  full->y[7] = 2;
282  full->y[8] = 3;
283  full->y[9] = 1;
284 
285  full->RAW = Calloc(double, full->n * (full->m+1));
286  matrix_set(full->RAW, full->m+1, 0, 1, 1.0);
287  matrix_set(full->RAW, full->m+1, 0, 2, 1.0);
288  matrix_set(full->RAW, full->m+1, 1, 1, 2.0);
289  matrix_set(full->RAW, full->m+1, 1, 2, 2.0);
290  matrix_set(full->RAW, full->m+1, 2, 1, 3.0);
291  matrix_set(full->RAW, full->m+1, 2, 2, 3.0);
292  matrix_set(full->RAW, full->m+1, 3, 1, 4.0);
293  matrix_set(full->RAW, full->m+1, 3, 2, 4.0);
294  matrix_set(full->RAW, full->m+1, 4, 1, 5.0);
295  matrix_set(full->RAW, full->m+1, 4, 2, 5.0);
296  matrix_set(full->RAW, full->m+1, 5, 1, 6.0);
297  matrix_set(full->RAW, full->m+1, 5, 2, 6.0);
298  matrix_set(full->RAW, full->m+1, 6, 1, 7.0);
299  matrix_set(full->RAW, full->m+1, 6, 2, 7.0);
300  matrix_set(full->RAW, full->m+1, 7, 1, 8.0);
301  matrix_set(full->RAW, full->m+1, 7, 2, 8.0);
302  matrix_set(full->RAW, full->m+1, 8, 1, 9.0);
303  matrix_set(full->RAW, full->m+1, 8, 2, 9.0);
304  matrix_set(full->RAW, full->m+1, 9, 1, 10.0);
305  matrix_set(full->RAW, full->m+1, 9, 2, 10.0);
306  full->Z = full->RAW;
307 
308  // convert Z to a sparse matrix to test the sparse functions
309  full->spZ = gensvm_dense_to_sparse(full->RAW, full->n, full->m+1);
310  free(full->RAW);
311  full->RAW = NULL;
312  full->Z = NULL;
313 
314  long *cv_idx = Calloc(long, full->n);
315  cv_idx[0] = 1;
316  cv_idx[1] = 0;
317  cv_idx[2] = 1;
318  cv_idx[3] = 0;
319  cv_idx[4] = 1;
320  cv_idx[5] = 2;
321  cv_idx[6] = 3;
322  cv_idx[7] = 2;
323  cv_idx[8] = 3;
324  cv_idx[9] = 2;
325 
326  struct GenData *train = gensvm_init_data();
327  struct GenData *test = gensvm_init_data();
328 
329  // start test code //
330  gensvm_get_tt_split(full, train, test, cv_idx, 0);
331 
332  mu_assert(train->n == 8, "train_n incorrect.");
333  mu_assert(test->n == 2, "test_n incorrect.");
334 
335  mu_assert(train->m == 2, "train_m incorrect.");
336  mu_assert(test->m == 2, "test_m incorrect.");
337 
338  mu_assert(train->K == 3, "train_K incorrect.");
339  mu_assert(test->K == 3, "test_K incorrect.");
340 
341  mu_assert(train->y[0] == 1, "train y incorrect.");
342  mu_assert(train->y[1] == 3, "train y incorrect.");
343  mu_assert(train->y[2] == 2, "train y incorrect.");
344  mu_assert(train->y[3] == 3, "train y incorrect.");
345  mu_assert(train->y[4] == 1, "train y incorrect.");
346  mu_assert(train->y[5] == 2, "train y incorrect.");
347  mu_assert(train->y[6] == 3, "train y incorrect.");
348  mu_assert(train->y[7] == 1, "train y incorrect.");
349 
350  mu_assert(test->y[0] == 2, "test y incorrect.");
351  mu_assert(test->y[1] == 1, "test y incorrect.");
352 
353  // check the train GenSparse struct
354  mu_assert(train->spZ->nnz == 16, "train nnz incorrect");
355  mu_assert(train->spZ->n_row == 8, "train n_row incorrect");
356  mu_assert(train->spZ->n_col == 3, "train n_col incorrect");
357 
358  mu_assert(train->spZ->values[0] == 1.0, "Wrong train value at 0");
359  mu_assert(train->spZ->values[1] == 1.0, "Wrong train value at 1");
360  mu_assert(train->spZ->values[2] == 3.0, "Wrong train value at 2");
361  mu_assert(train->spZ->values[3] == 3.0, "Wrong train value at 3");
362  mu_assert(train->spZ->values[4] == 5.0, "Wrong train value at 4");
363  mu_assert(train->spZ->values[5] == 5.0, "Wrong train value at 5");
364  mu_assert(train->spZ->values[6] == 6.0, "Wrong train value at 6");
365  mu_assert(train->spZ->values[7] == 6.0, "Wrong train value at 7");
366  mu_assert(train->spZ->values[8] == 7.0, "Wrong train value at 8");
367  mu_assert(train->spZ->values[9] == 7.0, "Wrong train value at 9");
368  mu_assert(train->spZ->values[10] == 8.0, "Wrong train value at 10");
369  mu_assert(train->spZ->values[11] == 8.0, "Wrong train value at 11");
370  mu_assert(train->spZ->values[12] == 9.0, "Wrong train value at 12");
371  mu_assert(train->spZ->values[13] == 9.0, "Wrong train value at 13");
372  mu_assert(train->spZ->values[14] == 10.0, "Wrong train value at 14");
373  mu_assert(train->spZ->values[15] == 10.0, "Wrong train value at 15");
374 
375  mu_assert(train->spZ->ia[0] == 0, "Wrong train ia at 0");
376  mu_assert(train->spZ->ia[1] == 2, "Wrong train ia at 1");
377  mu_assert(train->spZ->ia[2] == 4, "Wrong train ia at 2");
378  mu_assert(train->spZ->ia[3] == 6, "Wrong train ia at 3");
379  mu_assert(train->spZ->ia[4] == 8, "Wrong train ia at 4");
380  mu_assert(train->spZ->ia[5] == 10, "Wrong train ia at 5");
381  mu_assert(train->spZ->ia[6] == 12, "Wrong train ia at 6");
382  mu_assert(train->spZ->ia[7] == 14, "Wrong train ia at 7");
383  mu_assert(train->spZ->ia[8] == 16, "Wrong train ia at 8");
384 
385  mu_assert(train->spZ->ja[0] == 1, "Wrong train ja at 0");
386  mu_assert(train->spZ->ja[1] == 2, "Wrong train ja at 1");
387  mu_assert(train->spZ->ja[2] == 1, "Wrong train ja at 2");
388  mu_assert(train->spZ->ja[3] == 2, "Wrong train ja at 3");
389  mu_assert(train->spZ->ja[4] == 1, "Wrong train ja at 4");
390  mu_assert(train->spZ->ja[5] == 2, "Wrong train ja at 5");
391  mu_assert(train->spZ->ja[6] == 1, "Wrong train ja at 6");
392  mu_assert(train->spZ->ja[7] == 2, "Wrong train ja at 7");
393  mu_assert(train->spZ->ja[8] == 1, "Wrong train ja at 8");
394  mu_assert(train->spZ->ja[9] == 2, "Wrong train ja at 9");
395  mu_assert(train->spZ->ja[10] == 1, "Wrong train ja at 10");
396  mu_assert(train->spZ->ja[11] == 2, "Wrong train ja at 11");
397  mu_assert(train->spZ->ja[12] == 1, "Wrong train ja at 12");
398  mu_assert(train->spZ->ja[13] == 2, "Wrong train ja at 13");
399  mu_assert(train->spZ->ja[14] == 1, "Wrong train ja at 14");
400  mu_assert(train->spZ->ja[15] == 2, "Wrong train ja at 15");
401 
402  // check the test GenSparse struct
403  mu_assert(test->spZ->nnz == 4, "test nnz incorrect");
404  mu_assert(test->spZ->n_row == 2, "test n_row incorrect");
405  mu_assert(test->spZ->n_col == 3, "test n_col incorrect");
406 
407  mu_assert(test->spZ->values[0] == 2.0, "Wrong test value at 0");
408  mu_assert(test->spZ->values[1] == 2.0, "Wrong test value at 1");
409  mu_assert(test->spZ->values[2] == 4.0, "Wrong test value at 2");
410  mu_assert(test->spZ->values[3] == 4.0, "Wrong test value at 3");
411 
412  mu_assert(test->spZ->ia[0] == 0, "Wrong test ia at 0");
413  mu_assert(test->spZ->ia[1] == 2, "Wrong test ia at 1");
414  mu_assert(test->spZ->ia[2] == 4, "Wrong test ia at 2");
415 
416  mu_assert(test->spZ->ja[0] == 1, "Wrong test ja at 0");
417  mu_assert(test->spZ->ja[1] == 2, "Wrong test ja at 1");
418  mu_assert(test->spZ->ja[2] == 1, "Wrong test ja at 2");
419  mu_assert(test->spZ->ja[3] == 2, "Wrong test ja at 3");
420 
421  // end test code //
422  gensvm_free_data(full);
423  gensvm_free_data(train);
424  gensvm_free_data(test);
425  free(cv_idx);
426 
427  return NULL;
428 }
429 
430 char *all_tests()
431 {
432  mu_suite_start();
437 
438  return NULL;
439 }
440 
Minimal unit testing framework for C.
#define Calloc(type, size)
Definition: gensvm_memory.h:40
long * ja
column indices, should be of length nnz
Definition: gensvm_sparse.h:67
long n_col
number of columns of the original matrix
Definition: gensvm_sparse.h:60
#define mu_assert(test, message)
Definition: minunit.h:29
long K
number of classes
Definition: gensvm_base.h:58
char * all_tests()
#define matrix_get(M, cols, i, j)
double * Z
Definition: gensvm_base.h:68
long nnz
number of nonzero elements
Definition: gensvm_sparse.h:56
#define mu_run_test(test)
Definition: minunit.h:35
long * y
array of class labels, 1..K
Definition: gensvm_base.h:66
char * test_make_cv_split_1()
A structure to represent the data.
Definition: gensvm_base.h:57
double * values
actual nonzero values, should be of length nnz
Definition: gensvm_sparse.h:63
void gensvm_get_tt_split(struct GenData *full_data, struct GenData *train_data, struct GenData *test_data, long *cv_idx, long fold_idx)
Wrapper around sparse/dense versions of this function.
void gensvm_make_cv_split(long N, long folds, long *cv_idx)
Create a cross validation split vector.
char * test_get_tt_split_dense()
void gensvm_free_data(struct GenData *data)
Free allocated GenData struct.
Definition: gensvm_base.c:73
RUN_TESTS(all_tests)
char * test_make_cv_split_2()
long r
number of eigenvalues (width of Z)
Definition: gensvm_base.h:64
char * test_get_tt_split_sparse()
long m
number of predictors (width of RAW)
Definition: gensvm_base.h:62
struct GenSparse * gensvm_dense_to_sparse(double *A, long rows, long cols)
Convert a dense matrix to a GenSparse structure if advantageous.
#define matrix_set(M, cols, i, j, val)
long n
number of instances
Definition: gensvm_base.h:60
struct GenData * gensvm_init_data(void)
Initialize a GenData structure.
Definition: gensvm_base.c:45
long * ia
cumulative row lengths, should be of length n_row+1
Definition: gensvm_sparse.h:65
#define mu_suite_start()
Definition: minunit.h:24
double * RAW
augmented raw data matrix
Definition: gensvm_base.h:73
struct GenSparse * spZ
sparse representation of the augmented data matrix
Definition: gensvm_base.h:71
Header file for gensvm_cv_util.c.
long n_row
number of rows of the original matrix
Definition: gensvm_sparse.h:58