GenSVM
test_gensvm_train.c
Go to the documentation of this file.
1 
27 #include "minunit.h"
28 #include "gensvm_train.h"
29 
31 {
32  struct GenModel *model = gensvm_init_model();
33  struct GenModel *seed = gensvm_init_model();
34  struct GenData *data = gensvm_init_data();
35 
36  // note that model n, m, K are set by the train function
37  model->p = 1.2143;
38  model->kappa = 0.90298;
39  model->lambda = 0.00219038;
40  model->epsilon = 1e-15;
41  model->weight_idx = 1;
42  model->kerneltype = K_LINEAR;
43 
44  data->n = 10;
45  data->m = 3;
46  data->K = 4;
47  data->RAW = Calloc(double, data->n * (data->m+1));
48  data->Z = data->RAW;
49  data->y = Calloc(long, data->n);
50 
51  matrix_set(data->Z, data->m+1, 0, 0, 1.0000000000000000);
52  matrix_set(data->Z, data->m+1, 0, 1, 0.8056271362589000);
53  matrix_set(data->Z, data->m+1, 0, 2, 0.4874175854113872);
54  matrix_set(data->Z, data->m+1, 0, 3, 0.4453015882771756);
55  matrix_set(data->Z, data->m+1, 1, 0, 1.0000000000000000);
56  matrix_set(data->Z, data->m+1, 1, 1, 0.7940590105180981);
57  matrix_set(data->Z, data->m+1, 1, 2, 0.1861049005485224);
58  matrix_set(data->Z, data->m+1, 1, 3, 0.8469394287449229);
59  matrix_set(data->Z, data->m+1, 2, 0, 1.0000000000000000);
60  matrix_set(data->Z, data->m+1, 2, 1, 0.0294257611061681);
61  matrix_set(data->Z, data->m+1, 2, 2, 0.0242717976065267);
62  matrix_set(data->Z, data->m+1, 2, 3, 0.5039128672814752);
63  matrix_set(data->Z, data->m+1, 3, 0, 1.0000000000000000);
64  matrix_set(data->Z, data->m+1, 3, 1, 0.1746563833537603);
65  matrix_set(data->Z, data->m+1, 3, 2, 0.9135736087631979);
66  matrix_set(data->Z, data->m+1, 3, 3, 0.5270258081021366);
67  matrix_set(data->Z, data->m+1, 4, 0, 1.0000000000000000);
68  matrix_set(data->Z, data->m+1, 4, 1, 0.0022298761599785);
69  matrix_set(data->Z, data->m+1, 4, 2, 0.3773482059713607);
70  matrix_set(data->Z, data->m+1, 4, 3, 0.8009654729622842);
71  matrix_set(data->Z, data->m+1, 5, 0, 1.0000000000000000);
72  matrix_set(data->Z, data->m+1, 5, 1, 0.6638830667081945);
73  matrix_set(data->Z, data->m+1, 5, 2, 0.6467607601353914);
74  matrix_set(data->Z, data->m+1, 5, 3, 0.0434948735457108);
75  matrix_set(data->Z, data->m+1, 6, 0, 1.0000000000000000);
76  matrix_set(data->Z, data->m+1, 6, 1, 0.0770493004546461);
77  matrix_set(data->Z, data->m+1, 6, 2, 0.3699566427075194);
78  matrix_set(data->Z, data->m+1, 6, 3, 0.7863539761080217);
79  matrix_set(data->Z, data->m+1, 7, 0, 1.0000000000000000);
80  matrix_set(data->Z, data->m+1, 7, 1, 0.2685233952731509);
81  matrix_set(data->Z, data->m+1, 7, 2, 0.8539966432782011);
82  matrix_set(data->Z, data->m+1, 7, 3, 0.0967159557826836);
83  matrix_set(data->Z, data->m+1, 8, 0, 1.0000000000000000);
84  matrix_set(data->Z, data->m+1, 8, 1, 0.1163951898554611);
85  matrix_set(data->Z, data->m+1, 8, 2, 0.7667861436369238);
86  matrix_set(data->Z, data->m+1, 8, 3, 0.5031912600213351);
87  matrix_set(data->Z, data->m+1, 9, 0, 1.0000000000000000);
88  matrix_set(data->Z, data->m+1, 9, 1, 0.2290251898688216);
89  matrix_set(data->Z, data->m+1, 9, 2, 0.4401981048538806);
90  matrix_set(data->Z, data->m+1, 9, 3, 0.0884616753393881);
91 
92  matrix_set(data->y, 1, 0, 0, 2);
93  matrix_set(data->y, 1, 0, 1, 1);
94  matrix_set(data->y, 1, 0, 2, 3);
95  matrix_set(data->y, 1, 0, 3, 2);
96  matrix_set(data->y, 1, 0, 4, 3);
97  matrix_set(data->y, 1, 0, 5, 2);
98  matrix_set(data->y, 1, 0, 6, 4);
99  matrix_set(data->y, 1, 0, 7, 1);
100  matrix_set(data->y, 1, 0, 8, 3);
101  matrix_set(data->y, 1, 0, 9, 4);
102 
103  seed->V = Calloc(double, (data->m+1)*(data->K-1));
104  matrix_set(seed->V, data->K-1, 0, 0, 0.8233234072519983);
105  matrix_set(seed->V, data->K-1, 0, 1, 0.7701104553132680);
106  matrix_set(seed->V, data->K-1, 0, 2, 0.1102697774064020);
107  matrix_set(seed->V, data->K-1, 1, 0, 0.7956168453294307);
108  matrix_set(seed->V, data->K-1, 1, 1, 0.3267543833513200);
109  matrix_set(seed->V, data->K-1, 1, 2, 0.8659836346403005);
110  matrix_set(seed->V, data->K-1, 2, 0, 0.5777227081256917);
111  matrix_set(seed->V, data->K-1, 2, 1, 0.3693175185473680);
112  matrix_set(seed->V, data->K-1, 2, 2, 0.2728942849022845);
113  matrix_set(seed->V, data->K-1, 3, 0, 0.4426030703804438);
114  matrix_set(seed->V, data->K-1, 3, 1, 0.2456426390463990);
115  matrix_set(seed->V, data->K-1, 3, 2, 0.2665038412777220);
116 
117  // start test code //
118  gensvm_train(model, data, seed);
119 
120  mu_assert(model->n == data->n, "Incorrect model n");
121  mu_assert(model->m == data->m, "Incorrect model m");
122  mu_assert(model->K == data->K, "Incorrect model K");
123 
124  double eps = 1e-13;
125  mu_assert(fabs(matrix_get(model->V, model->K-1, 0, 0) -
126  -1.1907736868272805) < eps,
127  "Incorrect model->V at 0, 0");
128  mu_assert(fabs(matrix_get(model->V, model->K-1, 0, 1) -
129  1.8651287814979396) < eps,
130  "Incorrect model->V at 0, 1");
131  mu_assert(fabs(matrix_get(model->V, model->K-1, 0, 2) -
132  1.7250030581662932) < eps,
133  "Incorrect model->V at 0, 2");
134  mu_assert(fabs(matrix_get(model->V, model->K-1, 1, 0) -
135  0.7925100058806183) < eps,
136  "Incorrect model->V at 1, 0");
137  mu_assert(fabs(matrix_get(model->V, model->K-1, 1, 1) -
138  -3.6093428916761665) < eps,
139  "Incorrect model->V at 1, 1");
140  mu_assert(fabs(matrix_get(model->V, model->K-1, 1, 2) -
141  -1.3394018960329377) < eps,
142  "Incorrect model->V at 1, 2");
143  mu_assert(fabs(matrix_get(model->V, model->K-1, 2, 0) -
144  1.5203132433193016) < eps,
145  "Incorrect model->V at 2, 0");
146  mu_assert(fabs(matrix_get(model->V, model->K-1, 2, 1) -
147  -1.9118604362643852) < eps,
148  "Incorrect model->V at 2, 1");
149  mu_assert(fabs(matrix_get(model->V, model->K-1, 2, 2) -
150  -1.7939246097629342) < eps,
151  "Incorrect model->V at 2, 2");
152  mu_assert(fabs(matrix_get(model->V, model->K-1, 3, 0) -
153  0.0658817457370326) < eps,
154  "Incorrect model->V at 3, 0");
155  mu_assert(fabs(matrix_get(model->V, model->K-1, 3, 1) -
156  0.6547924025329720) < eps,
157  "Incorrect model->V at 3, 1");
158  mu_assert(fabs(matrix_get(model->V, model->K-1, 3, 2) -
159  -0.6773346708737853) < eps,
160  "Incorrect model->V at 3, 2");
161 
162  // end test code //
163 
164  gensvm_free_model(model);
165  gensvm_free_model(seed);
166  gensvm_free_data(data);
167 
168  return NULL;
169 }
170 
172 {
173  struct GenModel *model = gensvm_init_model();
174  struct GenData *data = gensvm_init_data();
175 
176  // note that model n, m, K are set by the train function
177  model->p = 1.2143;
178  model->kappa = 0.90298;
179  model->lambda = 0.00219038;
180  model->epsilon = 1e-15;
181  model->weight_idx = 1;
182  model->kerneltype = K_RBF;
183  model->gamma = 0.348;
184  model->kernel_eigen_cutoff = 5e-3;
185 
186  data->n = 10;
187  data->m = 5;
188  data->K = 4;
189  data->RAW = Calloc(double, data->n * (data->m+1));
190  data->Z = data->RAW;
191  matrix_set(data->Z, data->m+1, 0, 0, 1.0000000000000000);
192  matrix_set(data->Z, data->m+1, 0, 1, 0.0657799204744603);
193  matrix_set(data->Z, data->m+1, 0, 2, 0.2576653302581353);
194  matrix_set(data->Z, data->m+1, 0, 3, 0.0221000752651170);
195  matrix_set(data->Z, data->m+1, 0, 4, 0.6666929354133441);
196  matrix_set(data->Z, data->m+1, 0, 5, 0.6178892590244618);
197  matrix_set(data->Z, data->m+1, 1, 0, 1.0000000000000000);
198  matrix_set(data->Z, data->m+1, 1, 1, 0.9797668012781366);
199  matrix_set(data->Z, data->m+1, 1, 2, 0.7636361573939686);
200  matrix_set(data->Z, data->m+1, 1, 3, 0.3195806959299131);
201  matrix_set(data->Z, data->m+1, 1, 4, 0.2947771273705799);
202  matrix_set(data->Z, data->m+1, 1, 5, 0.8358899802514324);
203  matrix_set(data->Z, data->m+1, 2, 0, 1.0000000000000000);
204  matrix_set(data->Z, data->m+1, 2, 1, 0.9473849700145257);
205  matrix_set(data->Z, data->m+1, 2, 2, 0.8682867844262768);
206  matrix_set(data->Z, data->m+1, 2, 3, 0.7116177283612393);
207  matrix_set(data->Z, data->m+1, 2, 4, 0.5092752476335579);
208  matrix_set(data->Z, data->m+1, 2, 5, 0.1046097156193449);
209  matrix_set(data->Z, data->m+1, 3, 0, 1.0000000000000000);
210  matrix_set(data->Z, data->m+1, 3, 1, 0.5846585351601830);
211  matrix_set(data->Z, data->m+1, 3, 2, 0.4076887966131124);
212  matrix_set(data->Z, data->m+1, 3, 3, 0.8661556045821296);
213  matrix_set(data->Z, data->m+1, 3, 4, 0.0904082115920005);
214  matrix_set(data->Z, data->m+1, 3, 5, 0.0799888711622944);
215  matrix_set(data->Z, data->m+1, 4, 0, 1.0000000000000000);
216  matrix_set(data->Z, data->m+1, 4, 1, 0.8112201081242789);
217  matrix_set(data->Z, data->m+1, 4, 2, 0.3112642417912803);
218  matrix_set(data->Z, data->m+1, 4, 3, 0.7902557587124555);
219  matrix_set(data->Z, data->m+1, 4, 4, 0.3001992968661185);
220  matrix_set(data->Z, data->m+1, 4, 5, 0.6030590437920392);
221  matrix_set(data->Z, data->m+1, 5, 0, 1.0000000000000000);
222  matrix_set(data->Z, data->m+1, 5, 1, 0.0098576324913424);
223  matrix_set(data->Z, data->m+1, 5, 2, 0.5686603332895077);
224  matrix_set(data->Z, data->m+1, 5, 3, 0.9933970661175713);
225  matrix_set(data->Z, data->m+1, 5, 4, 0.5215400841900655);
226  matrix_set(data->Z, data->m+1, 5, 5, 0.4307310515440625);
227  matrix_set(data->Z, data->m+1, 6, 0, 1.0000000000000000);
228  matrix_set(data->Z, data->m+1, 6, 1, 0.2773296707204919);
229  matrix_set(data->Z, data->m+1, 6, 2, 0.5114254316901164);
230  matrix_set(data->Z, data->m+1, 6, 3, 0.5057613745592034);
231  matrix_set(data->Z, data->m+1, 6, 4, 0.6411421568717217);
232  matrix_set(data->Z, data->m+1, 6, 5, 0.3114658800558432);
233  matrix_set(data->Z, data->m+1, 7, 0, 1.0000000000000000);
234  matrix_set(data->Z, data->m+1, 7, 1, 0.7195909422652624);
235  matrix_set(data->Z, data->m+1, 7, 2, 0.7754155342547566);
236  matrix_set(data->Z, data->m+1, 7, 3, 0.5955643008534165);
237  matrix_set(data->Z, data->m+1, 7, 4, 0.5920949759391909);
238  matrix_set(data->Z, data->m+1, 7, 5, 0.7029537245575100);
239  matrix_set(data->Z, data->m+1, 8, 0, 1.0000000000000000);
240  matrix_set(data->Z, data->m+1, 8, 1, 0.3792168380438625);
241  matrix_set(data->Z, data->m+1, 8, 2, 0.1920178667928286);
242  matrix_set(data->Z, data->m+1, 8, 3, 0.2742847467912714);
243  matrix_set(data->Z, data->m+1, 8, 4, 0.2337979820454409);
244  matrix_set(data->Z, data->m+1, 8, 5, 0.3978991644742557);
245  matrix_set(data->Z, data->m+1, 9, 0, 1.0000000000000000);
246  matrix_set(data->Z, data->m+1, 9, 1, 0.0797813938980598);
247  matrix_set(data->Z, data->m+1, 9, 2, 0.5863311792537960);
248  matrix_set(data->Z, data->m+1, 9, 3, 0.8565105304166337);
249  matrix_set(data->Z, data->m+1, 9, 4, 0.8266471128109379);
250  matrix_set(data->Z, data->m+1, 9, 5, 0.8070610088865674);
251 
252  data->y = Calloc(long, data->n);
253  matrix_set(data->y, 1, 0, 0, 2);
254  matrix_set(data->y, 1, 0, 1, 1);
255  matrix_set(data->y, 1, 0, 2, 3);
256  matrix_set(data->y, 1, 0, 3, 2);
257  matrix_set(data->y, 1, 0, 4, 3);
258  matrix_set(data->y, 1, 0, 5, 2);
259  matrix_set(data->y, 1, 0, 6, 4);
260  matrix_set(data->y, 1, 0, 7, 1);
261  matrix_set(data->y, 1, 0, 8, 3);
262  matrix_set(data->y, 1, 0, 9, 4);
263 
264  // start test code //
265 
266  // because the kernel eigendecomposition isn't known in advance,
267  // there's no way to seed the model when using kernels. We therefore
268  // use seed == NULL here. Note that due to the Memset in
269  // gensvm_reallocate_model(), V will be a matrix of zeros after
270  // reallocation, so we compare with the V = 0 result from Octave.
271  gensvm_train(model, data, NULL);
272 
273  mu_assert(model->n == data->n, "Incorrect model n");
274  mu_assert(model->m == data->r, "Incorrect model m");
275  mu_assert(model->K == data->K, "Incorrect model K");
276 
277  double eps = 1e-13;
278  mu_assert(fabs(matrix_get(data->Sigma, 1, 0, 0) -
279  7.8302939172918506) < eps,
280  "Incorrect data->Sigma at 0, 0");
281  mu_assert(fabs(matrix_get(data->Sigma, 1, 1, 0) -
282  0.7947913383766066) < eps,
283  "Incorrect data->Sigma at 1, 0");
284  mu_assert(fabs(matrix_get(data->Sigma, 1, 2, 0) -
285  0.5288740088908547) < eps,
286  "Incorrect data->Sigma at 2, 0");
287  mu_assert(fabs(matrix_get(data->Sigma, 1, 3, 0) -
288  0.4537982052555444) < eps,
289  "Incorrect data->Sigma at 3, 0");
290  mu_assert(fabs(matrix_get(data->Sigma, 1, 4, 0) -
291  0.2226012271232192) < eps,
292  "Incorrect data->Sigma at 4, 0");
293  mu_assert(fabs(matrix_get(data->Sigma, 1, 5, 0) -
294  0.0743004417495061) < eps,
295  "Incorrect data->Sigma at 5, 0");
296 
297  // we need a large eps here because there are numerical precision
298  // differences between the C and Octave implementations. We also
299  // compare with absolute values because of variability in the
300  // eigendecomposition.
301  eps = 1e-7;
302  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 0, 0)) -
303  fabs(5.0555413160638665)) < eps,
304  "Incorrect model->V at 0, 0");
305  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 0, 1)) -
306  fabs(-2.2586632211763198)) < eps,
307  "Incorrect model->V at 0, 1");
308  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 0, 2)) -
309  fabs(-4.5572671806963143)) < eps,
310  "Incorrect model->V at 0, 2");
311  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 1, 0)) -
312  fabs(-1.9627432869558412)) < eps,
313  "Incorrect model->V at 1, 0");
314  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 1, 1)) -
315  fabs(0.9934555242449399)) < eps,
316  "Incorrect model->V at 1, 1");
317  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 1, 2)) -
318  fabs(1.7855287218670219)) < eps,
319  "Incorrect model->V at 1, 2");
320  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 2, 0)) -
321  fabs(1.9393083227054353)) < eps,
322  "Incorrect model->V at 2, 0");
323  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 2, 1)) -
324  fabs(-1.1958487809502740)) < eps,
325  "Incorrect model->V at 2, 1");
326  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 2, 2)) -
327  fabs(2.1140967864804359)) < eps,
328  "Incorrect model->V at 2, 2");
329  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 3, 0)) -
330  fabs(2.3909204618652535)) < eps,
331  "Incorrect model->V at 3, 0");
332  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 3, 1)) -
333  fabs(-0.2834554569573399)) < eps,
334  "Incorrect model->V at 3, 1");
335  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 3, 2)) -
336  fabs(1.0926232371314393)) < eps,
337  "Incorrect model->V at 3, 2");
338  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 4, 0)) -
339  fabs(3.3374545494113272)) < eps,
340  "Incorrect model->V at 4, 0");
341  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 4, 1)) -
342  fabs(1.6699291195221897)) < eps,
343  "Incorrect model->V at 4, 1");
344  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 4, 2)) -
345  fabs(-1.4345249893609275)) < eps,
346  "Incorrect model->V at 4, 2");
347  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 5, 0)) -
348  fabs(-0.0221825925355533)) < eps,
349  "Incorrect model->V at 5, 0");
350  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 5, 1)) -
351  fabs(-0.1216077739550210)) < eps,
352  "Incorrect model->V at 5, 1");
353  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 5, 2)) -
354  fabs(-0.7900947982642630)) < eps,
355  "Incorrect model->V at 5, 2");
356  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 6, 0)) -
357  fabs(-0.0076471781062262)) < eps,
358  "Incorrect model->V at 6, 0");
359  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 6, 1)) -
360  fabs(-0.8781872510019056)) < eps,
361  "Incorrect model->V at 6, 1");
362  mu_assert(fabs(fabs(matrix_get(model->V, model->K-1, 6, 2)) -
363  fabs(-0.2782284589344380)) < eps,
364  "Incorrect model->V at 6, 2");
365  // end test code //
366 
367  gensvm_free_model(model);
368  gensvm_free_data(data);
369 
370  return NULL;
371 }
372 
373 char *all_tests()
374 {
375  mu_suite_start();
376 
379 
380  return NULL;
381 }
382 
Minimal unit testing framework for C.
#define Calloc(type, size)
Definition: gensvm_memory.h:40
double epsilon
stopping criterion for the IM algorithm.
Definition: gensvm_base.h:101
double p
parameter for the L-p norm in the loss function
Definition: gensvm_base.h:103
#define mu_assert(test, message)
Definition: minunit.h:29
long K
number of classes
Definition: gensvm_base.h:58
#define matrix_get(M, cols, i, j)
RUN_TESTS(all_tests)
char * all_tests()
double * Z
Definition: gensvm_base.h:68
void gensvm_free_model(struct GenModel *model)
Free allocated GenModel struct.
Definition: gensvm_base.c:211
int weight_idx
which weights to use (1 = unit, 2 = group)
Definition: gensvm_base.h:93
double * V
augmented weight matrix
Definition: gensvm_base.h:115
#define mu_run_test(test)
Definition: minunit.h:35
long * y
array of class labels, 1..K
Definition: gensvm_base.h:66
struct GenModel * gensvm_init_model(void)
Initialize a GenModel structure.
Definition: gensvm_base.c:102
A structure to represent the data.
Definition: gensvm_base.h:57
void gensvm_train(struct GenModel *model, struct GenData *data, struct GenModel *seed_model)
Utility function for training a GenSVM model.
Definition: gensvm_train.c:44
Header file for gensvm_train.c.
char * test_gensvm_train_seed_kernel()
A structure to represent a single GenSVM model.
Definition: gensvm_base.h:92
double * Sigma
eigenvalues from the reduced eigendecomposition
Definition: gensvm_base.h:75
long n
number of instances in the dataset
Definition: gensvm_base.h:97
void gensvm_free_data(struct GenData *data)
Free allocated GenData struct.
Definition: gensvm_base.c:73
long r
number of eigenvalues (width of Z)
Definition: gensvm_base.h:64
double kappa
parameter for the Huber hinge function
Definition: gensvm_base.h:105
long K
number of classes in the dataset
Definition: gensvm_base.h:95
char * test_gensvm_train_seed_linear()
long m
number of predictors (width of RAW)
Definition: gensvm_base.h:62
#define matrix_set(M, cols, i, j, val)
KernelType kerneltype
type of kernel used in the model
Definition: gensvm_base.h:136
double gamma
kernel parameter for RBF, poly, and sigmoid
Definition: gensvm_base.h:109
long n
number of instances
Definition: gensvm_base.h:60
struct GenData * gensvm_init_data(void)
Initialize a GenData structure.
Definition: gensvm_base.c:45
long m
number of predictor variables in the dataset
Definition: gensvm_base.h:99
#define mu_suite_start()
Definition: minunit.h:24
double * RAW
augmented raw data matrix
Definition: gensvm_base.h:73
double kernel_eigen_cutoff
cutoff value for the ratio of eigenvalues in the reduced
Definition: gensvm_base.h:138
double lambda
regularization parameter in the loss function
Definition: gensvm_base.h:107