Source code for crate_anon.anonymise.subset_db

#!/usr/bin/env python

"""
crate_anon/anonymise/subset_db.py

===============================================================================

    Copyright (C) 2015, University of Cambridge, Department of Psychiatry.
    Created by Rudolf Cardinal (rnc1001@cam.ac.uk).

    This file is part of CRATE.

    CRATE is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    CRATE is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with CRATE. If not, see <https://www.gnu.org/licenses/>.

===============================================================================

**Create a simple subset of a database.**

"""

import argparse
from dataclasses import dataclass
import logging
from typing import Any, Generator, List, Set, Union

from cardinal_pythonlib.argparse_func import str2bool
from cardinal_pythonlib.file_io import gen_lines_without_comments
from cardinal_pythonlib.logs import main_only_quicksetup_rootlogger
from sqlalchemy.engine.url import make_url

from sqlalchemy.engine import Row
from sqlalchemy.sql.expression import column, select, table
from sqlalchemy.schema import Table

from crate_anon.anonymise.dbholder import DatabaseHolder
from crate_anon.common.argparse_assist import (
    RawDescriptionArgumentDefaultsRichHelpFormatter,
)
from crate_anon.version import CRATE_VERSION_PRETTY

log = logging.getLogger(__name__)


# =============================================================================
# Constants
# =============================================================================

BOOLHELP = " (Specify as yes/y/true/t/1 or no/n/false/f/0.)"
INCHELP = (
    " (If 'include' tables are given, only tables explicitly named are "
    "included. If no 'include' tables are specified, all tables are included "
    "by default. Explicit excluding always overrides including.)"
)


class SubsetDefaults:
    INC_IF_FILTERCOL_NULL = False
    INC_TABLES_NO_FILTERCOL = True


# =============================================================================
# Helper functions
# =============================================================================


[docs]def to_str(x: Any) -> Union[str, None]: """ Convert to a string, or None. """ return None if x is None else str(x)
# ============================================================================= # Data classes # =============================================================================
[docs]@dataclass class DatabaseFilterSource: name: str url: str table: str column: str echo: bool = False def gen_values(self) -> Generator[str, None, None]: dbh = DatabaseHolder( name=self.name, url=self.url, with_session=True, echo=self.echo ) query = select(column(self.column)).select_from(table(self.table)) result = dbh.session.execute(query) for row in result: yield to_str(row[0])
# ============================================================================= # Config # =============================================================================
[docs]class SubsetConfig: """ Simple configuration class for subsetting databases. """
[docs] def __init__( self, src_db_url: str, dst_db_url: str, filter_column: str = None, filter_values: List[str] = None, filter_value_filenames: List[str] = None, filter_value_db_urls: List[str] = None, filter_value_tablecols: List[str] = None, include_rows_filtercol_null: bool = ( SubsetDefaults.INC_IF_FILTERCOL_NULL ), include_tables_without_filtercol: bool = ( SubsetDefaults.INC_TABLES_NO_FILTERCOL ), include_tables: List[str] = None, include_table_filenames: List[str] = None, exclude_tables: List[str] = None, exclude_table_filenames: List[str] = None, echo: bool = False, ) -> None: """ Args: src_db_url: SQLAlchemy URL for the source database. dst_db_url: SQLAlchemy URL for the destination database. filter_column: Name of column to filter on (e.g. "patient_id"). If blank, might copy everything. filter_values: Values, treated as strings, to accept. filter_value_filenames: Filename(s), containing values, treated as strings, to accept. include_rows_filtercol_null: Allow the filter column to be NULL as well? include_tables_without_filtercol: Include tables that don't possess the filter column (e.g. system/lookup tables)? include_tables: Specific named tables to include. include_table_filenames: Filename(s), containin specific named tables to include. exclude_tables: Specific named tables to exclude. exclude_table_filenames: Filename(s), containin specific named tables to exclude. echo: Echo SQL (debugging only)? """ filter_values = filter_values or [] filter_value_filenames = filter_value_filenames or [] filter_value_db_urls = filter_value_db_urls or [] filter_value_tablecols = filter_value_tablecols or [] include_tables = include_tables or [] include_table_filenames = include_table_filenames or [] exclude_tables = exclude_tables or [] exclude_table_filenames = exclude_table_filenames or [] self.src_db_url = src_db_url self.dst_db_url = dst_db_url self.filter_column = filter_column self.include_rows_filtercol_null = include_rows_filtercol_null self.include_tables_without_filtercol = ( include_tables_without_filtercol ) self.echo = echo # Parse database filter sources if len(filter_value_db_urls) != len(filter_value_tablecols): raise ValueError( f"filter_value_db_urls (length {len(filter_value_db_urls)} " f"must be the same length as filter_value_tablecols (length " f"{len(filter_value_tablecols)})" ) dbfilters = [] # type: List[DatabaseFilterSource] for i, (url, tablecol) in enumerate( zip(filter_value_db_urls, filter_value_tablecols), start=1 ): tcvals = tablecol.split(".") if len(tcvals) != 2: raise ValueError( f"Arguments to 'filter_value_tablecols' must be in the " f"format 'table.column' but one is: {tablecol!r}" ) dbfilters.append( DatabaseFilterSource( name=f"filtersource_{i}", url=url, table=tcvals[0], column=tcvals[1], echo=self.echo, ) ) # Fetch filter values. # - No conversion to string required for command-line arguments, or # file sources, which come to us as strings anyway. # - But conversion required for database sources, which is # performed by the DatabaseFilterSource generator. self.filter_values = set(filter_values) # type: Set[Union[str, None]] for filename in filter_value_filenames: self.filter_values.update(gen_lines_without_comments(filename)) for dbfiltersource in dbfilters: self.filter_values.update(dbfiltersource.gen_values()) # Permit NULL? if self.include_rows_filtercol_null: self.filter_values.add(None) # Fetch "include" tables: self.include_tables = set(include_tables) # type: Set[str] for filename in include_table_filenames: self.include_tables.update(gen_lines_without_comments(filename)) # Fetch "exclude" tables: self.exclude_tables = set(exclude_tables) # type: Set[str] for filename in exclude_table_filenames: self.exclude_tables.update(gen_lines_without_comments(filename))
@staticmethod def _safe_url(url: str) -> str: """ Return a version of the SQLAlchemy URL with any password obscured. """ u = make_url(url) return repr(u) # obscures password @property def safe_src_db_url(self) -> str: """ Password-obscured version of the source database URL. """ return self._safe_url(self.src_db_url) @property def safe_dst_db_url(self) -> str: """ Password-obscured version of the destination database URL. """ return self._safe_url(self.dst_db_url)
[docs] def permit_table_name(self, table_name: str) -> bool: """ Should this table be permitted (judging only by its name)? """ if self.include_tables: # "include_tables" is not empty; therefore, some tables are being # explicitly included. In that case, only specifically named tables # can be included. (If "include_tables" was empty, we'd include # everything by default. See command-line help.) if table_name not in self.include_tables: # Not specifically included. return False if table_name in self.exclude_tables: # Specifically excluded. return False # Otherwise, OK. return True
# ============================================================================= # Subsetter # =============================================================================
[docs]class Subsetter: """ Class to take a subset of data from one database to another. """
[docs] def __init__(self, cfg: SubsetConfig) -> None: self.cfg = cfg log.info(f"Opening source database: {cfg.safe_src_db_url}") self.src_db = DatabaseHolder( name="source", url=cfg.src_db_url, with_session=True, echo=cfg.echo ) self.table_names = self.src_db.table_names # reflects log.info(f"Opening destination database: {cfg.safe_dst_db_url}") self.dst_db = DatabaseHolder( name="destination", url=cfg.dst_db_url, with_session=True, echo=cfg.echo, ) # Any warnings around filters: if not cfg.filter_column: if cfg.include_tables_without_filtercol: log.warning( "No filter column specified. Copying tables UNFILTERED." ) else: raise ValueError( "No filter column specified, and tables without a filter " "column not permitted; therefore, nothing to do." ) else: if not cfg.filter_values: if cfg.include_tables_without_filtercol: log.warning( f"No filter values. Only copying tables without the " f"filter column {cfg.filter_column!r}." )
# ------------------------------------------------------------------------- # Information about tables # -------------------------------------------------------------------------
[docs] def src_sqla_table(self, table_name: str) -> Table: """ Returns the SQLAlchemy Table from the source database. """ metadata = self.src_db.metadata return metadata.tables[table_name]
[docs] def column_names(self, table_name: str) -> List[str]: """ Returns column names for a (source) table column. """ t = self.src_sqla_table(table_name) # noinspection PyTypeChecker return [c.name for c in t.columns]
[docs] def contains_filter_col(self, table_name: str) -> bool: """ Does this table contain our target filter column? """ return self.cfg.filter_column in self.column_names(table_name)
[docs] def permit_table(self, table_name: str) -> bool: """ Is this table name permitted to go through to the destination? """ if table_name not in self.table_names: # log.debug(f"... {table_name}: unknown table") return False if not self.cfg.permit_table_name(table_name): # log.debug(f"... {table_name}: prohibited table") return False if not self.contains_filter_col(table_name): # log.debug(f"... {table_name}: system/lookup table") return self.cfg.include_tables_without_filtercol # log.debug(f"... {table_name}: standard permitted table") return True
[docs] def dst_sqla_table(self, table_name: str) -> Table: """ Returns the SQLAlchemy Table from the destination database. """ metadata = self.dst_db.metadata return metadata.tables[table_name]
# ------------------------------------------------------------------------- # DDL manipulation # -------------------------------------------------------------------------
[docs] def drop_dst_table_if_exists(self, table_name: str) -> None: """ Drop a table on the destination side. Also remove it from the destination metadata, so we can recreate it (if necessary) without complaint. """ log.debug(f"Dropping destination table: {table_name}") dst_metadata = self.dst_db.metadata t = Table(table_name, dst_metadata) t.drop(self.dst_db.engine, checkfirst=True) dst_metadata.remove(t)
[docs] def create_dst_table(self, table_name: str) -> None: """ Create a table on the destination side. """ log.debug(f"Creating destination table: {table_name}") t = self.src_sqla_table(table_name).to_metadata(self.dst_db.metadata) t.create(self.dst_db.engine, checkfirst=True)
# ------------------------------------------------------------------------- # Filtering # -------------------------------------------------------------------------
[docs] def gen_src_rows(self, table_name: str) -> Generator[Row, None, None]: """ Generate unfiltered source rows from the database. """ query = select(["*"]).select_from(table(table_name)) result = self.src_db.session.execute(query) yield from result
[docs] def gen_filtered_rows(self, table_name: str) -> Generator[Row, None, None]: """ Generate filtered source rows from the database. """ srcgen = self.gen_src_rows(table_name) if self.contains_filter_col(table_name): filtercol = self.cfg.filter_column filtervals = self.cfg.filter_values for row in srcgen: # String-based comparison if to_str(row[filtercol]) in filtervals: # Row permitted yield row else: # All rows permitted; go faster. yield from srcgen
[docs] def subset_table(self, table_name: str) -> None: """ Read rows from the source table; filter them as required; store them in the destination table. """ n_inserted = 0 dst_session = self.dst_db.session dst_table = self.dst_sqla_table(table_name) for row in self.gen_filtered_rows(table_name): dst_session.execute(dst_table.insert(values=row)) n_inserted += 1 log.info(f"Processing table {table_name}: inserted {n_inserted} rows")
[docs] def commit(self) -> None: """ Commit changes to the destination database. """ log.debug("Committing...") self.dst_db.session.commit()
[docs] def subset_db(self) -> None: """ Main function -- create a subset of the source database. """ log.info(f"Filtering on column: {self.cfg.filter_column}") for table_name in self.table_names: self.drop_dst_table_if_exists(table_name) if not self.permit_table(table_name): log.info(f"SKIPPING table {table_name}") continue log.info(f"Processing table {table_name}") self.create_dst_table(table_name) self.subset_table(table_name) self.commit() log.info("Done.")
# ============================================================================= # Main # =============================================================================
[docs]def main() -> None: """ Command-line entry point. """ parser = argparse.ArgumentParser( description=f"Create a simple subset of a database, copying one " f"database to another while applying filters. You can filter by a " f"standard column (e.g. one representing patient IDs), taking " f"permitted filter values from the command line, from file(s), and/or " f"from database(s). You can also decide which tables to " f"include/exclude. ({CRATE_VERSION_PRETTY})", formatter_class=RawDescriptionArgumentDefaultsRichHelpFormatter, ) grp_src = parser.add_argument_group("SOURCE DATABASE") grp_src.add_argument( "--src_db_url", required=True, help="Source database SQLAlchemy URL" ) grp_dst = parser.add_argument_group("DESTINATION DATABASE") grp_dst.add_argument( "--dst_db_url", required=True, help="Destination database SQLAlchemy URL", ) grp_fr = parser.add_argument_group("ROW FILTERING") grp_fr.add_argument( "--filter_column", help="Column on which to filter. Typically the one that defines " "individuals (e.g. 'patient_research_id', 'rid', 'brcid'). If " "omitted, then the whole database might be copied unfiltered (if you " "set --include_tables_without_filtercol).", ) grp_fr.add_argument( "--filter_values", nargs="*", help="Filter values to permit. (Comparison is performed as strings.)", ) grp_fr.add_argument( "--filter_value_filenames", nargs="*", help="Filename(s) of files containing filter values to permit. " "('#' denotes comments in the file. " "Comparison is performed as strings.)", ) grp_fr.add_argument( "--filter_value_db_urls", nargs="*", help="SQLAlchemy URLs of databases to pull additional filter values " "from. Must be in the same order as corresponding arguments to " "--filter_value_tablecols.", ) grp_fr.add_argument( "--filter_value_tablecols", nargs="*", help="Table/column pairs, each expressed as 'table.column', of " "database columns to pull additional filter values from. Must be in " "the same order as corresponding arguments to " "--filter_value_db_urls.", ) grp_fr.add_argument( "--include_rows_filtercol_null", type=str2bool, nargs="?", const=SubsetDefaults.INC_IF_FILTERCOL_NULL, default=SubsetDefaults.INC_IF_FILTERCOL_NULL, help="Include rows where the filter column is NULL. You can't " "otherwise specify NULL as a permitted value (at least, not from the " "command line or from files)." + BOOLHELP, ) grp_ft = parser.add_argument_group("TABLE FILTERING") grp_ft.add_argument( "--include_tables_without_filtercol", type=str2bool, nargs="?", const=SubsetDefaults.INC_TABLES_NO_FILTERCOL, # ... if present with no parameter default=SubsetDefaults.INC_TABLES_NO_FILTERCOL, # ... if argument entirely absent help="Include tables that do not possess the filter column (e.g. " "system/lookup tables)." + BOOLHELP, ) grp_ft.add_argument( "--include_tables", nargs="*", help="Names of tables to include." + INCHELP, ) grp_ft.add_argument( "--include_table_filenames", nargs="*", help="Filename(s) of files containing names of tables to include." + INCHELP, ) grp_ft.add_argument( "--exclude_tables", nargs="*", help="Names of tables to exclude.", ) grp_ft.add_argument( "--exclude_table_filenames", nargs="*", help="Filename(s) of files containing names of tables to exclude.", ) grp_progress = parser.add_argument_group("PROGRESS") grp_progress.add_argument( "--verbose", "-v", action="store_true", help="Be verbose" ) grp_progress.add_argument( "--echo", action="store_true", help="Echo SQL (slow; for debugging only)", ) args = parser.parse_args() # ------------------------------------------------------------------------- # Verbosity, logging # ------------------------------------------------------------------------- loglevel = logging.DEBUG if args.verbose else logging.INFO main_only_quicksetup_rootlogger(level=loglevel) # ------------------------------------------------------------------------- # Onwards # ------------------------------------------------------------------------- subsetcfg = SubsetConfig( # source src_db_url=args.src_db_url, # destination dst_db_url=args.dst_db_url, # filter filter_column=args.filter_column, filter_values=args.filter_values, filter_value_filenames=args.filter_value_filenames, filter_value_db_urls=args.filter_value_db_urls, filter_value_tablecols=args.filter_value_tablecols, include_rows_filtercol_null=args.include_rows_filtercol_null, include_tables_without_filtercol=args.include_tables_without_filtercol, include_tables=args.include_tables, include_table_filenames=args.include_table_filenames, exclude_tables=args.exclude_tables, exclude_table_filenames=args.exclude_table_filenames, # progress echo=args.echo, ) subsetter = Subsetter(subsetcfg) subsetter.subset_db()
if __name__ == "__main__": main()