SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2007-2008 Vojtech Franc 00008 * Written (W) 2007-2009 Soeren Sonnenburg 00009 * Copyright (C) 2007-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _WDSVMOCAS_H___ 00013 #define _WDSVMOCAS_H___ 00014 00015 #include <shogun/lib/common.h> 00016 #include <shogun/machine/Machine.h> 00017 #include <shogun/classifier/svm/SVMOcas.h> 00018 #include <shogun/features/StringFeatures.h> 00019 #include <shogun/features/Labels.h> 00020 00021 namespace shogun 00022 { 00023 template <class ST> class CStringFeatures; 00024 00026 class CWDSVMOcas : public CMachine 00027 { 00028 public: 00030 CWDSVMOcas(); 00031 00036 CWDSVMOcas(E_SVM_TYPE type); 00037 00046 CWDSVMOcas( 00047 float64_t C, int32_t d, int32_t from_d, 00048 CStringFeatures<uint8_t>* traindat, CLabels* trainlab); 00049 virtual ~CWDSVMOcas(); 00050 00055 virtual inline EClassifierType get_classifier_type() { return CT_WDSVMOCAS; } 00056 00063 inline void set_C(float64_t c_neg, float64_t c_pos) { C1=c_neg; C2=c_pos; } 00064 00069 inline float64_t get_C1() { return C1; } 00070 00075 inline float64_t get_C2() { return C2; } 00076 00081 inline void set_epsilon(float64_t eps) { epsilon=eps; } 00082 00087 inline float64_t get_epsilon() { return epsilon; } 00088 00093 inline void set_features(CStringFeatures<uint8_t>* feat) 00094 { 00095 SG_UNREF(features); 00096 SG_REF(feat); 00097 features=feat; 00098 } 00099 00104 inline CStringFeatures<uint8_t>* get_features() 00105 { 00106 SG_REF(features); 00107 return features; 00108 } 00109 00114 inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; } 00115 00120 inline bool get_bias_enabled() { return use_bias; } 00121 00126 inline void set_bufsize(int32_t sz) { bufsize=sz; } 00127 00132 inline int32_t get_bufsize() { return bufsize; } 00133 00139 inline void set_degree(int32_t d, int32_t from_d) 00140 { 00141 degree=d; 00142 from_degree=from_d; 00143 } 00144 00149 inline int32_t get_degree() { return degree; } 00150 00155 CLabels* apply(); 00156 00162 virtual CLabels* apply(CFeatures* data); 00163 00169 inline virtual float64_t apply(int32_t num) 00170 { 00171 ASSERT(features); 00172 if (!wd_weights) 00173 set_wd_weights(); 00174 00175 int32_t len=0; 00176 float64_t sum=0; 00177 bool free_vec; 00178 uint8_t* vec=features->get_feature_vector(num, len, free_vec); 00179 //SG_INFO("len %d, string_length %d\n", len, string_length); 00180 ASSERT(len==string_length); 00181 00182 for (int32_t j=0; j<string_length; j++) 00183 { 00184 int32_t offs=w_dim_single_char*j; 00185 int32_t val=0; 00186 for (int32_t k=0; (j+k<string_length) && (k<degree); k++) 00187 { 00188 val=val*alphabet_size + vec[j+k]; 00189 sum+=wd_weights[k] * w[offs+val]; 00190 offs+=w_offsets[k]; 00191 } 00192 } 00193 features->free_feature_vector(vec, num, free_vec); 00194 return sum/normalization_const; 00195 } 00196 00198 inline void set_normalization_const() 00199 { 00200 ASSERT(features); 00201 normalization_const=0; 00202 for (int32_t i=0; i<degree; i++) 00203 normalization_const+=(string_length-i)*wd_weights[i]*wd_weights[i]; 00204 00205 normalization_const=CMath::sqrt(normalization_const); 00206 SG_DEBUG("normalization_const:%f\n", normalization_const); 00207 } 00208 00213 inline float64_t get_normalization_const() { return normalization_const; } 00214 00215 00216 protected: 00221 int32_t set_wd_weights(); 00222 00231 static void compute_W( 00232 float64_t *sq_norm_W, float64_t *dp_WoldW, float64_t *alpha, 00233 uint32_t nSel, void* ptr ); 00234 00241 static float64_t update_W(float64_t t, void* ptr ); 00242 00248 static void* add_new_cut_helper(void* ptr); 00249 00258 static int add_new_cut( 00259 float64_t *new_col_H, uint32_t *new_cut, uint32_t cut_length, 00260 uint32_t nSel, void* ptr ); 00261 00267 static void* compute_output_helper(void* ptr); 00268 00274 static int compute_output( float64_t *output, void* ptr ); 00275 00282 static int sort( float64_t* vals, float64_t* data, uint32_t size); 00283 00285 static inline void print(ocas_return_value_T value) 00286 { 00287 return; 00288 } 00289 00290 00292 inline virtual const char* get_name() const { return "WDSVMOcas"; } 00293 00294 protected: 00303 virtual bool train_machine(CFeatures* data=NULL); 00304 00305 protected: 00307 CStringFeatures<uint8_t>* features; 00309 bool use_bias; 00311 int32_t bufsize; 00313 float64_t C1; 00315 float64_t C2; 00317 float64_t epsilon; 00319 E_SVM_TYPE method; 00320 00322 int32_t degree; 00324 int32_t from_degree; 00326 float32_t* wd_weights; 00328 int32_t num_vec; 00330 int32_t string_length; 00332 int32_t alphabet_size; 00333 00335 float64_t normalization_const; 00336 00338 float64_t bias; 00340 float64_t old_bias; 00342 int32_t* w_offsets; 00344 int32_t w_dim; 00346 int32_t w_dim_single_char; 00348 float32_t* w; 00350 float32_t* old_w; 00352 float64_t* lab; 00353 00355 float32_t** cuts; 00357 float64_t* cp_bias; 00358 }; 00359 } 00360 #endif