SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/classifier/svm/LibSVM.h> 00012 #include <shogun/io/SGIO.h> 00013 00014 using namespace shogun; 00015 00016 CLibSVM::CLibSVM(LIBSVM_SOLVER_TYPE st) 00017 : CSVM(), model(NULL), solver_type(st) 00018 { 00019 } 00020 00021 CLibSVM::CLibSVM(float64_t C, CKernel* k, CLabels* lab) 00022 : CSVM(C, k, lab), model(NULL), solver_type(LIBSVM_C_SVC) 00023 { 00024 problem = svm_problem(); 00025 } 00026 00027 CLibSVM::~CLibSVM() 00028 { 00029 } 00030 00031 00032 bool CLibSVM::train_machine(CFeatures* data) 00033 { 00034 struct svm_node* x_space; 00035 00036 ASSERT(labels && labels->get_num_labels()); 00037 ASSERT(labels->is_two_class_labeling()); 00038 00039 if (data) 00040 { 00041 if (labels->get_num_labels() != data->get_num_vectors()) 00042 SG_ERROR("Number of training vectors does not match number of labels\n"); 00043 kernel->init(data, data); 00044 } 00045 00046 problem.l=labels->get_num_labels(); 00047 SG_INFO( "%d trainlabels\n", problem.l); 00048 00049 // set linear term 00050 if (m_linear_term.vlen>0) 00051 { 00052 if (labels->get_num_labels()!=m_linear_term.vlen) 00053 SG_ERROR("Number of training vectors does not match length of linear term\n"); 00054 00055 // set with linear term from base class 00056 problem.pv = get_linear_term_array(); 00057 } 00058 else 00059 { 00060 // fill with minus ones 00061 problem.pv = SG_MALLOC(float64_t, problem.l); 00062 00063 for (int i=0; i!=problem.l; i++) 00064 problem.pv[i] = -1.0; 00065 } 00066 00067 problem.y=SG_MALLOC(float64_t, problem.l); 00068 problem.x=SG_MALLOC(struct svm_node*, problem.l); 00069 problem.C=SG_MALLOC(float64_t, problem.l); 00070 00071 x_space=SG_MALLOC(struct svm_node, 2*problem.l); 00072 00073 for (int32_t i=0; i<problem.l; i++) 00074 { 00075 problem.y[i]=labels->get_label(i); 00076 problem.x[i]=&x_space[2*i]; 00077 x_space[2*i].index=i; 00078 x_space[2*i+1].index=-1; 00079 } 00080 00081 int32_t weights_label[2]={-1,+1}; 00082 float64_t weights[2]={1.0,get_C2()/get_C1()}; 00083 00084 ASSERT(kernel && kernel->has_features()); 00085 ASSERT(kernel->get_num_vec_lhs()==problem.l); 00086 00087 param.svm_type=solver_type; // C SVM or NU_SVM 00088 param.kernel_type = LINEAR; 00089 param.degree = 3; 00090 param.gamma = 0; // 1/k 00091 param.coef0 = 0; 00092 param.nu = get_nu(); 00093 param.kernel=kernel; 00094 param.cache_size = kernel->get_cache_size(); 00095 param.max_train_time = max_train_time; 00096 param.C = get_C1(); 00097 param.eps = epsilon; 00098 param.p = 0.1; 00099 param.shrinking = 1; 00100 param.nr_weight = 2; 00101 param.weight_label = weights_label; 00102 param.weight = weights; 00103 param.use_bias = get_bias_enabled(); 00104 00105 const char* error_msg = svm_check_parameter(&problem, ¶m); 00106 00107 if(error_msg) 00108 SG_ERROR("Error: %s\n",error_msg); 00109 00110 model = svm_train(&problem, ¶m); 00111 00112 if (model) 00113 { 00114 ASSERT(model->nr_class==2); 00115 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef[0])); 00116 00117 int32_t num_sv=model->l; 00118 00119 create_new_model(num_sv); 00120 CSVM::set_objective(model->objective); 00121 00122 float64_t sgn=model->label[0]; 00123 00124 set_bias(-sgn*model->rho[0]); 00125 00126 for (int32_t i=0; i<num_sv; i++) 00127 { 00128 set_support_vector(i, (model->SV[i])->index); 00129 set_alpha(i, sgn*model->sv_coef[0][i]); 00130 } 00131 00132 SG_FREE(problem.x); 00133 SG_FREE(problem.y); 00134 SG_FREE(problem.pv); 00135 SG_FREE(problem.C); 00136 00137 00138 SG_FREE(x_space); 00139 00140 svm_destroy_model(model); 00141 model=NULL; 00142 return true; 00143 } 00144 else 00145 return false; 00146 }