SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2008 Gunnar Raetsch 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef __PLIF_H__ 00012 #define __PLIF_H__ 00013 00014 #include <shogun/lib/common.h> 00015 #include <shogun/mathematics/Math.h> 00016 #include <shogun/structure/PlifBase.h> 00017 00018 namespace shogun 00019 { 00020 00022 enum ETransformType 00023 { 00025 T_LINEAR, 00027 T_LOG, 00029 T_LOG_PLUS1, 00031 T_LOG_PLUS3, 00033 T_LINEAR_PLUS3 00034 }; 00035 00037 class CPlif: public CPlifBase 00038 { 00039 public: 00044 CPlif(int32_t len=0); 00045 virtual ~CPlif(); 00046 00048 void init_penalty_struct_cache(); 00049 00056 float64_t lookup_penalty_svm( 00057 float64_t p_value, float64_t *d_values) const; 00058 00065 float64_t lookup_penalty( 00066 float64_t p_value, float64_t* svm_values) const; 00067 00074 float64_t lookup_penalty(int32_t p_value, float64_t* svm_values) const; 00075 00081 inline float64_t lookup(float64_t p_value) 00082 { 00083 ASSERT(use_svm == 0); 00084 return lookup_penalty(p_value, NULL); 00085 } 00086 00088 void penalty_clear_derivative(); 00089 00096 void penalty_add_derivative_svm( 00097 float64_t p_value, float64_t* svm_values, float64_t factor) ; 00098 00105 void penalty_add_derivative(float64_t p_value, float64_t* svm_values, float64_t factor); 00106 00112 const float64_t * get_cum_derivative(int32_t & p_len) const 00113 { 00114 p_len = len; 00115 return cum_derivatives; 00116 } 00117 00123 bool set_transform_type(const char *type_str); 00124 00129 const char* get_transform_type() 00130 { 00131 if (transform== T_LINEAR) 00132 return "linear"; 00133 else if (transform== T_LOG) 00134 return "log"; 00135 else if (transform== T_LOG_PLUS1) 00136 return "log(+1)"; 00137 else if (transform== T_LOG_PLUS3) 00138 return "log(+3)"; 00139 else if (transform== T_LINEAR_PLUS3) 00140 return "(+3)"; 00141 else 00142 SG_ERROR("wrong type"); 00143 return ""; 00144 } 00145 00146 00151 void set_id(int32_t p_id) 00152 { 00153 id=p_id; 00154 } 00155 00160 int32_t get_id() const 00161 { 00162 return id; 00163 } 00164 00169 int32_t get_max_id() const 00170 { 00171 return get_id(); 00172 } 00173 00178 void set_use_svm(int32_t p_use_svm) 00179 { 00180 invalidate_cache(); 00181 use_svm=p_use_svm; 00182 } 00183 00188 int32_t get_use_svm() const 00189 { 00190 return use_svm; 00191 } 00192 00197 virtual bool uses_svm_values() const 00198 { 00199 return (get_use_svm()!=0); 00200 } 00201 00206 void set_use_cache(int32_t p_use_cache) 00207 { 00208 invalidate_cache(); 00209 use_cache=p_use_cache; 00210 } 00211 00214 void invalidate_cache() 00215 { 00216 SG_FREE(cache); 00217 cache=NULL; 00218 } 00219 00224 int32_t get_use_cache() 00225 { 00226 return use_cache; 00227 } 00228 00235 void set_plif( 00236 int32_t p_len, float64_t *p_limits, float64_t* p_penalties) 00237 { 00238 ASSERT(len==p_len); 00239 00240 for (int32_t i=0; i<len; i++) 00241 { 00242 limits[i]=p_limits[i]; 00243 penalties[i]=p_penalties[i]; 00244 } 00245 00246 invalidate_cache(); 00247 penalty_clear_derivative(); 00248 } 00249 00254 void set_plif_limits(SGVector<float64_t> p_limits) 00255 { 00256 ASSERT(len==p_limits.vlen); 00257 00258 for (int32_t i=0; i<len; i++) 00259 limits[i]=p_limits.vector[i]; 00260 00261 invalidate_cache(); 00262 penalty_clear_derivative(); 00263 } 00264 00265 00270 void set_plif_penalty(SGVector<float64_t> p_penalties) 00271 { 00272 ASSERT(len==p_penalties.vlen); 00273 00274 for (int32_t i=0; i<len; i++) 00275 penalties[i]=p_penalties.vector[i]; 00276 00277 invalidate_cache(); 00278 penalty_clear_derivative(); 00279 } 00280 00285 void set_plif_length(int32_t p_len) 00286 { 00287 if (len!=p_len) 00288 { 00289 len=p_len; 00290 SG_FREE(limits); 00291 SG_FREE(penalties); 00292 SG_FREE(cum_derivatives); 00293 00294 SG_DEBUG( "set_plif len=%i\n", p_len); 00295 limits=SG_MALLOC(float64_t, len); 00296 penalties=SG_MALLOC(float64_t, len); 00297 cum_derivatives=SG_MALLOC(float64_t, len); 00298 } 00299 00300 for (int32_t i=0; i<len; i++) 00301 { 00302 limits[i]=0.0; 00303 penalties[i]=0.0; 00304 } 00305 00306 invalidate_cache(); 00307 penalty_clear_derivative(); 00308 } 00309 00314 float64_t* get_plif_limits() 00315 { 00316 return limits; 00317 } 00318 00323 float64_t* get_plif_penalties() 00324 { 00325 return penalties; 00326 } 00331 inline void set_max_value(float64_t p_max_value) 00332 { 00333 max_value=p_max_value; 00334 invalidate_cache(); 00335 } 00336 00341 virtual float64_t get_max_value() const 00342 { 00343 return max_value; 00344 } 00345 00350 inline void set_min_value(float64_t p_min_value) 00351 { 00352 min_value=p_min_value; 00353 invalidate_cache(); 00354 } 00355 00360 virtual float64_t get_min_value() const 00361 { 00362 return min_value; 00363 } 00364 00369 void set_plif_name(char *p_name); 00370 00375 inline char* get_plif_name() const 00376 { 00377 if (name) 00378 return name; 00379 else 00380 { 00381 char buf[20]; 00382 sprintf(buf, "plif%i", id); 00383 //name = strdup(buf); 00384 return strdup(buf); 00385 } 00386 } 00387 00392 bool get_do_calc(); 00393 00398 void set_do_calc(bool b); 00399 00403 void get_used_svms(int32_t* num_svms, int32_t* svm_ids); 00404 00409 inline int32_t get_plif_len() 00410 { 00411 return len; 00412 } 00413 00418 virtual void list_plif() const 00419 { 00420 SG_PRINT("CPlif(min_value=%1.2f, max_value=%1.2f, use_svm=%i)\n", min_value, max_value, use_svm) ; 00421 } 00422 00428 static void delete_penalty_struct(CPlif** PEN, int32_t P); 00429 00431 inline virtual const char* get_name() const { return "Plif"; } 00432 00433 protected: 00435 int32_t len; 00437 float64_t *limits; 00439 float64_t *penalties; 00441 float64_t *cum_derivatives; 00443 float64_t max_value; 00445 float64_t min_value; 00447 float64_t *cache; 00449 enum ETransformType transform; 00451 int32_t id; 00453 char * name; 00455 int32_t use_svm; 00457 bool use_cache; 00461 bool do_calc; 00462 }; 00463 } 00464 #endif