37 #ifndef VIGRA_LINEAR_SOLVE_HXX
38 #define VIGRA_LINEAR_SOLVE_HXX
42 #include "mathutil.hxx"
44 #include "singular_value_decomposition.hxx"
55 template <
class T,
class C1>
56 T determinantByLUDecomposition(MultiArrayView<2, T, C1>
const & a)
61 vigra_precondition(n == m,
62 "determinant(): square matrix required.");
74 LU(i,j) = LU(i,j) -= s;
99 T givensCoefficients(T a, T b, T & c, T & s)
127 bool givensRotationMatrix(T a, T b, Matrix<T> & gTranspose)
131 givensCoefficients(a, b, gTranspose(0,0), gTranspose(0,1));
132 gTranspose(1,1) = gTranspose(0,0);
133 gTranspose(1,0) = -gTranspose(0,1);
141 givensReflectionMatrix(T a, T b, Matrix<T> & g)
145 givensCoefficients(a, b, g(0,0), g(0,1));
152 template <
class T,
class C1,
class C2>
154 qrGivensStepImpl(
MultiArrayIndex i, MultiArrayView<2, T, C1> r, MultiArrayView<2, T, C2> rhs)
156 typedef typename Matrix<T>::difference_type Shape;
161 vigra_precondition(m ==
rowCount(rhs),
162 "qrGivensStepImpl(): Matrix shape mismatch.");
164 Matrix<T> givens(2,2);
165 for(
int k=m-1; k>
static_cast<int>(i); --k)
167 if(!givensReflectionMatrix(r(k-1,i), r(k,i), givens))
170 r(k-1,i) = givens(0,0)*r(k-1,i) + givens(0,1)*r(k,i);
173 r.subarray(Shape(k-1,i+1), Shape(k+1,n)) = givens*r.subarray(Shape(k-1,i+1), Shape(k+1,n));
174 rhs.subarray(Shape(k-1,0), Shape(k+1,rhsCount)) = givens*rhs.subarray(Shape(k-1,0), Shape(k+1,rhsCount));
176 return r(i,i) != 0.0;
180 template <
class T,
class C1,
class C2,
class Permutation>
183 MultiArrayView<2, T, C1> &r, MultiArrayView<2, T, C2> &rhs, Permutation & permutation)
185 typedef typename Matrix<T>::difference_type Shape;
190 vigra_precondition(i < n && j < n,
191 "upperTriangularCyclicShiftColumns(): Shift indices out of range.");
192 vigra_precondition(m ==
rowCount(rhs),
193 "upperTriangularCyclicShiftColumns(): Matrix shape mismatch.");
205 permutation[k] = permutation[k+1];
210 Matrix<T> givens(2,2);
213 if(!givensReflectionMatrix(r(k,k), r(k+1,k), givens))
216 r(k,k) = givens(0,0)*r(k,k) + givens(0,1)*r(k+1,k);
219 r.subarray(Shape(k,k+1), Shape(k+2,n)) = givens*r.subarray(Shape(k,k+1), Shape(k+2,n));
220 rhs.subarray(Shape(k,0), Shape(k+2,rhsCount)) = givens*rhs.subarray(Shape(k,0), Shape(k+2,rhsCount));
225 template <
class T,
class C1,
class C2,
class Permutation>
228 MultiArrayView<2, T, C1> &r, MultiArrayView<2, T, C2> &rhs, Permutation & permutation)
230 typedef typename Matrix<T>::difference_type Shape;
235 vigra_precondition(i < n && j < n,
236 "upperTriangularSwapColumns(): Swap indices out of range.");
237 vigra_precondition(m ==
rowCount(rhs),
238 "upperTriangularSwapColumns(): Matrix shape mismatch.");
246 std::swap(permutation[i], permutation[j]);
248 Matrix<T> givens(2,2);
249 for(
int k=m-1; k>
static_cast<int>(i); --k)
251 if(!givensReflectionMatrix(r(k-1,i), r(k,i), givens))
254 r(k-1,i) = givens(0,0)*r(k-1,i) + givens(0,1)*r(k,i);
257 r.subarray(Shape(k-1,i+1), Shape(k+1,n)) = givens*r.subarray(Shape(k-1,i+1), Shape(k+1,n));
258 rhs.subarray(Shape(k-1,0), Shape(k+1,rhsCount)) = givens*rhs.subarray(Shape(k-1,0), Shape(k+1,rhsCount));
263 if(!givensReflectionMatrix(r(k,k), r(k+1,k), givens))
266 r(k,k) = givens(0,0)*r(k,k) + givens(0,1)*r(k+1,k);
269 r.subarray(Shape(k,k+1), Shape(k+2,n)) = givens*r.subarray(Shape(k,k+1), Shape(k+2,n));
270 rhs.subarray(Shape(k,0), Shape(k+2,rhsCount)) = givens*rhs.subarray(Shape(k,0), Shape(k+2,rhsCount));
275 template <
class T,
class C1,
class C2,
class U>
276 bool householderVector(MultiArrayView<2, T, C1>
const & v, MultiArrayView<2, T, C2> & u, U & vnorm)
278 vnorm = (v(0,0) > 0.0)
283 if(f == NumericTraits<U>::zero())
285 u.init(NumericTraits<T>::zero());
290 u(0,0) = (v(0,0) - vnorm) / f;
298 template <
class T,
class C1,
class C2,
class C3>
301 MultiArrayView<2, T, C2> & rhs, MultiArrayView<2, T, C3> & householderMatrix)
303 typedef typename Matrix<T>::difference_type Shape;
309 vigra_precondition(i < n && i < m,
310 "qrHouseholderStepImpl(): Index i out of range.");
314 bool nontrivial = householderVector(
columnVector(r, Shape(i,i), m), u, vnorm);
317 columnVector(r, Shape(i+1,i), m).init(NumericTraits<T>::zero());
329 return r(i,i) != 0.0;
332 template <
class T,
class C1,
class C2>
334 qrColumnHouseholderStep(
MultiArrayIndex i, MultiArrayView<2, T, C1> &r, MultiArrayView<2, T, C2> &rhs)
336 Matrix<T> dontStoreHouseholderVectors;
337 return qrHouseholderStepImpl(i, r, rhs, dontStoreHouseholderVectors);
340 template <
class T,
class C1,
class C2>
342 qrRowHouseholderStep(
MultiArrayIndex i, MultiArrayView<2, T, C1> &r, MultiArrayView<2, T, C2> & householderMatrix)
344 Matrix<T> dontTransformRHS;
345 MultiArrayView<2, T, StridedArrayTag> rt =
transpose(r),
347 return qrHouseholderStepImpl(i, rt, dontTransformRHS, ht);
351 template <
class T,
class C1,
class C2,
class SNType>
353 incrementalMaxSingularValueApproximation(MultiArrayView<2, T, C1>
const & newColumn,
354 MultiArrayView<2, T, C2> & z, SNType & v)
356 typedef typename Matrix<T>::difference_type Shape;
367 z(n,0) = s*newColumn(n,0);
371 template <
class T,
class C1,
class C2,
class SNType>
373 incrementalMinSingularValueApproximation(MultiArrayView<2, T, C1>
const & newColumn,
374 MultiArrayView<2, T, C2> & z, SNType & v,
double tolerance)
376 typedef typename Matrix<T>::difference_type Shape;
386 T
gamma = newColumn(n,0);
400 z(n,0) = (s - c*yv) / gamma;
401 v *=
norm(gamma) /
hypot(c*gamma, v*(s - c*yv));
405 template <
class T,
class C1,
class C2,
class C3>
407 qrTransformToTriangularImpl(MultiArrayView<2, T, C1> & r, MultiArrayView<2, T, C2> & rhs, MultiArrayView<2, T, C3> & householder,
408 ArrayVector<MultiArrayIndex> & permutation,
double epsilon)
410 typedef typename Matrix<T>::difference_type Shape;
411 typedef typename NormTraits<MultiArrayView<2, T, C1> >::NormType NormType;
412 typedef typename NormTraits<MultiArrayView<2, T, C1> >::SquaredNormType SNType;
418 vigra_precondition(m >= n,
419 "qrTransformToTriangularImpl(): Coefficient matrix with at least as many rows as columns required.");
422 bool transformRHS = rhsCount > 0;
423 vigra_precondition(!transformRHS || m ==
rowCount(rhs),
424 "qrTransformToTriangularImpl(): RHS matrix shape mismatch.");
426 bool storeHouseholderSteps =
columnCount(householder) > 0;
427 vigra_precondition(!storeHouseholderSteps || r.shape() == householder.shape(),
428 "qrTransformToTriangularImpl(): Householder matrix shape mismatch.");
430 bool pivoting = permutation.size() > 0;
431 vigra_precondition(!pivoting || n == static_cast<MultiArrayIndex>(permutation.size()),
432 "qrTransformToTriangularImpl(): Permutation array size mismatch.");
437 Matrix<SNType> columnSquaredNorms;
440 columnSquaredNorms.reshape(Shape(1,n));
444 int pivot =
argMax(columnSquaredNorms);
448 std::swap(columnSquaredNorms[0], columnSquaredNorms[pivot]);
449 std::swap(permutation[0], permutation[pivot]);
453 qrHouseholderStepImpl(0, r, rhs, householder);
456 NormType maxApproxSingularValue =
norm(r(0,0)),
457 minApproxSingularValue = maxApproxSingularValue;
459 double tolerance = (epsilon == 0.0)
460 ? m*maxApproxSingularValue*NumericTraits<T>::epsilon()
463 bool simpleSingularValueApproximation = (n < 4);
464 Matrix<T> zmax, zmin;
465 if(minApproxSingularValue <= tolerance)
469 simpleSingularValueApproximation =
true;
471 if(!simpleSingularValueApproximation)
473 zmax.reshape(Shape(m,1));
474 zmin.reshape(Shape(m,1));
476 zmin(0,0) = 1.0 / r(0,0);
486 if(pivot != static_cast<int>(k))
489 std::swap(columnSquaredNorms[k], columnSquaredNorms[pivot]);
490 std::swap(permutation[k], permutation[pivot]);
494 qrHouseholderStepImpl(k, r, rhs, householder);
496 if(simpleSingularValueApproximation)
498 NormType nv =
norm(r(k,k));
499 maxApproxSingularValue = std::max(nv, maxApproxSingularValue);
500 minApproxSingularValue = std::min(nv, minApproxSingularValue);
504 incrementalMaxSingularValueApproximation(
columnVector(r, Shape(0,k),k+1), zmax, maxApproxSingularValue);
505 incrementalMinSingularValueApproximation(
columnVector(r, Shape(0,k),k+1), zmin, minApproxSingularValue, tolerance);
509 Matrix<T> u(k+1,k+1), s(k+1, 1), v(k+1,k+1);
511 std::cerr <<
"estimate, svd " << k <<
": " << minApproxSingularValue <<
" " << s(k,0) <<
"\n";
515 tolerance = m*maxApproxSingularValue*NumericTraits<T>::epsilon();
517 if(minApproxSingularValue > tolerance)
522 return static_cast<unsigned int>(rank);
525 template <
class T,
class C1,
class C2>
527 qrTransformToUpperTriangular(MultiArrayView<2, T, C1> & r, MultiArrayView<2, T, C2> & rhs,
528 ArrayVector<MultiArrayIndex> & permutation,
double epsilon = 0.0)
530 Matrix<T> dontStoreHouseholderVectors;
531 return qrTransformToTriangularImpl(r, rhs, dontStoreHouseholderVectors, permutation, epsilon);
535 template <
class T,
class C1,
class C2,
class C3>
537 qrTransformToLowerTriangular(MultiArrayView<2, T, C1> & r, MultiArrayView<2, T, C2> & rhs, MultiArrayView<2, T, C3> & householderMatrix,
538 double epsilon = 0.0)
540 ArrayVector<MultiArrayIndex> permutation(static_cast<unsigned int>(
rowCount(rhs)));
541 for(
MultiArrayIndex k=0; k<static_cast<MultiArrayIndex>(permutation.size()); ++k)
543 Matrix<T> dontTransformRHS;
544 MultiArrayView<2, T, StridedArrayTag> rt =
transpose(r),
546 unsigned int rank = qrTransformToTriangularImpl(rt, dontTransformRHS, ht, permutation, epsilon);
549 Matrix<T> tempRHS(rhs);
550 for(
MultiArrayIndex k=0; k<static_cast<MultiArrayIndex>(permutation.size()); ++k)
556 template <
class T,
class C1,
class C2>
558 qrTransformToUpperTriangular(MultiArrayView<2, T, C1> & r, MultiArrayView<2, T, C2> & rhs,
559 double epsilon = 0.0)
561 ArrayVector<MultiArrayIndex> noPivoting;
563 return (qrTransformToUpperTriangular(r, rhs, noPivoting, epsilon) ==
568 template <
class T,
class C1,
class C2>
570 qrTransformToLowerTriangular(MultiArrayView<2, T, C1> & r, MultiArrayView<2, T, C2> & householder,
571 double epsilon = 0.0)
573 Matrix<T> noPivoting;
575 return (qrTransformToLowerTriangular(r, noPivoting, householder, epsilon) ==
576 static_cast<unsigned int>(
rowCount(r)));
580 template <
class T,
class C1,
class C2,
class Permutation>
581 void inverseRowPermutation(MultiArrayView<2, T, C1> &permuted, MultiArrayView<2, T, C2> &res,
582 Permutation
const & permutation)
586 res(permutation[l], k) = permuted(l,k);
589 template <
class T,
class C1,
class C2>
590 void applyHouseholderColumnReflections(MultiArrayView<2, T, C1>
const &householder, MultiArrayView<2, T, C2> &res)
592 typedef typename Matrix<T>::difference_type Shape;
597 for(
int k = m-1; k >= 0; --k)
599 MultiArrayView<2, T, C1> u =
columnVector(householder, Shape(k,k), n);
607 template <
class T,
class C1,
class C2,
class C3>
609 linearSolveQRReplace(MultiArrayView<2, T, C1> &A, MultiArrayView<2, T, C2> &b,
610 MultiArrayView<2, T, C3> & res,
611 double epsilon = 0.0)
613 typedef typename Matrix<T>::difference_type Shape;
623 "linearSolveQR(): RHS and solution must have the same number of columns.");
624 vigra_precondition(m ==
rowCount(b),
625 "linearSolveQR(): Coefficient matrix and RHS must have the same number of rows.");
626 vigra_precondition(n ==
rowCount(res),
627 "linearSolveQR(): Mismatch between column count of coefficient matrix and row count of solution.");
628 vigra_precondition(epsilon >= 0.0,
629 "linearSolveQR(): 'epsilon' must be non-negative.");
634 Matrix<T> householderMatrix(n, m);
635 MultiArrayView<2, T, StridedArrayTag> ht =
transpose(householderMatrix);
636 rank =
static_cast<MultiArrayIndex>(detail::qrTransformToLowerTriangular(A, b, ht, epsilon));
637 res.subarray(Shape(rank,0), Shape(n, rhsCount)).init(NumericTraits<T>::zero());
641 MultiArrayView<2, T, C1> Asub = A.subarray(ul, Shape(m,rank));
642 detail::qrTransformToUpperTriangular(Asub, b, epsilon);
644 b.subarray(ul, Shape(rank,rhsCount)),
645 res.subarray(ul, Shape(rank, rhsCount)));
651 b.subarray(ul, Shape(rank, rhsCount)),
652 res.subarray(ul, Shape(rank, rhsCount)));
654 detail::applyHouseholderColumnReflections(householderMatrix.subarray(ul, Shape(n, rank)), res);
659 ArrayVector<MultiArrayIndex> permutation(static_cast<unsigned int>(n));
663 rank =
static_cast<MultiArrayIndex>(detail::qrTransformToUpperTriangular(A, b, permutation, epsilon));
665 Matrix<T> permutedSolution(n, rhsCount);
669 Matrix<T> householderMatrix(n, rank);
670 MultiArrayView<2, T, StridedArrayTag> ht =
transpose(householderMatrix);
671 MultiArrayView<2, T, C1> Asub = A.subarray(ul, Shape(rank,n));
672 detail::qrTransformToLowerTriangular(Asub, ht, epsilon);
674 b.subarray(ul, Shape(rank, rhsCount)),
675 permutedSolution.subarray(ul, Shape(rank, rhsCount)));
676 detail::applyHouseholderColumnReflections(householderMatrix, permutedSolution);
682 b.subarray(ul, Shape(rank,rhsCount)),
685 detail::inverseRowPermutation(permutedSolution, res, permutation);
687 return static_cast<unsigned int>(rank);
690 template <
class T,
class C1,
class C2,
class C3>
691 unsigned int linearSolveQR(MultiArrayView<2, T, C1>
const & A, MultiArrayView<2, T, C2>
const & b,
692 MultiArrayView<2, T, C3> & res)
694 Matrix<T> r(A), rhs(b);
695 return linearSolveQRReplace(r, rhs, res);
719 template <
class T,
class C1,
class C2>
727 "inverse(): shape of output matrix must be the transpose of the input matrix' shape.");
736 transpose(q).subarray(Shape(0,0), Shape(m,n)),
745 transpose(q).subarray(Shape(0,0), Shape(n,m)),
773 template <
class T,
class C>
777 vigra_precondition(
inverse(v, ret),
778 "inverse(): matrix is not invertible.");
783 TemporaryMatrix<T>
inverse(
const TemporaryMatrix<T> &v)
787 vigra_precondition(
inverse(v,
const_cast<TemporaryMatrix<T> &
>(v)),
788 "inverse(): matrix is not invertible.");
794 vigra_precondition(
inverse(v, ret),
795 "inverse(): matrix is not invertible.");
816 template <
class T,
class C1>
820 vigra_precondition(
rowCount(a) == n,
821 "determinant(): Square matrix required.");
828 return a(0,0)*a(1,1) - a(0,1)*a(1,0);
831 return detail::determinantByLUDecomposition(a);
833 else if(method ==
"cholesky")
837 "determinant(): Cholesky method requires symmetric positive definite matrix.");
845 vigra_precondition(
false,
"determinant(): Unknown solution method.");
859 template <
class T,
class C1>
863 vigra_precondition(
rowCount(a) == n,
864 "logDeterminant(): Square matrix required.");
867 vigra_precondition(a(0,0) > 0.0,
868 "logDeterminant(): Matrix not positive definite.");
873 T det = a(0,0)*a(1,1) - a(0,1)*a(1,0);
874 vigra_precondition(det > 0.0,
875 "logDeterminant(): Matrix not positive definite.");
882 "logDeterminant(): Matrix not positive definite.");
907 template <
class T,
class C1,
class C2>
912 vigra_precondition(
rowCount(A) == n,
913 "choleskyDecomposition(): Input matrix must be square.");
915 "choleskyDecomposition(): Output matrix must have same shape as input matrix.");
917 "choleskyDecomposition(): Input matrix must be symmetric.");
927 s += L(k, i)*L(j, i);
929 L(j, k) = s = (A(j, k) - s)/L(k, k);
962 template <
class T,
class C1,
class C2,
class C3>
965 double epsilon = 0.0)
971 "qrDecomposition(): Matrix shape mismatch.");
973 q = identityMatrix<T>(m);
977 return (static_cast<MultiArrayIndex>(detail::qrTransformToUpperTriangular(r, tq, noPivoting, epsilon) == std::min(m,n)));
982 template <
class T,
class C1,
class C2,
class C3>
1014 template <
class T,
class C1,
class C2,
class C3>
1021 "linearSolveUpperTriangular(): square coefficient matrix required.");
1023 "linearSolveUpperTriangular(): matrix shape mismatch.");
1027 for(
int i=m-1; i>=0; --i)
1029 if(r(i,i) == NumericTraits<T>::zero())
1033 sum -= r(i, j) * x(j, k);
1034 x(i, k) = sum / r(i, i);
1064 template <
class T,
class C1,
class C2,
class C3>
1070 vigra_precondition(m ==
rowCount(l),
1071 "linearSolveLowerTriangular(): square coefficient matrix required.");
1073 "linearSolveLowerTriangular(): matrix shape mismatch.");
1079 if(l(i,i) == NumericTraits<T>::zero())
1083 sum -= l(i, j) * x(j, k);
1084 x(i, k) = sum / l(i, i);
1113 template <
class T,
class C1,
class C2,
class C3>
1194 template <
class T,
class C1,
class C2,
class C3>
1198 std::string method =
"QR")
1203 vigra_precondition(n <= m,
1204 "linearSolve(): Coefficient matrix A must have at least as many rows as columns.");
1205 vigra_precondition(n ==
rowCount(res) &&
1207 "linearSolve(): matrix shape mismatch.");
1210 if(method ==
"cholesky")
1213 "linearSolve(): Cholesky method requires square coefficient matrix.");
1214 Matrix<T> L(A.
shape());
1219 else if(method ==
"qr")
1223 else if(method ==
"ne")
1227 else if(method ==
"svd")
1230 Matrix<T> u(A.
shape()), s(n, 1), v(n, n);
1240 t(k,l) = NumericTraits<T>::zero();
1248 vigra_precondition(
false,
"linearSolve(): Unknown solution method.");
1253 template <
class T,
class C1,
int N>
1254 bool linearSolve(MultiArrayView<2, T, C1>
const & A,
1255 TinyVector<T, N>
const & b,
1256 TinyVector<T, N> & res,
1257 std::string method =
"QR")
1260 return linearSolve(A, MultiArrayView<2, T>(shape, b.data()), MultiArrayView<2, T>(shape, res.data()), method);
1280 #endif // VIGRA_LINEAR_SOLVE_HXX