[docs]classAssoOpt(Asso):'''The Asso algorithm with exhaustive search over each row of U. This implementation may be slow, but is able to deal with larger number of factors `k` or dimension of `X_train`. .. topic:: Reference The discrete basis problem. Zhang et al. 2007. '''def__init__(self,model,w_fp=1,w_fn=1):self.check_params(model=model,w_fp=w_fp,w_fn=w_fn)
[docs]defcheck_params(self,**kwargs):super().check_params(**kwargs)# model to importif'model'inkwargs:model=kwargs.get('model')self.import_model(k=model.k,U=model.U,V=model.V,logs=model.logs)
[docs]def_fit(self):'''Using exhaustive search to refine U. '''tic=time.perf_counter()# with Pool() as pool:# pool.map(self.set_optimal_row, range(self.m))results=p_map(self.set_optimal_row,range(self.m))toc=time.perf_counter()print("[I] Exhaustive search finished in {}s.".format(toc-tic))foriinrange(self.m):self.U[i]=int2bin(results[i],self.k)self.X_pd=get_prediction(U=self.U,V=self.V,boolean=True)score=coverage_score(gt=self.X_train,pd=self.X_pd,w=self.w)self.evaluate(df_name='refinements',train_info={'score':score})
[docs]defset_optimal_row(self,i):'''Update the i-th row in U. '''trials=2**self.kscores=np.zeros(trials)X_gt=self.X_train[i,:]forjinrange(trials):U=int2bin(j,self.k)X_pd=matmul(U,self.V.T,sparse=True,boolean=True)scores[j]=coverage_score(gt=X_gt,pd=X_pd,w_fn=self.w_fn,w_fp=self.w_fp)idx=np.argmax(scores)returnidx
[docs]defint2bin(i,bits):'''Turn `i` into (1, `bits`) binary sparse matrix. '''returncsr_matrix(list(bin(i)[2:].zfill(bits)),dtype=int)