SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/lib/common.h> 00012 #include <shogun/io/SGIO.h> 00013 #include <shogun/base/Parallel.h> 00014 #include <shogun/base/Parameter.h> 00015 00016 #include <shogun/classifier/svm/SVM.h> 00017 #include <shogun/classifier/mkl/MKL.h> 00018 00019 #include <string.h> 00020 00021 #ifdef HAVE_PTHREAD 00022 #include <pthread.h> 00023 #endif 00024 00025 using namespace shogun; 00026 00027 CSVM::CSVM(int32_t num_sv) 00028 : CKernelMachine() 00029 { 00030 set_defaults(num_sv); 00031 } 00032 00033 CSVM::CSVM(float64_t C, CKernel* k, CLabels* lab) 00034 : CKernelMachine() 00035 { 00036 set_defaults(); 00037 set_C(C,C); 00038 set_labels(lab); 00039 set_kernel(k); 00040 } 00041 00042 CSVM::~CSVM() 00043 { 00044 m_linear_term.destroy_vector(); 00045 SG_UNREF(mkl); 00046 } 00047 00048 void CSVM::set_defaults(int32_t num_sv) 00049 { 00050 SG_ADD(&C1, "C1", "", MS_AVAILABLE); 00051 SG_ADD(&C2, "C2", "", MS_AVAILABLE); 00052 SG_ADD(&svm_loaded, "svm_loaded", "SVM is loaded.", MS_NOT_AVAILABLE); 00053 SG_ADD(&epsilon, "epsilon", "", MS_NOT_AVAILABLE); 00054 SG_ADD(&tube_epsilon, "tube_epsilon", 00055 "Tube epsilon for support vector regression.", MS_NOT_AVAILABLE); 00056 SG_ADD(&nu, "nu", "", MS_AVAILABLE); 00057 SG_ADD(&objective, "objective", "", MS_NOT_AVAILABLE); 00058 SG_ADD(&qpsize, "qpsize", "", MS_NOT_AVAILABLE); 00059 SG_ADD(&use_shrinking, "use_shrinking", "Shrinking shall be used.", 00060 MS_NOT_AVAILABLE); 00061 SG_ADD((CSGObject**) &mkl, "mkl", "MKL object that svm optimizers need.", 00062 MS_NOT_AVAILABLE); 00063 SG_ADD(&m_linear_term, "linear_term", "Linear term in qp.", 00064 MS_NOT_AVAILABLE); 00065 00066 callback=NULL; 00067 mkl=NULL; 00068 00069 svm_loaded=false; 00070 00071 epsilon=1e-5; 00072 tube_epsilon=1e-2; 00073 00074 nu=0.5; 00075 C1=1; 00076 C2=1; 00077 00078 objective=0; 00079 00080 qpsize=41; 00081 use_bias=true; 00082 use_shrinking=true; 00083 use_batch_computation=true; 00084 use_linadd=true; 00085 00086 if (num_sv>0) 00087 create_new_model(num_sv); 00088 } 00089 00090 bool CSVM::load(FILE* modelfl) 00091 { 00092 bool result=true; 00093 char char_buffer[1024]; 00094 int32_t int_buffer; 00095 float64_t double_buffer; 00096 int32_t line_number=1; 00097 00098 SG_SET_LOCALE_C; 00099 00100 if (fscanf(modelfl,"%4s\n", char_buffer)==EOF) 00101 { 00102 result=false; 00103 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00104 } 00105 else 00106 { 00107 char_buffer[4]='\0'; 00108 if (strcmp("%SVM", char_buffer)!=0) 00109 { 00110 result=false; 00111 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00112 } 00113 line_number++; 00114 } 00115 00116 int_buffer=0; 00117 if (fscanf(modelfl," numsv=%d; \n", &int_buffer) != 1) 00118 { 00119 result=false; 00120 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00121 } 00122 00123 if (!feof(modelfl)) 00124 line_number++; 00125 00126 SG_INFO( "loading %ld support vectors\n",int_buffer); 00127 create_new_model(int_buffer); 00128 00129 if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1) 00130 { 00131 result=false; 00132 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00133 } 00134 00135 if (!feof(modelfl)) 00136 line_number++; 00137 00138 double_buffer=0; 00139 00140 if (fscanf(modelfl," b=%lf; \n", &double_buffer) != 1) 00141 { 00142 result=false; 00143 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00144 } 00145 00146 if (!feof(modelfl)) 00147 line_number++; 00148 00149 set_bias(double_buffer); 00150 00151 if (fscanf(modelfl,"%8s\n", char_buffer) == EOF) 00152 { 00153 result=false; 00154 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00155 } 00156 else 00157 { 00158 char_buffer[9]='\0'; 00159 if (strcmp("alphas=[", char_buffer)!=0) 00160 { 00161 result=false; 00162 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00163 } 00164 line_number++; 00165 } 00166 00167 for (int32_t i=0; i<get_num_support_vectors(); i++) 00168 { 00169 double_buffer=0; 00170 int_buffer=0; 00171 00172 if (fscanf(modelfl," \[%lf,%d]; \n", &double_buffer, &int_buffer) != 2) 00173 { 00174 result=false; 00175 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00176 } 00177 00178 if (!feof(modelfl)) 00179 line_number++; 00180 00181 set_support_vector(i, int_buffer); 00182 set_alpha(i, double_buffer); 00183 } 00184 00185 if (fscanf(modelfl,"%2s", char_buffer) == EOF) 00186 { 00187 result=false; 00188 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00189 } 00190 else 00191 { 00192 char_buffer[3]='\0'; 00193 if (strcmp("];", char_buffer)!=0) 00194 { 00195 result=false; 00196 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00197 } 00198 line_number++; 00199 } 00200 00201 svm_loaded=result; 00202 SG_RESET_LOCALE; 00203 return result; 00204 } 00205 00206 bool CSVM::save(FILE* modelfl) 00207 { 00208 SG_SET_LOCALE_C; 00209 00210 if (!kernel) 00211 SG_ERROR("Kernel not defined!\n"); 00212 00213 SG_INFO( "Writing model file..."); 00214 fprintf(modelfl,"%%SVM\n"); 00215 fprintf(modelfl,"numsv=%d;\n", get_num_support_vectors()); 00216 fprintf(modelfl,"kernel='%s';\n", kernel->get_name()); 00217 fprintf(modelfl,"b=%+10.16e;\n",get_bias()); 00218 00219 fprintf(modelfl, "alphas=\[\n"); 00220 00221 for(int32_t i=0; i<get_num_support_vectors(); i++) 00222 fprintf(modelfl,"\t[%+10.16e,%d];\n", 00223 CSVM::get_alpha(i), get_support_vector(i)); 00224 00225 fprintf(modelfl, "];\n"); 00226 00227 SG_DONE(); 00228 SG_RESET_LOCALE; 00229 return true ; 00230 } 00231 00232 void CSVM::set_callback_function(CMKL* m, bool (*cb) 00233 (CMKL* mkl, const float64_t* sumw, const float64_t suma)) 00234 { 00235 SG_UNREF(mkl); 00236 mkl=m; 00237 SG_REF(mkl); 00238 00239 callback=cb; 00240 } 00241 00242 float64_t CSVM::compute_svm_dual_objective() 00243 { 00244 int32_t n=get_num_support_vectors(); 00245 00246 if (labels && kernel) 00247 { 00248 objective=0; 00249 for (int32_t i=0; i<n; i++) 00250 { 00251 int32_t ii=get_support_vector(i); 00252 objective-=get_alpha(i)*labels->get_label(ii); 00253 00254 for (int32_t j=0; j<n; j++) 00255 { 00256 int32_t jj=get_support_vector(j); 00257 objective+=0.5*get_alpha(i)*get_alpha(j)*kernel->kernel(ii,jj); 00258 } 00259 } 00260 } 00261 else 00262 SG_ERROR( "cannot compute objective, labels or kernel not set\n"); 00263 00264 return objective; 00265 } 00266 00267 float64_t CSVM::compute_svm_primal_objective() 00268 { 00269 int32_t n=get_num_support_vectors(); 00270 float64_t regularizer=0; 00271 float64_t loss=0; 00272 00273 00274 00275 if (labels && kernel) 00276 { 00277 float64_t C2_tmp=C1; 00278 if(C2>0) 00279 { 00280 C2_tmp=C2; 00281 } 00282 00283 for (int32_t i=0; i<n; i++) 00284 { 00285 int32_t ii=get_support_vector(i); 00286 for (int32_t j=0; j<n; j++) 00287 { 00288 int32_t jj=get_support_vector(j); 00289 regularizer-=0.5*get_alpha(i)*get_alpha(j)*kernel->kernel(ii,jj); 00290 } 00291 00292 loss-=(C1*(-get_label(ii)+1)/2.0 + C2_tmp*(get_label(ii)+1)/2.0 )*CMath::max(0.0, 1.0-get_label(ii)*apply(ii)); 00293 } 00294 00295 } 00296 else 00297 SG_ERROR( "cannot compute objective, labels or kernel not set\n"); 00298 00299 return regularizer+loss; 00300 } 00301 00302 float64_t* CSVM::get_linear_term_array() 00303 { 00304 if (m_linear_term.vlen==0) 00305 return NULL; 00306 00307 SGVector<float64_t> a(m_linear_term.vlen); 00308 memcpy(a.vector, m_linear_term.vector, 00309 m_linear_term.vlen*sizeof(float64_t)); 00310 00311 return a.vector; 00312 } 00313 00314 void CSVM::set_linear_term(SGVector<float64_t> linear_term) 00315 { 00316 ASSERT(linear_term.vector); 00317 00318 if (!labels) 00319 SG_ERROR("Please assign labels first!\n"); 00320 00321 int32_t num_labels=labels->get_num_labels(); 00322 00323 if (num_labels != linear_term.vlen) 00324 { 00325 SG_ERROR("Number of labels (%d) does not match number" 00326 "of entries (%d) in linear term \n", num_labels, linear_term.vlen); 00327 } 00328 00329 m_linear_term.destroy_vector(); 00330 00331 m_linear_term.vlen=linear_term.vlen; 00332 m_linear_term=SGVector<float64_t> (linear_term.vlen); 00333 memcpy(m_linear_term.vector, linear_term.vector, 00334 linear_term.vlen*sizeof(float64_t)); 00335 } 00336 00337 SGVector<float64_t> CSVM::get_linear_term() 00338 { 00339 return m_linear_term; 00340 }