Source code for asreview.database.database

import json
import sqlite3
import time
from functools import cached_property

import pandas as pd

from asreview.data.record import Record
from asreview.database.store import DataStore
from asreview.database.store import _build_conn_uri

__all__ = ["Database"]

CURRENT_DATABASE_VERSION = 3

MODEL_COLUMNS = [
    "classifier",
    "querier",
    "balancer",
    "feature_extractor",
    "training_set",
]

REQUIRED_TABLES = [
    "results",
    "last_ranking",
    "decision_changes",
]

RESULTS_TABLE_COLUMNS_PANDAS_DTYPES = {
    "record_id": "Int64",
    "label": "Int64",
    "classifier": "object",
    "querier": "object",
    "balancer": "object",
    "feature_extractor": "object",
    "training_set": "Int64",
    "time": "Float64",
    "note": "object",
    "tags": "object",
    "user_id": "Int64",
}

RANKING_TABLE_COLUMNS_PANDAS_DTYPES = {
    "record_id": "Int64",
    "ranking": "Int64",
    "classifier": "object",
    "querier": "object",
    "balancer": "object",
    "feature_extractor": "object",
    "training_set": "Int64",
    "time": "Float64",
}

CREATE_LAST_RANKING_TABLE_SQL = """CREATE TABLE IF NOT EXISTS last_ranking
    (record_id INTEGER UNIQUE,
    ranking INT,
    classifier TEXT,
    querier TEXT,
    balancer TEXT,
    feature_extractor TEXT,
    training_set INTEGER,
    time FLOAT)"""


[docs] def open_db(fp, read_only=False): """Open a database. Parameters ---------- fp : path-like File path to the database read_only : bool, optional Whether to create a new database if one doesn't exist yet and whether the opened database will be in read only mode or not. Returns ------- Database ASReview database. Raises ------ FileNotFoundError If `read_only` and there is no file at `fp`. ValueError If `read_only` and there is no valid database at `fp`. """ if not fp.is_file(): if read_only: raise FileNotFoundError( f"File path {fp} is not a file and 'read_only' is 'True'" ) fp.parent.mkdir(parents=True, exist_ok=True) db = Database(fp, read_only=read_only) try: db._is_valid() except ValueError as e: if read_only: raise ValueError( f"There is no valid database at {fp} and the database is opened in" " read-only mode" ) from e db.create_tables() return db
[docs] class Database: """Database containing the input data and results. Database contains two parts: the input and the results. For more information on the input, see `asreview.database.store.py`. For more information on the results, see `asreview.database.sqlstate.py`. Attributes ---------- user_version: str Return the version number of the database. """ def __init__(self, fp=":memory:", record_cls=Record, read_only=False): """Initialize the Database. Parameters ---------- fp : str | Path Path of the database file. Use `":memory:"` for an in-memory database. record_cls : type[asreview.data.record.Base], optional Type to use for the input records, see `DataStore` for more information. read_only : bool, optional Whether to open the database in read only mode. If the database is opened in read only mode and an attempt to write to the database is made, an `sqlite3.OperationalError` will be raised. """ if fp == ":memory:" and read_only: raise ValueError("Can't open an in-memory database in read only mode") self.fp = fp self.record_cls = record_cls self.read_only = read_only self._in_memory = fp == ":memory:" self._closed = False self._conn_uri = _build_conn_uri(fp, read_only) self.input = DataStore( conn_uri=self._conn_uri, record_cls=record_cls, read_only=read_only ) if self._in_memory: # Eagerly open the sqlite3 connection. For named in-memory databases, # the database is destroyed when the last connection to it closes. # This connection acts as an anchor that keeps the database alive # for the lifetime of this object. self._conn def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() def __del__(self): self.close() @cached_property def _conn(self): """Get a connection to the SQLite database. Returns ------- sqlite3.Connection Connection to the SQLite database. """ return sqlite3.connect(self._conn_uri, uri=True)
[docs] def close(self): """Close the database and release all resources. For in-memory databases this will destroy the database. Safe to call multiple times. """ if self._closed: return self._closed = True self.input.engine.dispose() if "_conn" in self.__dict__: self._conn.close() del self.__dict__["_conn"]
@property def user_version(self): """Version number of the state.""" cur = self._conn.cursor() version = cur.execute("PRAGMA user_version") return int(version.fetchone()[0]) @user_version.setter def user_version(self, version): cur = self._conn.cursor() cur.execute(f"PRAGMA user_version = {version}") self._conn.commit() cur.close()
[docs] def create_tables(self): self.user_version = CURRENT_DATABASE_VERSION self.input.create_tables() cur = self._conn.cursor() cur.execute( """CREATE TABLE results (record_id INTEGER UNIQUE, label INTEGER, classifier TEXT, querier TEXT, balancer TEXT, feature_extractor TEXT, training_set INTEGER, time FLOAT, note TEXT, tags JSON, user_id INTEGER)""" ) cur.execute(CREATE_LAST_RANKING_TABLE_SQL) cur.execute( """CREATE TABLE decision_changes (record_id INTEGER, label INTEGER, time FLOAT, user_id INTEGER)""" ) self._conn.commit() self._set_results_changes_triggers()
def _is_valid(self): if self.user_version != CURRENT_DATABASE_VERSION: raise ValueError( f"Database version {self.user_version} is not supported. " "See migration guide." ) cur = self._conn.cursor() column_names = cur.execute("PRAGMA table_info(results)").fetchall() table_names = cur.execute( "SELECT name FROM sqlite_master WHERE type='table';" ).fetchall() table_names = [tup[0] for tup in table_names] missing_tables = [ table for table in REQUIRED_TABLES + [self.record_table_name] if table not in table_names ] if missing_tables: raise ValueError( f"The SQL file should contain tables named " f"'{' '.join(missing_tables)}'." ) column_names = [tup[1] for tup in column_names] missing_columns = [ col for col in RESULTS_TABLE_COLUMNS_PANDAS_DTYPES.keys() if col not in column_names ] if missing_columns: raise ValueError( f"The results table does not contain the columns " f"{' '.join(missing_columns)}." ) if not self.read_only: self._fix_decision_changes_schema(cur) def _fix_decision_changes_schema(self, cur): """Fix decision_changes schema for projects migrated from old v2 format. Old v2 projects had (record_id, new_label, time) instead of (record_id, label, time, user_id). Projects already migrated to v3 may still carry the old schema. """ columns = [row[1] for row in cur.execute("PRAGMA table_info(decision_changes)")] if "new_label" in columns and "label" not in columns: cur.execute("ALTER TABLE decision_changes RENAME COLUMN new_label TO label") if "user_id" not in columns: cur.execute("ALTER TABLE decision_changes ADD COLUMN user_id INTEGER") self._conn.commit() def _set_results_changes_triggers(self): con = self._conn cur = con.cursor() cur.execute(""" CREATE TRIGGER IF NOT EXISTS trg_results_delete AFTER DELETE ON results FOR EACH ROW BEGIN INSERT INTO decision_changes (record_id, label, time, user_id) VALUES (OLD.record_id, OLD.label, OLD.time, OLD.user_id); END """) cur.execute(""" CREATE TRIGGER IF NOT EXISTS trg_results_label_update AFTER UPDATE OF label ON results FOR EACH ROW WHEN OLD.label IS NOT NEW.label BEGIN INSERT INTO decision_changes (record_id, label, time, user_id) VALUES (OLD.record_id, OLD.label, OLD.time, OLD.user_id); END """) con.commit() @property def record_table_name(self): return self.input.record_cls.__tablename__ @property def exist_new_labeled_records(self): """Return True if there are new labeled records. Return True if there are any record labels added since the last time the model ranking was added to the state. Also returns True if no model was trained yet, but priors have been added. """ labeled = self.get_results_table("label") last_training_set = self.get_last_ranking_table()["training_set"] if last_training_set.empty or pd.isna(last_training_set.max()): return len(labeled) > 0 else: return len(labeled) > last_training_set.max() def _replace_results_from_df(self, results): if not set(results.columns) == set(RESULTS_TABLE_COLUMNS_PANDAS_DTYPES): raise ValueError( f"Columns of the results dataframe should be " f"{list(RESULTS_TABLE_COLUMNS_PANDAS_DTYPES.keys())}." ) cur = self._conn.cursor() cur.execute("delete from results") self._conn.commit() cur.close() results.to_sql("results", self._conn, if_exists="append", index=False) def _replace_last_ranking_from_df(self, last_ranking): if not set(last_ranking.columns) == set(RANKING_TABLE_COLUMNS_PANDAS_DTYPES): raise ValueError( f"Columns of the last ranking dataframe should be " f"{list(RANKING_TABLE_COLUMNS_PANDAS_DTYPES.keys())}." ) self._write_last_ranking(last_ranking) def _write_last_ranking(self, df): """Atomically replace the contents of the `last_ranking` table. The table is never dropped: the rows are replaced inside a single `BEGIN IMMEDIATE` transaction so that a "database is locked" error rolls back cleanly, leaving the previous ranking intact instead of a missing table. The `CREATE TABLE IF NOT EXISTS` also self-heals a database whose `last_ranking` table was lost by an earlier interrupted write. """ columns = list(RANKING_TABLE_COLUMNS_PANDAS_DTYPES) # itertuples yields native Python types (avoids numpy adapter issues with # sqlite3) and the explicit column order guarantees value alignment. rows = list(df[columns].itertuples(index=False, name=None)) col_list = ", ".join(columns) placeholders = ", ".join(["?"] * len(columns)) insert_sql = f"INSERT INTO last_ranking ({col_list}) VALUES ({placeholders})" con = self._conn cur = con.cursor() try: cur.execute("BEGIN IMMEDIATE") cur.execute(CREATE_LAST_RANKING_TABLE_SQL) cur.execute("DELETE FROM last_ranking") cur.executemany(insert_sql, rows) con.commit() except Exception: con.rollback() raise finally: cur.close()
[docs] def add_last_ranking( self, ranked_record_ids, classifier, querier, balancer, feature_extractor, training_set=None, ): """Save the ranking of the last iteration of the model. Save the ranking of the last iteration of the model, in the ranking order, so the record on row 0 is ranked first by the model. Parameters ---------- ranked_record_ids: list, numpy.ndarray A list of records ids in the order that they were ranked. classifier: str Name of the classifier of the model. querier: str Name of the query strategy of the model. balancer: str Name of the balance strategy of the model. feature_extractor: str Name of the feature extraction method of the model. training_set: int Number of labeled records available at the time of training. """ self._write_last_ranking( pd.DataFrame( { "record_id": ranked_record_ids, "ranking": range(len(ranked_record_ids)), "classifier": classifier, "querier": querier, "balancer": balancer, "feature_extractor": feature_extractor, "training_set": training_set, "time": time.time(), } ) )
[docs] def get_last_ranking_table(self): """Get the ranking from the state. Returns ------- pd.DataFrame Dataframe with columns 'record_id', 'ranking', 'classifier', 'querier', 'balancer', 'feature_extractor', 'training_set' and 'time'. It has one row for each record in the dataset, and is ordered by ranking. """ # Self-heal a `last_ranking` table that was lost by an earlier interrupted # write, so reads (and `exist_new_labeled_records`, which runs before the # write path in `run_model`) keep working on an already-broken database. cur = self._conn.cursor() cur.execute(CREATE_LAST_RANKING_TABLE_SQL) self._conn.commit() cur.close() return pd.read_sql_query( "SELECT * FROM last_ranking", self._conn, dtype=RANKING_TABLE_COLUMNS_PANDAS_DTYPES, )
[docs] def label_record(self, record_id, label, tags=None, user_id=None): if tags is not None: tags = json.dumps(tags) labeling_time = time.time() con = self._conn cur = con.cursor() model_string = ", ".join(MODEL_COLUMNS) target_result_string = ", ".join( f"target_result.{col}" for col in MODEL_COLUMNS ) upsert_columns = ["label", "time", "tags", "user_id"] + MODEL_COLUMNS upsert_string = ", ".join(f"{col} = excluded.{col}" for col in upsert_columns) cur.execute( f""" WITH target_group AS ( SELECT record_id FROM {self.record_table_name} WHERE group_id = ( SELECT group_id FROM {self.record_table_name} WHERE record_id=:record_id ) ), target_result AS ( SELECT {model_string} FROM results WHERE record_id = :record_id ) INSERT INTO results(record_id, label, time, tags, user_id, {model_string}) SELECT target_group.record_id, :label, :time, :tags, :user_id, {target_result_string} FROM target_group LEFT JOIN target_result ON 1 ON CONFLICT(record_id) DO UPDATE SET {upsert_string}; """, { "record_id": record_id, "label": label, "time": labeling_time, "tags": tags, "user_id": user_id, }, ) con.commit()
[docs] def query_top_ranked(self, user_id=None): model_string = ", ".join(MODEL_COLUMNS) top_record_string = ", ".join(f"top_record.{col}" for col in MODEL_COLUMNS) con = self._conn cur = con.cursor() cur.execute( f"""INSERT INTO results (record_id, user_id, {model_string}) WITH top_record AS ( SELECT last_ranking.* FROM last_ranking LEFT JOIN results USING (record_id) WHERE results.record_id IS NULL ORDER BY ranking LIMIT 1 ), group_records AS ( SELECT record.record_id FROM record WHERE group_id = ( SELECT group_id FROM record WHERE record.record_id = (SELECT record_id FROM top_record) ) ) SELECT group_records.record_id, :user_id, {top_record_string} FROM group_records CROSS JOIN top_record;""", {"user_id": user_id}, ) con.commit() if not cur.rowcount: raise ValueError("Failed to query top ranked record") return self.get_pending(user_id=user_id)
[docs] def update_result(self, record_id, label=None, tags=None, user_id=None): if label is None and tags is None: raise ValueError("At least one of 'label' or 'tags' must be provided.") fields = [] values = {"record_id": record_id} if label is not None: fields.append("label = :label") values["label"] = label if user_id is not None: # We only update the user_id if the label changes. fields.append("user_id = :user_id") values["user_id"] = user_id if tags is not None: fields.append("tags = :tags") values["tags"] = json.dumps(tags) set_string = ", ".join(fields) con = self._conn cur = con.cursor() cur.execute( f""" WITH target_group AS ( SELECT record_id FROM {self.record_table_name} WHERE group_id = ( SELECT group_id FROM {self.record_table_name} WHERE record_id=:record_id ) ) UPDATE results SET {set_string} WHERE record_id IN (SELECT record_id FROM target_group) """, values, ) con.commit()
[docs] def update_note(self, record_id, note=None): """Change the note of an already labeled or pending record. Parameters ---------- record_id: int Id of the record whose label should be changed. note: str Note to add to the record. """ cur = self._conn.cursor() cur.execute( f""" WITH target_group AS ( SELECT record_id FROM {self.record_table_name} WHERE group_id = ( SELECT group_id FROM {self.record_table_name} WHERE record_id=:record_id ) ) UPDATE results SET note = :note WHERE record_id IN ( SELECT record_id FROM target_group )""", {"note": note, "record_id": record_id}, ) if cur.rowcount == 0: raise ValueError(f"Record with id {record_id} not found.") self._conn.commit()
[docs] def delete_result(self, record_id): con = self._conn cur = con.cursor() cur.execute( f""" WITH target_group AS ( SELECT record_id FROM {self.record_table_name} WHERE group_id = ( SELECT group_id FROM {self.record_table_name} WHERE record_id=:record_id ) ) DELETE FROM results WHERE record_id IN (SELECT record_id FROM target_group) """, {"record_id": record_id}, ) con.commit()
[docs] def get_results_record(self, record_id): """Get the data of a specific query from the results table. Parameters ---------- record_id: int Record id of which you want the data. Returns ------- pd.DataFrame Dataframe containing the data from the results table with the given record_id and columns. """ result = pd.read_sql_query( f"SELECT * FROM results WHERE record_id={record_id}", self._conn, dtype=RESULTS_TABLE_COLUMNS_PANDAS_DTYPES, ) result["tags"] = result["tags"].map(json.loads, na_action="ignore") return result
[docs] def get_results_table(self, columns=None, priors=True, pending=False, groups=False): """Get a subset from the results table. Can be used to get any column subset from the results table. Most other get functions use this one, except some that use a direct SQL query for efficiency. Parameters ---------- columns: list, str List of columns names of the results table, or a string containing one column name. priors: bool Whether to keep the records containing the prior knowledge. pending: bool Whether to keep the records which are pending a labeling decision. groups: bool Return all the records of a group of records. Be default only returns the base record of each group. Returns ------- pd.DataFrame: Dataframe containing the data of the specified columns of the results table. """ if isinstance(columns, str): columns = [columns] if (not priors) or (not pending) or (not groups): sql_where = [] if not priors: sql_where.append("querier is not NULL") if not pending: sql_where.append("label is not NULL") if not groups: sql_where.append( f"record_id IN ( SELECT group_id FROM {self.record_table_name})" ) sql_where_str = "WHERE " + " AND ".join(sql_where) else: sql_where_str = "" if columns is None: col_dtype = RESULTS_TABLE_COLUMNS_PANDAS_DTYPES else: col_dtype = { k: v for k, v in RESULTS_TABLE_COLUMNS_PANDAS_DTYPES.items() if columns and k in columns } query_string = "*" if columns is None else ",".join(columns) df_results = pd.read_sql_query( f"SELECT {query_string} FROM results {sql_where_str} ORDER BY rowid", self._conn, dtype=col_dtype, ) if columns is None or "tags" in columns: df_results["tags"] = df_results["tags"].map(json.loads, na_action="ignore") return df_results
[docs] def get_priors(self): """Get the record ids of the priors. Returns ------- pd.DataFrame: The result records of the priors in the order they were added. If multiple records are in the same group, only the base record of the group is returned. """ df_results = pd.read_sql_query( f""" SELECT * FROM results WHERE results.querier is NULL AND results.label is not NULL AND record_id IN ( SELECT group_id FROM {self.record_table_name} ) ORDER BY rowid """, self._conn, dtype=RESULTS_TABLE_COLUMNS_PANDAS_DTYPES, ) df_results["tags"] = df_results["tags"].map(json.loads, na_action="ignore") return df_results
[docs] def get_pool(self): """Get the unlabeled, not-pending records in ranking order. Returns ------- pd.Series Series containing the record_ids of the unlabeled, not pending records, in the order of the last available ranking. If the state does not yet contain a last ranking, the return value will be an empty dataframe. If multiple records are in the same group, only the base record of the group is returned. """ return pd.read_sql_query( f"""SELECT record_id, last_ranking.ranking FROM last_ranking LEFT JOIN results USING (record_id) WHERE results.record_id is null AND last_ranking.record_id IN ( SELECT group_id FROM {self.record_table_name} ) ORDER BY ranking """, self._conn, )["record_id"]
[docs] def get_unlabeled(self, groups=False): """Get the unlabeled record ids in ranking order. Records that have no label or no entry in the results table are considered unlabeled. Parameters ---------- groups : bool If True, return all records in each unlabeled group. If False, return only group representatives (record_id == group_id). Returns ------- pd.Series Series of record_ids of unlabeled records ordered by ranking. """ if groups: sql_group_filter = "" else: sql_group_filter = ( f"AND record_id IN (SELECT group_id FROM {self.record_table_name})" ) return pd.read_sql_query( f"""SELECT record_id, last_ranking.ranking FROM last_ranking JOIN {self.record_table_name} USING (record_id) LEFT JOIN results USING (record_id) WHERE (results.record_id IS NULL OR results.label IS NULL) {sql_group_filter} ORDER BY ranking """, self._conn, )["record_id"]
[docs] def get_pending(self, user_id=None): """Get pending records from the results table. Parameters ---------- user_id: int User id of the user who labeled the records. Returns ------- pd.DataFrame DataFrame with pending results records. """ query = f"""SELECT * FROM results WHERE label is null AND record_id IN ( SELECT group_id FROM {self.record_table_name} )""" params = None if user_id is not None: query += " AND user_id=?" params = (user_id,) query += " ORDER BY rowid" return pd.read_sql_query( query, self._conn, params=params, dtype=RESULTS_TABLE_COLUMNS_PANDAS_DTYPES, )
[docs] def get_decision_changes(self): """Get the record ids for any decision changes. Get the record ids of the records whose labels have been changed after the original labeling action. Returns ------- pd.DataFrame Dataframe with columns 'record_id', 'label', 'time', and 'user_id' for each record of which the labeling decision was changed. """ return pd.read_sql_query("SELECT * FROM decision_changes", self._conn)