Source code for asreview.database.store

import functools
import sqlite3
from collections import defaultdict
from uuid import uuid4

import numpy as np
import pandas as pd
from sqlalchemy import NullPool
from sqlalchemy import create_engine
from sqlalchemy import event
from sqlalchemy import select
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker

from asreview.data.record import Base
from asreview.data.record import Record

CURRENT_DATASTORE_VERSION = 0

# SQLite max SQL variables limit (since 3.32.0, 2020).
# See: https://www.sqlite.org/limits.html#max_variable_number
SQLITE_MAX_VARIABLE_NUMBER = 32766


def _batched_in(query_fn, values, batch_size=SQLITE_MAX_VARIABLE_NUMBER):
    """Execute a query in batches to avoid exceeding SQLite's variable limit.

    Parameters
    ----------
    query_fn : callable
        A function that takes a batch of values and returns a list of results.
        For example: ``lambda batch: session.query(Record).filter(
        Record.record_id.in_(batch)).all()``.
    values : list
        The full list of values to pass to the IN clause.
    batch_size : int, optional
        Maximum number of values per batch.

    Returns
    -------
    list
        Concatenated results from all batches.
    """
    results = []
    for i in range(0, len(values), batch_size):
        batch = values[i : i + batch_size]
        results.extend(query_fn(batch))
    return results


def normalize_duplicate_chain(session, record: Record):
    """Normalize the duplicate chain of a record.

    We consider records to be in a group when they point to each other with the
    `duplicate_of` column. We want to make it easy to query the groups by making sure
    that all records in a group point to the same root record. We call the records
    pointing to each other with the `duplicate_of` field the duplicate chain of the
    record. So we want to avoid duplicate chains of length more than 2 and we want to
    avoid circular duplicate chains.

    For example, if `r1`, `r2` and `r3` are in a group we want to have:
    ```
    r1.duplicate_of = r3
    r2.duplicate_of = r3
    r3.duplicate_of = None
    ```
    and not
    ```
    r1.duplicate_of = r2
    r2.duplicate_of = r3
    r3.duplicate_of = None
    ```
    or even `r3.duplicate_of = r1`. We also avoid things like `r1.duplicate_of = r2` and
    `r2.duplicate_of = 1` or even `r1.duplicate_of = r1`.

    Parameters
    ----------
    session : sqlalchemy.Session
        Database session.
    record : Record
        Record for which to normalize the duplicate chain.

    Raises
    ------
    ValueError
        If `record.duplicate_of` contains a non-existent record id.
    """
    current = record
    record_chain = [current]

    while current.duplicate_of is not None:
        next_record = session.get(Record, current.duplicate_of)
        if next_record is None:
            raise ValueError(f"Invalid duplicate_of reference: {current.duplicate_of}")
        if next_record in record_chain:
            # cycle detected, set the record with the minimal record_id as root.
            min_id = min(r.record_id for r in record_chain)
            for r in record_chain:
                r.duplicate_of = min_id if r.record_id != min_id else None
            return
        record_chain.append(next_record)
        current = next_record

    if len(record_chain) > 2:
        root = record_chain[-1]
        for r in record_chain[:-1]:
            r.duplicate_of = root.record_id


# Hook that ensures that record duplicate chains get normalized before flushing the
# record to the database.
@event.listens_for(Session, "before_flush")
def flatten_duplicate_of(session, flush_context, instances):
    record_mutations = [
        obj for obj in session.new.union(session.dirty) if isinstance(obj, Base)
    ]
    if any(record.duplicate_of is not None for record in record_mutations):
        for record in sorted(record_mutations, key=lambda record: record.record_id):
            normalize_duplicate_chain(session, record)


def unwrap_operational_errors(func):
    """Decorator to unwrap SQLAlchemy OperationalError to sqlite3.OperationalError.

    When in read only mode, SQLAlchemy will raise sqlalchemy.exc.OperationalError if you
    try to write to the database. This is a wrapper around the original
    sqlite3.OperationalError. Since we are combining both sqlite3 and sqlalchemy in the
    Database object, we unpack the sqlalchemy errors so that all readonly errors are of
    the same type."""

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except OperationalError as e:
            raise e.orig from e

    return wrapper


def _build_conn_uri(fp, read_only=False):
    """Build a SQLite connection URI.

    For in-memory databases a unique name is generated so that multiple connections
    can share the same database via `cache=shared`, while separate calls to this
    function produce isolated databases.

    Parameters
    ----------
    fp : str | Path
        File path or `":memory:"`.
    read_only : bool, optional
        Open in read-only mode.

    Returns
    -------
    str
        A SQLite URI suitable for `sqlite3.connect(uri, uri=True)`.
    """
    params = {}
    if fp == ":memory:":
        fp = uuid4().hex
        params["cache"] = "shared"
        params["mode"] = "memory"
    if read_only:
        params["mode"] = "ro"
    param_str = "&".join(f"{k}={v}" for k, v in params.items())
    uri = f"file:{fp}"
    if param_str:
        uri = uri + "?" + param_str
    return uri


[docs] class DataStore: """Data store to hold user input data. Data input always happens via the record class. This means that if you want to add data to the data store, you will first need to clean it, make sure it has the correct columns and make sure it passes the validations defined in the record class. Getting data from the store can happen in rows or in columns. If you read rows, you will get record objects as response. If you read columns, you will get pandas objects. If you ask for a single column you get a pandas Series, and if you ask for multiple columns you get a pandas DataFrame. DataStore uses an SQLite database in the backend and SQLAlchemy ORM to interact with the database.""" def __init__( self, fp=":memory:", record_cls=Record, read_only=False, conn_uri=None ): """Initialize the data store. Parameters ---------- fp : str | Path Location of the database file. If `fp == ":memory:"`, the data store will be in memory. record_cls : asreview.data.record.Base, optional The record class to use. The record class specifies which fields each record can have, field validation and more properties of the database. See `asreview.data.record`. By default uses `asreview.data.record.Record`. read_only : bool, optional Whether to open the database in read only mode. If the database is opened in read only mode and an attempt to write to the database is made, an `sqlite3.OperationalError` will be raised. conn_uri : str | None, optional A pre-built SQLite connection URI. When provided, `fp` and `read_only` are ignored for URI construction and this URI is used directly. This is useful when embedding the DataStore inside a `Database` that already owns the connection URI. """ if conn_uri is None and fp == ":memory:" and read_only: raise ValueError("Can't open an in-memory database in read only mode") self.fp = fp self.read_only = read_only self._conn_uri = ( conn_uri if conn_uri is not None else _build_conn_uri(fp, read_only) ) self._in_memory = conn_uri is not None or fp == ":memory:" # I'm using NullPool here, indicating that the engine should not use a # connection pool, but just create and dispose of a connection every time a # request comes. This makes it very easy dispose of the engine, but is less # efficient. self.engine = create_engine( "sqlite://", creator=lambda: sqlite3.connect(self._conn_uri, uri=True), poolclass=NullPool, ) # I put expire_on_commit=False, so that after you put records in the database, # you can still use them in your code without having access to the database. # The downside is that if you use the record after committing it to the database # and another mutation happens to the database, your record might be out of # date. See https://docs.sqlalchemy.org/en/20/orm/session_api.html#sqlalchemy.orm.Session.params.expire_on_commit self.Session = sessionmaker(self.engine, expire_on_commit=False) self.record_cls = record_cls self._columns = self.record_cls.get_columns() self._pandas_dtype_mapping = self.record_cls.get_pandas_dtype_mapping() @property def columns(self): return self._columns @property def pandas_dtype_mapping(self): """Mapping {column name: pandas data type}""" return self._pandas_dtype_mapping
[docs] @unwrap_operational_errors def create_tables(self): """Initialize the tables containing the data. If you are creating a new data store, you will need to call this method before adding data to the data store.""" Base.metadata.create_all(self.engine)
[docs] @unwrap_operational_errors def add_records(self, records): """Add records to the data store. Parameters ---------- records : list[self.record_cls] List of records to add to the store. Raises ------ ValueError If some `record.duplicate_of` points to a non-existing record_id. """ # SQLite makes an autoincremented primary key column start at 1. We want it to # start at 0, so that the record_id is equal to the row number of the record in # feature matrix. By making sure that the first record has record_id 0, we force # the autoincremented column to start at 0. if self.is_empty(): records[0].record_id = 0 with self.Session() as session, session.begin(): session.add_all(records)
[docs] @unwrap_operational_errors def delete_record(self, record_id): """Delete a record from the store. WARNING: This method is purely here for completeness, it should not be used in any production setting. Deleting records can lead to undefined behavior because we make assumptions about the record_id in other parts of the code. """ with self.Session() as session, session.begin(): record = session.get(self.record_cls, record_id) if record is None: raise ValueError( f"DataStore does not contain a record with record_id {record_id}" ) session.delete(record)
def __len__(self): with self.Session() as session: return session.query(self.record_cls).count() def __getitem__(self, item): # We allow a string or a list of strings as input. If the input is a string we # return that column as a pandas series. If the input is a list of strings we # return a pandas DataFrame containing those columns. This way the output you # get is the same if you do __getitem__ on a DataStore instance or on a pandas # DataFrame containing the same data. if isinstance(item, str): columns = [item] else: columns = item # Always order by record_id. Without ORDER BY, SQLite can return rows in # different orders across separate single-column queries (depending on which # index the planner picks). Code that positionally joins two such queries, # e.g. `record_id` and `included`, would then line up labels with the wrong # records. table = self.record_cls.__tablename__ select_cols = ", ".join(f'"{c}"' for c in columns) dtype = {c: t for c, t in self.pandas_dtype_mapping.items() if c in columns} with self.engine.connect() as con: df = pd.read_sql_query( f"SELECT {select_cols} FROM {table} ORDER BY record_id", con, dtype=dtype, ) if isinstance(item, str): return df[item] else: return df def __contains__(self, item): return item in self.columns
[docs] def is_empty(self): with self.Session() as session: return session.query(self.record_cls).first() is None
[docs] def get_records(self, record_id=None): """Get the records with the given record identifiers. Parameters ---------- record_id : int | list[int] | None Record identifier or list record identifiers. If None, get all records. Returns ------- asreview.data.record.Record | list[asreview.data.record.Record] | None """ if isinstance(record_id, np.integer): record_id = record_id.item() with self.Session() as session: if record_id is None: return session.query(self.record_cls).all() elif isinstance(record_id, int): return ( session.query(self.record_cls) .filter(self.record_cls.record_id == record_id) .first() ) else: records = _batched_in( lambda batch: ( session.query(self.record_cls) .filter(self.record_cls.record_id.in_(batch)) .all() ), record_id, ) record_id_to_position = {id: i for i, id in enumerate(record_id)} return sorted(records, key=lambda r: record_id_to_position[r.record_id])
[docs] def set_groups(self, groups): """Add record group information to the data store. Parameters ---------- groups : list[tuple[int,int]] List of tuples (group_id, record_id). This data is added to the record as the `duplicate_of` attribute. The data store will normalize these values: One record is chosen as the root, satisfying `root.duplicate_of = None`. All other records in the group will get `record.duplicate_of = root.record_id`. """ group_to_records = defaultdict(set) for group_id, record_id in groups: group_to_records[group_id].add(record_id) record_to_group = {} for group in group_to_records.values(): group_id = min(group) for record_id in group: record_to_group[record_id] = group_id with self.Session() as session, session.begin(): record_ids = list(record_to_group.keys()) records = _batched_in( lambda batch: session.scalars( select(Record).where(Record.record_id.in_(batch)) ).all(), record_ids, ) for record in records: record.duplicate_of = record_to_group[record.record_id]
[docs] def get_groups(self, record_id=None): """Get the record groups. Parameters ---------- record_id : int | None Get only the group containing the record with this record_id. Returns ------- list[tuple[int, int]] List of tuples (group_id, record_id) ordered by group id. The tuples values are also accessible by the attribute names (so `tuple.group_id` and `tuple.record_id`). """ stmt = select( Record.group_id, Record.record_id, ) if record_id is not None: # Get the records in the group of the record with the given record_id. target_group_subq = ( select(Record.group_id) .where(Record.record_id == record_id) .scalar_subquery() ) stmt = stmt.where(Record.group_id == target_group_subq) with self.Session() as session: return session.execute(stmt.order_by("group_id")).all()
[docs] def get_df(self): """Get all data from the data store as a pandas DataFrmae. Returns ------- pd.DataFrame """ with self.engine.connect() as con: return pd.read_sql_query( f"SELECT * FROM {self.record_cls.__tablename__} ORDER BY record_id", con, dtype=self.pandas_dtype_mapping, )