|
ViennaCL - The Vienna Computing Library
1.1.2
|
00001 /* ======================================================================= 00002 Copyright (c) 2010, Institute for Microelectronics, TU Vienna. 00003 http://www.iue.tuwien.ac.at 00004 ----------------- 00005 ViennaCL - The Vienna Computing Library 00006 ----------------- 00007 00008 authors: Karl Rupp rupp@iue.tuwien.ac.at 00009 Florian Rudolf flo.rudy+viennacl@gmail.com 00010 Josef Weinbub weinbub@iue.tuwien.ac.at 00011 00012 license: MIT (X11), see file LICENSE in the ViennaCL base directory 00013 ======================================================================= */ 00014 00015 #ifndef _VIENNACL_DIRECT_SOLVE_HPP_ 00016 #define _VIENNACL_DIRECT_SOLVE_HPP_ 00017 00022 #include "viennacl/vector.hpp" 00023 #include "viennacl/matrix.hpp" 00024 #include "viennacl/tools/matrix_kernel_class_deducer.hpp" 00025 #include "viennacl/tools/matrix_solve_kernel_class_deducer.hpp" 00026 #include "viennacl/ocl/kernel.hpp" 00027 #include "viennacl/ocl/device.hpp" 00028 #include "viennacl/ocl/handle.hpp" 00029 00030 00031 namespace viennacl 00032 { 00033 namespace linalg 00034 { 00036 00041 template<typename SCALARTYPE, typename F1, typename F2, unsigned int A1, unsigned int A2, typename SOLVERTAG> 00042 void inplace_solve(const matrix<SCALARTYPE, F1, A1> & mat, 00043 matrix<SCALARTYPE, F2, A2> & B, 00044 SOLVERTAG) 00045 { 00046 assert(mat.size1() == mat.size2()); 00047 assert(mat.size2() == B.size1()); 00048 00049 typedef typename viennacl::tools::MATRIX_SOLVE_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F1, A1>, 00050 matrix<SCALARTYPE, F2, A2> >::ResultType KernelClass; 00051 KernelClass::init(); 00052 00053 std::stringstream ss; 00054 ss << SOLVERTAG::name() << "_solve"; 00055 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str()); 00056 00057 k.global_work_size(0, B.size2() * k.local_work_size()); 00058 viennacl::ocl::enqueue(k(mat, mat.size1(), mat.size2(), mat.internal_size1(), mat.internal_size2(), 00059 B, B.size1(), B.size2(), B.internal_size1(), B.internal_size2())); 00060 } 00061 00067 template<typename SCALARTYPE, typename F1, typename F2, unsigned int A1, unsigned int A2, typename SOLVERTAG> 00068 void inplace_solve(const matrix<SCALARTYPE, F1, A1> & mat, 00069 const matrix_expression< const matrix<SCALARTYPE, F2, A2>, 00070 const matrix<SCALARTYPE, F2, A2>, 00071 op_trans> & B, 00072 SOLVERTAG) 00073 { 00074 assert(mat.size1() == mat.size2()); 00075 assert(mat.size2() == B.lhs().size2()); 00076 00077 typedef typename viennacl::tools::MATRIX_SOLVE_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F1, A1>, 00078 matrix<SCALARTYPE, F2, A2> >::ResultType KernelClass; 00079 KernelClass::init(); 00080 00081 std::stringstream ss; 00082 ss << SOLVERTAG::name() << "_trans_solve"; 00083 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str()); 00084 00085 k.global_work_size(0, B.lhs().size1() * k.local_work_size()); 00086 viennacl::ocl::enqueue(k(mat, mat.size1(), mat.size2(), mat.internal_size1(), mat.internal_size2(), 00087 B.lhs(), B.lhs().size1(), B.lhs().size2(), B.lhs().internal_size1(), B.lhs().internal_size2())); 00088 } 00089 00090 //upper triangular solver for transposed lower triangular matrices 00096 template<typename SCALARTYPE, typename F1, typename F2, unsigned int A1, unsigned int A2, typename SOLVERTAG> 00097 void inplace_solve(const matrix_expression< const matrix<SCALARTYPE, F1, A1>, 00098 const matrix<SCALARTYPE, F1, A1>, 00099 op_trans> & proxy, 00100 matrix<SCALARTYPE, F2, A2> & B, 00101 SOLVERTAG) 00102 { 00103 assert(proxy.lhs().size1() == proxy.lhs().size2()); 00104 assert(proxy.lhs().size2() == B.size1()); 00105 00106 typedef typename viennacl::tools::MATRIX_SOLVE_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F1, A1>, 00107 matrix<SCALARTYPE, F2, A2> >::ResultType KernelClass; 00108 KernelClass::init(); 00109 00110 std::stringstream ss; 00111 ss << "trans_" << SOLVERTAG::name() << "_solve"; 00112 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str()); 00113 00114 k.global_work_size(0, B.size2() * k.local_work_size()); 00115 viennacl::ocl::enqueue(k(proxy.lhs(), proxy.lhs().size1(), proxy.lhs().size2(), proxy.lhs().internal_size1(), proxy.lhs().internal_size2(), 00116 B, B.size1(), B.size2(), B.internal_size1(), B.internal_size2())); 00117 } 00118 00124 template<typename SCALARTYPE, typename F1, typename F2, unsigned int A1, unsigned int A2, typename SOLVERTAG> 00125 void inplace_solve(const matrix_expression< const matrix<SCALARTYPE, F1, A1>, 00126 const matrix<SCALARTYPE, F1, A1>, 00127 op_trans> & proxy, 00128 const matrix_expression< const matrix<SCALARTYPE, F2, A2>, 00129 const matrix<SCALARTYPE, F2, A2>, 00130 op_trans> & B, 00131 SOLVERTAG) 00132 { 00133 assert(proxy.lhs().size1() == proxy.lhs().size2()); 00134 assert(proxy.lhs().size2() == B.lhs().size2()); 00135 00136 typedef typename viennacl::tools::MATRIX_SOLVE_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F1, A1>, 00137 matrix<SCALARTYPE, F2, A2> >::ResultType KernelClass; 00138 KernelClass::init(); 00139 00140 std::stringstream ss; 00141 ss << "trans_" << SOLVERTAG::name() << "_trans_solve"; 00142 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str()); 00143 00144 k.global_work_size(0, B.lhs().size1() * k.local_work_size()); 00145 viennacl::ocl::enqueue(k(proxy.lhs(), proxy.lhs().size1(), proxy.lhs().size2(), proxy.lhs().internal_size1(), proxy.lhs().internal_size2(), 00146 B.lhs(), B.lhs().size1(), B.lhs().size2(), B.lhs().internal_size1(), B.lhs().internal_size2())); 00147 } 00148 00149 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT, typename SOLVERTAG> 00150 void inplace_solve(const matrix<SCALARTYPE, F, ALIGNMENT> & mat, 00151 vector<SCALARTYPE, VEC_ALIGNMENT> & vec, 00152 SOLVERTAG) 00153 { 00154 assert(mat.size1() == vec.size()); 00155 assert(mat.size2() == vec.size()); 00156 00157 typedef typename viennacl::tools::MATRIX_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F, ALIGNMENT> >::ResultType KernelClass; 00158 00159 std::stringstream ss; 00160 ss << SOLVERTAG::name() << "_triangular_substitute_inplace"; 00161 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str()); 00162 00163 k.global_work_size(0, k.local_work_size()); 00164 viennacl::ocl::enqueue(k(mat, mat.size1(), mat.size2(), mat.internal_size1(), mat.internal_size2(), vec)); 00165 } 00166 00172 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT, typename SOLVERTAG> 00173 void inplace_solve(const matrix_expression< const matrix<SCALARTYPE, F, ALIGNMENT>, 00174 const matrix<SCALARTYPE, F, ALIGNMENT>, 00175 op_trans> & proxy, 00176 vector<SCALARTYPE, VEC_ALIGNMENT> & vec, 00177 SOLVERTAG) 00178 { 00179 assert(proxy.lhs().size1() == vec.size()); 00180 assert(proxy.lhs().size2() == vec.size()); 00181 00182 typedef typename viennacl::tools::MATRIX_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F, ALIGNMENT> >::ResultType KernelClass; 00183 00184 std::stringstream ss; 00185 ss << "trans_" << SOLVERTAG::name() << "_triangular_substitute_inplace"; 00186 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str()); 00187 00188 k.global_work_size(0, k.local_work_size()); 00189 viennacl::ocl::enqueue(k(proxy.lhs(), proxy.lhs().size1(), proxy.lhs().size2(), 00190 proxy.lhs().internal_size1(), proxy.lhs().internal_size2(), vec)); 00191 } 00192 00194 00201 template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B, typename TAG> 00202 matrix<SCALARTYPE, F2, ALIGNMENT_B> solve(const matrix<SCALARTYPE, F1, ALIGNMENT_A> & A, 00203 const matrix<SCALARTYPE, F2, ALIGNMENT_B> & B, 00204 TAG const & tag) 00205 { 00206 // do an inplace solve on the result vector: 00207 matrix<SCALARTYPE, F2, ALIGNMENT_A> result(B.size1(), B.size2()); 00208 result = B; 00209 00210 inplace_solve(A, result, tag); 00211 00212 return result; 00213 } 00214 00221 template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B, typename TAG> 00222 matrix<SCALARTYPE, F2, ALIGNMENT_B> solve(const matrix<SCALARTYPE, F1, ALIGNMENT_A> & A, 00223 const matrix_expression< const matrix<SCALARTYPE, F2, ALIGNMENT_B>, 00224 const matrix<SCALARTYPE, F2, ALIGNMENT_B>, 00225 op_trans> & proxy, 00226 TAG const & tag) 00227 { 00228 // do an inplace solve on the result vector: 00229 matrix<SCALARTYPE, F2, ALIGNMENT_B> result(proxy.lhs().size2(), proxy.lhs().size1()); 00230 result = proxy; 00231 00232 inplace_solve(A, result, tag); 00233 00234 return result; 00235 } 00236 00243 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT, typename TAG> 00244 vector<SCALARTYPE, VEC_ALIGNMENT> solve(const matrix<SCALARTYPE, F, ALIGNMENT> & mat, 00245 const vector<SCALARTYPE, VEC_ALIGNMENT> & vec, 00246 TAG const & tag) 00247 { 00248 // do an inplace solve on the result vector: 00249 vector<SCALARTYPE, VEC_ALIGNMENT> result(vec.size()); 00250 result = vec; 00251 00252 inplace_solve(mat, result, tag); 00253 00254 return result; 00255 } 00256 00257 00259 00265 template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B, typename TAG> 00266 matrix<SCALARTYPE, F2, ALIGNMENT_B> solve(const matrix_expression< const matrix<SCALARTYPE, F1, ALIGNMENT_A>, 00267 const matrix<SCALARTYPE, F1, ALIGNMENT_A>, 00268 op_trans> & proxy, 00269 const matrix<SCALARTYPE, F2, ALIGNMENT_B> & B, 00270 TAG const & tag) 00271 { 00272 // do an inplace solve on the result vector: 00273 matrix<SCALARTYPE, F2, ALIGNMENT_B> result(B.size1(), B.size2()); 00274 result = B; 00275 00276 inplace_solve(proxy, result, tag); 00277 00278 return result; 00279 } 00280 00281 00288 template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B, typename TAG> 00289 matrix<SCALARTYPE, F2, ALIGNMENT_B> solve(const matrix_expression< const matrix<SCALARTYPE, F1, ALIGNMENT_A>, 00290 const matrix<SCALARTYPE, F1, ALIGNMENT_A>, 00291 op_trans> & proxy_A, 00292 const matrix_expression< const matrix<SCALARTYPE, F2, ALIGNMENT_B>, 00293 const matrix<SCALARTYPE, F2, ALIGNMENT_B>, 00294 op_trans> & proxy_B, 00295 TAG const & tag) 00296 { 00297 // do an inplace solve on the result vector: 00298 matrix<SCALARTYPE, F2, ALIGNMENT_B> result(proxy_B.lhs().size2(), proxy_B.lhs().size1()); 00299 result = trans(proxy_B.lhs()); 00300 00301 inplace_solve(proxy_A, result, tag); 00302 00303 return result; 00304 } 00305 00312 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT, typename TAG> 00313 vector<SCALARTYPE, VEC_ALIGNMENT> solve(const matrix_expression< const matrix<SCALARTYPE, F, ALIGNMENT>, 00314 const matrix<SCALARTYPE, F, ALIGNMENT>, 00315 op_trans> & proxy, 00316 const vector<SCALARTYPE, VEC_ALIGNMENT> & vec, 00317 TAG const & tag) 00318 { 00319 // do an inplace solve on the result vector: 00320 vector<SCALARTYPE, VEC_ALIGNMENT> result(vec.size()); 00321 result = vec; 00322 00323 inplace_solve(proxy, result, tag); 00324 00325 return result; 00326 } 00327 00328 00330 00334 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT> 00335 void lu_factorize(matrix<SCALARTYPE, F, ALIGNMENT> & mat) 00336 { 00337 assert(mat.size1() == mat.size2()); 00338 00339 typedef typename viennacl::tools::MATRIX_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F, ALIGNMENT> >::ResultType KernelClass; 00340 00341 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), "lu_factorize"); 00342 00343 k.global_work_size(0, k.local_work_size()); 00344 viennacl::ocl::enqueue(k(mat, mat.size1(), mat.size2(), mat.internal_size1(), mat.internal_size2())); 00345 } 00346 00347 00353 template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B> 00354 void lu_substitute(matrix<SCALARTYPE, F1, ALIGNMENT_A> const & A, 00355 matrix<SCALARTYPE, F2, ALIGNMENT_B> & B) 00356 { 00357 assert(A.size1() == A.size2()); 00358 assert(A.size1() == A.size2()); 00359 inplace_solve(A, B, unit_lower_tag()); 00360 inplace_solve(A, B, upper_tag()); 00361 } 00362 00368 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT> 00369 void lu_substitute(matrix<SCALARTYPE, F, ALIGNMENT> const & mat, 00370 vector<SCALARTYPE, VEC_ALIGNMENT> & vec) 00371 { 00372 assert(mat.size1() == mat.size2()); 00373 inplace_solve(mat, vec, unit_lower_tag()); 00374 inplace_solve(mat, vec, upper_tag()); 00375 } 00376 00377 } 00378 } 00379 00380 #endif
1.7.6.1