TT Cross
Implementation of (DMRG) TT-cross algorithm like in the Matlab TT-Toolbox.
The (DMRG) TT-cross algorithm is an efficient method to approximate black-box functions with tensor trains. For this library, its main usage is to initialize the Tensor Train used in the TTML estimator. The algorithm can however be used for many other purposes as well.
For example, below we use TT-DMRG cross to approximate the sine function in 5 dimensions:
>>> from ttml.tt_cross import estimator_to_tt_cross
... import numpy as np
...
... def f(x):
... return np.sin(np.sum(x, axis=1))
...
... # Same thresholds for every feature: [0,0.2,0.4,0.6,0.8]
... thresholds = [np.linspace(0, 1, 11)] * 5
...
... tt = estimator_to_tt_cross(f, thresholds)
... tt.gather(np.array([[2, 2, 2, 2, 2]])) - np.sin(1)
array([3.33066907e-16])
Note here that the index (2, 2, 2, 2, 2) corresponds to the point
(0.2, 0.2, 0.2, 0.2, 0.2). It turns out that this function admits a low-rank
tensor train decomposition. In fact, we can look at the singular values of this
tensor train:
>>> tt.sing_vals()
[array([2.62706082e+02, 1.10057236e+02, 4.55274516e-14, 3.68166789e-14, 3.26467439e-14]),
array([2.50489642e+02, 1.35580309e+02, 4.13492807e-14, 1.73287140e-14, 6.94636648e-15]),
array([2.50489642e+02, 1.35580309e+02, 1.39498787e-14, 1.02576584e-14, 6.98685871e-15]),
array([2.62706082e+02, 1.10057236e+02, 2.02285923e-14, 1.07720967e-14, 9.56667996e-15])]
We see that effectively the tensor train has rank 2. We can control the rank
by setting the max_rank keyword argument in estimator_to_tt_cross():
>>> tt2 = estimator_to_tt_cross(f, thresholds, max_rank=2)
... tt2
<TensorTrain of order 5 with outer dimensions (11, 11, 11, 11, 11),
TT-rank (2, 2, 2, 2), and orthogonalized at mode 4>
And indeed, tt2 is very close to tt:
>>> (tt - tt2).norm()
4.731228127074942e-13
We can do this for any function, and indeed for TTML we use a machine
learning estimator instead of the function f above. For example we can use
this to obtain a tensor train approximating a random forest’s predict()
method.
>>> from sklearn.ensemble import RandomForestRegressor
...
... X = np.random.normal(size=(1000, 5))
... y = np.exp(np.sum(X, axis=1))
...
... forest = RandomForestRegressor()
... forest.fit(X, y)
...
... thresholds = [np.linspace(0, 1, 11)] * 5
...
... tt = estimator_to_tt_cross(forest.predict, thresholds, max_rank=2)
... tt
<TensorTrain of order 5 with outer dimensions (11, 11, 11, 11, 11),
TT-rank (2, 2, 2, 2), and orthogonalized at mode 4>
We implemented two versions of the TT-cross approximation algorithm. A ‘dmrg’ and a ‘regular’ version. The default is the ‘dmrg’ version, and it optimizes the TT one core at a time in alternating left-to-right and right-to-left sweeps. The ‘dmrg’ version optimizes two cores at the same time. The latter approach is more costly numerically speaking, but has the potential ability to estimate the rank of the underlying TT automatically (although this is not an implemented feature). The DMRG algorithm also converges faster, and tends to result in a better final test error. Therefore the DMRG is the default, despite the fact that it is slower. We can control which version is used through the method kwarg:
>>> tt = estimator_to_tt_cross(forest.predict, thresholds, method='dmrg')
- ttml.tt_cross.estimator_to_tt_cross(predict_method, thresholds, max_rank=5, tol=0.01, max_its=5, method='regular', use_cache=False, verbose=False)[source]
Use TT-cross to convert an estimator into a TT
- Parameters
predict_method (function) – function mapping data X onto truth labels y used for training
thresholds (list[np.ndarray]) – List of thresholds to use for each feature. Should be a list of arrays, one array per feature. The last element of each array is expected to be np.inf.
max_rank (int (default: 5)) – Maximum rank for the tensor train
tol (float (default: 1e-8)) – Tolerance for checking convergence for the DMRG algorithm. If maximum local error in an entire sweep is smaller than tol, we stop early.
max_its (int (default: 10)) – Number of (half) sweeps to perform
method (: str (default: "dmrg")) – Whether to use “regular” or “dmrg” tt-cross algorithm
use_cache (bool (default: False)) – Whether to cache function calls (Experimental, needs better implementation)
verbose (bool (default: False)) – If True, print convergence information after every half sweep.
- ttml.tt_cross.index_function_wrapper(fun)[source]
Modify a multi-index function to accept multi-dimensional arrays of multi-indices.
- ttml.tt_cross.index_function_wrapper_with_cache(fun)[source]
Modify a multi-index function to accept multi-dimensional arrays of multi-indices.
- ttml.tt_cross.maxvol(A, eps=0.01, niters=100)[source]
Quasi-max volume submatrix
Initializes with pivot indices of LU decomposition, then greedily interchanges rows.
- ttml.tt_cross.tt_cross_dmrg(tt, index_fun, tol=0.001, max_its=10, verbose=False, inplace=True)[source]
Implements DMRG TT-Cross algorithm
Recovers a tensor-train from a function mapping indices to numbers. The function index_fun should accept arbitrary multidimensional arrays of indices, with last axis the same shape as the number of dimensions. You can use index_function_wrapper to convert a function to this form.
- Parameters
tt (TensorTrain) –
index_fun (function) –
tol (float (default: 1e-3)) – Tolerance for convergence. The algorithm is stopped if after a half- sweep the maximum difference in the half-sweep between any cross-sampled supercore and supercore of the TT is less than tol.
max_its (int (default: 5)) –
verbose (bool (default: False)) –
inplace (bool (default: True)) –
- Returns
tt
- Return type
- ttml.tt_cross.tt_cross_regular(tt, index_fun, tol=0.01, max_its=10, verbose=False, inplace=True)[source]
Implements DMRG TT-Cross algorithm
Recovers a tensor-train from a function mapping indices to numbers. The function index_fun should accept arbitrary multidimensional arrays of indices, with last axis the same shape as the number of dimensions. You can use index_function_wrapper to convert a function to this form.
- Parameters
tt (TensorTrain) –
index_fun (function) –
tol (float (default: 1e-8)) – Tolerance for convergence. The algorithm is stopped if after a half- sweep the maximum difference in the half-sweep between any cross-sampled supercore and supercore of the TT is less than tol.
max_its (int (default: 5)) –
verbose (bool (default: False)) –
inplace (bool (default: True)) –
- Returns
tt
- Return type