• Main Page
  • Namespaces
  • Data Structures
  • Files
  • File List
  • Globals

/data/development/ViennaCL/ViennaCL-1.1.2/viennacl/linalg/direct_solve.hpp

Go to the documentation of this file.
00001 /* =======================================================================
00002    Copyright (c) 2010, Institute for Microelectronics, TU Vienna.
00003    http://www.iue.tuwien.ac.at
00004                              -----------------
00005                      ViennaCL - The Vienna Computing Library
00006                              -----------------
00007                             
00008    authors:    Karl Rupp                          rupp@iue.tuwien.ac.at
00009                Florian Rudolf                     flo.rudy+viennacl@gmail.com
00010                Josef Weinbub                      weinbub@iue.tuwien.ac.at
00011 
00012    license:    MIT (X11), see file LICENSE in the ViennaCL base directory
00013 ======================================================================= */
00014 
00015 #ifndef _VIENNACL_DIRECT_SOLVE_HPP_
00016 #define _VIENNACL_DIRECT_SOLVE_HPP_
00017 
00022 #include "viennacl/vector.hpp"
00023 #include "viennacl/matrix.hpp"
00024 #include "viennacl/tools/matrix_kernel_class_deducer.hpp"
00025 #include "viennacl/tools/matrix_solve_kernel_class_deducer.hpp"
00026 #include "viennacl/ocl/kernel.hpp"
00027 #include "viennacl/ocl/device.hpp"
00028 #include "viennacl/ocl/handle.hpp"
00029 
00030 
00031 namespace viennacl
00032 {
00033   namespace linalg
00034   {
00036 
00041     template<typename SCALARTYPE, typename F1, typename F2, unsigned int A1, unsigned int A2, typename SOLVERTAG>
00042     void inplace_solve(const matrix<SCALARTYPE, F1, A1> & mat,
00043                        matrix<SCALARTYPE, F2, A2> & B,
00044                        SOLVERTAG)
00045     {
00046       assert(mat.size1() == mat.size2());
00047       assert(mat.size2() == B.size1());
00048       
00049       typedef typename viennacl::tools::MATRIX_SOLVE_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F1, A1>,
00050                                                                            matrix<SCALARTYPE, F2, A2> >::ResultType    KernelClass;
00051       KernelClass::init();
00052       
00053       std::stringstream ss;
00054       ss << SOLVERTAG::name() << "_solve";
00055       viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00056 
00057       k.global_work_size(0, B.size2() * k.local_work_size());
00058       viennacl::ocl::enqueue(k(mat, mat.size1(), mat.size2(), mat.internal_size1(), mat.internal_size2(),
00059                                                              B,   B.size1(),   B.size2(),   B.internal_size1(),   B.internal_size2()));        
00060     }
00061     
00067     template<typename SCALARTYPE, typename F1, typename F2, unsigned int A1, unsigned int A2, typename SOLVERTAG>
00068     void inplace_solve(const matrix<SCALARTYPE, F1, A1> & mat,
00069                        const matrix_expression< const matrix<SCALARTYPE, F2, A2>,
00070                                                 const matrix<SCALARTYPE, F2, A2>,
00071                                                 op_trans> & B,
00072                        SOLVERTAG)
00073     {
00074       assert(mat.size1() == mat.size2());
00075       assert(mat.size2() == B.lhs().size2());
00076       
00077       typedef typename viennacl::tools::MATRIX_SOLVE_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F1, A1>,
00078                                                                            matrix<SCALARTYPE, F2, A2> >::ResultType    KernelClass;
00079       KernelClass::init();
00080 
00081       std::stringstream ss;
00082       ss << SOLVERTAG::name() << "_trans_solve";
00083       viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00084 
00085       k.global_work_size(0, B.lhs().size1() * k.local_work_size());
00086       viennacl::ocl::enqueue(k(mat, mat.size1(), mat.size2(), mat.internal_size1(), mat.internal_size2(),
00087                                B.lhs(), B.lhs().size1(), B.lhs().size2(), B.lhs().internal_size1(), B.lhs().internal_size2()));     
00088     }
00089     
00090     //upper triangular solver for transposed lower triangular matrices
00096     template<typename SCALARTYPE, typename F1, typename F2, unsigned int A1, unsigned int A2, typename SOLVERTAG>
00097     void inplace_solve(const matrix_expression< const matrix<SCALARTYPE, F1, A1>,
00098                                                 const matrix<SCALARTYPE, F1, A1>,
00099                                                 op_trans> & proxy,
00100                        matrix<SCALARTYPE, F2, A2> & B,
00101                        SOLVERTAG)
00102     {
00103       assert(proxy.lhs().size1() == proxy.lhs().size2());
00104       assert(proxy.lhs().size2() == B.size1());
00105       
00106       typedef typename viennacl::tools::MATRIX_SOLVE_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F1, A1>,
00107                                                                            matrix<SCALARTYPE, F2, A2> >::ResultType    KernelClass;
00108       KernelClass::init();
00109 
00110       std::stringstream ss;
00111       ss << "trans_" << SOLVERTAG::name() << "_solve";
00112       viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00113 
00114       k.global_work_size(0, B.size2() * k.local_work_size());
00115       viennacl::ocl::enqueue(k(proxy.lhs(), proxy.lhs().size1(), proxy.lhs().size2(), proxy.lhs().internal_size1(), proxy.lhs().internal_size2(),
00116                                          B,   B.size1(),   B.size2(),   B.internal_size1(),   B.internal_size2()));        
00117     }
00118 
00124     template<typename SCALARTYPE, typename F1, typename F2, unsigned int A1, unsigned int A2, typename SOLVERTAG>
00125     void inplace_solve(const matrix_expression< const matrix<SCALARTYPE, F1, A1>,
00126                                                 const matrix<SCALARTYPE, F1, A1>,
00127                                                 op_trans> & proxy,
00128                        const matrix_expression< const matrix<SCALARTYPE, F2, A2>,
00129                                                 const matrix<SCALARTYPE, F2, A2>,
00130                                                 op_trans> & B,
00131                        SOLVERTAG)
00132     {
00133       assert(proxy.lhs().size1() == proxy.lhs().size2());
00134       assert(proxy.lhs().size2() == B.lhs().size2());
00135       
00136       typedef typename viennacl::tools::MATRIX_SOLVE_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F1, A1>,
00137                                                                            matrix<SCALARTYPE, F2, A2> >::ResultType    KernelClass;
00138       KernelClass::init();
00139 
00140       std::stringstream ss;
00141       ss << "trans_" << SOLVERTAG::name() << "_trans_solve";
00142       viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00143 
00144       k.global_work_size(0, B.lhs().size1() * k.local_work_size());
00145       viennacl::ocl::enqueue(k(proxy.lhs(), proxy.lhs().size1(), proxy.lhs().size2(), proxy.lhs().internal_size1(), proxy.lhs().internal_size2(),
00146                                B.lhs(), B.lhs().size1(), B.lhs().size2(), B.lhs().internal_size1(), B.lhs().internal_size2()));        
00147     }
00148 
00149     template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT, typename SOLVERTAG>
00150     void inplace_solve(const matrix<SCALARTYPE, F, ALIGNMENT> & mat,
00151                        vector<SCALARTYPE, VEC_ALIGNMENT> & vec,
00152                        SOLVERTAG)
00153     {
00154       assert(mat.size1() == vec.size());
00155       assert(mat.size2() == vec.size());
00156       
00157       typedef typename viennacl::tools::MATRIX_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F, ALIGNMENT> >::ResultType    KernelClass;
00158 
00159       std::stringstream ss;
00160       ss << SOLVERTAG::name() << "_triangular_substitute_inplace";
00161       viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00162 
00163       k.global_work_size(0, k.local_work_size());
00164       viennacl::ocl::enqueue(k(mat, mat.size1(), mat.size2(), mat.internal_size1(), mat.internal_size2(), vec));        
00165     }
00166 
00172     template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT, typename SOLVERTAG>
00173     void inplace_solve(const matrix_expression< const matrix<SCALARTYPE, F, ALIGNMENT>,
00174                                                 const matrix<SCALARTYPE, F, ALIGNMENT>,
00175                                                 op_trans> & proxy,
00176                        vector<SCALARTYPE, VEC_ALIGNMENT> & vec,
00177                        SOLVERTAG)
00178     {
00179       assert(proxy.lhs().size1() == vec.size());
00180       assert(proxy.lhs().size2() == vec.size());
00181 
00182       typedef typename viennacl::tools::MATRIX_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F, ALIGNMENT> >::ResultType    KernelClass;
00183       
00184       std::stringstream ss;
00185       ss << "trans_" << SOLVERTAG::name() << "_triangular_substitute_inplace";
00186       viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00187       
00188       k.global_work_size(0, k.local_work_size());
00189       viennacl::ocl::enqueue(k(proxy.lhs(), proxy.lhs().size1(), proxy.lhs().size2(),
00190                                                            proxy.lhs().internal_size1(), proxy.lhs().internal_size2(), vec));        
00191     }
00192     
00194 
00201     template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B, typename TAG>
00202     matrix<SCALARTYPE, F2, ALIGNMENT_B> solve(const matrix<SCALARTYPE, F1, ALIGNMENT_A> & A,
00203                                         const matrix<SCALARTYPE, F2, ALIGNMENT_B> & B,
00204                                         TAG const & tag)
00205     {
00206       // do an inplace solve on the result vector:
00207       matrix<SCALARTYPE, F2, ALIGNMENT_A> result(B.size1(), B.size2());
00208       result = B;
00209     
00210       inplace_solve(A, result, tag);
00211     
00212       return result;
00213     }
00214 
00221     template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B, typename TAG>
00222     matrix<SCALARTYPE, F2, ALIGNMENT_B> solve(const matrix<SCALARTYPE, F1, ALIGNMENT_A> & A,
00223                                         const matrix_expression< const matrix<SCALARTYPE, F2, ALIGNMENT_B>,
00224                                                                      const matrix<SCALARTYPE, F2, ALIGNMENT_B>,
00225                                                                      op_trans> & proxy,
00226                                         TAG const & tag)
00227     {
00228       // do an inplace solve on the result vector:
00229       matrix<SCALARTYPE, F2, ALIGNMENT_B> result(proxy.lhs().size2(), proxy.lhs().size1());
00230       result = proxy;
00231     
00232       inplace_solve(A, result, tag);
00233     
00234       return result;
00235     }
00236 
00243     template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT, typename TAG>
00244     vector<SCALARTYPE, VEC_ALIGNMENT> solve(const matrix<SCALARTYPE, F, ALIGNMENT> & mat,
00245                                         const vector<SCALARTYPE, VEC_ALIGNMENT> & vec,
00246                                         TAG const & tag)
00247     {
00248       // do an inplace solve on the result vector:
00249       vector<SCALARTYPE, VEC_ALIGNMENT> result(vec.size());
00250       result = vec;
00251     
00252       inplace_solve(mat, result, tag);
00253     
00254       return result;
00255     }
00256     
00257     
00259 
00265     template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B, typename TAG>
00266     matrix<SCALARTYPE, F2, ALIGNMENT_B> solve(const matrix_expression< const matrix<SCALARTYPE, F1, ALIGNMENT_A>,
00267                                                                      const matrix<SCALARTYPE, F1, ALIGNMENT_A>,
00268                                                                      op_trans> & proxy,
00269                                             const matrix<SCALARTYPE, F2, ALIGNMENT_B> & B,
00270                                             TAG const & tag)
00271     {
00272       // do an inplace solve on the result vector:
00273       matrix<SCALARTYPE, F2, ALIGNMENT_B> result(B.size1(), B.size2());
00274       result = B;
00275     
00276       inplace_solve(proxy, result, tag);
00277     
00278       return result;
00279     }
00280     
00281     
00288     template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B, typename TAG>
00289     matrix<SCALARTYPE, F2, ALIGNMENT_B> solve(const matrix_expression< const matrix<SCALARTYPE, F1, ALIGNMENT_A>,
00290                                                                      const matrix<SCALARTYPE, F1, ALIGNMENT_A>,
00291                                                                      op_trans> & proxy_A,
00292                                             const matrix_expression< const matrix<SCALARTYPE, F2, ALIGNMENT_B>,
00293                                                                      const matrix<SCALARTYPE, F2, ALIGNMENT_B>,
00294                                                                      op_trans> & proxy_B,
00295                                             TAG const & tag)
00296     {
00297       // do an inplace solve on the result vector:
00298       matrix<SCALARTYPE, F2, ALIGNMENT_B> result(proxy_B.lhs().size2(), proxy_B.lhs().size1());
00299       result = trans(proxy_B.lhs());
00300     
00301       inplace_solve(proxy_A, result, tag);
00302     
00303       return result;
00304     }
00305     
00312     template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT, typename TAG>
00313     vector<SCALARTYPE, VEC_ALIGNMENT> solve(const matrix_expression< const matrix<SCALARTYPE, F, ALIGNMENT>,
00314                                                                      const matrix<SCALARTYPE, F, ALIGNMENT>,
00315                                                                      op_trans> & proxy,
00316                                             const vector<SCALARTYPE, VEC_ALIGNMENT> & vec,
00317                                             TAG const & tag)
00318     {
00319       // do an inplace solve on the result vector:
00320       vector<SCALARTYPE, VEC_ALIGNMENT> result(vec.size());
00321       result = vec;
00322     
00323       inplace_solve(proxy, result, tag);
00324     
00325       return result;
00326     }
00327     
00328     
00330 
00334     template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT>
00335     void lu_factorize(matrix<SCALARTYPE, F, ALIGNMENT> & mat)
00336     {
00337       assert(mat.size1() == mat.size2());
00338 
00339       typedef typename viennacl::tools::MATRIX_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F, ALIGNMENT> >::ResultType    KernelClass;
00340       
00341       viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), "lu_factorize");
00342       
00343       k.global_work_size(0, k.local_work_size());
00344       viennacl::ocl::enqueue(k(mat, mat.size1(), mat.size2(), mat.internal_size1(), mat.internal_size2()));        
00345     }
00346 
00347 
00353     template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B>
00354     void lu_substitute(matrix<SCALARTYPE, F1, ALIGNMENT_A> const & A,
00355                        matrix<SCALARTYPE, F2, ALIGNMENT_B> & B)
00356     {
00357       assert(A.size1() == A.size2());
00358       assert(A.size1() == A.size2());
00359       inplace_solve(A, B, unit_lower_tag());
00360       inplace_solve(A, B, upper_tag());
00361     }
00362 
00368     template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT>
00369     void lu_substitute(matrix<SCALARTYPE, F, ALIGNMENT> const & mat,
00370                        vector<SCALARTYPE, VEC_ALIGNMENT> & vec)
00371     {
00372       assert(mat.size1() == mat.size2());
00373       inplace_solve(mat, vec, unit_lower_tag());
00374       inplace_solve(mat, vec, upper_tag());
00375     }
00376 
00377   }
00378 }
00379 
00380 #endif

Generated on Sat May 21 2011 20:36:50 for ViennaCL - The Vienna Computing Library by  doxygen 1.7.1