tesseract++ 0.0.1
N-dimensional tensor library for embedded systems
Loading...
Searching...
No Matches
triangular_solve.h
Go to the documentation of this file.
1#ifndef FUSED_ALGORITHMS_TRIANGULAR_SOLVE_H
2#define FUSED_ALGORITHMS_TRIANGULAR_SOLVE_H
3
4#include "config.h"
6#include "matrix_traits.h"
8#include "math/math_utils.h"
9
49namespace matrix_algorithms
50{
51
53
54 // ========================================================================
55 // Forward substitution: solve Lx = b (L lower-triangular, single RHS)
56 // ========================================================================
57
73 template <bool UnitDiag = false, typename T, my_size_t N>
75 const FusedMatrix<T, N, N> &L,
76 const FusedVector<T, N> &b)
77 {
78 static_assert(is_floating_point_v<T>,
79 "forward_substitute requires a floating-point scalar type");
80
81 FusedVector<T, N> x(T(0));
82
83 // --- Fixed-size fully unrolled paths ---
84 if constexpr (!UnitDiag && N == 3)
85 {
86 if (math::abs(L(0, 0)) <= T(PRECISION_TOLERANCE))
87 return Unexpected{MatrixStatus::Singular};
88 x(0) = b(0) / L(0, 0);
89
90 if (math::abs(L(1, 1)) <= T(PRECISION_TOLERANCE))
91 return Unexpected{MatrixStatus::Singular};
92 x(1) = (b(1) - L(1, 0) * x(0)) / L(1, 1);
93
94 if (math::abs(L(2, 2)) <= T(PRECISION_TOLERANCE))
95 return Unexpected{MatrixStatus::Singular};
96 x(2) = (b(2) - L(2, 0) * x(0) - L(2, 1) * x(1)) / L(2, 2);
97 }
98 else if constexpr (UnitDiag && N == 3)
99 {
100 x(0) = b(0);
101 x(1) = b(1) - L(1, 0) * x(0);
102 x(2) = b(2) - L(2, 0) * x(0) - L(2, 1) * x(1);
103 }
104 else if constexpr (!UnitDiag && N == 4)
105 {
106 if (math::abs(L(0, 0)) <= T(PRECISION_TOLERANCE))
107 return Unexpected{MatrixStatus::Singular};
108 x(0) = b(0) / L(0, 0);
109
110 if (math::abs(L(1, 1)) <= T(PRECISION_TOLERANCE))
111 return Unexpected{MatrixStatus::Singular};
112 x(1) = (b(1) - L(1, 0) * x(0)) / L(1, 1);
113
114 if (math::abs(L(2, 2)) <= T(PRECISION_TOLERANCE))
115 return Unexpected{MatrixStatus::Singular};
116 x(2) = (b(2) - L(2, 0) * x(0) - L(2, 1) * x(1)) / L(2, 2);
117
118 if (math::abs(L(3, 3)) <= T(PRECISION_TOLERANCE))
119 return Unexpected{MatrixStatus::Singular};
120 x(3) = (b(3) - L(3, 0) * x(0) - L(3, 1) * x(1) - L(3, 2) * x(2)) / L(3, 3);
121 }
122 else if constexpr (UnitDiag && N == 4)
123 {
124 x(0) = b(0);
125 x(1) = b(1) - L(1, 0) * x(0);
126 x(2) = b(2) - L(2, 0) * x(0) - L(2, 1) * x(1);
127 x(3) = b(3) - L(3, 0) * x(0) - L(3, 1) * x(1) - L(3, 2) * x(2);
128 }
129 else if constexpr (!UnitDiag && N == 6)
130 {
131 if (math::abs(L(0, 0)) <= T(PRECISION_TOLERANCE))
132 return Unexpected{MatrixStatus::Singular};
133 x(0) = b(0) / L(0, 0);
134
135 if (math::abs(L(1, 1)) <= T(PRECISION_TOLERANCE))
136 return Unexpected{MatrixStatus::Singular};
137 x(1) = (b(1) - L(1, 0) * x(0)) / L(1, 1);
138
139 if (math::abs(L(2, 2)) <= T(PRECISION_TOLERANCE))
140 return Unexpected{MatrixStatus::Singular};
141 x(2) = (b(2) - L(2, 0) * x(0) - L(2, 1) * x(1)) / L(2, 2);
142
143 if (math::abs(L(3, 3)) <= T(PRECISION_TOLERANCE))
144 return Unexpected{MatrixStatus::Singular};
145 x(3) = (b(3) - L(3, 0) * x(0) - L(3, 1) * x(1) - L(3, 2) * x(2)) / L(3, 3);
146
147 if (math::abs(L(4, 4)) <= T(PRECISION_TOLERANCE))
148 return Unexpected{MatrixStatus::Singular};
149 x(4) = (b(4) - L(4, 0) * x(0) - L(4, 1) * x(1) - L(4, 2) * x(2) - L(4, 3) * x(3)) / L(4, 4);
150
151 if (math::abs(L(5, 5)) <= T(PRECISION_TOLERANCE))
152 return Unexpected{MatrixStatus::Singular};
153 x(5) = (b(5) - L(5, 0) * x(0) - L(5, 1) * x(1) - L(5, 2) * x(2) - L(5, 3) * x(3) - L(5, 4) * x(4)) / L(5, 5);
154 }
155 else if constexpr (UnitDiag && N == 6)
156 {
157 x(0) = b(0);
158 x(1) = b(1) - L(1, 0) * x(0);
159 x(2) = b(2) - L(2, 0) * x(0) - L(2, 1) * x(1);
160 x(3) = b(3) - L(3, 0) * x(0) - L(3, 1) * x(1) - L(3, 2) * x(2);
161 x(4) = b(4) - L(4, 0) * x(0) - L(4, 1) * x(1) - L(4, 2) * x(2) - L(4, 3) * x(3);
162 x(5) = b(5) - L(5, 0) * x(0) - L(5, 1) * x(1) - L(5, 2) * x(2) - L(5, 3) * x(3) - L(5, 4) * x(4);
163 }
164 // --- Generic scalar path ---
165 else
166 {
167 for (my_size_t i = 0; i < N; ++i)
168 {
169 T sum = b(i);
170
171 for (my_size_t k = 0; k < i; ++k)
172 {
173 sum -= L(i, k) * x(k);
174 }
175
176 if constexpr (UnitDiag)
177 {
178 x(i) = sum;
179 }
180 else
181 {
182 T diag = L(i, i);
183
184 if (math::abs(diag) <= T(PRECISION_TOLERANCE))
185 {
186 return Unexpected{MatrixStatus::Singular};
187 }
188
189 x(i) = sum / diag;
190 }
191 }
192 }
193
194 return move(x);
195 }
196
197 // ========================================================================
198 // Back substitution: solve Ux = b (U upper-triangular, single RHS)
199 // ========================================================================
200
215 template <bool UnitDiag = false, typename T, my_size_t N>
217 const FusedMatrix<T, N, N> &U,
218 const FusedVector<T, N> &b)
219 {
220 static_assert(is_floating_point_v<T>,
221 "back_substitute requires a floating-point scalar type");
222
223 FusedVector<T, N> x(T(0));
224
225 // --- Fixed-size fully unrolled paths ---
226 if constexpr (!UnitDiag && N == 3)
227 {
228 if (math::abs(U(2, 2)) <= T(PRECISION_TOLERANCE))
229 return Unexpected{MatrixStatus::Singular};
230 x(2) = b(2) / U(2, 2);
231
232 if (math::abs(U(1, 1)) <= T(PRECISION_TOLERANCE))
233 return Unexpected{MatrixStatus::Singular};
234 x(1) = (b(1) - U(1, 2) * x(2)) / U(1, 1);
235
236 if (math::abs(U(0, 0)) <= T(PRECISION_TOLERANCE))
237 return Unexpected{MatrixStatus::Singular};
238 x(0) = (b(0) - U(0, 1) * x(1) - U(0, 2) * x(2)) / U(0, 0);
239 }
240 else if constexpr (UnitDiag && N == 3)
241 {
242 x(2) = b(2);
243 x(1) = b(1) - U(1, 2) * x(2);
244 x(0) = b(0) - U(0, 1) * x(1) - U(0, 2) * x(2);
245 }
246 else if constexpr (!UnitDiag && N == 4)
247 {
248 if (math::abs(U(3, 3)) <= T(PRECISION_TOLERANCE))
249 return Unexpected{MatrixStatus::Singular};
250 x(3) = b(3) / U(3, 3);
251
252 if (math::abs(U(2, 2)) <= T(PRECISION_TOLERANCE))
253 return Unexpected{MatrixStatus::Singular};
254 x(2) = (b(2) - U(2, 3) * x(3)) / U(2, 2);
255
256 if (math::abs(U(1, 1)) <= T(PRECISION_TOLERANCE))
257 return Unexpected{MatrixStatus::Singular};
258 x(1) = (b(1) - U(1, 2) * x(2) - U(1, 3) * x(3)) / U(1, 1);
259
260 if (math::abs(U(0, 0)) <= T(PRECISION_TOLERANCE))
261 return Unexpected{MatrixStatus::Singular};
262 x(0) = (b(0) - U(0, 1) * x(1) - U(0, 2) * x(2) - U(0, 3) * x(3)) / U(0, 0);
263 }
264 else if constexpr (UnitDiag && N == 4)
265 {
266 x(3) = b(3);
267 x(2) = b(2) - U(2, 3) * x(3);
268 x(1) = b(1) - U(1, 2) * x(2) - U(1, 3) * x(3);
269 x(0) = b(0) - U(0, 1) * x(1) - U(0, 2) * x(2) - U(0, 3) * x(3);
270 }
271 else if constexpr (!UnitDiag && N == 6)
272 {
273 if (math::abs(U(5, 5)) <= T(PRECISION_TOLERANCE))
274 return Unexpected{MatrixStatus::Singular};
275 x(5) = b(5) / U(5, 5);
276
277 if (math::abs(U(4, 4)) <= T(PRECISION_TOLERANCE))
278 return Unexpected{MatrixStatus::Singular};
279 x(4) = (b(4) - U(4, 5) * x(5)) / U(4, 4);
280
281 if (math::abs(U(3, 3)) <= T(PRECISION_TOLERANCE))
282 return Unexpected{MatrixStatus::Singular};
283 x(3) = (b(3) - U(3, 4) * x(4) - U(3, 5) * x(5)) / U(3, 3);
284
285 if (math::abs(U(2, 2)) <= T(PRECISION_TOLERANCE))
286 return Unexpected{MatrixStatus::Singular};
287 x(2) = (b(2) - U(2, 3) * x(3) - U(2, 4) * x(4) - U(2, 5) * x(5)) / U(2, 2);
288
289 if (math::abs(U(1, 1)) <= T(PRECISION_TOLERANCE))
290 return Unexpected{MatrixStatus::Singular};
291 x(1) = (b(1) - U(1, 2) * x(2) - U(1, 3) * x(3) - U(1, 4) * x(4) - U(1, 5) * x(5)) / U(1, 1);
292
293 if (math::abs(U(0, 0)) <= T(PRECISION_TOLERANCE))
294 return Unexpected{MatrixStatus::Singular};
295 x(0) = (b(0) - U(0, 1) * x(1) - U(0, 2) * x(2) - U(0, 3) * x(3) - U(0, 4) * x(4) - U(0, 5) * x(5)) / U(0, 0);
296 }
297 else if constexpr (UnitDiag && N == 6)
298 {
299 x(5) = b(5);
300 x(4) = b(4) - U(4, 5) * x(5);
301 x(3) = b(3) - U(3, 4) * x(4) - U(3, 5) * x(5);
302 x(2) = b(2) - U(2, 3) * x(3) - U(2, 4) * x(4) - U(2, 5) * x(5);
303 x(1) = b(1) - U(1, 2) * x(2) - U(1, 3) * x(3) - U(1, 4) * x(4) - U(1, 5) * x(5);
304 x(0) = b(0) - U(0, 1) * x(1) - U(0, 2) * x(2) - U(0, 3) * x(3) - U(0, 4) * x(4) - U(0, 5) * x(5);
305 }
306 // --- Generic scalar path ---
307 else
308 {
309 for (my_size_t i = N; i-- > 0;)
310 {
311 T sum = b(i);
312
313 for (my_size_t k = i + 1; k < N; ++k)
314 {
315 sum -= U(i, k) * x(k);
316 }
317
318 if constexpr (UnitDiag)
319 {
320 x(i) = sum;
321 }
322 else
323 {
324 T diag = U(i, i);
325
326 if (math::abs(diag) <= T(PRECISION_TOLERANCE))
327 {
328 return Unexpected{MatrixStatus::Singular};
329 }
330
331 x(i) = sum / diag;
332 }
333 }
334 }
335
336 return move(x);
337 }
338
339 // ========================================================================
340 // Multi-RHS: solve LX = B (B is FusedMatrix, column-by-column)
341 // ========================================================================
342
357 template <bool UnitDiag = false, typename T, my_size_t N, my_size_t Ncols>
359 const FusedMatrix<T, N, N> &L,
361 {
362 static_assert(is_floating_point_v<T>,
363 "forward_substitute requires a floating-point scalar type");
364
366
367 for (my_size_t j = 0; j < Ncols; ++j)
368 {
369 for (my_size_t i = 0; i < N; ++i)
370 {
371 T sum = B(i, j);
372
373 for (my_size_t k = 0; k < i; ++k)
374 {
375 sum -= L(i, k) * X(k, j);
376 }
377
378 if constexpr (UnitDiag)
379 {
380 X(i, j) = sum;
381 }
382 else
383 {
384 T diag = L(i, i);
385
386 if (math::abs(diag) <= T(PRECISION_TOLERANCE))
387 {
388 return Unexpected{MatrixStatus::Singular};
389 }
390
391 X(i, j) = sum / diag;
392 }
393 }
394 }
395
396 return move(X);
397 }
398
399 // ========================================================================
400 // Multi-RHS: solve UX = B (B is FusedMatrix, column-by-column)
401 // ========================================================================
402
417 template <bool UnitDiag = false, typename T, my_size_t N, my_size_t Ncols>
419 const FusedMatrix<T, N, N> &U,
421 {
422 static_assert(is_floating_point_v<T>,
423 "back_substitute requires a floating-point scalar type");
424
426
427 for (my_size_t j = 0; j < Ncols; ++j)
428 {
429 for (my_size_t i = N; i-- > 0;)
430 {
431 T sum = B(i, j);
432
433 for (my_size_t k = i + 1; k < N; ++k)
434 {
435 sum -= U(i, k) * X(k, j);
436 }
437
438 if constexpr (UnitDiag)
439 {
440 X(i, j) = sum;
441 }
442 else
443 {
444 T diag = U(i, i);
445
446 if (math::abs(diag) <= T(PRECISION_TOLERANCE))
447 {
448 return Unexpected{MatrixStatus::Singular};
449 }
450
451 X(i, j) = sum / diag;
452 }
453 }
454 }
455
456 return move(X);
457 }
458
459} // namespace matrix_algorithms
460
461#endif // FUSED_ALGORITHMS_TRIANGULAR_SOLVE_H
A discriminated union holding either a success value or an error.
Definition expected.h:86
Definition fused_matrix.h:12
Definition fused_vector.h:9
Global configuration for the tesseract tensor library.
#define my_size_t
Size/index type used throughout the library.
Definition config.h:126
#define PRECISION_TOLERANCE
Tolerance for floating-point comparisons (e.g. symmetry checks, Cholesky).
Definition config.h:117
A minimal, STL-free expected/result type for failable operations.
Runtime property descriptors and error codes for matrices.
constexpr T abs(T x) noexcept
Compute the absolute value of a numeric value.
Definition math_utils.h:48
Definition cholesky.h:52
Expected< FusedVector< T, N >, MatrixStatus > back_substitute(const FusedMatrix< T, N, N > &U, const FusedVector< T, N > &b)
Solve the upper-triangular system Ux = b by back substitution.
Definition triangular_solve.h:216
Expected< FusedVector< T, N >, MatrixStatus > forward_substitute(const FusedMatrix< T, N, N > &L, const FusedVector< T, N > &b)
Solve the lower-triangular system Lx = b by forward substitution.
Definition triangular_solve.h:74
MatrixStatus
Error codes for matrix decomposition and solver algorithms.
Definition matrix_traits.h:33
Expr::value_type sum(const BaseExpr< Expr > &expr)
Definition reductions.h:30
constexpr remove_reference_t< T > && move(T &&t) noexcept
Cast to rvalue reference (replacement for std::move).
Definition simple_type_traits.h:178
Tag type for constructing an Expected in the error state.
Definition expected.h:30