"""
crate_anon/common/sql.py
===============================================================================
Copyright (C) 2015, University of Cambridge, Department of Psychiatry.
Created by Rudolf Cardinal (rnc1001@cam.ac.uk).
This file is part of CRATE.
CRATE is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
CRATE is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with CRATE. If not, see <https://www.gnu.org/licenses/>.
===============================================================================
**Low-level SQL manipulation functions.**
These are about the manipulation of SQL as text (e.g. for query building
assistance for researchers, or for interpreting SQL data types in data
dictionaries), not about a higher-level approach like SQLAlchemy.
"""
from collections import OrderedDict
from dataclasses import dataclass
import functools
import logging
import re
from typing import Any, Dict, Iterable, List, Tuple, Union, Optional
from cardinal_pythonlib.json_utils.serialize import (
METHOD_PROVIDES_INIT_KWARGS,
METHOD_STRIP_UNDERSCORE,
register_for_json,
)
from cardinal_pythonlib.lists import unique_list
from cardinal_pythonlib.reprfunc import mapped_repr_stripping_underscores
from cardinal_pythonlib.sizeformatter import sizeof_fmt
from cardinal_pythonlib.sql.literals import (
sql_date_literal,
sql_string_literal,
)
from cardinal_pythonlib.sql.sql_grammar import SqlGrammar, text_from_parsed
from cardinal_pythonlib.sql.sql_grammar_factory import (
make_grammar,
mysql_grammar,
)
from cardinal_pythonlib.sql.validation import (
SQLTYPES_INTEGER,
SQLTYPES_BIT,
SQLTYPES_FLOAT,
SQLTYPES_TEXT,
SQLTYPES_OTHER_NUMERIC,
)
from cardinal_pythonlib.sqlalchemy.core_query import count_star
from cardinal_pythonlib.sqlalchemy.dialect import SqlaDialectName
from cardinal_pythonlib.sqlalchemy.schema import (
column_creation_ddl,
execute_ddl,
)
from cardinal_pythonlib.timing import MultiTimerContext, timer
from pyparsing import ParseResults
from sqlalchemy import inspect
from sqlalchemy.dialects.mssql.base import MS_2012_VERSION
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.exc import CompileError
from sqlalchemy.orm.session import Session
from sqlalchemy.schema import Column, Table
from sqlalchemy.sql.sqltypes import TypeEngine
from crate_anon.common.stringfunc import get_spec_match_regex
log = logging.getLogger(__name__)
# =============================================================================
# Types
# =============================================================================
SqlArgsTupleType = Tuple[str, List[Any]]
# =============================================================================
# Constants
# =============================================================================
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Generic
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
TIMING_COMMIT = "commit"
SQL_OPS_VALUE_UNNECESSARY = ["IS NULL", "IS NOT NULL"]
SQL_OPS_MULTIPLE_VALUES = ["IN", "NOT IN"]
SQLTYPES_INTEGER_OR_BIT = SQLTYPES_INTEGER + SQLTYPES_BIT
SQLTYPES_FLOAT_OR_OTHER_NUMERIC = SQLTYPES_FLOAT + SQLTYPES_OTHER_NUMERIC
# Must match querybuilder.js:
QB_DATATYPE_INTEGER = "int"
QB_DATATYPE_FLOAT = "float"
QB_DATATYPE_DATE = "date"
QB_DATATYPE_STRING = "string"
QB_DATATYPE_STRING_FULLTEXT = "string_fulltext"
QB_DATATYPE_UNKNOWN = "unknown"
QB_STRING_TYPES = [QB_DATATYPE_STRING, QB_DATATYPE_STRING_FULLTEXT]
COLTYPE_WITH_ONE_INTEGER_REGEX = re.compile(r"^([A-z]+)\((-?\d+)\)$")
# ... start, group(alphabetical), literal (, group(optional_minus_sign digits),
# literal ), end
# Dictionaries for the different dialects mapping text column type to length
# or default length.
# Doesn't include things like VARCHAR which require the user to specify length
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# SQLAlchemy dialects
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
DATABRICKS_COLTYPE_TO_LEN = {
# https://docs.databricks.com/en/sql/language-manual/data-types/string-type.html # noqa: E501
"STRING": None # There is no maximum.
}
MSSQL_COLTYPE_TO_LEN = {
# The "N" prefix means Unicode.
# https://docs.microsoft.com/en-us/sql/t-sql/data-types/char-and-varchar-transact-sql?view=sql-server-ver15 # noqa: E501
# https://docs.microsoft.com/en-us/sql/t-sql/data-types/nchar-and-nvarchar-transact-sql?view=sql-server-ver15 # noqa: E501
# https://docs.microsoft.com/en-us/sql/t-sql/data-types/ntext-text-and-image-transact-sql?view=sql-server-ver15 # noqa: E501
"NVARCHAR_MAX": 2**30 - 1,
# Can specify NVARCHAR(1) to NVARCHAR(4000), or NVARCHAR(MAX) for 2^30 - 1.
"VARCHAR_MAX": 2**31 - 1,
# Can specify VARCHAR(1) to VARCHAR(8000), or VARCHAR(MAX) for 2^31 - 1.
"TEXT": 2**31 - 1,
"NTEXT": 2**30 - 1,
}
MYSQL_COLTYPE_TO_LEN = {
# https://dev.mysql.com/doc/refman/8.0/en/string-type-overview.html
"CHAR": 1, # can specify CHAR(0) to CHAR(255), but if omitted, length is 1
"TINYTEXT": 255, # 2^8 - 1
"TEXT": 65535, # 2^16 - 1
"MEDIUMTEXT": 16777215, # 2^24 - 1
"LONGTEXT": 4294967295, # 2^32 - 1
}
DIALECT_TO_STRING_LEN_LOOKUP = {
SqlaDialectName.DATABRICKS: DATABRICKS_COLTYPE_TO_LEN,
SqlaDialectName.MSSQL: MSSQL_COLTYPE_TO_LEN,
SqlaDialectName.MYSQL: MYSQL_COLTYPE_TO_LEN,
}
# =============================================================================
# Helper classes
# =============================================================================
[docs]@dataclass
class IndexCreationInfo:
index_name: str #: Name of the index
column: Union[str, List[str]] #: Column name(s) to index
unique: bool = False #: Make a unique index?
@property
def column_names(self) -> str:
if isinstance(self.column, str):
# Single column
return self.column
else:
# Multiple columns
return ", ".join(self.column)
# =============================================================================
# SQL elements: identifiers
# =============================================================================
[docs]@register_for_json(method=METHOD_STRIP_UNDERSCORE)
@functools.total_ordering
class SchemaId:
"""
Represents a database schema. This is a bit complex:
- In SQL Server, schemas live within databases. Tables can be referred to
as ``table``, ``schema.table``, or ``database.schema.table``.
- https://docs.microsoft.com/en-us/dotnet/framework/data/adonet/sql/ownership-and-user-schema-separation-in-sql-server
- The default schema is named ``dbo``.
- In PostgreSQL, schemas live within databases. Tables can be referred to
as ``table``, ``schema.table``, or ``database.schema.table``.
- https://www.postgresql.org/docs/current/static/ddl-schemas.html
- The default schema is named ``public``.
- In MySQL, "database" and "schema" are synonymous. Tables can be referred
to as ``table`` or ``database.table`` (= ``schema.table``).
- https://stackoverflow.com/questions/11618277/difference-between-schema-database-in-mysql
""" # noqa: E501
[docs] def __init__(self, db: str = "", schema: str = "") -> None:
"""
Args:
db: database name
schema: schema name
"""
assert "." not in db, f"Bad database name ({db!r}); can't include '.'"
assert (
"." not in schema
), f"Bad schema name ({schema!r}); can't include '.'"
self._db = db
self._schema = schema
@property
def schema_tag(self) -> str:
"""
String suitable for encoding the SchemaId e.g. in a single HTML form.
Takes the format ``database.schema``.
The :func:`__init__` function has already checked the assumption of no
``'.'`` characters in either part.
"""
return f"{self._db}.{self._schema}"
[docs] @classmethod
def from_schema_tag(cls, tag: str) -> "SchemaId":
"""
Returns a :class:`SchemaId` from a tag of the form ``db.schema``.
"""
parts = tag.split(".")
assert len(parts) == 2, f"Bad schema tag {tag!r}"
db, schema = parts
return SchemaId(db, schema)
def __bool__(self) -> bool:
"""
Returns:
is there a named schema?
"""
return bool(self._schema)
def __eq__(self, other: "SchemaId") -> bool:
return ( # ordering is for speed
self._schema == other._schema and self._db == other._db
)
def __lt__(self, other: "SchemaId") -> bool:
return (self._db, self._schema) < (other._db, other._schema)
def __hash__(self) -> int:
return hash(str(self))
[docs] def identifier(self, grammar: SqlGrammar) -> str:
"""
Returns an SQL identifier for this schema using the specified SQL
grammar, quoting it if need be.
Args:
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
"""
return make_identifier(grammar, database=self._db, schema=self._schema)
[docs] def table_id(self, table: str) -> "TableId":
"""
Returns a :class:`TableId` combining this schema and the specified
table.
Args:
table: name of the table
"""
return TableId(db=self._db, schema=self._schema, table=table)
[docs] def column_id(self, table: str, column: str) -> "ColumnId":
"""
Returns a :class:`ColumnId` combining this schema and the specified
table/column.
Args:
table: name of the table
column: name of the column
"""
return ColumnId(
db=self._db, schema=self._schema, table=table, column=column
)
@property
def db(self) -> str:
"""
Returns the database part.
"""
return self._db
@property
def schema(self) -> str:
"""
Returns the schema part.
"""
return self._schema
def __str__(self) -> str:
return self.identifier(mysql_grammar) # specific one unimportant
def __repr__(self) -> str:
return mapped_repr_stripping_underscores(self, ["_db", "_schema"])
[docs] def is_present(self) -> bool:
"""
Is this a blank/nonfunctional schema, with no ``database`` or
``schema`` part?
"""
return bool(self._db or self._schema)
[docs] def is_blank(self) -> bool:
"""
Is this a blank/nonfunctional schema, with no ``database`` or
``schema`` part?
"""
return not self.is_present()
[docs]@register_for_json(method=METHOD_STRIP_UNDERSCORE)
@functools.total_ordering
class TableId:
"""
Represents a database table.
"""
[docs] def __init__(
self, db: str = "", schema: str = "", table: str = ""
) -> None:
"""
Args:
db: database name
schema: schema name
table: table name
"""
self._db = db
self._schema = schema
self._table = table
def __bool__(self) -> bool:
return bool(self._table)
def __eq__(self, other: "TableId") -> bool:
return ( # ordering is for speed
self._table == other._table
and self._schema == other._schema
and self._db == other._db
)
def __lt__(self, other: "TableId") -> bool:
return (self._db, self._schema, self._table) < (
other._db,
other._schema,
other._table,
)
def __hash__(self) -> int:
return hash(str(self))
[docs] def identifier(self, grammar: SqlGrammar) -> str:
"""
Returns an SQL identifier for this table using the specified SQL
grammar, quoting it if need be.
Args:
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
"""
return make_identifier(
grammar, database=self._db, schema=self._schema, table=self._table
)
@property
def schema_id(self) -> SchemaId:
"""
Returns a :class:`SchemaId` for the schema of our table.
"""
return SchemaId(db=self._db, schema=self._schema)
[docs] def column_id(self, column: str) -> "ColumnId":
"""
Returns a :class:`ColumnId` combining this table and the specified
column.
Args:
column: name of the column
"""
return ColumnId(
db=self._db, schema=self._schema, table=self._table, column=column
)
[docs] def database_schema_part(self, grammar: SqlGrammar) -> str:
"""
Returns an SQL identifier for this table's database/schema (without the
table part) using the specified SQL grammar, quoting it if need be.
Args:
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
"""
return make_identifier(grammar, database=self._db, schema=self._schema)
[docs] def table_part(self, grammar: SqlGrammar) -> str:
"""
Returns an SQL identifier for this table's table name (only) using the
specified SQL grammar, quoting it if need be.
Args:
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
"""
return make_identifier(grammar, table=self._table)
@property
def db(self) -> str:
"""
Returns the database part.
"""
return self._db
@property
def schema(self) -> str:
"""
Returns the schema part.
"""
return self._schema
@property
def table(self) -> str:
"""
Returns the table part.
"""
return self._table
def __str__(self) -> str:
return self.identifier(mysql_grammar) # specific one unimportant
def __repr__(self) -> str:
return mapped_repr_stripping_underscores(
self, ["_db", "_schema", "_table"]
)
[docs]@register_for_json(method=METHOD_STRIP_UNDERSCORE)
@functools.total_ordering
class ColumnId:
"""
Represents a database column.
"""
[docs] def __init__(
self, db: str = "", schema: str = "", table: str = "", column: str = ""
) -> None:
"""
Args:
db: database name
schema: schema name
table: table name
column: column name
"""
self._db = db
self._schema = schema
self._table = table
self._column = column
def __bool__(self) -> bool:
return bool(self._column)
def __eq__(self, other: "ColumnId") -> bool:
return (
self._column == other._column
and self._table == other._table
and self._schema == other._schema
and self._db == other._db
)
def __lt__(self, other: "ColumnId") -> bool:
return (self._db, self._schema, self._table, self._column) < (
other._db,
other._schema,
other._table,
other._column,
)
@property
def is_valid(self) -> bool:
"""
Do we know about a table and a column, at least?
"""
return bool(self._table and self._column) # the minimum
def identifier(self, grammar: SqlGrammar) -> str:
return make_identifier(
grammar,
database=self._db,
schema=self._schema,
table=self._table,
column=self._column,
)
@property
def db(self) -> str:
"""
Returns the database part.
"""
return self._db
@property
def schema(self) -> str:
"""
Returns the schema part.
"""
return self._schema
@property
def table(self) -> str:
"""
Returns the table part.
"""
return self._table
@property
def column(self) -> str:
"""
Returns the column part.
"""
return self._column
@property
def schema_id(self) -> SchemaId:
"""
Returns a :class:`SchemaId` for the schema of our column.
"""
return SchemaId(db=self._db, schema=self._schema)
@property
def table_id(self) -> TableId:
"""
Returns a :class:`TableId` for our table.
"""
return TableId(db=self._db, schema=self._schema, table=self._table)
@property
def has_table_and_column(self) -> bool:
"""
Do we know about a table and a column?
"""
return bool(self._table and self._column)
def __str__(self) -> str:
return self.identifier(mysql_grammar) # specific one unimportant
def __repr__(self) -> str:
return mapped_repr_stripping_underscores(
self, ["_db", "_schema", "_table", "_column"]
)
# def html(self, grammar: SqlGrammar, bold_column: bool = True) -> str:
# components = [
# html.escape(grammar.quote_identifier_if_required(x))
# for x in [self._db, self._schema, self._table, self._column]
# if x]
# if not components:
# return ''
# if bold_column:
# components[-1] = f"<b>{components[-1]}</b>"
# return ".".join(components)
[docs]def split_db_schema_table(db_schema_table: str) -> TableId:
"""
Converts a simple SQL-style identifier string into a :class:`TableId`.
Args:
db_schema_table:
one of: ``database.schema.table``, ``schema.table``, ``table``
Returns:
a :class:`TableId`
Raises:
:exc:`ValueError` if the input is bad
"""
components = db_schema_table.split(".")
if len(components) == 3: # db.schema.table
d, s, t = components[0], components[1], components[2]
elif len(components) == 2: # schema.table
d, s, t = "", components[0], components[1]
elif len(components) == 1: # table
d, s, t = "", "", components[0]
else:
raise ValueError(f"Bad db_schema_table: {db_schema_table}")
return TableId(db=d, schema=s, table=t)
[docs]def split_db_schema_table_column(db_schema_table_col: str) -> ColumnId:
"""
Converts a simple SQL-style identifier string into a :class:`ColumnId`.
Args:
db_schema_table_col:
one of: ``database.schema.table.column``, ``schema.table.column``,
``table.column``, ``column``
Returns:
a :class:`ColumnId`
Raises:
:exc:`ValueError` if the input is bad
"""
components = db_schema_table_col.split(".")
if len(components) == 4: # db.schema.table.column
d, s, t, c = components[0], components[1], components[2], components[3]
elif len(components) == 3: # schema.table.column
d, s, t, c = "", components[0], components[1], components[2]
elif len(components) == 2: # table.column
d, s, t, c = "", "", components[0], components[1]
elif len(components) == 1: # column
d, s, t, c = "", "", "", components[0]
else:
raise ValueError(f"Bad db_schema_table_col: {db_schema_table_col}")
return ColumnId(db=d, schema=s, table=t, column=c)
[docs]def columns_to_table_column_hierarchy(
columns: List[ColumnId], sort: bool = True
) -> List[Tuple[TableId, List[ColumnId]]]:
"""
Converts a list of column IDs
Args:
columns: list of :class:`ColumnId` objects
sort: sort by table, and column within table?
Returns:
a list of tuples, each ``table, columns``, where ``table`` is a
:class:`TableId` and ``columns`` is a list of :class:`ColumnId`
"""
tables = unique_list(c.table_id for c in columns)
if sort:
tables.sort()
table_column_map = [] # type: List[Tuple[TableId, List[ColumnId]]]
for t in tables:
t_columns = [c for c in columns if c.table_id == t]
if sort:
t_columns.sort()
table_column_map.append((t, t_columns))
return table_column_map
# =============================================================================
# Using SQL grammars (but without reference to Django models, for testing)
# =============================================================================
[docs]def make_identifier(
grammar: SqlGrammar,
database: str = None,
schema: str = None,
table: str = None,
column: str = None,
) -> str:
"""
Makes an SQL identifier by quoting its elements according to the style of
the specific SQL grammar, and then joining them with ``.``.
Args:
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
database: database name
schema: schema name
table: table name
column: column name
Returns:
a string as above in the order "database, schema, table, column", but
omitting any missing parts
"""
elements = [
grammar.quote_identifier_if_required(x)
for x in (database, schema, table, column)
if x
]
assert elements, "make_identifier(): No elements passed!"
return ".".join(elements)
[docs]def dumb_make_identifier(
database: str = None,
schema: str = None,
table: str = None,
column: str = None,
) -> str:
"""
Makes an SQL-style identifier by joining all the parts with ``.``, without
bothering to quote them.
Args:
database: database name
schema: schema name
table: table name
column: column name
Returns:
a string as above in the order "database, schema, table, column", but
omitting any missing parts
"""
elements = filter(None, [database, schema, table, column])
assert elements, "make_identifier(): No elements passed!"
return ".".join(elements)
[docs]def parser_add_result_column(
parsed: ParseResults, column: str, grammar: SqlGrammar
) -> ParseResults:
"""
Takes a parsed SQL statement of the form
.. code-block:: sql
SELECT a, b, c
FROM sometable
WHERE conditions;
and adds a result column, e.g. ``d``, to give
.. code-block:: sql
SELECT a, b, c, d
FROM sometable
WHERE conditions;
Presupposes that there is at least one column already in the SELECT
statement.
Args:
parsed: a `pyparsing.ParseResults` result
column: column name
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
Returns:
a `pyparsing.ParseResults` result
"""
existing_columns = parsed.select_expression.select_columns.asList()
if column not in existing_columns:
# doesn't exist; add it
newcol = grammar.get_result_column().parseString(column, parseAll=True)
parsed.select_expression.extend([",", newcol])
return parsed
[docs]class JoinInfo:
"""
Object to represent a SQL join condition in a simple way.
"""
[docs] def __init__(
self,
table: str,
join_type: str = "INNER JOIN",
join_condition: str = "",
) -> None: # e.g. "ON x = y"
"""
Args:
table: table to be joined in
join_type: join method, e.g. ``"INNER JOIN"``
join_condition: join condition, e.g. ``"ON x = y"``
"""
self.join_type = join_type
self.table = table
self.join_condition = join_condition
[docs]def parser_add_from_tables(
parsed: ParseResults, join_info_list: List[JoinInfo], grammar: SqlGrammar
) -> ParseResults:
"""
Takes a parsed SQL statement of the form
.. code-block:: sql
SELECT a, b, c
FROM sometable
WHERE conditions;
and adds one or more join columns, e.g. ``JoinInfo("othertable", "INNER
JOIN", "ON table.key = othertable.key")``, to give
.. code-block:: sql
SELECT a, b, c
FROM sometable
INNER JOIN othertable ON table.key = othertable.key
WHERE conditions;
Presupposes that there at least one table already in the FROM clause.
Args:
parsed: a `pyparsing.ParseResults` result
join_info_list: list of :class:`JoinInfo` objects
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
Returns:
a `pyparsing.ParseResults` result
"""
# log.critical(parsed.dump())
existing_tables = parsed.join_source.from_tables.asList()
# log.critical(f"existing tables: {existing_tables}")
# log.critical(f"adding table: {table}")
for ji in join_info_list:
if ji.table in existing_tables: # already there
# log.critical("field already present")
continue
parsed_join = grammar.get_join_op().parseString(
ji.join_type, parseAll=True
)[
0
] # e.g. INNER JOIN
parsed_table = grammar.get_table_spec().parseString(
ji.table, parseAll=True
)[0]
extrabits = [parsed_join, parsed_table]
if ji.join_condition: # e.g. ON x = y
extrabits.append(
grammar.get_join_constraint().parseString(
ji.join_condition, parseAll=True
)[0]
)
parsed.join_source.extend(extrabits)
# log.critical(parsed.dump())
return parsed
[docs]def get_first_from_table(
parsed: ParseResults,
match_db: str = "",
match_schema: str = "",
match_table: str = "",
) -> TableId:
"""
Given a set of parsed results from a SELECT statement, returns the ``db,
schema, table`` tuple representing the first table in the FROM clause.
Optionally, the match may be constrained with the ``match*`` parameters.
Args:
parsed: a `pyparsing.ParseResults` result
match_db: optional database name to constrain the result to
match_schema: optional schema name to constrain the result to
match_table: optional table name to constrain the result to
Returns:
a :class:`TableId`, which will be empty in case of failure
"""
existing_tables = parsed.join_source.from_tables.asList()
for t in existing_tables:
if isinstance(t, list):
assert len(t) == 1
t = t[0]
table_id = split_db_schema_table(t)
if match_db and table_id.db != match_db:
continue
if match_schema and table_id.schema != match_schema:
continue
if match_table and table_id.table != match_table:
continue
return table_id
return TableId()
[docs]def set_distinct_within_parsed(p: ParseResults, action: str = "set") -> None:
"""
Modifies (in place) the DISTINCT status of a parsed SQL statement.
Args:
p: a `pyparsing.ParseResults` result
action: ``"set"`` to turn DISTINCT on; ``"clear"`` to turn it off;
or ``"toggle"`` to toggle it.
"""
ss = p.select_specifier # type: ParseResults
if action == "set":
if "DISTINCT" not in ss.asList():
ss.append("DISTINCT")
elif action == "clear":
if "DISTINCT" in ss.asList():
del ss[:]
elif action == "toggle":
if "DISTINCT" in ss.asList():
del ss[:]
else:
ss.append("DISTINCT")
else:
raise ValueError("action must be one of set/clear/toggle")
[docs]def set_distinct(
sql: str,
grammar: SqlGrammar,
action: str = "set",
formatted: bool = True,
debug: bool = False,
debug_verbose: bool = False,
) -> str:
"""
Takes an SQL statement (as a string) and modifies its DISTINCT status.
Args:
sql: SQL statment as text
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
action: one of ``"set"``, ``"clear"``, ``"toggle"``; see
:func:`set_distinct_within_parsed`
formatted: pretty-format the result?
debug: show debugging information to the Python log
debug_verbose: be verbose when debugging
Returns:
the modified SQL statment, as a string
"""
p = grammar.get_select_statement().parseString(sql, parseAll=True)
if debug:
log.info(f"START: {sql}")
if debug_verbose:
log.debug("start dump:\n" + p.dump())
set_distinct_within_parsed(p, action=action)
result = text_from_parsed(p, formatted=formatted)
if debug:
log.info(f"END: {result}")
if debug_verbose:
log.debug("end dump:\n" + p.dump())
return result
[docs]def toggle_distinct(
sql: str,
grammar: SqlGrammar,
formatted: bool = True,
debug: bool = False,
debug_verbose: bool = False,
) -> str:
"""
Takes an SQL statement and toggles its DISTINCT status.
Args:
sql: SQL statment as text
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
formatted: pretty-format the result?
debug: show debugging information to the Python log
debug_verbose: be verbose when debugging
Returns:
the modified SQL statment, as a string
"""
return set_distinct(
sql=sql,
grammar=grammar,
action="toggle",
formatted=formatted,
debug=debug,
debug_verbose=debug_verbose,
)
# =============================================================================
# SQLAlchemy reflection and DDL
# =============================================================================
_global_print_not_execute_sql = False
[docs]def set_print_not_execute(print_not_execute: bool) -> None:
"""
Sets a nasty global flag: should we print DDL, rather than executing it,
when we issue DDL commands from this module?
Args:
print_not_execute: print (not execute)?
"""
global _global_print_not_execute_sql
_global_print_not_execute_sql = print_not_execute
def _exec_ddl(engine: Engine, sql: str) -> None:
"""
Executes SQL as DDL.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
Args:
engine: SQLAlchemy database Engine
sql: raw SQL to execute (or print)
"""
log.debug(sql)
if _global_print_not_execute_sql:
print(format_sql_for_print(sql) + "\n;")
# extra \n in case the SQL ends in a comment
else:
execute_ddl(engine, sql=sql)
[docs]def execute(engine: Engine, sql: str) -> None:
"""
Executes plain SQL in a transaction.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
Args:
engine: SQLAlchemy database Engine
sql: raw SQL to execute (or print)
"""
log.debug(sql)
if _global_print_not_execute_sql:
print(format_sql_for_print(sql) + "\n;")
# extra \n in case the SQL ends in a comment
else:
with engine.begin() as connection:
connection.execute(sql)
[docs]def add_columns(engine: Engine, table: Table, columns: List[Column]) -> None:
"""
Adds columns to a table.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
Args:
engine: SQLAlchemy database Engine
table: SQLAlchemy Table object
columns: SQLAlchemy Column objects to add to the table
Behaviour of different database systems:
- ANSI SQL: add one column at a time: ``ALTER TABLE ADD [COLUMN] coldef``
- i.e. "COLUMN" optional, one at a time, no parentheses
- https://www.contrib.andrew.cmu.edu/~shadow/sql/sql1992.txt
- MySQL: ``ALTER TABLE ADD [COLUMN] (a INT, b VARCHAR(32));``
- i.e. "COLUMN" optional, parentheses required for >1, multiple OK
- https://dev.mysql.com/doc/refman/5.7/en/alter-table.html
- MS SQL Server: ``ALTER TABLE ADD COLUMN a INT, B VARCHAR(32);``
- i.e. no "COLUMN", no parentheses, multiple OK
- https://msdn.microsoft.com/en-us/library/ms190238.aspx
- https://msdn.microsoft.com/en-us/library/ms190273.aspx
- https://stackoverflow.com/questions/2523676
This function therefore operates one at a time.
SQLAlchemy doesn't provide a shortcut for this.
"""
existing_column_names = get_column_names(
engine, tablename=table.name, to_lower=True
)
column_defs = [] # type: List[str]
for column in columns:
if column.name.lower() not in existing_column_names:
column_defs.append(column_creation_ddl(column, engine.dialect))
else:
log.debug(
f"Table {table.name!r}: column {column.name!r} "
f"already exists; not adding"
)
for column_def in column_defs:
log.info(f"Table {table.name!r}: adding column {column_def!r}")
sql = f"ALTER TABLE {table.name} ADD {column_def}"
_exec_ddl(engine, sql)
[docs]def drop_columns(
engine: Engine, table: Table, column_names: Iterable[str]
) -> None:
"""
Drops columns from a table.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
Args:
engine: SQLAlchemy database Engine
table: SQLAlchemy Table object
column_names: names of columns to drop
Columns are dropped one by one.
"""
existing_column_names = get_column_names(
engine, tablename=table.name, to_lower=True
)
for name in column_names:
if name.lower() not in existing_column_names:
log.debug(
f"Table {table.name!r}: column {name!r} "
f"does not exist; not dropping"
)
else:
log.info(f"Table {table.name!r}: dropping column {name!r}")
# SQL Server:
# http://www.techonthenet.com/sql_server/tables/alter_table.php
# MySQL:
# http://dev.mysql.com/doc/refman/5.7/en/alter-table.html
_exec_ddl(engine, f"ALTER TABLE {table.name} DROP COLUMN {name}")
[docs]def add_indexes(
engine: Engine, table: Table, index_info_list: Iterable[IndexCreationInfo]
) -> None:
"""
Adds indexes to a table.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
Args:
engine:
SQLAlchemy database Engine
table:
SQLAlchemy Table object
index_info_list:
Index(es) to create: list of :class:`IndexCreationInfo` objects.
"""
existing_index_names = get_index_names(
engine, tablename=table.name, to_lower=True
)
for i in index_info_list:
index_name = i.index_name
column = i.column_names
if index_name.lower() not in existing_index_names:
log.info(
f"Table {table.name!r}: adding index {index_name!r} on "
f"column {column!r}"
)
_exec_ddl(
engine,
f"""
CREATE{" UNIQUE" if i.unique else ""} INDEX {index_name}
ON {table.name} ({column})
""",
)
else:
log.debug(
f"Table {table.name!r}: index {index_name!r} "
f"already exists; not adding"
)
[docs]def drop_indexes(
engine: Engine, table: Table, index_names: Iterable[str]
) -> None:
"""
Drops indexes from a table.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
Args:
engine: SQLAlchemy database Engine
table: SQLAlchemy Table object
index_names: names of indexes to drop
"""
existing_index_names = get_index_names(
engine, tablename=table.name, to_lower=True
)
for index_name in index_names:
if index_name.lower() not in existing_index_names:
log.debug(
f"Table {table.name!r}: index {index_name!r} "
f"does not exist; not dropping"
)
else:
log.info(f"Table {table.name!r}: dropping index {index_name!r}")
if engine.dialect.name == "mysql":
sql = f"ALTER TABLE {table.name} DROP INDEX {index_name}"
elif engine.dialect.name == "mssql":
sql = f"DROP INDEX {table.name}.{index_name}"
else:
assert False, f"Unknown dialect: {engine.dialect.name}"
_exec_ddl(engine, sql)
[docs]def get_table_names(
engine: Engine, to_lower: bool = False, sort: bool = False
) -> List[str]:
"""
Returns all table names for the database.
Args:
engine: SQLAlchemy database Engine
to_lower: convert table names to lower case?
sort: sort table names?
Returns:
list of table names
"""
inspector = inspect(engine)
table_names = inspector.get_table_names()
if to_lower:
table_names = [x.lower() for x in table_names]
if sort:
table_names = sorted(table_names, key=lambda x: x.lower())
return table_names
[docs]def get_view_names(
engine: Engine, to_lower: bool = False, sort: bool = False
) -> List[str]:
"""
Returns all view names for the database.
Args:
engine: SQLAlchemy database Engine
to_lower: convert view names to lower case?
sort: sort view names?
Returns:
list of view names
"""
inspector = inspect(engine)
view_names = inspector.get_view_names()
if to_lower:
view_names = [x.lower() for x in view_names]
if sort:
view_names = sorted(view_names, key=lambda x: x.lower())
return view_names
[docs]def get_column_names(
engine: Engine, tablename: str, to_lower: bool = False, sort: bool = False
) -> List[str]:
"""
Reads columns names afresh from the database, for a specific table (in case
metadata is out of date).
Args:
engine: SQLAlchemy database Engine
tablename: name of the table
to_lower: convert view names to lower case?
sort: sort view names?
Returns:
list of column names
"""
inspector = inspect(engine)
columns = inspector.get_columns(tablename)
column_names = [x["name"] for x in columns]
if to_lower:
column_names = [x.lower() for x in column_names]
if sort:
column_names = sorted(column_names, key=lambda x: x.lower())
return column_names
[docs]def get_index_names(
engine: Engine, tablename: str, to_lower: bool = False, sort: bool = False
) -> List[str]:
"""
Reads index names from the database, for a specific table.
Args:
engine: SQLAlchemy database Engine
tablename: name of the table
to_lower: convert index names to lower case?
sort: sort index names?
Returns:
list of index names
"""
# http://docs.sqlalchemy.org/en/latest/core/reflection.html
inspector = inspect(engine)
indexes = inspector.get_indexes(tablename)
index_names = [x["name"] for x in indexes if x["name"]]
# ... at least for SQL Server, there always seems to be a blank one
# with {'name': None, ...}.
if to_lower:
index_names = [x.lower() for x in index_names]
if sort:
index_names = sorted(index_names, key=lambda x: x.lower())
return index_names
[docs]def ensure_columns_present(
engine: Engine, tablename: str, column_names: Iterable[str]
) -> None:
"""
Ensure all these columns are present in a table, or raise an exception.
Operates in case-insensitive fashion.
Args:
engine: SQLAlchemy database Engine
tablename: name of the table
column_names: names of required columns
Raises:
:exc:`ValueError` if any are missing
"""
existing_column_names = get_column_names(
engine, tablename=tablename, to_lower=True
)
if not column_names:
return
for col in column_names:
if col.lower() not in existing_column_names:
raise ValueError(
f"Column {col!r} missing from table {tablename!r}, "
f"whose columns are {existing_column_names!r}"
)
[docs]def create_view(engine: Engine, viewname: str, select_sql: str) -> None:
"""
Creates a view.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
Args:
engine: SQLAlchemy database Engine
viewname: view name
select_sql: SQL SELECT statement for this view
"""
if engine.dialect.name == "mysql":
# MySQL has CREATE OR REPLACE VIEW.
sql = f"CREATE OR REPLACE VIEW {viewname} AS {select_sql}"
else:
# SQL Server doesn't: https://stackoverflow.com/questions/18534919
drop_view(engine, viewname, quiet=True)
sql = f"CREATE VIEW {viewname} AS {select_sql}"
log.info(f"Creating view: {viewname!r}")
_exec_ddl(engine, sql)
[docs]def assert_view_has_same_num_rows(
engine: Engine, basetable: str, viewname: str
) -> None:
"""
Ensures that a view gives the same number of rows as a table. (For use in
situations where this should hold; views don't have to do this in general!)
Args:
engine: SQLAlchemy database Engine
basetable: name of the table that this view should have a 1:1
relationship to
viewname: view name
Raises:
:exc:`AssertionError` if they don't have the same number of rows
"""
# Note that this relies on the data, i.e. design failures MAY cause this
# assertion to fail, but won't necessarily (e.g. if the table is empty).
n_base = count_star(engine, basetable)
n_view = count_star(engine, viewname)
assert n_view == n_base, (
f"View bug: view {viewname} has {n_view} records but its base table "
f"{basetable} has {n_base}; they should be equal"
)
[docs]def drop_view(engine: Engine, viewname: str, quiet: bool = False) -> None:
"""
Drops a view.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
Args:
engine: SQLAlchemy database Engine
viewname: view name
quiet: don't announce this to the Python log
"""
# MySQL has DROP VIEW IF EXISTS, but SQL Server only has that from
# SQL Server 2016 onwards.
# - https://msdn.microsoft.com/en-us/library/ms173492.aspx
# - http://dev.mysql.com/doc/refman/5.7/en/drop-view.html
view_names = get_view_names(engine, to_lower=True)
if viewname.lower() not in view_names:
log.debug(f"View {viewname} does not exist; not dropping")
else:
if not quiet:
log.info(f"Dropping view: {viewname!r}")
_exec_ddl(engine, f"DROP VIEW {viewname}")
[docs]def get_column_fk_description(c: Column) -> str:
"""
Standardized description of a column's foreign keys.
Args:
c:
SQLAlchemy Column
"""
fkeys = sorted(
c.foreign_keys, key=lambda x: (x.column.table.name, x.column.name)
)
if not fkeys:
return ""
fk_strings = [f"{fk.column.table.name}.{fk.column.name}" for fk in fkeys]
return "FK to " + ", ".join(fk_strings)
[docs]@dataclass
class ReflectedColumnInfo:
"""
Provides information about a column reflected from a database, with
optional additional information from a CRATE data dictionary, +/- a
description of values in that column (for researcher reports).
"""
column: Column
override_comment: str = None # can override SQLAlchemy-level comment
crate_annotation: str = None
values_info: str = None
@property
def name(self) -> str:
return self.columnname
@property
def columnname(self) -> str:
return self.column.name
# Do not manipulate the case of SOURCE tables/columns.
# If you do, they can fail to match the SQLAlchemy
# introspection and cause a crash.
@property
def tablename(self) -> str:
return self.column.table.name
@property
def tablename_columname(self) -> str:
return f"{self.column.table.name}.{self.column.name}"
@property
def sqla_coltype(self) -> TypeEngine:
return self.column.type
@property
def sql_type(self) -> str:
try:
return str(self.column.type)
except CompileError:
log.critical(f"Column that failed was: {self.column!r}")
raise
@property
def datatype_sqltext(self) -> str:
return self.sql_type
@property
def pk(self) -> bool:
return self.column.primary_key
@property
def nullable(self) -> bool:
return self.column.nullable
@property
def comment(self) -> str:
"""
The database comment, if present, or another that has been supplied.
"""
db_comment = getattr(self.column, "comment", "")
# ... not all dialects support reflecting comments;
# https://docs.sqlalchemy.org/en/14/core/reflection.html
return self.override_comment or db_comment or ""
@property
def nullable_str(self) -> str:
return "✓" if self.nullable else "NOT NULL"
@property
def pk_str(self) -> str:
return "PK" if self.pk else ""
@property
def fk_str(self) -> str:
return get_column_fk_description(self.column)
[docs] def get_column_source_description(self, with_fk: bool = True) -> str:
"""
Returns a description of where the column is from, used as a suffix for
data dictionary comment generation.
Args:
with_fk:
Include foreign key descriptions (helpful because CRATE doesn't
reproduce FK relationships in the destination DDL).
"""
if with_fk:
fk_str = self.fk_str
if fk_str:
fk_str = "; " + fk_str
else:
fk_str = ""
return f" [from {self.tablename_columname}{fk_str}]"
@property
def crate_annotation_str(self) -> str:
"""
Human-oriented version for report.
"""
return self.crate_annotation or "?"
@property
def values_info_str(self) -> str:
"""
Human-oriented version for report.
"""
return self.values_info or "?"
# =============================================================================
# ViewMaker
# =============================================================================
[docs]class ViewMaker:
"""
View-building assistance class.
"""
[docs] def __init__(
self,
viewname: str,
engine: Engine,
basetable: str,
existing_to_lower: bool = False,
rename: Dict[str, str] = None,
userobj: Any = None,
enforce_same_n_rows_as_base: bool = True,
insert_basetable_columns: bool = True,
) -> None:
"""
Args:
viewname: name of the view
engine: SQLAlchemy database Engine
basetable: name of the single base table that this view draws from
existing_to_lower: translate column names to lower case in the
view?
rename: optional dictionary mapping ``from_name: to_name`` to
translate column names in the view
userobj: optional object (e.g. `argparse.Namespace`,
dictionary...), not used by this class, and purely to store
information for others' benefit
enforce_same_n_rows_as_base: ensure that the view produces the
same number of rows as its base table?
insert_basetable_columns: start drafting the view by including all
columns from the base table?
"""
rename = rename or {}
assert basetable, "ViewMaker: basetable missing!"
self.viewname = viewname
self.engine = engine
self.basetable = basetable
self.userobj = userobj # only for others' benefit
self.enforce_same_n_rows_as_base = enforce_same_n_rows_as_base
self.select_elements = [] # type: List[str]
self.from_elements = [basetable] # type: List[str]
self.where_elements = [] # type: List[str]
self.lookup_tables = [] # type: List[str]
self.index_requests = OrderedDict() # type: Dict[str, List[str]]
if insert_basetable_columns:
grammar = make_grammar(engine.dialect.name)
def q(identifier: str) -> str:
return grammar.quote_identifier_if_required(identifier)
for colname in get_column_names(
engine, tablename=basetable, to_lower=existing_to_lower
):
if colname in rename:
rename_to = rename[colname]
if not rename_to:
continue
as_clause = f" AS {q(rename_to)}"
else:
as_clause = ""
self.select_elements.append(
f"{q(basetable)}.{q(colname)}{as_clause}"
)
assert self.select_elements, (
"Must have some active SELECT " "elements from base table"
)
[docs] def add_select(self, element: str) -> None:
"""
Add an element to the SELECT clause of the the draft view's SQL
(meaning: add e.g. a result column).
"""
self.select_elements.append(element)
[docs] def add_from(self, element: str) -> None:
"""
Add an element to the FROM clause of the draft view's SQL statement.
"""
self.from_elements.append(element)
[docs] def add_where(self, element: str) -> None:
"""
Add an element to the WHERE clause of the draft view's SQL statement.
"""
self.where_elements.append(element)
[docs] def get_sql(self) -> str:
"""
Returns the view-creation SQL.
"""
assert self.select_elements, "ViewMaker: no SELECT elements!"
if self.where_elements:
where = "\n WHERE {}".format(
"\n AND ".join(self.where_elements)
)
else:
where = ""
return (
"\n SELECT {select_elements}"
"\n FROM {from_elements}{where}".format(
select_elements=",\n ".join(self.select_elements),
from_elements="\n ".join(self.from_elements),
where=where,
)
)
[docs] def create_view(self, engine: Engine) -> None:
"""
Creates the view.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
If ``enforce_same_n_rows_as_base`` is set, check the number of rows
returned matches the base table.
Args:
engine: SQLAlchemy database Engine
"""
create_view(engine, self.viewname, self.get_sql())
if self.enforce_same_n_rows_as_base:
assert_view_has_same_num_rows(
engine, self.basetable, self.viewname
)
[docs] def drop_view(self, engine: Engine) -> None:
"""
Drops the view.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
Args:
engine: SQLAlchemy database Engine
"""
drop_view(engine, self.viewname)
[docs] def record_lookup_table(self, table: str) -> None:
"""
Keep a record of a lookup table. The framework may wish to suppress
these from a data dictionary later (e.g. create a view, suppress the
messier raw data). See :func:`get_lookup_tables`.
Args:
table: table name
"""
if table not in self.lookup_tables:
self.lookup_tables.append(table)
[docs] def get_lookup_tables(self) -> List[str]:
"""
Returns all lookup tables that we have recorded. See
:func:`record_lookup_table`.
"""
return self.lookup_tables
[docs] def request_index(self, table: str, column: str) -> None:
"""
Note a request that a specific column be indexed. The framework can use
the ViewMaker to keep a note of these requests, and then add index
hints to a data dictionary if it wishes. See
:func:`get_index_request_dict`.
Args:
table: table name
column: column name
"""
if table not in self.index_requests:
self.index_requests[table] = [] # type: List[str]
if column not in self.index_requests[table]:
self.index_requests[table].append(column)
[docs] def get_index_request_dict(self) -> Dict[str, List[str]]:
"""
Returns all our recorded index requests, as a dictionary mapping each
table name to a list of column names to be indexed. See
:func:`request_index`.
"""
return self.index_requests
[docs] def record_lookup_table_keyfield(
self, table: str, keyfield: Union[str, Iterable[str]]
) -> None:
"""
Makes a note that a table is a lookup table, and its key field(s)
should be indexed. See :func:`get_lookup_tables`,
:func:`get_index_request_dict`.
Args:
table: table name
keyfield: field name, or iterable (e.g. list) of them
"""
if isinstance(keyfield, str):
keyfield = [keyfield]
self.record_lookup_table(table)
for kf in keyfield:
self.request_index(table, kf)
[docs] def record_lookup_table_keyfields(
self,
table_keyfield_tuples: Iterable[Tuple[str, Union[str, Iterable[str]]]],
) -> None:
"""
Make a note of a whole set of lookup table / key field groups. See
:func:`record_lookup_table_keyfield`.
Args:
table_keyfield_tuples:
iterable (e.g. list) of tuples of the format ``tablename,
keyfield``. Each will be passed to
:func:`record_lookup_table_keyfield`.
"""
for t, k in table_keyfield_tuples:
self.record_lookup_table_keyfield(t, k)
# =============================================================================
# TransactionSizeLimiter
# =============================================================================
[docs]class TransactionSizeLimiter:
"""
Class to allow us to limit the size of database transactions.
"""
[docs] def __init__(
self,
session: Session,
max_rows_before_commit: int = None,
max_bytes_before_commit: int = None,
) -> None:
"""
Args:
session: SQLAlchemy database Session
max_rows_before_commit: how many rows should we insert before
triggering a COMMIT? ``None`` for no limit.
max_bytes_before_commit: how many bytes should we insert before
triggering a COMMIT? ``None`` for no limit.
"""
self._session = session
self._max_rows_before_commit = max_rows_before_commit
self._max_bytes_before_commit = max_bytes_before_commit
self._bytes_in_transaction = 0
self._rows_in_transaction = 0
[docs] def commit(self) -> None:
"""
Performs a database COMMIT and resets our counters.
(Measures some timing information, too.)
"""
with MultiTimerContext(timer, TIMING_COMMIT):
self._session.commit()
self._bytes_in_transaction = 0
self._rows_in_transaction = 0
[docs] def notify(
self, n_rows: int, n_bytes: int, force_commit: bool = False
) -> None:
"""
Use this function to notify the limiter of data that you've inserted
into the database. If the total number of rows or bytes exceeds a limit
that we've set, this will trigger a COMMIT.
Args:
n_rows: number of rows inserted
n_bytes: number of bytes inserted
force_commit: force a COMMIT?
"""
if force_commit:
self.commit()
return
self._bytes_in_transaction += n_bytes
self._rows_in_transaction += n_rows
if (
self._max_bytes_before_commit is not None
and self._bytes_in_transaction >= self._max_bytes_before_commit
):
log.debug(
f"Triggering early commit based on byte count "
f"(reached {sizeof_fmt(self._bytes_in_transaction)}, "
f"limit is {sizeof_fmt(self._max_bytes_before_commit)})"
)
self.commit()
elif (
self._max_rows_before_commit is not None
and self._rows_in_transaction >= self._max_rows_before_commit
):
log.debug(
f"Triggering early commit based on row count "
f"(reached {self._rows_in_transaction} rows, "
f"limit is {self._max_rows_before_commit})"
)
self.commit()
# =============================================================================
# Specification matching
# =============================================================================
def _matches_tabledef(table: str, tabledef: str) -> bool:
"""
Does the table name match the wildcard-based table definition?
Args:
table: tablename
tabledef: ``fnmatch``-style pattern (e.g.
``"patient_address_table_*"``)
"""
tr = get_spec_match_regex(tabledef)
return bool(tr.match(table))
[docs]def matches_tabledef(table: str, tabledef: Union[str, List[str]]) -> bool:
"""
Does the table name match the wildcard-based table definition?
Args:
table: table name
tabledef: ``fnmatch``-style pattern (e.g.
``"patient_address_table_*"``), or list of them
"""
if isinstance(tabledef, str):
return _matches_tabledef(table, tabledef)
elif not tabledef:
return False
else: # list
return any(_matches_tabledef(table, td) for td in tabledef)
def _matches_fielddef(table: str, field: str, fielddef: str) -> bool:
"""
Does the table/field name match the wildcard-based field definition?
Args:
table: tablename
field: fieldname
fielddef: ``fnmatch``-style pattern (e.g. ``"system_table.*"``,
``"*.nhs_number"``)
"""
column_id = split_db_schema_table_column(fielddef)
cr = get_spec_match_regex(column_id.column)
if not column_id.table:
# Table not specified in the wildcard.
# It's a match if the field matches.
return bool(cr.match(field))
# Table specified in the wildcard.
# Both the table and the field parts have to match.
tr = get_spec_match_regex(column_id.table)
return bool(tr.match(table)) and bool(cr.match(field))
[docs]def matches_fielddef(
table: str, field: str, fielddef: Union[str, List[str]]
) -> bool:
"""
Does the table/field name match the wildcard-based field definition?
Args:
table: table name
field: fieldname
fielddef: ``fnmatch``-style pattern (e.g. ``"system_table.*"`` or
``"*.nhs_number"``), or list of them
"""
if isinstance(fielddef, str):
return _matches_fielddef(table, field, fielddef)
elif not fielddef:
return False
else: # list
return any(_matches_fielddef(table, field, fd) for fd in fielddef)
# =============================================================================
# More SQL
# =============================================================================
[docs]def sql_fragment_cast_to_int(
expr: str,
big: bool = True,
dialect: Dialect = None,
viewmaker: ViewMaker = None,
) -> str:
"""
Takes an SQL expression and coerces it to an integer. For Microsoft SQL
Server.
Args:
expr: starting SQL expression
big: use BIGINT, not INTEGER?
dialect: optional :class:`sqlalchemy.engine.interfaces.Dialect`. If
``None`` and we have a ``viewmaker``, use the viewmaker's dialect.
Otherwise, assume SQL Server.
viewmaker: optional :class:`ViewMaker`
Returns:
modified SQL expression
*Notes*
Conversion to INT:
- https://stackoverflow.com/questions/2000045
- https://stackoverflow.com/questions/14719760 (this one in particular!)
- https://stackoverflow.com/questions/14692131
- see LIKE example.
- see ISNUMERIC();
https://msdn.microsoft.com/en-us/library/ms186272.aspx;
but that includes non-integer numerics
- https://msdn.microsoft.com/en-us/library/ms174214(v=sql.120).aspx;
relates to the SQL Server Management Studio "Find and Replace"
dialogue box, not to SQL itself!
- https://stackoverflow.com/questions/29206404/mssql-regular-expression
Note that the regex-like expression supported by LIKE is extremely limited.
- https://msdn.microsoft.com/en-us/library/ms179859.aspx
- The only things supported are:
.. code-block:: none
% any characters
_ any single character
[] single character in range or set, e.g. [a-f], [abcdef]
[^] single character NOT in range or set, e.g. [^a-f], [abcdef]
SQL Server does not support a REGEXP command directly.
So the best bet is to have the LIKE clause check for a non-integer:
.. code-block:: sql
CASE
WHEN something LIKE '%[^0-9]%' THEN NULL
ELSE CAST(something AS BIGINT)
END
... which doesn't deal with spaces properly, but there you go.
Could also strip whitespace left/right:
.. code-block:: sql
CASE
WHEN LTRIM(RTRIM(something)) LIKE '%[^0-9]%' THEN NULL
ELSE CAST(something AS BIGINT)
END
That only works for positive integers.
LTRIM/RTRIM are not ANSI SQL.
Nor are unusual LIKE clauses; see
https://stackoverflow.com/questions/712580/list-of-special-characters-for-sql-like-clause
The other, for SQL Server 2012 or higher, is TRY_CAST:
.. code-block:: sql
TRY_CAST(something AS BIGINT)
... which returns NULL upon failure; see
https://msdn.microsoft.com/en-us/library/hh974669.aspx
Therefore, our **method** is as follows:
- If the database supports TRY_CAST, use that.
- Otherwise if we're using SQL Server, use a CASE/CAST construct.
- Otherwise, raise :exc:`ValueError` as we don't know what to do.
"""
inttype = "BIGINT" if big else "INTEGER"
if dialect is None and viewmaker is not None:
dialect = viewmaker.engine.dialect
if dialect is None:
sql_server = True
supports_try_cast = False
else:
# noinspection PyUnresolvedReferences
sql_server = dialect.name == "mssql"
# noinspection PyUnresolvedReferences
supports_try_cast = (
sql_server and dialect.server_version_info >= MS_2012_VERSION
)
if supports_try_cast:
return f"TRY_CAST({expr} AS {inttype})"
elif sql_server:
return (
f"CASE WHEN LTRIM(RTRIM({expr})) LIKE '%[^0-9]%' "
f"THEN NULL ELSE CAST({expr} AS {inttype}) END"
)
# Doesn't support negative integers.
else:
# noinspection PyUnresolvedReferences
raise ValueError(
f"Code not yet written for convert-to-int for "
f"dialect {dialect.name}"
)
# =============================================================================
# Abstracted SQL WHERE condition
# =============================================================================
[docs]@register_for_json(method=METHOD_PROVIDES_INIT_KWARGS)
@functools.total_ordering
class WhereCondition:
"""
Ancillary class for building SQL WHERE expressions from our web forms.
The essence of it is ``WHERE column op value_or_values``.
"""
[docs] def __init__(
self,
column_id: ColumnId = None,
op: str = "",
datatype: str = "",
value_or_values: Any = None,
raw_sql: str = "",
from_table_for_raw_sql: TableId = None,
) -> None:
"""
Args:
column_id:
:class:`ColumnId` for the column
op:
operation (e.g. ``=``, ``<``, ``<=``, etc.)
datatype:
data type string that must match values in our
``querybuilder.js``; see source code. We use this to know how
to build SQL literal values. (Not terribly elegant, but it
works; SQL injection isn't a particular concern because we
let our users run any SQL they want and ensure the connection
is made read-only.)
value_or_values:
``None``, single value, or list of values. Which is appropriate
depends on the operation. For example, ``IS NULL`` takes no
value; ``=`` takes one; ``IN`` takes many.
raw_sql:
override any thinking we might wish to do, and just return this
raw SQL
from_table_for_raw_sql:
if we are using raw SQL, provide a :class:`TableId` for the
relevant table here
"""
self._column_id = column_id
self._op = op.upper()
self._datatype = datatype
self._value = value_or_values
self._no_value = False
self._multivalue = False
self._raw_sql = raw_sql
self._from_table_for_raw_sql = from_table_for_raw_sql
if not self._raw_sql:
if self._op in SQL_OPS_VALUE_UNNECESSARY:
self._no_value = True
assert value_or_values is None, "Superfluous value passed"
elif self._op in SQL_OPS_MULTIPLE_VALUES:
self._multivalue = True
assert isinstance(value_or_values, list), "Need list"
else:
assert not isinstance(
value_or_values, list
), "Need single value"
def init_kwargs(self) -> Dict:
return {
"column_id": self._column_id,
"op": self._op,
"datatype": self._datatype,
"value_or_values": self._value,
"raw_sql": self._raw_sql,
"from_table_for_raw_sql": self._from_table_for_raw_sql,
}
def __repr__(self) -> str:
return (
"{qualname}("
"column_id={column_id}, "
"op={op}, "
"datatype={datatype}, "
"value_or_values={value_or_values}, "
"raw_sql={raw_sql}, "
"from_table_for_raw_sql={from_table_for_raw_sql}"
")".format(
qualname=self.__class__.__qualname__,
column_id=repr(self._column_id),
op=repr(self._op),
datatype=repr(self._datatype),
value_or_values=repr(self._value),
raw_sql=repr(self._raw_sql),
from_table_for_raw_sql=repr(self._from_table_for_raw_sql),
)
)
def __eq__(self, other: "WhereCondition") -> bool:
return (
self._raw_sql == other._raw_sql
and self._column_id == other._column_id
and self._op == other._op
and self._value == other._value
)
def __lt__(self, other: "WhereCondition") -> bool:
return (self._raw_sql, self._column_id, self._op, self._value) < (
other._raw_sql,
other._column_id,
other._op,
other._value,
)
@property
def column_id(self) -> ColumnId:
"""
Returns the :class:`ColumnId` provided at creation.
"""
return self._column_id
@property
def table_id(self) -> TableId:
"""
Returns a :class:`TableId`:
- for raw SQL, our ``from_table_for_raw_sql`` attribute
- otherwise, the table ID extracted from our ``column_id`` attribute
"""
if self._raw_sql:
return self._from_table_for_raw_sql
return self.column_id.table_id
[docs] def table_str(self, grammar: SqlGrammar) -> str:
"""
Returns the table identifier in the specified SQL grammar.
Args:
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
"""
return self.table_id.identifier(grammar)
[docs] def sql(self, grammar: SqlGrammar) -> str:
"""
Returns the WHERE clause (without ``WHERE`` itself!) for our condition,
in the specified SQL grammar. Some examples might be:
- ``somecol = 3``
- ``othercol IN (6, 7, 8)``
- ``thirdcol IS NOT NULL``
- ``textcol LIKE '%paracetamol%'``
- ``MATCH (fulltextcol AGAINST 'paracetamol')`` (MySQL)
- ``CONTAINS(fulltextcol, 'paracetamol')`` (SQL Server)
Args:
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
"""
if self._raw_sql:
return self._raw_sql
col = self._column_id.identifier(grammar)
op = self._op
if self._no_value:
return f"{col} {op}"
if self._datatype in QB_STRING_TYPES:
element_converter = sql_string_literal
elif self._datatype == QB_DATATYPE_DATE:
element_converter = sql_date_literal
elif self._datatype == QB_DATATYPE_INTEGER:
element_converter = str
elif self._datatype == QB_DATATYPE_FLOAT:
element_converter = str
else:
# Safe default
element_converter = sql_string_literal
if self._multivalue:
literal = "({})".format(
", ".join(element_converter(v) for v in self._value)
)
else:
literal = element_converter(self._value)
if self._op == "MATCH": # MySQL
return f"MATCH ({col}) AGAINST ({literal})"
elif self._op == "CONTAINS": # SQL Server
return f"CONTAINS({col}, {literal})"
else:
return f"{col} {op} {literal}"
# =============================================================================
# SQL formatting
# =============================================================================
# =============================================================================
# Plain SQL types
# =============================================================================
[docs]def is_sql_column_type_textual(column_type: str, min_length: int = 1) -> bool:
"""
Does an SQL column type look textual?
Args:
column_type: SQL column type as a string, e.g. ``"VARCHAR(50)"``
min_length: what's the minimum string length we'll say "yes" to?
Returns:
is it a textual column (of the minimum length or more)?
Note:
- For SQL Server's NVARCHAR(MAX),
:meth:`crate_anon.crateweb.research.research_db_info._schema_query_microsoft`
returns "NVARCHAR(-1)"
"""
if not column_type:
return False
column_type = column_type.upper().split()[0]
if column_type in SQLTYPES_TEXT:
# A text type without a specific length
return True
try:
m = COLTYPE_WITH_ONE_INTEGER_REGEX.match(column_type)
basetype = m.group(1)
length = int(m.group(2))
except (AttributeError, ValueError):
return False
return (length >= min_length or length < 0) and basetype in SQLTYPES_TEXT
[docs]def coltype_length_if_text(column_type: str, dialect: str) -> Optional[int]:
"""
Find the length of an sql text column type.
Args:
column_type: SQL column type as a string, e.g. ``"VARCHAR(50)"``
dialect: the SQL dialect the column type is from
Returns:
length of the column or ``None`` if it's not a text column.
"""
column_type = column_type.upper()
if column_type in SQLTYPES_TEXT:
# No length specified - get the default
try:
lookup = DIALECT_TO_STRING_LEN_LOOKUP[dialect]
except KeyError:
possible = list(DIALECT_TO_STRING_LEN_LOOKUP.keys())
raise ValueError(
f"CRATE doesn't properly understand SQL dialect {dialect!r}. "
f"Supported: {possible}"
)
try:
return lookup[column_type]
except KeyError:
raise ValueError(
f"For SQL dialect {dialect!r}, CRATE doesn't know the length "
f"for string data type {column_type!r}"
)
else:
# Length specified - get it from the column type
try:
m = COLTYPE_WITH_ONE_INTEGER_REGEX.match(column_type)
basetype = m.group(1)
length = m.group(2)
if length == "MAX" or length == "-1":
if dialect == SqlaDialectName.MSSQL:
if basetype == "VARCHAR":
return MSSQL_COLTYPE_TO_LEN["VARCHAR_MAX"]
elif basetype == "NVARCHAR":
return MSSQL_COLTYPE_TO_LEN["NVARCHAR_MAX"]
return None
except AttributeError:
# Not the correct type of column
return None
try:
return int(length)
except ValueError:
# Not the correct type of column
return None
[docs]def escape_quote_in_literal(s: str) -> str:
r"""
Escape ``'``. We could use ``''`` or ``\'``.
Let's use ``\.`` for consistency with percent escaping.
"""
return s.replace("'", r"\'")
[docs]def escape_percent_in_literal(sql: str) -> str:
r"""
Escapes ``%`` by converting it to ``\%``.
Use this for LIKE clauses.
- https://dev.mysql.com/doc/refman/5.7/en/string-literals.html
"""
return sql.replace("%", r"\%")
[docs]def escape_percent_for_python_dbapi(sql: str) -> str:
"""
Escapes ``%`` by converting it to ``%%``.
Use this for SQL within Python where ``%`` characters are used for argument
placeholders.
"""
return sql.replace("%", "%%")
[docs]def escape_sql_string_literal(s: str) -> str:
"""
Escapes SQL string literal fragments against quotes and parameter
substitution.
"""
return escape_percent_in_literal(escape_quote_in_literal(s))
[docs]def make_string_literal(s: str) -> str:
"""
Converts a Python string into an SQL single-quoted (and escaped) string
literal.
"""
return f"'{escape_sql_string_literal(s)}'"
[docs]def escape_sql_string_or_int_literal(s: Union[str, int]) -> str:
"""
Converts an integer or a string into an SQL literal (with single quotes and
escaping in the case of a string).
"""
if isinstance(s, int):
return str(s)
else:
return make_string_literal(s)
[docs]def translate_sql_qmark_to_percent(sql: str) -> str:
"""
This function translates SQL using ``?`` placeholders to SQL using ``%s``
placeholders, without breaking literal ``'?'`` or ``'%'``, e.g. inside
string literals.
*Notes*
- MySQL likes ``?`` as a placeholder.
- https://dev.mysql.com/doc/refman/5.7/en/sql-syntax-prepared-statements.html
- Python DBAPI allows several: ``%s``, ``?``, ``:1``, ``:name``,
``%(name)s``.
- https://www.python.org/dev/peps/pep-0249/#paramstyle
- Django uses ``%s``.
- https://docs.djangoproject.com/en/1.8/topics/db/sql/
- Microsoft like ``?``, ``@paramname``, and ``:paramname``.
- https://msdn.microsoft.com/en-us/library/yy6y35y8(v=vs.110).aspx
- We need to parse SQL with argument placeholders.
- See :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` classes,
particularly: ``bind_parameter``
I prefer ``?``, because ``%`` is used in LIKE clauses, and the databases
we're using like it.
So:
- We use ``%s`` when using ``cursor.execute()`` directly, via Django.
- We use ``?`` when talking to users, and
:class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` objects, so that
the visual appearance matches what they expect from their database.
""" # noqa: E501
# 1. Escape % characters
sql = escape_percent_for_python_dbapi(sql)
# 2. Replace ? characters that are not within quotes with %s.
newsql = ""
in_quotes = False
for c in sql:
if c == "'":
in_quotes = not in_quotes
if c == "?" and not in_quotes:
newsql += "%s"
else:
newsql += c
return newsql
[docs]def decorate_index_name(
idxname: str, tablename: str = None, engine: Engine = None
) -> str:
"""
Amend the name of a database index. Specifically, this is because SQLite
(which we won't use much, but do use for testing!) won't accept two indexes
with the same names applying to different tables.
Args:
idxname:
The original index name.
tablename:
The name of the table.
engine:
The SQLAlchemy engine, from which we obtain the dialect.
Returns:
The index name, amended if necessary.
"""
if not tablename or not engine:
return idxname
if engine.dialect.name == "sqlite":
return f"{idxname}_{tablename}"
return idxname