Source code for

import logging
from abc import ABC

import numpy as np
import pandas as pd
from tqdm import tqdm

from asreview.config import DEFAULT_N_INSTANCES
from asreview.config import LABEL_NA
from asreview.models.balance.simple import SimpleBalance
from asreview.models.classifiers import NaiveBayesClassifier
from asreview.models.feature_extraction.tfidf import Tfidf
from asreview.models.query.max import MaxQuery
from asreview.project import open_state
from asreview.settings import ASReviewSettings

[docs]class BaseReview(ABC): """Base class for Systematic Review. Arguments --------- as_data: asreview.ASReviewData The data object which contains the text, labels, etc. project: path-like Path to the project file. model: BaseTrainClassifier Initialized model to fit the data during active learning. See for possible models. query_model: BaseQueryStrategy Initialized model to query new instances for review, such as random sampling or max sampling. See for query models. balance_model: BaseBalance Initialized model to redistribute the training data during the active learning process. They might either resample or undersample specific papers. feature_model: BaseFeatureExtraction Feature extraction model that converts texts and keywords to feature matrices. n_instances: int Number of papers to query at each step in the active learning process. stop_if: int Number of steps/queries to perform. Set to None for no limit. start_idx: numpy.ndarray Start the simulation/review with these indices. They are assumed to be already labeled. Failing to do so might result bad behaviour. """ def __init__( self, as_data, project, model=NaiveBayesClassifier(), query_model=MaxQuery(), balance_model=SimpleBalance(), feature_model=Tfidf(), n_papers=None, n_instances=DEFAULT_N_INSTANCES, stop_if=None, start_idx=[], ): """Initialize the reviewer base class, so that everything is ready to train a new model.""" super(BaseReview, self).__init__() # Set the model. self.classifier = model self.balance_model = balance_model self.query_strategy = query_model self.feature_extraction = feature_model # Set the settings. self.as_data = as_data self.project = project self.n_instances = n_instances self.stop_if = stop_if self.prior_indices = start_idx if n_papers is not None: logging.warning("Argument n_papers is deprecated, ignoring n_papers.") # Get the known labels. self.data_labels = as_data.labels if self.data_labels is None: self.data_labels = np.full(len(as_data), LABEL_NA) with open_state(self.project, read_only=False) as state: # If the state is empty, add the settings. if state.is_empty(): state.settings = self.settings # Add the record table to the state if it is not already there. self.record_table = state.get_record_table() if self.record_table.empty: state.add_record_table(as_data.record_ids) self.record_table = state.get_record_table() # Retrieve feature matrix from the project file or create # one from scratch. try: self.X = self.project.get_feature_matrix( except FileNotFoundError: self.X = self.feature_extraction.fit_transform( as_data.texts, as_data.headings, as_data.bodies, as_data.keywords ) # check if the number of records after the transform equals # the number of records in the dataset if self.X.shape[0] != len(as_data): raise ValueError( "Dataset has {} records while feature " "extractor returns {} records" .format(len(as_data), self.X.shape[0])) self.project.add_feature_matrix(self.X, # Check if the number or records in the feature matrix matches the # length of the dataset. if self.X.shape[0] != len(self.data_labels): raise ValueError("The state file does not correspond to the " "given data file, please use another state " "file or dataset.") # Make sure the priors are labeled. self._label_priors() @property def settings(self): """Get an ASReview settings object""" extra_kwargs = {} if hasattr(self, 'n_prior_included'): extra_kwargs['n_prior_included'] = self.n_prior_included if hasattr(self, 'n_prior_excluded'): extra_kwargs['n_prior_excluded'] = self.n_prior_excluded return ASReviewSettings(,,,, n_instances=self.n_instances, stop_if=self.stop_if, model_param=self.classifier.param, query_param=self.query_strategy.param, balance_param=self.balance_model.param, feature_param=self.feature_extraction.param, **extra_kwargs)
[docs] def review(self): """Do a full review.""" # Label any pending records. with open_state(self.project, read_only=False) as s: pending = s.get_pending() if not pending.empty: self._label(pending) labels_prior = s.get_labels() # progress bars pbar_rel = tqdm( initial=sum(labels_prior), total=sum(self.as_data.labels), desc="Relevant records found" ) pbar_total = tqdm( initial=len(labels_prior), total=len(self.as_data), desc="Records labeled " ) # While the stopping condition has not been met: while not self._stop_review(): # Train a new model. self.train() # Query for new records to label. record_ids = self._query(self.n_instances) # Label the records. labels = self._label(record_ids) # monitor progress here pbar_rel.update(sum(labels)) pbar_total.update(len(labels)) else: # write to state when stopped pbar_rel.close() pbar_total.close() self._write_to_state()
def _label_priors(self): """Make sure the prior records are labeled.""" with open_state(self.project, read_only=False) as state: labeled = state.get_labeled() unlabeled_priors = [x for x in self.prior_indices if x not in labeled['record_id'].to_list()] self._label(unlabeled_priors, prior=True) def _stop_review(self): """Check if the review should be stopped according to stopping rule obtained from the settings. Returns ------- bool If True, the stopping criteria have been met. """ stop = False # Get the pool and labeled. There never should be pending papers here. with open_state(self.project) as state: pool, labeled, _ = state.get_pool_labeled_pending() # if the pool is empty, always stop if pool.empty: stop = True # If stop_if is set to min, stop when all papers in the pool are # irrelevant. if self.stop_if == 'min' and (self.data_labels[pool] == 0).all(): stop = True # Otherwise, stop when reaching stop_if (if provided) elif self.stop_if is not None: with open_state(self.project) as state: training_sets = state.get_training_sets() # There is one query per trained model. We subtract 1 # for the priors. stop_if = len(set(training_sets)) - 1 if stop_if >= self.stop_if: stop = True return stop def _query(self, n): """Query new records to label. Arguments --------- n: int Number of records to query. Returns ------- list List of record_ids of the n top ranked records according to the last ranking saved in the state. """ with open_state(self.project, read_only=False) as s: top_n_records = s.query_top_ranked(n) return top_n_records def _label(self, record_ids, prior=False): """Label queried records uses the known labels in a simulated review. Arguments --------- record_ids: list List of record_ids that will be labeled. prior: bool Whether the records priors or not. """ labels = self.data_labels[record_ids] with open_state(self.project, read_only=False) as s: s.add_labeling_data(record_ids, labels, prior=prior)
[docs] def train(self): """Train a new model on the labeled data.""" # Check if both labels are available. with open_state(self.project) as state: labeled = state.get_labeled() labels = labeled['label'].to_list() training_set = len(labeled) if not (0 in labels and 1 in labels): raise ValueError('Not both labels available. ' 'Stopped training the model') # TODO: Simplify balance model input. # Use the balance model to sample the trainings data. y_sample_input = pd.DataFrame(self.record_table).\ merge(labeled, how='left', on='record_id').\ loc[:, 'label'].\ fillna(LABEL_NA).\ to_numpy() train_idx = np.where(y_sample_input != LABEL_NA)[0] X_train, y_train = self.balance_model.sample( self.X, y_sample_input, train_idx ) # Fit the classifier on the trainings data., y_train) # Use the query strategy to produce a ranking. ranked_record_ids = \ self.query_strategy.query(self.X, classifier=self.classifier) # TODO: Also log the probablities. # Log the ranking in the state. with open_state(self.project, read_only=False) as state: state.add_last_ranking(ranked_record_ids,,,,, training_set)