SHOGUN
v1.1.0
|
00001 /* 00002 SVM with stochastic gradient 00003 Copyright (C) 2007- Leon Bottou 00004 00005 This program is free software; you can redistribute it and/or 00006 modify it under the terms of the GNU Lesser General Public 00007 License as published by the Free Software Foundation; either 00008 version 2.1 of the License, or (at your option) any later version. 00009 00010 This program is distributed in the hope that it will be useful, 00011 but WITHOUT ANY WARRANTY; without even the implied warranty of 00012 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00013 GNU General Public License for more details. 00014 00015 You should have received a copy of the GNU General Public License 00016 along with this program; if not, write to the Free Software 00017 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111, USA 00018 $Id: svmsgd.cpp,v 1.13 2007/10/02 20:40:06 cvs Exp $ 00019 00020 Shogun adjustments (w) 2008-2009 Soeren Sonnenburg 00021 */ 00022 00023 #include <shogun/classifier/svm/OnlineSVMSGD.h> 00024 #include <shogun/base/Parameter.h> 00025 #include <shogun/lib/Signal.h> 00026 #include <shogun/loss/HingeLoss.h> 00027 00028 using namespace shogun; 00029 00030 COnlineSVMSGD::COnlineSVMSGD() 00031 : COnlineLinearMachine() 00032 { 00033 init(); 00034 } 00035 00036 COnlineSVMSGD::COnlineSVMSGD(float64_t C) 00037 : COnlineLinearMachine() 00038 { 00039 init(); 00040 00041 C1=C; 00042 C2=C; 00043 } 00044 00045 COnlineSVMSGD::COnlineSVMSGD(float64_t C, CStreamingDotFeatures* traindat) 00046 : COnlineLinearMachine() 00047 { 00048 init(); 00049 C1=C; 00050 C2=C; 00051 00052 set_features(traindat); 00053 } 00054 00055 COnlineSVMSGD::~COnlineSVMSGD() 00056 { 00057 SG_UNREF(loss); 00058 } 00059 00060 void COnlineSVMSGD::set_loss_function(CLossFunction* loss_func) 00061 { 00062 if (loss) 00063 SG_UNREF(loss); 00064 loss=loss_func; 00065 SG_REF(loss); 00066 } 00067 00068 bool COnlineSVMSGD::train(CFeatures* data) 00069 { 00070 if (data) 00071 { 00072 if (!data->has_property(FP_STREAMING_DOT)) 00073 SG_ERROR("Specified features are not of type CStreamingDotFeatures\n"); 00074 set_features((CStreamingDotFeatures*) data); 00075 } 00076 00077 features->start_parser(); 00078 00079 // allocate memory for w and initialize everyting w and bias with 0 00080 ASSERT(features); 00081 ASSERT(features->get_has_labels()); 00082 if (w) 00083 SG_FREE(w); 00084 w_dim=1; 00085 w=new float32_t; 00086 bias=0; 00087 00088 // Shift t in order to have a 00089 // reasonable initial learning rate. 00090 // This assumes |x| \approx 1. 00091 float64_t maxw = 1.0 / sqrt(lambda); 00092 float64_t typw = sqrt(maxw); 00093 float64_t eta0 = typw / CMath::max(1.0,-loss->first_derivative(-typw,1)); 00094 t = 1 / (eta0 * lambda); 00095 00096 SG_INFO("lambda=%f, epochs=%d, eta0=%f\n", lambda, epochs, eta0); 00097 00098 //do the sgd 00099 calibrate(); 00100 if (features->is_seekable()) 00101 features->reset_stream(); 00102 00103 CSignal::clear_cancel(); 00104 00105 ELossType loss_type = loss->get_loss_type(); 00106 bool is_log_loss = false; 00107 if ((loss_type == L_LOGLOSS) || (loss_type == L_LOGLOSSMARGIN)) 00108 is_log_loss = true; 00109 00110 int32_t vec_count; 00111 for(int32_t e=0; e<epochs && (!CSignal::cancel_computations()); e++) 00112 { 00113 vec_count=0; 00114 count = skip; 00115 while (features->get_next_example()) 00116 { 00117 vec_count++; 00118 // Expand w vector if more features are seen in this example 00119 features->expand_if_required(w, w_dim); 00120 00121 float64_t eta = 1.0 / (lambda * t); 00122 float64_t y = features->get_label(); 00123 float64_t z = y * (features->dense_dot(w, w_dim) + bias); 00124 00125 if (z < 1 || is_log_loss) 00126 { 00127 float64_t etd = -eta * loss->first_derivative(z,1); 00128 features->add_to_dense_vec(etd * y / wscale, w, w_dim); 00129 00130 if (use_bias) 00131 { 00132 if (use_regularized_bias) 00133 bias *= 1 - eta * lambda * bscale; 00134 bias += etd * y * bscale; 00135 } 00136 } 00137 00138 if (--count <= 0) 00139 { 00140 float32_t r = 1 - eta * lambda * skip; 00141 if (r < 0.8) 00142 r = pow(1 - eta * lambda, skip); 00143 CMath::scale_vector(r, w, w_dim); 00144 count = skip; 00145 } 00146 t++; 00147 00148 features->release_example(); 00149 } 00150 00151 // If the stream is seekable, reset the stream to the first 00152 // example (for epochs > 1) 00153 if (features->is_seekable() && e < epochs-1) 00154 features->reset_stream(); 00155 else 00156 break; 00157 00158 } 00159 00160 features->end_parser(); 00161 float64_t wnorm = CMath::dot(w,w, w_dim); 00162 SG_INFO("Norm: %.6f, Bias: %.6f\n", wnorm, bias); 00163 00164 return true; 00165 } 00166 00167 void COnlineSVMSGD::calibrate(int32_t max_vec_num) 00168 { 00169 int32_t c_dim=1; 00170 float32_t* c=new float32_t; 00171 00172 // compute average gradient size 00173 int32_t n = 0; 00174 float64_t m = 0; 00175 float64_t r = 0; 00176 00177 while (features->get_next_example()) 00178 { 00179 //Expand c if more features are seen in this example 00180 features->expand_if_required(c, c_dim); 00181 00182 r += features->get_nnz_features_for_vector(); 00183 features->add_to_dense_vec(1, c, c_dim, true); 00184 00185 //waste cpu cycles for readability 00186 //(only changed dims need checking) 00187 m=CMath::max(c, c_dim); 00188 n++; 00189 00190 features->release_example(); 00191 if (n>=max_vec_num || m > 1000) 00192 break; 00193 } 00194 00195 SG_PRINT("Online SGD calibrated using %d vectors.\n", n); 00196 00197 // bias update scaling 00198 bscale = 0.5*m/n; 00199 00200 // compute weight decay skip 00201 skip = (int32_t) ((16 * n * c_dim) / r); 00202 00203 SG_INFO("using %d examples. skip=%d bscale=%.6f\n", n, skip, bscale); 00204 00205 SG_FREE(c); 00206 } 00207 00208 void COnlineSVMSGD::init() 00209 { 00210 t=1; 00211 C1=1; 00212 C2=1; 00213 lambda=1e-4; 00214 wscale=1; 00215 bscale=1; 00216 epochs=1; 00217 skip=1000; 00218 count=1000; 00219 use_bias=true; 00220 00221 use_regularized_bias=false; 00222 00223 loss=new CHingeLoss(); 00224 SG_REF(loss); 00225 00226 m_parameters->add(&C1, "C1", "Cost constant 1."); 00227 m_parameters->add(&C2, "C2", "Cost constant 2."); 00228 m_parameters->add(&lambda, "lambda", "Regularization parameter."); 00229 m_parameters->add(&wscale, "wscale", "W scale"); 00230 m_parameters->add(&bscale, "bscale", "b scale"); 00231 m_parameters->add(&epochs, "epochs", "epochs"); 00232 m_parameters->add(&skip, "skip", "skip"); 00233 m_parameters->add(&count, "count", "count"); 00234 m_parameters->add(&use_bias, "use_bias", "Indicates if bias is used."); 00235 m_parameters->add(&use_regularized_bias, "use_regularized_bias", "Indicates if bias is regularized."); 00236 }