51#ifdef DEBUG_FUSED_TENSOR
56#ifdef DEBUG_FUSED_TENSOR
65 : data_(
move(other.data_))
67#ifdef DEBUG_FUSED_TENSOR
72#ifdef DEBUG_FUSED_TENSOR
79 template <
typename Output>
84 if constexpr (is_same_v<remove_cvref_t<Output>,
FusedTensorND>)
86 return this == &output;
94 template <
typename Expr>
97#ifdef DEBUG_FUSED_TENSOR
100 const auto &e = expr.
derived();
102 if (e.may_alias(*
this))
108 if constexpr (
NumDims != Expr::NumDims)
112 if constexpr (!dims_match<NumDims>(
Dim, Expr::Dim))
131 template <
typename T_, my_
size_t Bits,
typename Arch>
135 return K::load(data_.
data() + flat);
146 template <
typename T_, my_
size_t Bits,
typename Arch>
152 if constexpr (K::simdWidth == 1)
154 return K::load(data_.
data() +
160 for (
my_size_t i = 0; i < K::simdWidth; ++i)
162 return K::gather(data_.
data(), idxList);
168#ifdef DEBUG_FUSED_TENSOR
173#ifdef DEBUG_FUSED_TENSOR
187#ifdef DEBUG_FUSED_TENSOR
192#ifdef DEBUG_FUSED_TENSOR
199 data_ =
move(other.data_);
204 template <
typename... Indices>
205 requires(
sizeof...(Indices) ==
NumDims)
213 template <
typename... Indices>
214 requires(
sizeof...(Indices) ==
NumDims)
261 static constexpr my_size_t total_combinations = (1 * ... * Dims);
262 my_size_t combinations[total_combinations][
sizeof...(Dims)];
263 static constexpr my_size_t max_vals[
sizeof...(Dims)] = {Dims...};
264 generate_combinations(max_vals, combinations);
266 for (
my_size_t i = 0; i < total_combinations; ++i)
270 bool isElementDiagonal =
true;
273 if (combinations[i][j] != combinations[i][0])
275 isElementDiagonal =
false;
280 if (isElementDiagonal)
314 static_assert(
sizeof...(Dims) == 2,
"Transpose is only supported for 2D tensors");
336 std::string shape =
"(";
339 shape += std::to_string(
getDim(i));
365 std::mt19937 rng(
static_cast<unsigned int>(std::time(
nullptr)));
367 if constexpr (std::is_floating_point<T>::value)
369 std::uniform_real_distribution<T> dist(_minRand, _maxRand);
375 std::uniform_int_distribution<T> dist(_minRand, _maxRand);
385 static_assert(
sizeof...(Dims) >= 2,
"setDiagonal requires at least 2 dimensions.");
410 static_assert(
sizeof...(Dims) >= 2,
"Identity requires at least 2 dimensions.");
411 static_assert(
all_equal<Dims...>(),
"All dimensions must be equal for an identity tensor");
438 template <my_
size_t DiagonalSize>
441 static_assert(
sizeof...(Dims) >= 2,
"Getting diagonal entries requires at least 2 dimensions.");
481 template <
typename LeftExpr,
typename RightExpr>
490 static constexpr my_size_t Dims1 = LeftExpr::NumDims;
491 static constexpr my_size_t Dims2 = RightExpr::NumDims;
493 static_assert(Dims1 >= 2,
"Tensor 1 must have at least 2 dimensions");
494 static_assert(Dims2 >= 2,
"Tensor 2 must have at least 2 dimensions");
497 if (a >= Dims1 || b >= Dims2)
503 using Layout1 =
typename LeftExpr::Layout;
504 using Layout2 =
typename RightExpr::Layout;
505 using OutputLayout =
Layout;
509 const my_size_t contract_stride1 = Layout1::stride(a);
510 const my_size_t contract_stride2 = Layout2::stride(b);
516 static constexpr my_size_t n_newDims = Dims1 + Dims2 - 2;
522 out_dims[d++] = _tensor1.
derived().getDim(i);
525 out_dims[d++] = _tensor2.
derived().getDim(i);
528 for (
my_size_t i = 0; i < n_newDims; ++i)
530 if (out_dims[i] != _outp.
getDim(i))
538 if constexpr (Dims1 == 2 && Dims2 == 2)
543 auto run_gemm = [&](
const auto &A_ready,
const auto &B_ready)
545 using LayoutA =
typename std::remove_cvref_t<
decltype(A_ready)>
::Layout;
546 using LayoutB =
typename std::remove_cvref_t<
decltype(B_ready)>
::Layout;
552 A_ready.data(), M, K_len, LayoutA::stride(0),
553 B_ready.data(), N, LayoutB::stride(0),
554 _outp.
data(), OutputLayout::stride(0));
557 auto make_transposed = [](
const auto &expr)
560 if constexpr (!
requires { expr.transpose(); })
571 return expr.transpose();
577 auto &base = expr.transpose();
584 auto ensure_materialized = [](
const auto &expr)
587 if constexpr (!
requires { expr.transpose(); })
595 return expr.transpose();
606 if (a == 1 && b == 0)
608 if constexpr (
requires { _tensor1.
derived().transpose(); } ||
requires { _tensor2.
derived().transpose(); })
612 run_gemm(ensure_materialized(_tensor1.
derived()),
613 ensure_materialized(_tensor2.
derived()));
621 else if (a == 0 && b == 0)
624 auto A_t = make_transposed(_tensor1.
derived());
625 run_gemm(A_t, _tensor2.
derived());
627 else if (a == 1 && b == 1)
630 auto B_t = make_transposed(_tensor2.
derived());
631 run_gemm(_tensor1.
derived(), B_t);
636 auto A_t = make_transposed(_tensor1.
derived());
637 auto B_t = make_transposed(_tensor2.
derived());
656 strides1_map[d] = Layout1::stride(i);
666 strides2_map[d] = Layout2::stride(i);
672 for (
my_size_t i = 0; i < n_newDims; ++i)
673 out_strides[i] = OutputLayout::stride(i);
675 T *out_ptr = _outp.
data();
677 static constexpr my_size_t total_elements = (1 * ... * Dims);
679 for (
my_size_t flat = 0; flat < total_elements; ++flat)
685 coords[i] = tmp % out_dims[i];
692 for (
my_size_t i = 0; i < n_newDims; ++i)
694 base1 += coords[i] * strides1_map[i];
695 base2 += coords[i] * strides2_map[i];
696 out_phys += coords[i] * out_strides[i];
699 out_ptr[out_phys] = Kern::dot(
700 _tensor1.
derived(), base1, contract_stride1,
701 _tensor2.
derived(), base2, contract_stride2,
709 void print(
bool with_padding =
false)
const
737 const my_size_t physColDim = Layout::PadPolicyType::PhysicalDims.at(ND - 1);
739 for (
my_size_t s = 0; s < numSlices; ++s)
742 if constexpr (ND > 2)
757 if constexpr (ND >= 2)
770 if (showPadding && physColDim > colDim)
777 for (
my_size_t j = colDim; j < physColDim; ++j)
789 if constexpr (ND > 2)
794 if (coords[d] <
getDim(d))
864 if (this->
getDim(i) != other.getDim(i))
871 template <my_
size_t NumDims, my_
size_t M>
872 [[deprecated]]
static void print_combinations(
const my_size_t (&combinations)[M][
NumDims])
887 template <my_
size_t NumDims, my_
size_t M>
897 combinations[row][i] = combination[i];
913 while (position >= 0)
915 ++combination[position];
916 if (combination[position] < max_values[position])
920 combination[position] = 0;
927 [[deprecated]]
void print1D()
const
938 [[deprecated]]
void print2D(
bool with_padding)
const
942 ? AccessPolicy::PadPolicy::PhysicalDims[1]
959 if (with_padding && j ==
getDim(1) - 1)
968 [[deprecated]]
void print3D()
const
985 [[deprecated]]
void print4D()
const
const Derived & derived() const
Definition BaseExpr.h:17
Dense storage access with padding policy.
Definition dense_access.h:20
static constexpr my_size_t PhysicalSize
Definition dense_access.h:24
FORCE_INLINE constexpr T * data() noexcept
Definition dense_access.h:52
static void log(const T &msg, ErrorLevel level=ErrorLevel::Plain)
Definition error_handler.h:18
static void error(const T &msg)
Definition error_handler.h:30
Definition fused_tensor.h:31
FusedTensorND & setSequencial(void)
Definition fused_tensor.h:428
void print_access_policy_info() const
Definition fused_tensor.h:826
T & operator()(my_size_t(&indices)[NumDims]) TESSERACT_CONDITIONAL_NOEXCEPT
Definition fused_tensor.h:235
const T & operator()(my_size_t(&indices)[NumDims]) const TESSERACT_CONDITIONAL_NOEXCEPT
Definition fused_tensor.h:240
void getDiagonalEntries(FusedTensorND< T, DiagonalSize, 1 > &diagonalEntries) const
Definition fused_tensor.h:439
FusedTensorND(FusedTensorND &&other) noexcept
Definition fused_tensor.h:64
FusedTensorND & setIdentity(void)
Definition fused_tensor.h:408
FORCE_INLINE auto transpose_view() const noexcept
Definition fused_tensor.h:304
FusedTensorND & setDiagonal(T _val)
Definition fused_tensor.h:383
FORCE_INLINE auto transpose_view(void) const noexcept
Definition fused_tensor.h:310
FusedTensorND & setToZero(void) noexcept
Definition fused_tensor.h:347
StridedLayoutConstExpr< typename AccessPolicy::PadPolicy > Layout
Definition fused_tensor.h:1022
FusedTensorND< T, Dims... > Self
Definition fused_tensor.h:38
Microkernel< T_, Bits, Arch >::VecType evalu(my_size_t flat) const noexcept
Definition fused_tensor.h:132
static FORCE_INLINE constexpr my_size_t getNumDims() noexcept
Definition fused_tensor.h:328
std::string getShape() const
Definition fused_tensor.h:334
FusedTensorND(const FusedTensorND &other) noexcept
Definition fused_tensor.h:48
FusedTensorND & setHomogen(T _val) noexcept
Definition fused_tensor.h:355
FusedTensorND & operator=(const BaseExpr< Expr > &expr)
Definition fused_tensor.h:95
const T & operator()(const my_size_t *indices) const TESSERACT_CONDITIONAL_NOEXCEPT
Definition fused_tensor.h:228
static constexpr bool areDimsEqual()
Definition fused_tensor.h:246
FusedTensorND & operator=(const FusedTensorND &other) noexcept
Definition fused_tensor.h:166
void printND(bool showPadding=false) const
Print tensor of arbitrary dimensions.
Definition fused_tensor.h:724
static constexpr my_size_t Dim[]
Definition fused_tensor.h:35
FORCE_INLINE constexpr T * data() noexcept
Definition fused_tensor.h:1020
bool isIdentity() const
Definition fused_tensor.h:251
static constexpr my_size_t NumDims
Definition fused_tensor.h:34
T & operator()(const my_size_t *indices) TESSERACT_CONDITIONAL_NOEXCEPT
Definition fused_tensor.h:222
static FORCE_INLINE constexpr my_size_t getDim(my_size_t i) TESSERACT_CONDITIONAL_NOEXCEPT
Definition fused_tensor.h:804
static FORCE_INLINE constexpr my_size_t getStride(my_size_t i) TESSERACT_CONDITIONAL_NOEXCEPT
Definition fused_tensor.h:809
void print(bool with_padding=false) const
Definition fused_tensor.h:709
FusedTensorND() noexcept=default
T value_type
Definition fused_tensor.h:37
FORCE_INLINE Microkernel< T_, Bits, Arch >::VecType logical_evalu(my_size_t logical_flat) const noexcept
Evaluate at a LOGICAL flat index.
Definition fused_tensor.h:148
static constexpr my_size_t TotalSize
Definition fused_tensor.h:36
FusedTensorND & operator=(FusedTensorND &&other) noexcept
Definition fused_tensor.h:185
static FORCE_INLINE constexpr my_size_t getTotalSize() noexcept
Definition fused_tensor.h:323
FusedTensorND & setRandom(T _maxRand, T _minRand)
Definition fused_tensor.h:363
bool may_alias(const Output &output) const noexcept
Definition fused_tensor.h:80
void printLayoutInfo() const
Definition fused_tensor.h:815
FORCE_INLINE constexpr const T * data() const noexcept
Definition fused_tensor.h:1019
static FusedTensorND einsum(const BaseExpr< LeftExpr > &_tensor1, const BaseExpr< RightExpr > &_tensor2, const my_size_t a, const my_size_t b)
Contract two tensors along specified axes using SIMD dot products.
Definition fused_tensor.h:484
void print_flat_data() const
Definition fused_tensor.h:835
friend class PermutedViewConstExpr
Definition fused_tensor.h:1016
Compile-time permuted view over a tensor.
Definition permuted_view_constexpr.h:36
Definition static_storage.h:9
Global configuration for the tesseract tensor library.
#define my_size_t
Size/index type used throughout the library.
Definition config.h:126
#define TESSERACT_CONDITIONAL_NOEXCEPT
Definition config.h:56
#define PRECISION_TOLERANCE
Tolerance for floating-point comparisons (e.g. symmetry checks, Cholesky).
Definition config.h:117
#define FORCE_INLINE
Hint the compiler to always inline a function.
Definition config.h:26
consteval bool all_equal()
Check if all values in a parameter pack are equal.
Definition helper_traits.h:20
consteval my_size_t min_value()
Compile-time minimum of a non-type parameter pack.
Definition helper_traits.h:68
Façade for higher-level kernel operations built on top of microkernels.
STL-free memory utilities.
SimdPaddingPolicyBase< T, Microkernel< T, BITS, DefaultArch >::simdWidth, Dims... > SimdPaddingPolicy
Definition simd_padding_policy.h:349
typename remove_cvref< T >::type remove_cvref_t
Alias template for remove_cvref.
Definition simple_type_traits.h:169
constexpr remove_reference_t< T > && move(T &&t) noexcept
Cast to rvalue reference (replacement for std::move).
Definition simple_type_traits.h:178
Definition kernel_ops.h:28
static FORCE_INLINE void eval(T *output, const Expr &expr) noexcept
Evaluation: Dispatch: pick contiguous or permuted eval based on expression layout.
Definition kernel_ops.h:41
Definition microkernel_base.h:16
T VecType
Definition microkernel_base.h:18
Compile-time strided layout with optional permutation.
Definition strided_layout_constexpr.h:38
static constexpr my_size_t NumDims
Definition strided_layout_constexpr.h:39
static constexpr my_size_t PhysicalSize
Definition strided_layout_constexpr.h:41
static FORCE_INLINE constexpr my_size_t logical_flat_to_physical_flat(my_size_t logical_flat) TESSERACT_CONDITIONAL_NOEXCEPT
Definition strided_layout_constexpr.h:290
static FORCE_INLINE constexpr my_size_t stride(my_size_t i) TESSERACT_CONDITIONAL_NOEXCEPT
Get physical stride at dimension i (with permutation applied).
Definition strided_layout_constexpr.h:263
static FORCE_INLINE constexpr my_size_t base_stride(my_size_t i) TESSERACT_CONDITIONAL_NOEXCEPT
Get base stride at dimension i (unpermuted, for physical decomposition).
Definition strided_layout_constexpr.h:251
static FORCE_INLINE constexpr my_size_t logical_coords_to_physical_flat(const my_size_t(&indices)[NumDims]) TESSERACT_CONDITIONAL_NOEXCEPT
Logical coordinates (Array multi-index) to physical flat index (bounds-checked).
Definition strided_layout_constexpr.h:328
static FORCE_INLINE constexpr my_size_t logical_dim(my_size_t i) TESSERACT_CONDITIONAL_NOEXCEPT
Get logical dimension at index i (with permutation applied).
Definition strided_layout_constexpr.h:239
static void gemm(const T *A, my_size_t M, my_size_t K_len, my_size_t strideA, const T *B, my_size_t N, my_size_t strideB, T *C, my_size_t strideC) noexcept
Register-blocked GEMM: C[M,N] = A[M,K] × B[K,N].
Definition kernel_gemm.h:147
Definition basic_expr_traits.h:6