import functools
import sqlite3
from collections import defaultdict
from uuid import uuid4
import numpy as np
import pandas as pd
from sqlalchemy import NullPool
from sqlalchemy import create_engine
from sqlalchemy import event
from sqlalchemy import select
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from asreview.data.record import Base
from asreview.data.record import Record
CURRENT_DATASTORE_VERSION = 0
# SQLite max SQL variables limit (since 3.32.0, 2020).
# See: https://www.sqlite.org/limits.html#max_variable_number
SQLITE_MAX_VARIABLE_NUMBER = 32766
def _batched_in(query_fn, values, batch_size=SQLITE_MAX_VARIABLE_NUMBER):
"""Execute a query in batches to avoid exceeding SQLite's variable limit.
Parameters
----------
query_fn : callable
A function that takes a batch of values and returns a list of results.
For example: ``lambda batch: session.query(Record).filter(
Record.record_id.in_(batch)).all()``.
values : list
The full list of values to pass to the IN clause.
batch_size : int, optional
Maximum number of values per batch.
Returns
-------
list
Concatenated results from all batches.
"""
results = []
for i in range(0, len(values), batch_size):
batch = values[i : i + batch_size]
results.extend(query_fn(batch))
return results
def normalize_duplicate_chain(session, record: Record):
"""Normalize the duplicate chain of a record.
We consider records to be in a group when they point to each other with the
`duplicate_of` column. We want to make it easy to query the groups by making sure
that all records in a group point to the same root record. We call the records
pointing to each other with the `duplicate_of` field the duplicate chain of the
record. So we want to avoid duplicate chains of length more than 2 and we want to
avoid circular duplicate chains.
For example, if `r1`, `r2` and `r3` are in a group we want to have:
```
r1.duplicate_of = r3
r2.duplicate_of = r3
r3.duplicate_of = None
```
and not
```
r1.duplicate_of = r2
r2.duplicate_of = r3
r3.duplicate_of = None
```
or even `r3.duplicate_of = r1`. We also avoid things like `r1.duplicate_of = r2` and
`r2.duplicate_of = 1` or even `r1.duplicate_of = r1`.
Parameters
----------
session : sqlalchemy.Session
Database session.
record : Record
Record for which to normalize the duplicate chain.
Raises
------
ValueError
If `record.duplicate_of` contains a non-existent record id.
"""
current = record
record_chain = [current]
while current.duplicate_of is not None:
next_record = session.get(Record, current.duplicate_of)
if next_record is None:
raise ValueError(f"Invalid duplicate_of reference: {current.duplicate_of}")
if next_record in record_chain:
# cycle detected, set the record with the minimal record_id as root.
min_id = min(r.record_id for r in record_chain)
for r in record_chain:
r.duplicate_of = min_id if r.record_id != min_id else None
return
record_chain.append(next_record)
current = next_record
if len(record_chain) > 2:
root = record_chain[-1]
for r in record_chain[:-1]:
r.duplicate_of = root.record_id
# Hook that ensures that record duplicate chains get normalized before flushing the
# record to the database.
@event.listens_for(Session, "before_flush")
def flatten_duplicate_of(session, flush_context, instances):
record_mutations = [
obj for obj in session.new.union(session.dirty) if isinstance(obj, Base)
]
if any(record.duplicate_of is not None for record in record_mutations):
for record in sorted(record_mutations, key=lambda record: record.record_id):
normalize_duplicate_chain(session, record)
def unwrap_operational_errors(func):
"""Decorator to unwrap SQLAlchemy OperationalError to sqlite3.OperationalError.
When in read only mode, SQLAlchemy will raise sqlalchemy.exc.OperationalError if you
try to write to the database. This is a wrapper around the original
sqlite3.OperationalError. Since we are combining both sqlite3 and sqlalchemy in the
Database object, we unpack the sqlalchemy errors so that all readonly errors are of
the same type."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except OperationalError as e:
raise e.orig from e
return wrapper
def _build_conn_uri(fp, read_only=False):
"""Build a SQLite connection URI.
For in-memory databases a unique name is generated so that multiple connections
can share the same database via `cache=shared`, while separate calls to this
function produce isolated databases.
Parameters
----------
fp : str | Path
File path or `":memory:"`.
read_only : bool, optional
Open in read-only mode.
Returns
-------
str
A SQLite URI suitable for `sqlite3.connect(uri, uri=True)`.
"""
params = {}
if fp == ":memory:":
fp = uuid4().hex
params["cache"] = "shared"
params["mode"] = "memory"
if read_only:
params["mode"] = "ro"
param_str = "&".join(f"{k}={v}" for k, v in params.items())
uri = f"file:{fp}"
if param_str:
uri = uri + "?" + param_str
return uri
[docs]
class DataStore:
"""Data store to hold user input data.
Data input always happens via the record class. This means that if you want to add
data to the data store, you will first need to clean it, make sure it has the
correct columns and make sure it passes the validations defined in the record class.
Getting data from the store can happen in rows or in columns. If you read rows, you
will get record objects as response. If you read columns, you will get pandas
objects. If you ask for a single column you get a pandas Series, and if you ask for
multiple columns you get a pandas DataFrame.
DataStore uses an SQLite database in the backend and SQLAlchemy ORM to interact with
the database."""
def __init__(
self, fp=":memory:", record_cls=Record, read_only=False, conn_uri=None
):
"""Initialize the data store.
Parameters
----------
fp : str | Path
Location of the database file. If `fp == ":memory:"`, the data store will
be in memory.
record_cls : asreview.data.record.Base, optional
The record class to use. The record class specifies which fields each record
can have, field validation and more properties of the database. See
`asreview.data.record`. By default uses `asreview.data.record.Record`.
read_only : bool, optional
Whether to open the database in read only mode. If the database is opened in
read only mode and an attempt to write to the database is made, an
`sqlite3.OperationalError` will be raised.
conn_uri : str | None, optional
A pre-built SQLite connection URI. When provided, `fp` and `read_only`
are ignored for URI construction and this URI is used directly. This is
useful when embedding the DataStore inside a `Database` that already owns
the connection URI.
"""
if conn_uri is None and fp == ":memory:" and read_only:
raise ValueError("Can't open an in-memory database in read only mode")
self.fp = fp
self.read_only = read_only
self._conn_uri = (
conn_uri if conn_uri is not None else _build_conn_uri(fp, read_only)
)
self._in_memory = conn_uri is not None or fp == ":memory:"
# I'm using NullPool here, indicating that the engine should not use a
# connection pool, but just create and dispose of a connection every time a
# request comes. This makes it very easy dispose of the engine, but is less
# efficient.
self.engine = create_engine(
"sqlite://",
creator=lambda: sqlite3.connect(self._conn_uri, uri=True),
poolclass=NullPool,
)
# I put expire_on_commit=False, so that after you put records in the database,
# you can still use them in your code without having access to the database.
# The downside is that if you use the record after committing it to the database
# and another mutation happens to the database, your record might be out of
# date. See https://docs.sqlalchemy.org/en/20/orm/session_api.html#sqlalchemy.orm.Session.params.expire_on_commit
self.Session = sessionmaker(self.engine, expire_on_commit=False)
self.record_cls = record_cls
self._columns = self.record_cls.get_columns()
self._pandas_dtype_mapping = self.record_cls.get_pandas_dtype_mapping()
@property
def columns(self):
return self._columns
@property
def pandas_dtype_mapping(self):
"""Mapping {column name: pandas data type}"""
return self._pandas_dtype_mapping
[docs]
@unwrap_operational_errors
def create_tables(self):
"""Initialize the tables containing the data.
If you are creating a new data store, you will need to call this method before
adding data to the data store."""
Base.metadata.create_all(self.engine)
[docs]
@unwrap_operational_errors
def add_records(self, records):
"""Add records to the data store.
Parameters
----------
records : list[self.record_cls]
List of records to add to the store.
Raises
------
ValueError
If some `record.duplicate_of` points to a non-existing record_id.
"""
# SQLite makes an autoincremented primary key column start at 1. We want it to
# start at 0, so that the record_id is equal to the row number of the record in
# feature matrix. By making sure that the first record has record_id 0, we force
# the autoincremented column to start at 0.
if self.is_empty():
records[0].record_id = 0
with self.Session() as session, session.begin():
session.add_all(records)
[docs]
@unwrap_operational_errors
def delete_record(self, record_id):
"""Delete a record from the store.
WARNING: This method is purely here for completeness, it should not be used in
any production setting. Deleting records can lead to undefined behavior because
we make assumptions about the record_id in other parts of the code.
"""
with self.Session() as session, session.begin():
record = session.get(self.record_cls, record_id)
if record is None:
raise ValueError(
f"DataStore does not contain a record with record_id {record_id}"
)
session.delete(record)
def __len__(self):
with self.Session() as session:
return session.query(self.record_cls).count()
def __getitem__(self, item):
# We allow a string or a list of strings as input. If the input is a string we
# return that column as a pandas series. If the input is a list of strings we
# return a pandas DataFrame containing those columns. This way the output you
# get is the same if you do __getitem__ on a DataStore instance or on a pandas
# DataFrame containing the same data.
if isinstance(item, str):
columns = [item]
else:
columns = item
# Always order by record_id. Without ORDER BY, SQLite can return rows in
# different orders across separate single-column queries (depending on which
# index the planner picks). Code that positionally joins two such queries,
# e.g. `record_id` and `included`, would then line up labels with the wrong
# records.
table = self.record_cls.__tablename__
select_cols = ", ".join(f'"{c}"' for c in columns)
dtype = {c: t for c, t in self.pandas_dtype_mapping.items() if c in columns}
with self.engine.connect() as con:
df = pd.read_sql_query(
f"SELECT {select_cols} FROM {table} ORDER BY record_id",
con,
dtype=dtype,
)
if isinstance(item, str):
return df[item]
else:
return df
def __contains__(self, item):
return item in self.columns
[docs]
def is_empty(self):
with self.Session() as session:
return session.query(self.record_cls).first() is None
[docs]
def get_records(self, record_id=None):
"""Get the records with the given record identifiers.
Parameters
----------
record_id : int | list[int] | None
Record identifier or list record identifiers. If None, get all records.
Returns
-------
asreview.data.record.Record | list[asreview.data.record.Record] | None
"""
if isinstance(record_id, np.integer):
record_id = record_id.item()
with self.Session() as session:
if record_id is None:
return session.query(self.record_cls).all()
elif isinstance(record_id, int):
return (
session.query(self.record_cls)
.filter(self.record_cls.record_id == record_id)
.first()
)
else:
records = _batched_in(
lambda batch: (
session.query(self.record_cls)
.filter(self.record_cls.record_id.in_(batch))
.all()
),
record_id,
)
record_id_to_position = {id: i for i, id in enumerate(record_id)}
return sorted(records, key=lambda r: record_id_to_position[r.record_id])
[docs]
def set_groups(self, groups):
"""Add record group information to the data store.
Parameters
----------
groups : list[tuple[int,int]]
List of tuples (group_id, record_id). This data is added to the record as
the `duplicate_of` attribute. The data store will normalize these values:
One record is chosen as the root, satisfying `root.duplicate_of = None`. All
other records in the group will get `record.duplicate_of = root.record_id`.
"""
group_to_records = defaultdict(set)
for group_id, record_id in groups:
group_to_records[group_id].add(record_id)
record_to_group = {}
for group in group_to_records.values():
group_id = min(group)
for record_id in group:
record_to_group[record_id] = group_id
with self.Session() as session, session.begin():
record_ids = list(record_to_group.keys())
records = _batched_in(
lambda batch: session.scalars(
select(Record).where(Record.record_id.in_(batch))
).all(),
record_ids,
)
for record in records:
record.duplicate_of = record_to_group[record.record_id]
[docs]
def get_groups(self, record_id=None):
"""Get the record groups.
Parameters
----------
record_id : int | None
Get only the group containing the record with this record_id.
Returns
-------
list[tuple[int, int]]
List of tuples (group_id, record_id) ordered by group id. The tuples values
are also accessible by the attribute names (so `tuple.group_id` and
`tuple.record_id`).
"""
stmt = select(
Record.group_id,
Record.record_id,
)
if record_id is not None:
# Get the records in the group of the record with the given record_id.
target_group_subq = (
select(Record.group_id)
.where(Record.record_id == record_id)
.scalar_subquery()
)
stmt = stmt.where(Record.group_id == target_group_subq)
with self.Session() as session:
return session.execute(stmt.order_by("group_id")).all()
[docs]
def get_df(self):
"""Get all data from the data store as a pandas DataFrmae.
Returns
-------
pd.DataFrame
"""
with self.engine.connect() as con:
return pd.read_sql_query(
f"SELECT * FROM {self.record_cls.__tablename__} ORDER BY record_id",
con,
dtype=self.pandas_dtype_mapping,
)