00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015 #ifndef _VIENNACL_DIRECT_SOLVE_HPP_
00016 #define _VIENNACL_DIRECT_SOLVE_HPP_
00017
00022 #include "viennacl/vector.hpp"
00023 #include "viennacl/matrix.hpp"
00024 #include "viennacl/tools/matrix_kernel_class_deducer.hpp"
00025 #include "viennacl/tools/matrix_solve_kernel_class_deducer.hpp"
00026 #include "viennacl/ocl/kernel.hpp"
00027 #include "viennacl/ocl/device.hpp"
00028 #include "viennacl/ocl/handle.hpp"
00029
00030
00031 namespace viennacl
00032 {
00033 namespace linalg
00034 {
00036
00041 template<typename SCALARTYPE, typename F1, typename F2, unsigned int A1, unsigned int A2, typename SOLVERTAG>
00042 void inplace_solve(const matrix<SCALARTYPE, F1, A1> & mat,
00043 matrix<SCALARTYPE, F2, A2> & B,
00044 SOLVERTAG)
00045 {
00046 assert(mat.size1() == mat.size2());
00047 assert(mat.size2() == B.size1());
00048
00049 typedef typename viennacl::tools::MATRIX_SOLVE_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F1, A1>,
00050 matrix<SCALARTYPE, F2, A2> >::ResultType KernelClass;
00051 KernelClass::init();
00052
00053 std::stringstream ss;
00054 ss << SOLVERTAG::name() << "_solve";
00055 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00056
00057 k.global_work_size(0, B.size2() * k.local_work_size());
00058 viennacl::ocl::enqueue(k(mat, mat.size1(), mat.size2(), mat.internal_size1(), mat.internal_size2(),
00059 B, B.size1(), B.size2(), B.internal_size1(), B.internal_size2()));
00060 }
00061
00067 template<typename SCALARTYPE, typename F1, typename F2, unsigned int A1, unsigned int A2, typename SOLVERTAG>
00068 void inplace_solve(const matrix<SCALARTYPE, F1, A1> & mat,
00069 const matrix_expression< const matrix<SCALARTYPE, F2, A2>,
00070 const matrix<SCALARTYPE, F2, A2>,
00071 op_trans> & B,
00072 SOLVERTAG)
00073 {
00074 assert(mat.size1() == mat.size2());
00075 assert(mat.size2() == B.lhs().size2());
00076
00077 typedef typename viennacl::tools::MATRIX_SOLVE_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F1, A1>,
00078 matrix<SCALARTYPE, F2, A2> >::ResultType KernelClass;
00079 KernelClass::init();
00080
00081 std::stringstream ss;
00082 ss << SOLVERTAG::name() << "_trans_solve";
00083 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00084
00085 k.global_work_size(0, B.lhs().size1() * k.local_work_size());
00086 viennacl::ocl::enqueue(k(mat, mat.size1(), mat.size2(), mat.internal_size1(), mat.internal_size2(),
00087 B.lhs(), B.lhs().size1(), B.lhs().size2(), B.lhs().internal_size1(), B.lhs().internal_size2()));
00088 }
00089
00090
00096 template<typename SCALARTYPE, typename F1, typename F2, unsigned int A1, unsigned int A2, typename SOLVERTAG>
00097 void inplace_solve(const matrix_expression< const matrix<SCALARTYPE, F1, A1>,
00098 const matrix<SCALARTYPE, F1, A1>,
00099 op_trans> & proxy,
00100 matrix<SCALARTYPE, F2, A2> & B,
00101 SOLVERTAG)
00102 {
00103 assert(proxy.lhs().size1() == proxy.lhs().size2());
00104 assert(proxy.lhs().size2() == B.size1());
00105
00106 typedef typename viennacl::tools::MATRIX_SOLVE_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F1, A1>,
00107 matrix<SCALARTYPE, F2, A2> >::ResultType KernelClass;
00108 KernelClass::init();
00109
00110 std::stringstream ss;
00111 ss << "trans_" << SOLVERTAG::name() << "_solve";
00112 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00113
00114 k.global_work_size(0, B.size2() * k.local_work_size());
00115 viennacl::ocl::enqueue(k(proxy.lhs(), proxy.lhs().size1(), proxy.lhs().size2(), proxy.lhs().internal_size1(), proxy.lhs().internal_size2(),
00116 B, B.size1(), B.size2(), B.internal_size1(), B.internal_size2()));
00117 }
00118
00124 template<typename SCALARTYPE, typename F1, typename F2, unsigned int A1, unsigned int A2, typename SOLVERTAG>
00125 void inplace_solve(const matrix_expression< const matrix<SCALARTYPE, F1, A1>,
00126 const matrix<SCALARTYPE, F1, A1>,
00127 op_trans> & proxy,
00128 const matrix_expression< const matrix<SCALARTYPE, F2, A2>,
00129 const matrix<SCALARTYPE, F2, A2>,
00130 op_trans> & B,
00131 SOLVERTAG)
00132 {
00133 assert(proxy.lhs().size1() == proxy.lhs().size2());
00134 assert(proxy.lhs().size2() == B.lhs().size2());
00135
00136 typedef typename viennacl::tools::MATRIX_SOLVE_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F1, A1>,
00137 matrix<SCALARTYPE, F2, A2> >::ResultType KernelClass;
00138 KernelClass::init();
00139
00140 std::stringstream ss;
00141 ss << "trans_" << SOLVERTAG::name() << "_trans_solve";
00142 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00143
00144 k.global_work_size(0, B.lhs().size1() * k.local_work_size());
00145 viennacl::ocl::enqueue(k(proxy.lhs(), proxy.lhs().size1(), proxy.lhs().size2(), proxy.lhs().internal_size1(), proxy.lhs().internal_size2(),
00146 B.lhs(), B.lhs().size1(), B.lhs().size2(), B.lhs().internal_size1(), B.lhs().internal_size2()));
00147 }
00148
00149 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT, typename SOLVERTAG>
00150 void inplace_solve(const matrix<SCALARTYPE, F, ALIGNMENT> & mat,
00151 vector<SCALARTYPE, VEC_ALIGNMENT> & vec,
00152 SOLVERTAG)
00153 {
00154 assert(mat.size1() == vec.size());
00155 assert(mat.size2() == vec.size());
00156
00157 typedef typename viennacl::tools::MATRIX_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F, ALIGNMENT> >::ResultType KernelClass;
00158
00159 std::stringstream ss;
00160 ss << SOLVERTAG::name() << "_triangular_substitute_inplace";
00161 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00162
00163 k.global_work_size(0, k.local_work_size());
00164 viennacl::ocl::enqueue(k(mat, mat.size1(), mat.size2(), mat.internal_size1(), mat.internal_size2(), vec));
00165 }
00166
00172 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT, typename SOLVERTAG>
00173 void inplace_solve(const matrix_expression< const matrix<SCALARTYPE, F, ALIGNMENT>,
00174 const matrix<SCALARTYPE, F, ALIGNMENT>,
00175 op_trans> & proxy,
00176 vector<SCALARTYPE, VEC_ALIGNMENT> & vec,
00177 SOLVERTAG)
00178 {
00179 assert(proxy.lhs().size1() == vec.size());
00180 assert(proxy.lhs().size2() == vec.size());
00181
00182 typedef typename viennacl::tools::MATRIX_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F, ALIGNMENT> >::ResultType KernelClass;
00183
00184 std::stringstream ss;
00185 ss << "trans_" << SOLVERTAG::name() << "_triangular_substitute_inplace";
00186 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00187
00188 k.global_work_size(0, k.local_work_size());
00189 viennacl::ocl::enqueue(k(proxy.lhs(), proxy.lhs().size1(), proxy.lhs().size2(),
00190 proxy.lhs().internal_size1(), proxy.lhs().internal_size2(), vec));
00191 }
00192
00194
00201 template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B, typename TAG>
00202 matrix<SCALARTYPE, F2, ALIGNMENT_B> solve(const matrix<SCALARTYPE, F1, ALIGNMENT_A> & A,
00203 const matrix<SCALARTYPE, F2, ALIGNMENT_B> & B,
00204 TAG const & tag)
00205 {
00206
00207 matrix<SCALARTYPE, F2, ALIGNMENT_A> result(B.size1(), B.size2());
00208 result = B;
00209
00210 inplace_solve(A, result, tag);
00211
00212 return result;
00213 }
00214
00221 template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B, typename TAG>
00222 matrix<SCALARTYPE, F2, ALIGNMENT_B> solve(const matrix<SCALARTYPE, F1, ALIGNMENT_A> & A,
00223 const matrix_expression< const matrix<SCALARTYPE, F2, ALIGNMENT_B>,
00224 const matrix<SCALARTYPE, F2, ALIGNMENT_B>,
00225 op_trans> & proxy,
00226 TAG const & tag)
00227 {
00228
00229 matrix<SCALARTYPE, F2, ALIGNMENT_B> result(proxy.lhs().size2(), proxy.lhs().size1());
00230 result = proxy;
00231
00232 inplace_solve(A, result, tag);
00233
00234 return result;
00235 }
00236
00243 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT, typename TAG>
00244 vector<SCALARTYPE, VEC_ALIGNMENT> solve(const matrix<SCALARTYPE, F, ALIGNMENT> & mat,
00245 const vector<SCALARTYPE, VEC_ALIGNMENT> & vec,
00246 TAG const & tag)
00247 {
00248
00249 vector<SCALARTYPE, VEC_ALIGNMENT> result(vec.size());
00250 result = vec;
00251
00252 inplace_solve(mat, result, tag);
00253
00254 return result;
00255 }
00256
00257
00259
00265 template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B, typename TAG>
00266 matrix<SCALARTYPE, F2, ALIGNMENT_B> solve(const matrix_expression< const matrix<SCALARTYPE, F1, ALIGNMENT_A>,
00267 const matrix<SCALARTYPE, F1, ALIGNMENT_A>,
00268 op_trans> & proxy,
00269 const matrix<SCALARTYPE, F2, ALIGNMENT_B> & B,
00270 TAG const & tag)
00271 {
00272
00273 matrix<SCALARTYPE, F2, ALIGNMENT_B> result(B.size1(), B.size2());
00274 result = B;
00275
00276 inplace_solve(proxy, result, tag);
00277
00278 return result;
00279 }
00280
00281
00288 template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B, typename TAG>
00289 matrix<SCALARTYPE, F2, ALIGNMENT_B> solve(const matrix_expression< const matrix<SCALARTYPE, F1, ALIGNMENT_A>,
00290 const matrix<SCALARTYPE, F1, ALIGNMENT_A>,
00291 op_trans> & proxy_A,
00292 const matrix_expression< const matrix<SCALARTYPE, F2, ALIGNMENT_B>,
00293 const matrix<SCALARTYPE, F2, ALIGNMENT_B>,
00294 op_trans> & proxy_B,
00295 TAG const & tag)
00296 {
00297
00298 matrix<SCALARTYPE, F2, ALIGNMENT_B> result(proxy_B.lhs().size2(), proxy_B.lhs().size1());
00299 result = trans(proxy_B.lhs());
00300
00301 inplace_solve(proxy_A, result, tag);
00302
00303 return result;
00304 }
00305
00312 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT, typename TAG>
00313 vector<SCALARTYPE, VEC_ALIGNMENT> solve(const matrix_expression< const matrix<SCALARTYPE, F, ALIGNMENT>,
00314 const matrix<SCALARTYPE, F, ALIGNMENT>,
00315 op_trans> & proxy,
00316 const vector<SCALARTYPE, VEC_ALIGNMENT> & vec,
00317 TAG const & tag)
00318 {
00319
00320 vector<SCALARTYPE, VEC_ALIGNMENT> result(vec.size());
00321 result = vec;
00322
00323 inplace_solve(proxy, result, tag);
00324
00325 return result;
00326 }
00327
00328
00330
00334 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT>
00335 void lu_factorize(matrix<SCALARTYPE, F, ALIGNMENT> & mat)
00336 {
00337 assert(mat.size1() == mat.size2());
00338
00339 typedef typename viennacl::tools::MATRIX_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F, ALIGNMENT> >::ResultType KernelClass;
00340
00341 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), "lu_factorize");
00342
00343 k.global_work_size(0, k.local_work_size());
00344 viennacl::ocl::enqueue(k(mat, mat.size1(), mat.size2(), mat.internal_size1(), mat.internal_size2()));
00345 }
00346
00347
00353 template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B>
00354 void lu_substitute(matrix<SCALARTYPE, F1, ALIGNMENT_A> const & A,
00355 matrix<SCALARTYPE, F2, ALIGNMENT_B> & B)
00356 {
00357 assert(A.size1() == A.size2());
00358 assert(A.size1() == A.size2());
00359 inplace_solve(A, B, unit_lower_tag());
00360 inplace_solve(A, B, upper_tag());
00361 }
00362
00368 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT>
00369 void lu_substitute(matrix<SCALARTYPE, F, ALIGNMENT> const & mat,
00370 vector<SCALARTYPE, VEC_ALIGNMENT> & vec)
00371 {
00372 assert(mat.size1() == mat.size2());
00373 inplace_solve(mat, vec, unit_lower_tag());
00374 inplace_solve(mat, vec, upper_tag());
00375 }
00376
00377 }
00378 }
00379
00380 #endif