# Copyright 2019-2025 The ASReview Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = []
import time
from collections import defaultdict
import numpy as np
import pandas as pd
from tqdm import tqdm
from asreview.database.database import open_db
from asreview.metrics import loss
from asreview.metrics import ndcg
from asreview.models.stoppers import LastRelevant
from asreview.models.stoppers import NLabeled
def _propagate_record_info(record_info, groups, return_only_new=False):
"""Propagate record-level information across groups of records.
Each group defines a set of records that must share the same info. If one
record in a group has associated info, that info is propagated to all
other records in the group. If multiple records in the same group have
conflicting info, a ValueError is raised.
Parameters
----------
record_info : list[tuple]
A list of tuples in the form ``(record_id, *info)``, where ``info`` is
any associated data for the record.
groups: list[tuple[int,int]] | None
A list of tuples in the form `(group_id, record_id)`, where `group_id`
identifies the group a record belongs to.
Returns
-------
list of tuple
A list of tuples in the form ``(record_id, *info)`` where the info has
been propagated to all records in the same group.
Raises
------
ValueError
If records within the same group have conflicting info.
"""
record_to_group = {}
group_to_records = defaultdict(list)
for group_id, record_id in groups:
record_to_group[record_id] = group_id
group_to_records[group_id].append(record_id)
group_to_info = defaultdict(set)
for record_id, *info in record_info:
group_id = record_to_group.get(record_id, record_id)
group_to_info[group_id].add(tuple(info))
multivalued_groups = [
{
"group_id": group_id,
"record_ids": group_to_records[group_id],
"info": info_set,
}
for group_id, info_set in group_to_info.items()
if len(info_set) > 1
]
if multivalued_groups:
raise ValueError(
f"All records in the same group should have the same record info: {multivalued_groups}"
)
if return_only_new:
original_record_ids = set(record_id for record_id, *_ in record_info)
output = []
for group_id, info_set in group_to_info.items():
info = next(iter(info_set))
record_ids = group_to_records.get(group_id, [group_id])
if return_only_new:
output += [
(record_id, *info)
for record_id in record_ids
if record_id not in original_record_ids
]
else:
output += [(record_id, *info) for record_id in record_ids]
return output
def _get_name_from_estimator(estimator):
"""Get the name of the estimator.
Parameters
----------
estimator: object
The estimator to get the name from.
Returns
-------
str
The name of the estimator.
"""
if estimator is None:
return None
return estimator.name
def _assert_no_conflicts_in_groups(labels, groups):
"""Ensures that all records within the same group share the same label.
This function checks for label consistency within groups of records.
For each `(group_id, record_id)` pair in `groups`, it verifies that
all records assigned to the same `group_id` have identical labels
according to the `labels` list. If any group contains conflicting
labels, an `AssertionError` is raised.
Parameters
----------
labels : list
A list of labels, where the entry in spot `i` contains the label of the record
with `record_id = i`.
groups : Iterable[tuple[int, int]]
An iterable of `(group_id, record_id)` pairs representing which records belong
to which groups.
Raises
------
AssertionError
If a group contains records with differing labels.
"""
group_to_label = {}
for group_id, record_id in groups:
label = labels[record_id]
if group_id in group_to_label and group_to_label[group_id] != label:
raise AssertionError(f"Group {group_id} contains conflicting labels.")
group_to_label[group_id] = label
[docs]
class Simulate:
"""ASReview simulation class.
The simulation will stop when all records have been labeled or when the number of
steps/queries reaches the stopping.
To seed the simulation, provide the seed to the classifier, query strategy,
feature extraction model, and balance strategy or use a global random seed.
Parameters
----------
fm: numpy.ndarray
The feature matrix to use for the simulation.
labels: numpy.ndarray, pandas.Series, list
The labels to use for the simulation.
classifier: BaseModel
The initialized classifier to use for the simulation.
querier: BaseQueryModel
The initialized query strategy to use for the simulation.
balancer: BaseBalanceModel
The initialized balance strategy to use for the simulation.
feature_extractor: BaseFeatureModel
The initialized feature extraction model to use for the simulation. If None,
the name of the feature extraction model is set to None.
stopper: int, callable
The stopping mechanism to use for the simulation. When stopper is None,
the simulation stops when all relevant records are found. If an integer, the
simulation stops after n queries. A stopper or -1 stops the simulation after
all records have been labeled. If class with .stop() method, the simulation
stops when the callable returns True. Default is None.
skip_transform: bool
If True, the feature matrix is not computed in the simulation. It is assumed
that X is the feature matrix or input to the estimator. Default is False.
groups: list[tuple[int, int]] | None
List of tuples (group_id, record_id). If this is not None, records in the same
group will be labeled at the same time in the simulation.
"""
def __init__(
self,
X,
labels,
cycles,
stopper=None,
skip_transform=False,
print_progress=True,
groups=None,
):
self.X = X
self.labels = labels
self.cycles = cycles
self.stopper = stopper
self.skip_transform = skip_transform
self.print_progress = print_progress
if groups is not None:
try:
_assert_no_conflicts_in_groups(labels, groups)
except AssertionError as e:
raise ValueError(
f"Groups should not contain conflicting labels: {e}"
) from e
self.groups = groups
@property
def _results(self):
if not hasattr(self, "_Simulate__results"):
raise AttributeError("No results. Label records or call review.")
return self._Simulate__results
@_results.setter
def _results(self, value):
self._Simulate__results = value
@property
def _last_ranking(self):
if not hasattr(self, "_Simulate__last_ranking"):
raise AttributeError("No last ranking. Call train or review.")
return self._Simulate__last_ranking
@_last_ranking.setter
def _last_ranking(self, value):
self._Simulate__last_ranking = value
[docs]
def review(self):
"""Start the review process."""
if not hasattr(self, "_results"):
self._results = pd.DataFrame(
columns=[
"record_id",
"label",
"classifier",
"querier",
"balancer",
"feature_extractor",
"training_set",
"time",
"note",
"tags",
"user_id",
]
)
pbar_rel = tqdm(
initial=sum(self._results["label"]) if hasattr(self, "_results") else 0,
total=sum(self.labels),
desc="Relevant records found",
disable=not self.print_progress,
)
pbar_total = tqdm(
initial=len(self._results) if hasattr(self, "_results") else 0,
total=len(self.labels),
desc="Records labeled ",
disable=not self.print_progress,
)
if self.stopper is None:
stopper = LastRelevant()
elif isinstance(self.stopper, int):
stopper = NLabeled(self.stopper)
else:
stopper = self.stopper
cycles = self.cycles if isinstance(self.cycles, list) else [self.cycles]
for cycle in cycles:
# first run the overall simulation until the default stopper is met
while not stopper.stop(self._results, self.labels) and not cycle.stop(
self._results, self.labels
):
# compute the feature matrix for the labeled records if not in
# _X_features cache
if not hasattr(self, "_X_features"):
if not self.skip_transform and cycle.feature_extractor is not None:
self._X_features = cycle.transform(self.X)
elif isinstance(self.X, pd.DataFrame):
self._X_features = self.X.values
else:
self._X_features = self.X
# fit the estimator to the labeled records
if cycle.classifier is not None:
cycle.fit(
self._X_features[self._results["record_id"].values],
self._results["label"].values,
)
# collect the records in the pool
pool_record_ids = np.setdiff1d(
np.arange(len(self.labels)), self._results["record_id"].values
)
# rank the pool and convert the ranked pool to record ids
ranked_pool = cycle.rank(self._X_features[pool_record_ids])
ranked_pool_record_ids = pool_record_ids[ranked_pool]
# label n_query records from the pool
n_query = cycle.get_n_query(self._results, self.labels)
if not isinstance(n_query, int) or n_query < 1:
raise ValueError(
f"Number of records to query should be an integer "
f"greater than 0, got {n_query}."
)
labeled = self.label(ranked_pool_record_ids[:n_query], cycle=cycle)
pbar_rel.update(labeled["label"].sum())
pbar_total.update(n_query)
else:
if hasattr(self, "_X_features"):
del self._X_features
else:
pbar_rel.close()
pbar_total.close()
padded_results = list(
self._results.dropna(axis=0, subset="training_set")["label"]
) + [0] * (len(self.labels) - len(self._results["label"]))
if self.print_progress:
try:
print(
f"\nLoss: {loss(padded_results):.3f}\nNDCG: {ndcg(padded_results):.3f}"
)
except ValueError:
print(
"Can't compute loss and gain for labels with only relevant or irrelevant records"
)
[docs]
def label(self, record_ids, cycle=None):
"""Label the records with the given record_ids.
Parameters
----------
record_ids: list
The record ids to label.
"""
if cycle is None:
classifier = None
querier = None
balancer = None
feature_extractor = None
training_set = None
else:
classifier = _get_name_from_estimator(cycle.classifier)
querier = _get_name_from_estimator(cycle.querier)
balancer = _get_name_from_estimator(cycle.balancer)
feature_extractor = _get_name_from_estimator(cycle.feature_extractor)
training_set = len(self._results)
new_results = pd.DataFrame(
{
"record_id": record_ids,
"label": pd.Series(self.labels).iloc[record_ids],
"classifier": classifier,
"querier": querier,
"balancer": balancer,
"feature_extractor": feature_extractor,
"training_set": training_set,
"time": time.time(),
"note": None,
"tags": None,
"user_id": None,
}
)
if self.groups is not None:
record_info = list(new_results.itertuples(index=False, name=None))
group_record_info = _propagate_record_info(
record_info=record_info,
groups=self.groups,
return_only_new=True,
)
if group_record_info:
new_results = pd.concat(
[
new_results,
pd.DataFrame(group_record_info, columns=new_results.columns),
],
ignore_index=True,
)
if not hasattr(self, "_results") or self._results.empty:
self._results = new_results
else:
self._results = pd.concat(
[self._results, new_results],
ignore_index=True,
)
return new_results
[docs]
def to_sql(self, fp):
"""Write the data a sql file.
Parameters
----------
fp: str, Path
The path to the sqlite file to write the results to. If there is no database
yet at the location a new database will be created.
"""
with open_db(fp) as db:
db._replace_results_from_df(self._results)