from tqdm import tqdm
from .ContinuousModel import ContinuousModel
from ..utils import to_sparse, multiply, power, ignore_warnings, subtract, check_sparse, add, get_prediction
import numpy as np
from scipy.sparse import lil_matrix
[docs]
class WNMF(ContinuousModel):
'''Weighted Nonnegative Matrix Factorization.
.. topic:: Reference
Weighted Nonnegative Matrix Factorization and Face Feature Extraction.
For scipy implementation:
Projected Gradient Methods for Non-negative Matrix Factorization.
https://github.com/scikit-learn/scikit-learn/blob/a95203b249c1cf392f86d001ad999e29b2392739/sklearn/decomposition/nmf.py#L158
Parameters
----------
U : ndarray, spmatrix
Need to be prepared if ``init_method`` is 'custom'.
V : ndarray, spmatrix
Need to be prepared if ``init_method`` is 'custom'.
'''
def __init__(self, k, U=None, V=None, W='mask', beta_loss='frobenius', init_method='normal', solver='mu', tol=0.0, min_diff=0.0, max_iter=30, seed=None):
self.check_params(k=k, U=U, V=V, W=W, beta_loss=beta_loss, init_method=init_method, solver=solver, tol=tol, min_diff=min_diff, max_iter=max_iter, seed=seed)
[docs]
def check_params(self, **kwargs):
super().check_params(**kwargs)
assert self.beta_loss in ['frobenius', 'kullback-leibler']
assert self.solver in ['mu']
assert self.init_method in ['uniform', 'normal', 'custom']
[docs]
def fit(self, X_train, X_val=None, X_test=None, **kwargs):
'''Fit the model.
'''
super().fit(X_train, X_val, X_test, **kwargs)
self._fit()
self.X_pd = get_prediction(U=self.U, V=self.V, boolean=False)
self.finish(show_logs=self.show_logs, save_model=self.save_model, show_result=self.show_result)
[docs]
def _fit(self):
'''The alternative minimization algorithm.
'''
n_iter = 0
is_improving = True
# compute error
error_old = self.error()
# evaluate
# self.predict_X(boolean=False)
self.X_pd = get_prediction(U=self.U, V=self.V, boolean=False)
self.evaluate(df_name='updates', head_info={'iter': n_iter, 'error': error_old}, metrics=['RMSE', 'MAE'])
pbar = tqdm(total=self.max_iter, desc="[I] error: -")
while is_improving:
# update n_iter, U, V
n_iter += 1
self.update()
# compute error, diff
error_new = self.error()
diff = abs(error_old - error_new)
error_old = error_new
# evaluate
self.X_pd = get_prediction(U=self.U, V=self.V, boolean=False)
self.evaluate(df_name='updates', head_info={'iter': n_iter, 'error': error_new}, metrics=['RMSE', 'MAE'])
# display
if self.verbose and self.display and n_iter % 10 == 0:
self.show_matrix(boolean=False, title=f"iter {n_iter}")
# update pbar
pbar.update(1)
pbar.set_description(f"[I] error: {error_new:.6e}")
# early stop detection
is_improving = self.early_stop(error=error_old, diff=diff, n_iter=n_iter)
@ignore_warnings
def update(self):
'''Multiplicative update.
'''
if self.beta_loss == 'frobenius':
# update V
num = multiply(self.W, self.X_train).T @ self.U
denom = multiply(self.W, self.U @ self.V.T).T @ self.U
denom[denom == 0] = np.finfo(np.float64).eps
self.V = multiply(self.V, num / denom)
# update U
num = multiply(self.W, self.X_train) @ self.V
denom = multiply(self.W, self.U @ self.V.T) @ self.V
denom[denom == 0] = np.finfo(np.float64).eps
self.U = multiply(self.U, num / denom)
elif self.beta_loss == 'kullback-leibler':
WX = multiply(self.W, self.X_train)
O = lil_matrix(np.ones(self.X_train.shape))
# update V
UV = self.U @ self.V.T
num = (WX / UV).T @ self.U
denom = O.T @ self.U
denom[denom == 0] = np.finfo(np.float64).eps
self.V = multiply(self.V, num / denom)
# update U
UV = self.U @ self.V.T
num = (WX / UV) @ self.V
denom = O @ self.V
denom[denom == 0] = np.finfo(np.float64).eps
self.U = multiply(self.U, num / denom)
@ignore_warnings
def error(self):
'''The error function.
'''
X_gt = self.X_train
X_pd = self.U @ self.V.T
X_gt[X_gt == 0] = np.finfo(np.float64).eps
X_pd[X_pd == 0] = np.finfo(np.float64).eps
if self.beta_loss == 'frobenius':
rec_error = 0.5 * np.sum(multiply(self.W, power(X_gt - X_pd, 2)))
error = rec_error
elif self.beta_loss == 'kullback-leibler':
rec_error = np.sum(multiply(self.W, multiply(X_gt, np.log(X_gt / X_pd)) - X_gt + X_pd))
error = rec_error
return error