SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2008-2009 Soeren Sonnenburg 00008 * Copyright (C) 2008-2009 Fraunhofer Institute FIRST and Max Planck Society 00009 */ 00010 00011 #include <shogun/lib/config.h> 00012 #include <shogun/base/SGObject.h> 00013 #include <shogun/io/SGIO.h> 00014 #include <shogun/base/Parallel.h> 00015 #include <shogun/base/init.h> 00016 #include <shogun/base/Version.h> 00017 #include <shogun/base/Parameter.h> 00018 00019 #include <stdlib.h> 00020 #include <stdio.h> 00021 00022 00023 namespace shogun 00024 { 00025 class CMath; 00026 class Parallel; 00027 class IO; 00028 class Version; 00029 00030 extern CMath* sg_math; 00031 extern Parallel* sg_parallel; 00032 extern SGIO* sg_io; 00033 extern Version* sg_version; 00034 00035 template<> void CSGObject::set_generic<bool>() 00036 { 00037 m_generic = PT_BOOL; 00038 } 00039 00040 template<> void CSGObject::set_generic<char>() 00041 { 00042 m_generic = PT_CHAR; 00043 } 00044 00045 template<> void CSGObject::set_generic<int8_t>() 00046 { 00047 m_generic = PT_INT8; 00048 } 00049 00050 template<> void CSGObject::set_generic<uint8_t>() 00051 { 00052 m_generic = PT_UINT8; 00053 } 00054 00055 template<> void CSGObject::set_generic<int16_t>() 00056 { 00057 m_generic = PT_INT16; 00058 } 00059 00060 template<> void CSGObject::set_generic<uint16_t>() 00061 { 00062 m_generic = PT_UINT16; 00063 } 00064 00065 template<> void CSGObject::set_generic<int32_t>() 00066 { 00067 m_generic = PT_INT32; 00068 } 00069 00070 template<> void CSGObject::set_generic<uint32_t>() 00071 { 00072 m_generic = PT_UINT32; 00073 } 00074 00075 template<> void CSGObject::set_generic<int64_t>() 00076 { 00077 m_generic = PT_INT64; 00078 } 00079 00080 template<> void CSGObject::set_generic<uint64_t>() 00081 { 00082 m_generic = PT_UINT64; 00083 } 00084 00085 template<> void CSGObject::set_generic<float32_t>() 00086 { 00087 m_generic = PT_FLOAT32; 00088 } 00089 00090 template<> void CSGObject::set_generic<float64_t>() 00091 { 00092 m_generic = PT_FLOAT64; 00093 } 00094 00095 template<> void CSGObject::set_generic<floatmax_t>() 00096 { 00097 m_generic = PT_FLOATMAX; 00098 } 00099 00100 } /* namespace shogun */ 00101 00102 using namespace shogun; 00103 00104 CSGObject::CSGObject() 00105 { 00106 init(); 00107 set_global_objects(); 00108 00109 SG_GCDEBUG("SGObject created (%p)\n", this); 00110 } 00111 00112 CSGObject::CSGObject(const CSGObject& orig) 00113 :io(orig.io), parallel(orig.parallel), version(orig.version) 00114 { 00115 init(); 00116 set_global_objects(); 00117 } 00118 00119 CSGObject::~CSGObject() 00120 { 00121 SG_GCDEBUG("SGObject destroyed (%p)\n", this); 00122 00123 #ifdef HAVE_PTHREAD 00124 PTHREAD_LOCK_DESTROY(&m_ref_lock); 00125 #endif 00126 unset_global_objects(); 00127 delete m_parameters; 00128 delete m_model_selection_parameters; 00129 } 00130 00131 #ifdef USE_REFERENCE_COUNTING 00132 00133 int32_t CSGObject::ref() 00134 { 00135 #ifdef HAVE_PTHREAD 00136 PTHREAD_LOCK(&m_ref_lock); 00137 #endif //HAVE_PTHREAD 00138 ++m_refcount; 00139 int32_t count=m_refcount; 00140 #ifdef HAVE_PTHREAD 00141 PTHREAD_UNLOCK(&m_ref_lock); 00142 #endif //HAVE_PTHREAD 00143 SG_GCDEBUG("ref() refcount %ld obj %s (%p) increased\n", count, this->get_name(), this); 00144 return m_refcount; 00145 } 00146 00147 int32_t CSGObject::ref_count() 00148 { 00149 #ifdef HAVE_PTHREAD 00150 PTHREAD_LOCK(&m_ref_lock); 00151 #endif //HAVE_PTHREAD 00152 int32_t count=m_refcount; 00153 #ifdef HAVE_PTHREAD 00154 PTHREAD_UNLOCK(&m_ref_lock); 00155 #endif //HAVE_PTHREAD 00156 SG_GCDEBUG("ref_count(): refcount %d, obj %s (%p)\n", count, this->get_name(), this); 00157 return count; 00158 } 00159 00160 int32_t CSGObject::unref() 00161 { 00162 #ifdef HAVE_PTHREAD 00163 PTHREAD_LOCK(&m_ref_lock); 00164 #endif //HAVE_PTHREAD 00165 if (m_refcount==0 || --m_refcount==0) 00166 { 00167 SG_GCDEBUG("unref() refcount %ld, obj %s (%p) destroying\n", m_refcount, this->get_name(), this); 00168 #ifdef HAVE_PTHREAD 00169 PTHREAD_UNLOCK(&m_ref_lock); 00170 #endif //HAVE_PTHREAD 00171 delete this; 00172 return 0; 00173 } 00174 else 00175 { 00176 SG_GCDEBUG("unref() refcount %ld obj %s (%p) decreased\n", m_refcount, this->get_name(), this); 00177 #ifdef HAVE_PTHREAD 00178 PTHREAD_UNLOCK(&m_ref_lock); 00179 #endif //HAVE_PTHREAD 00180 return m_refcount; 00181 } 00182 } 00183 #endif //USE_REFERENCE_COUNTING 00184 00185 00186 void CSGObject::set_global_objects() 00187 { 00188 if (!sg_io || !sg_parallel || !sg_version) 00189 { 00190 fprintf(stderr, "call init_shogun() before using the library, dying.\n"); 00191 exit(1); 00192 } 00193 00194 SG_REF(sg_io); 00195 SG_REF(sg_parallel); 00196 SG_REF(sg_version); 00197 00198 io=sg_io; 00199 parallel=sg_parallel; 00200 version=sg_version; 00201 } 00202 00203 void CSGObject::unset_global_objects() 00204 { 00205 SG_UNREF(version); 00206 SG_UNREF(parallel); 00207 SG_UNREF(io); 00208 } 00209 00210 void CSGObject::set_global_io(SGIO* new_io) 00211 { 00212 SG_UNREF(sg_io); 00213 sg_io=new_io; 00214 SG_REF(sg_io); 00215 } 00216 00217 SGIO* CSGObject::get_global_io() 00218 { 00219 SG_REF(sg_io); 00220 return sg_io; 00221 } 00222 00223 void CSGObject::set_global_parallel(Parallel* new_parallel) 00224 { 00225 SG_UNREF(sg_parallel); 00226 sg_parallel=new_parallel; 00227 SG_REF(sg_parallel); 00228 } 00229 00230 Parallel* CSGObject::get_global_parallel() 00231 { 00232 SG_REF(sg_parallel); 00233 return sg_parallel; 00234 } 00235 00236 void CSGObject::set_global_version(Version* new_version) 00237 { 00238 SG_UNREF(sg_version); 00239 sg_version=new_version; 00240 SG_REF(sg_version); 00241 } 00242 00243 Version* CSGObject::get_global_version() 00244 { 00245 SG_REF(sg_version); 00246 return sg_version; 00247 } 00248 00249 bool CSGObject::is_generic(EPrimitiveType* generic) const 00250 { 00251 *generic = m_generic; 00252 00253 return m_generic != PT_NOT_GENERIC; 00254 } 00255 00256 void CSGObject::unset_generic() 00257 { 00258 m_generic = PT_NOT_GENERIC; 00259 } 00260 00261 void CSGObject::print_serializable(const char* prefix) 00262 { 00263 SG_PRINT("\n%s\n================================================================================\n", get_name()); 00264 m_parameters->print(prefix); 00265 } 00266 00267 bool CSGObject::save_serializable(CSerializableFile* file, 00268 const char* prefix) 00269 { 00270 SG_DEBUG("START SAVING CSGObject '%s'\n", get_name()); 00271 try 00272 { 00273 save_serializable_pre(); 00274 } 00275 catch (ShogunException e) 00276 { 00277 SG_SWARNING("%s%s::save_serializable_pre(): ShogunException: " 00278 "%s\n", prefix, get_name(), 00279 e.get_exception_string()); 00280 return false; 00281 } 00282 if (!m_save_pre_called) 00283 { 00284 SG_SWARNING("%s%s::save_serializable_pre(): Implementation " 00285 "error: BASE_CLASS::LOAD_SERIALIZABLE_PRE() not " 00286 "called!\n", prefix, get_name()); 00287 return false; 00288 } 00289 00290 /* save parameter version */ 00291 if (!save_parameter_version(file, prefix)) 00292 return false; 00293 00294 if (!m_parameters->save(file, prefix)) 00295 return false; 00296 00297 try 00298 { 00299 save_serializable_post(); 00300 } 00301 catch (ShogunException e) 00302 { 00303 SG_SWARNING("%s%s::save_serializable_post(): ShogunException: " 00304 "%s\n", prefix, get_name(), 00305 e.get_exception_string()); 00306 return false; 00307 } 00308 00309 if (!m_save_post_called) 00310 { 00311 SG_SWARNING("%s%s::save_serializable_post(): Implementation " 00312 "error: BASE_CLASS::LOAD_SERIALIZABLE_POST() not " 00313 "called!\n", prefix, get_name()); 00314 return false; 00315 } 00316 00317 if (prefix == NULL || *prefix == '\0') 00318 file->close(); 00319 00320 SG_DEBUG("DONE SAVING CSGObject '%s' (%p)\n", get_name(), this); 00321 00322 return true;; 00323 } 00324 00325 bool CSGObject::load_serializable(CSerializableFile* file, 00326 const char* prefix) 00327 { 00328 SG_DEBUG("START LOADING CSGObject '%s'\n", get_name()); 00329 try 00330 { 00331 load_serializable_pre(); 00332 } 00333 catch (ShogunException e) 00334 { 00335 SG_SWARNING("%s%s::load_serializable_pre(): ShogunException: " 00336 "%s\n", prefix, get_name(), 00337 e.get_exception_string()); 00338 return false; 00339 } 00340 if (!m_load_pre_called) 00341 { 00342 SG_SWARNING("%s%s::load_serializable_pre(): Implementation " 00343 "error: BASE_CLASS::LOAD_SERIALIZABLE_PRE() not " 00344 "called!\n", prefix, get_name()); 00345 return false; 00346 } 00347 00348 /* try to load version of parameters */ 00349 int32_t file_version=load_parameter_version(file, prefix); 00350 00351 if (file_version<0) 00352 { 00353 SG_WARNING("%s%s::load_serializable(): File contains no parameter " 00354 "version. Seems like your file is from the days before this " 00355 "was introduced. Ignore warning or serialize with this version " 00356 "of shogun to get rid of above and this warnings.\n", 00357 prefix, get_name()); 00358 } 00359 00360 if (file_version>version->get_version_parameter()) 00361 { 00362 SG_WARNING("%s%s::load_serializable(): parameter version of file " 00363 "larger than the one of shogun. Try with a more recent version " 00364 "of shogun.\n", prefix, get_name()); 00365 return false; 00366 } 00367 00368 if (!m_parameters->load(file, prefix)) 00369 return false; 00370 00371 try 00372 { 00373 load_serializable_post(); 00374 } 00375 catch (ShogunException e) 00376 { 00377 SG_SWARNING("%s%s::load_serializable_post(): ShogunException: " 00378 "%s\n", prefix, get_name(), 00379 e.get_exception_string()); 00380 return false; 00381 } 00382 00383 if (!m_load_post_called) 00384 { 00385 SG_SWARNING("%s%s::load_serializable_post(): Implementation " 00386 "error: BASE_CLASS::LOAD_SERIALIZABLE_POST() not " 00387 "called!\n", prefix, get_name()); 00388 return false; 00389 } 00390 SG_DEBUG("DONE LOADING CSGObject '%s' (%p)\n", get_name(), this); 00391 00392 return true; 00393 } 00394 00395 bool CSGObject::save_parameter_version(CSerializableFile* file, 00396 const char* prefix) 00397 { 00398 TSGDataType t(CT_SCALAR, ST_NONE, PT_INT32); 00399 int32_t v=version->get_version_parameter(); 00400 TParameter p(&t, &v, "version_parameter", 00401 "Version of parameters of this object"); 00402 return p.save(file, prefix); 00403 } 00404 00405 int32_t CSGObject::load_parameter_version(CSerializableFile* file, 00406 const char* prefix) 00407 { 00408 TSGDataType t(CT_SCALAR, ST_NONE, PT_INT32); 00409 int32_t v; 00410 TParameter tp(&t, &v, "version_parameter", ""); 00411 if (tp.load(file, prefix)) 00412 return v; 00413 else 00414 return -1; 00415 } 00416 00417 void CSGObject::load_serializable_pre() throw (ShogunException) 00418 { 00419 m_load_pre_called = true; 00420 } 00421 00422 void CSGObject::load_serializable_post() throw (ShogunException) 00423 { 00424 m_load_post_called = true; 00425 } 00426 00427 void CSGObject::save_serializable_pre() throw (ShogunException) 00428 { 00429 m_save_pre_called = true; 00430 } 00431 00432 void CSGObject::save_serializable_post() throw (ShogunException) 00433 { 00434 m_save_post_called = true; 00435 } 00436 00437 #ifdef TRACE_MEMORY_ALLOCS 00438 #include <shogun/lib/Set.h> 00439 extern CSet<shogun::MemoryBlock>* sg_mallocs; 00440 #endif 00441 00442 void CSGObject::init() 00443 { 00444 #ifdef HAVE_PTHREAD 00445 PTHREAD_LOCK_INIT(&m_ref_lock); 00446 #endif 00447 00448 #ifdef TRACE_MEMORY_ALLOCS 00449 if (sg_mallocs) 00450 { 00451 int32_t idx=sg_mallocs->index_of(MemoryBlock(this)); 00452 if (idx>-1) 00453 { 00454 MemoryBlock* b=sg_mallocs->get_element_ptr(idx); 00455 b->set_sgobject(); 00456 } 00457 } 00458 #endif 00459 00460 m_refcount = 0; 00461 io = NULL; 00462 parallel = NULL; 00463 version = NULL; 00464 m_parameters = new Parameter(); 00465 m_model_selection_parameters = new Parameter(); 00466 m_generic = PT_NOT_GENERIC; 00467 m_load_pre_called = false; 00468 m_load_post_called = false; 00469 } 00470 00471 SGVector<char*> CSGObject::get_modelsel_names() 00472 { 00473 SGVector<char*> result=SGVector<char*>( 00474 m_model_selection_parameters->get_num_parameters()); 00475 00476 for (index_t i=0; i<result.vlen; ++i) 00477 result.vector[i]=m_model_selection_parameters->get_parameter(i)->m_name; 00478 00479 return result; 00480 } 00481 00482 char* CSGObject::get_modsel_param_descr(const char* param_name) 00483 { 00484 index_t index=get_modsel_param_index(param_name); 00485 00486 if (index<0) 00487 { 00488 SG_ERROR("There is no model selection parameter called \"%s\" for %s", 00489 param_name, get_name()); 00490 } 00491 00492 return m_model_selection_parameters->get_parameter(index)->m_description; 00493 } 00494 00495 index_t CSGObject::get_modsel_param_index(const char* param_name) 00496 { 00497 /* use fact that names extracted from below method are in same order than 00498 * in m_model_selection_parameters variable */ 00499 SGVector<char*> names=get_modelsel_names(); 00500 00501 /* search for parameter with provided name */ 00502 index_t index=-1; 00503 for (index_t i=0; i<names.vlen; ++i) 00504 { 00505 TParameter* current=m_model_selection_parameters->get_parameter(i); 00506 if (!strcmp(param_name, current->m_name)) 00507 { 00508 index=i; 00509 break; 00510 } 00511 } 00512 00513 /* clean up */ 00514 names.destroy_vector(); 00515 00516 return index; 00517 }