Source code for hottbox.utils.validation.checks
import numpy as np
from hottbox.core.structures import Tensor
import itertools
[docs]def is_toeplitz_matrix(mat):
""" Checks if ``matrix`` has a Toeplitz structure
Parameters
----------
mat : np.ndarray
Input array to check
Returns
-------
Boolean indicating if Toeplitz matrix
"""
n, m = mat.shape
# Horizontal diagonals
for off in range(1, m):
if np.ptp(np.diagonal(mat, offset=off)):
return False
# Vertical diagonals
for off in range(1, n):
if np.ptp(np.diagonal(mat, offset=-off)):
return False
# we only reach here when all elements
# in given diagonal are same
return True
# Currently recursive, TODO: improve efficiency
[docs]def is_toeplitz_tensor(tensor, modes=None):
""" Checks if ``tensor`` has Toeplitz structure
Parameters
----------
tensor : Tensor
Input tensor to check
Returns
-------
Boolean indicating if Toeplitz matrix
"""
if tensor.order <= 2:
return is_toeplitz_matrix(tensor.data)
if modes is None:
modes = [0, 1]
sz = np.asarray(tensor.shape)
availmodes = np.setdiff1d(np.arange(len(sz)), modes)
for idx, mode in enumerate(availmodes):
dim = sz[mode]
# Go through each dim
for i in range(dim):
t = tensor.access(i, mode)
t = Tensor(t)
if not(is_toeplitz_tensor(t)):
print("Wrong slice: \n{}\n{}".format(t, (i, idx)))
return False
return True
[docs]def is_super_symmetric(tensor):
""" Checks if ``tensor`` has supers-symmetric structure
Parameters
----------
tensor : Tensor
Input tensor to check
Returns
-------
Boolean indicating if super-symmetric tensor
"""
tensor = tensor.data
idx = np.arange(len(tensor.shape))
inds = itertools.permutations(idx)
for i in inds:
s = np.transpose(tensor, np.array(i))
if not np.allclose(tensor, s, atol=1e-4, equal_nan=True):
print("{} \n is not the same as \n {}".format(tensor, s))
return False
return True