"""
crate_anon/common/sql.py
===============================================================================
Copyright (C) 2015, University of Cambridge, Department of Psychiatry.
Created by Rudolf Cardinal (rnc1001@cam.ac.uk).
This file is part of CRATE.
CRATE is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
CRATE is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with CRATE. If not, see <https://www.gnu.org/licenses/>.
===============================================================================
**Low-level SQL manipulation functions.**
"""
from collections import OrderedDict
from dataclasses import dataclass
import functools
import logging
import re
from typing import Any, Dict, Iterable, List, Tuple, Union, Optional
from cardinal_pythonlib.json.serialize import (
METHOD_PROVIDES_INIT_KWARGS,
METHOD_STRIP_UNDERSCORE,
register_for_json,
)
from cardinal_pythonlib.lists import unique_list
from cardinal_pythonlib.reprfunc import mapped_repr_stripping_underscores
from cardinal_pythonlib.sizeformatter import sizeof_fmt
from cardinal_pythonlib.sql.literals import (
sql_date_literal,
sql_string_literal,
)
from cardinal_pythonlib.sql.sql_grammar import SqlGrammar, text_from_parsed
from cardinal_pythonlib.sql.sql_grammar_factory import (
make_grammar,
mysql_grammar,
)
from cardinal_pythonlib.sqlalchemy.core_query import count_star
from cardinal_pythonlib.sqlalchemy.dialect import SqlaDialectName
from cardinal_pythonlib.sqlalchemy.schema import column_creation_ddl
from cardinal_pythonlib.timing import MultiTimerContext, timer
from pyparsing import ParseResults
from sqlalchemy import inspect
from sqlalchemy.dialects.mssql.base import MS_2012_VERSION
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.exc import CompileError
from sqlalchemy.orm.session import Session
from sqlalchemy.schema import Column, Table
from sqlalchemy.sql.sqltypes import TypeEngine
from crate_anon.common.stringfunc import get_spec_match_regex
log = logging.getLogger(__name__)
# =============================================================================
# Types
# =============================================================================
SqlArgsTupleType = Tuple[str, List[Any]]
# =============================================================================
# Constants
# =============================================================================
TIMING_COMMIT = "commit"
SQL_OPS_VALUE_UNNECESSARY = ["IS NULL", "IS NOT NULL"]
SQL_OPS_MULTIPLE_VALUES = ["IN", "NOT IN"]
SQLTYPE_DATE = "DATE"
SQLTYPES_INTEGER = [
"INT",
"INTEGER",
"TINYINT",
"SMALLINT",
"MEDIUMINT",
"BIGINT",
"BIT",
"BOOL",
"BOOLEAN",
]
SQLTYPES_FLOAT = [
"DOUBLE",
"FLOAT",
"DEC",
"DECIMAL",
]
SQLTYPES_TEXT = [
"CHAR",
"VARCHAR",
"NVARCHAR",
"TINYTEXT",
"TEXT",
"NTEXT",
"MEDIUMTEXT",
"LONGTEXT",
]
SQLTYPES_WITH_DATE = [
SQLTYPE_DATE,
"DATETIME",
"TIMESTAMP",
]
# SQLTYPES_BINARY = [
# "BINARY", "BLOB", "IMAGE", "LONGBLOB", "VARBINARY",
# ]
# Must match querybuilder.js:
QB_DATATYPE_INTEGER = "int"
QB_DATATYPE_FLOAT = "float"
QB_DATATYPE_DATE = "date"
QB_DATATYPE_STRING = "string"
QB_DATATYPE_STRING_FULLTEXT = "string_fulltext"
QB_DATATYPE_UNKNOWN = "unknown"
QB_STRING_TYPES = [QB_DATATYPE_STRING, QB_DATATYPE_STRING_FULLTEXT]
COLTYPE_WITH_ONE_INTEGER_REGEX = re.compile(r"^([A-z]+)\((-?\d+)\)$")
# ... start, group(alphabetical), literal (, group(optional_minus_sign digits),
# literal ), end
# Dictionaries for the different dialects mapping text column type to length
# or default length.
# Doesn't include things like VARCHAR which require the user to specify length
MYSQL_COLTYPE_TO_LEN = {
# https://dev.mysql.com/doc/refman/8.0/en/string-type-overview.html
"CHAR": 1, # can specify CHAR(0) to CHAR(255), but if omitted, length is 1
"TINYTEXT": 255, # 2^8 - 1
"TEXT": 65535, # 2^16 - 1
"MEDIUMTEXT": 16777215, # 2^24 - 1
"LONGTEXT": 4294967295, # 2^32 - 1
}
MSSQL_COLTYPE_TO_LEN = {
# The "N" prefix means Unicode.
# https://docs.microsoft.com/en-us/sql/t-sql/data-types/char-and-varchar-transact-sql?view=sql-server-ver15 # noqa
# https://docs.microsoft.com/en-us/sql/t-sql/data-types/nchar-and-nvarchar-transact-sql?view=sql-server-ver15 # noqa
# https://docs.microsoft.com/en-us/sql/t-sql/data-types/ntext-text-and-image-transact-sql?view=sql-server-ver15 # noqa
"NVARCHAR_MAX": 2**30 - 1,
# Can specify NVARCHAR(1) to NVARCHAR(4000), or NVARCHAR(MAX) for 2^30 - 1.
"VARCHAR_MAX": 2**31 - 1,
# Can specify VARCHAR(1) to VARCHAR(8000), or VARCHAR(MAX) for 2^31 - 1.
"TEXT": 2**31 - 1,
"NTEXT": 2**30 - 1,
}
# def combine_db_schema_table(db: Optional[str],
# schema: Optional[str],
# table: str) -> str:
# # ANSI SQL: http://www.contrib.andrew.cmu.edu/~shadow/sql/sql1992.txt
# # <table name>, <qualified name>
# if not table:
# raise ValueError("Missing table supplied to combine_db_schema_table")
# return ".".join(x for x in [db, schema, table] if x)
# =============================================================================
# Helper classes
# =============================================================================
[docs]@dataclass
class IndexCreationInfo:
index_name: str #: Name of the index
column: Union[str, List[str]] #: Column name(s) to index
unique: bool = False #: Make a unique index?
@property
def column_names(self) -> str:
if isinstance(self.column, str):
# Single column
return self.column
else:
# Multiple columns
return ", ".join(self.column)
# =============================================================================
# SQL elements: identifiers
# =============================================================================
[docs]@register_for_json(method=METHOD_STRIP_UNDERSCORE)
@functools.total_ordering
class SchemaId:
"""
Represents a database schema. This is a bit complex:
- In SQL Server, schemas live within databases. Tables can be referred to
as ``table``, ``schema.table``, or ``database.schema.table``.
- https://docs.microsoft.com/en-us/dotnet/framework/data/adonet/sql/ownership-and-user-schema-separation-in-sql-server
- The default schema is named ``dbo``.
- In PostgreSQL, schemas live within databases. Tables can be referred to
as ``table``, ``schema.table``, or ``database.schema.table``.
- https://www.postgresql.org/docs/current/static/ddl-schemas.html
- The default schema is named ``public``.
- In MySQL, "database" and "schema" are synonymous. Tables can be referred
to as ``table`` or ``database.table`` (= ``schema.table``).
- https://stackoverflow.com/questions/11618277/difference-between-schema-database-in-mysql
""" # noqa
[docs] def __init__(self, db: str = "", schema: str = "") -> None:
"""
Args:
db: database name
schema: schema name
"""
assert "." not in db, f"Bad database name ({db!r}); can't include '.'"
assert (
"." not in schema
), f"Bad schema name ({schema!r}); can't include '.'"
self._db = db
self._schema = schema
@property
def schema_tag(self) -> str:
"""
String suitable for encoding the SchemaId e.g. in a single HTML form.
Takes the format ``database.schema``.
The :func:`__init__` function has already checked the assumption of no
``'.'`` characters in either part.
"""
return f"{self._db}.{self._schema}"
[docs] @classmethod
def from_schema_tag(cls, tag: str) -> "SchemaId":
"""
Returns a :class:`SchemaId` from a tag of the form ``db.schema``.
"""
parts = tag.split(".")
assert len(parts) == 2, f"Bad schema tag {tag!r}"
db, schema = parts
return SchemaId(db, schema)
def __bool__(self) -> bool:
"""
Returns:
is there a named schema?
"""
return bool(self._schema)
def __eq__(self, other: "SchemaId") -> bool:
return ( # ordering is for speed
self._schema == other._schema and self._db == other._db
)
def __lt__(self, other: "SchemaId") -> bool:
return (self._db, self._schema) < (other._db, other._schema)
def __hash__(self) -> int:
return hash(str(self))
[docs] def identifier(self, grammar: SqlGrammar) -> str:
"""
Returns an SQL identifier for this schema using the specified SQL
grammar, quoting it if need be.
Args:
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
"""
return make_identifier(grammar, database=self._db, schema=self._schema)
[docs] def table_id(self, table: str) -> "TableId":
"""
Returns a :class:`TableId` combining this schema and the specified
table.
Args:
table: name of the table
"""
return TableId(db=self._db, schema=self._schema, table=table)
[docs] def column_id(self, table: str, column: str) -> "ColumnId":
"""
Returns a :class:`ColumnId` combining this schema and the specified
table/column.
Args:
table: name of the table
column: name of the column
"""
return ColumnId(
db=self._db, schema=self._schema, table=table, column=column
)
@property
def db(self) -> str:
"""
Returns the database part.
"""
return self._db
@property
def schema(self) -> str:
"""
Returns the schema part.
"""
return self._schema
def __str__(self) -> str:
return self.identifier(mysql_grammar) # specific one unimportant
def __repr__(self) -> str:
return mapped_repr_stripping_underscores(self, ["_db", "_schema"])
[docs] def is_present(self) -> bool:
"""
Is this a blank/nonfunctional schema, with no ``database`` or
``schema`` part?
"""
return bool(self._db or self._schema)
[docs] def is_blank(self) -> bool:
"""
Is this a blank/nonfunctional schema, with no ``database`` or
``schema`` part?
"""
return not self.is_present()
[docs]@register_for_json(method=METHOD_STRIP_UNDERSCORE)
@functools.total_ordering
class TableId:
"""
Represents a database table.
"""
[docs] def __init__(
self, db: str = "", schema: str = "", table: str = ""
) -> None:
"""
Args:
db: database name
schema: schema name
table: table name
"""
self._db = db
self._schema = schema
self._table = table
def __bool__(self) -> bool:
return bool(self._table)
def __eq__(self, other: "TableId") -> bool:
return ( # ordering is for speed
self._table == other._table
and self._schema == other._schema
and self._db == other._db
)
def __lt__(self, other: "TableId") -> bool:
return (self._db, self._schema, self._table) < (
other._db,
other._schema,
other._table,
)
def __hash__(self) -> int:
return hash(str(self))
[docs] def identifier(self, grammar: SqlGrammar) -> str:
"""
Returns an SQL identifier for this table using the specified SQL
grammar, quoting it if need be.
Args:
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
"""
return make_identifier(
grammar, database=self._db, schema=self._schema, table=self._table
)
@property
def schema_id(self) -> SchemaId:
"""
Returns a :class:`SchemaId` for the schema of our table.
"""
return SchemaId(db=self._db, schema=self._schema)
[docs] def column_id(self, column: str) -> "ColumnId":
"""
Returns a :class:`ColumnId` combining this table and the specified
column.
Args:
column: name of the column
"""
return ColumnId(
db=self._db, schema=self._schema, table=self._table, column=column
)
[docs] def database_schema_part(self, grammar: SqlGrammar) -> str:
"""
Returns an SQL identifier for this table's database/schema (without the
table part) using the specified SQL grammar, quoting it if need be.
Args:
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
"""
return make_identifier(grammar, database=self._db, schema=self._schema)
[docs] def table_part(self, grammar: SqlGrammar) -> str:
"""
Returns an SQL identifier for this table's table name (only) using the
specified SQL grammar, quoting it if need be.
Args:
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
"""
return make_identifier(grammar, table=self._table)
@property
def db(self) -> str:
"""
Returns the database part.
"""
return self._db
@property
def schema(self) -> str:
"""
Returns the schema part.
"""
return self._schema
@property
def table(self) -> str:
"""
Returns the table part.
"""
return self._table
def __str__(self) -> str:
return self.identifier(mysql_grammar) # specific one unimportant
def __repr__(self) -> str:
return mapped_repr_stripping_underscores(
self, ["_db", "_schema", "_table"]
)
[docs]@register_for_json(method=METHOD_STRIP_UNDERSCORE)
@functools.total_ordering
class ColumnId:
"""
Represents a database column.
"""
[docs] def __init__(
self, db: str = "", schema: str = "", table: str = "", column: str = ""
) -> None:
"""
Args:
db: database name
schema: schema name
table: table name
column: column name
"""
self._db = db
self._schema = schema
self._table = table
self._column = column
def __bool__(self) -> bool:
return bool(self._column)
def __eq__(self, other: "ColumnId") -> bool:
return (
self._column == other._column
and self._table == other._table
and self._schema == other._schema
and self._db == other._db
)
def __lt__(self, other: "ColumnId") -> bool:
return (self._db, self._schema, self._table, self._column) < (
other._db,
other._schema,
other._table,
other._column,
)
@property
def is_valid(self) -> bool:
"""
Do we know about a table and a column, at least?
"""
return bool(self._table and self._column) # the minimum
def identifier(self, grammar: SqlGrammar) -> str:
return make_identifier(
grammar,
database=self._db,
schema=self._schema,
table=self._table,
column=self._column,
)
@property
def db(self) -> str:
"""
Returns the database part.
"""
return self._db
@property
def schema(self) -> str:
"""
Returns the schema part.
"""
return self._schema
@property
def table(self) -> str:
"""
Returns the table part.
"""
return self._table
@property
def column(self) -> str:
"""
Returns the column part.
"""
return self._column
@property
def schema_id(self) -> SchemaId:
"""
Returns a :class:`SchemaId` for the schema of our column.
"""
return SchemaId(db=self._db, schema=self._schema)
@property
def table_id(self) -> TableId:
"""
Returns a :class:`TableId` for our table.
"""
return TableId(db=self._db, schema=self._schema, table=self._table)
@property
def has_table_and_column(self) -> bool:
"""
Do we know about a table and a column?
"""
return bool(self._table and self._column)
def __str__(self) -> str:
return self.identifier(mysql_grammar) # specific one unimportant
def __repr__(self) -> str:
return mapped_repr_stripping_underscores(
self, ["_db", "_schema", "_table", "_column"]
)
# def html(self, grammar: SqlGrammar, bold_column: bool = True) -> str:
# components = [
# html.escape(grammar.quote_identifier_if_required(x))
# for x in [self._db, self._schema, self._table, self._column]
# if x]
# if not components:
# return ''
# if bold_column:
# components[-1] = f"<b>{components[-1]}</b>"
# return ".".join(components)
[docs]def split_db_schema_table(db_schema_table: str) -> TableId:
"""
Converts a simple SQL-style identifier string into a :class:`TableId`.
Args:
db_schema_table:
one of: ``database.schema.table``, ``schema.table``, ``table``
Returns:
a :class:`TableId`
Raises:
:exc:`ValueError` if the input is bad
"""
components = db_schema_table.split(".")
if len(components) == 3: # db.schema.table
d, s, t = components[0], components[1], components[2]
elif len(components) == 2: # schema.table
d, s, t = "", components[0], components[1]
elif len(components) == 1: # table
d, s, t = "", "", components[0]
else:
raise ValueError(f"Bad db_schema_table: {db_schema_table}")
return TableId(db=d, schema=s, table=t)
[docs]def split_db_schema_table_column(db_schema_table_col: str) -> ColumnId:
"""
Converts a simple SQL-style identifier string into a :class:`ColumnId`.
Args:
db_schema_table_col:
one of: ``database.schema.table.column``, ``schema.table.column``,
``table.column``, ``column``
Returns:
a :class:`ColumnId`
Raises:
:exc:`ValueError` if the input is bad
"""
components = db_schema_table_col.split(".")
if len(components) == 4: # db.schema.table.column
d, s, t, c = components[0], components[1], components[2], components[3]
elif len(components) == 3: # schema.table.column
d, s, t, c = "", components[0], components[1], components[2]
elif len(components) == 2: # table.column
d, s, t, c = "", "", components[0], components[1]
elif len(components) == 1: # column
d, s, t, c = "", "", "", components[0]
else:
raise ValueError(f"Bad db_schema_table_col: {db_schema_table_col}")
return ColumnId(db=d, schema=s, table=t, column=c)
[docs]def columns_to_table_column_hierarchy(
columns: List[ColumnId], sort: bool = True
) -> List[Tuple[TableId, List[ColumnId]]]:
"""
Converts a list of column IDs
Args:
columns: list of :class:`ColumnId` objects
sort: sort by table, and column within table?
Returns:
a list of tuples, each ``table, columns``, where ``table`` is a
:class:`TableId` and ``columns`` is a list of :class:`ColumnId`
"""
tables = unique_list(c.table_id for c in columns)
if sort:
tables.sort()
table_column_map = [] # type: List[Tuple[TableId, List[ColumnId]]]
for t in tables:
t_columns = [c for c in columns if c.table_id == t]
if sort:
t_columns.sort()
table_column_map.append((t, t_columns))
return table_column_map
# =============================================================================
# Using SQL grammars (but without reference to Django models, for testing)
# =============================================================================
[docs]def make_identifier(
grammar: SqlGrammar,
database: str = None,
schema: str = None,
table: str = None,
column: str = None,
) -> str:
"""
Makes an SQL identifier by quoting its elements according to the style of
the specific SQL grammar, and then joining them with ``.``.
Args:
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
database: database name
schema: schema name
table: table name
column: column name
Returns:
a string as above in the order "database, schema, table, column", but
omitting any missing parts
"""
elements = [
grammar.quote_identifier_if_required(x)
for x in (database, schema, table, column)
if x
]
assert elements, "make_identifier(): No elements passed!"
return ".".join(elements)
[docs]def dumb_make_identifier(
database: str = None,
schema: str = None,
table: str = None,
column: str = None,
) -> str:
"""
Makes an SQL-style identifier by joining all the parts with ``.``, without
bothering to quote them.
Args:
database: database name
schema: schema name
table: table name
column: column name
Returns:
a string as above in the order "database, schema, table, column", but
omitting any missing parts
"""
elements = filter(None, [database, schema, table, column])
assert elements, "make_identifier(): No elements passed!"
return ".".join(elements)
[docs]def parser_add_result_column(
parsed: ParseResults, column: str, grammar: SqlGrammar
) -> ParseResults:
"""
Takes a parsed SQL statement of the form
.. code-block:: sql
SELECT a, b, c
FROM sometable
WHERE conditions;
and adds a result column, e.g. ``d``, to give
.. code-block:: sql
SELECT a, b, c, d
FROM sometable
WHERE conditions;
Presupposes that there is at least one column already in the SELECT
statement.
Args:
parsed: a `pyparsing.ParseResults` result
column: column name
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
Returns:
a `pyparsing.ParseResults` result
"""
existing_columns = parsed.select_expression.select_columns.asList()
if column not in existing_columns:
# doesn't exist; add it
newcol = grammar.get_result_column().parseString(column, parseAll=True)
parsed.select_expression.extend([",", newcol])
return parsed
[docs]class JoinInfo:
"""
Object to represent a SQL join condition in a simple way.
"""
[docs] def __init__(
self,
table: str,
join_type: str = "INNER JOIN",
join_condition: str = "",
) -> None: # e.g. "ON x = y"
"""
Args:
table: table to be joined in
join_type: join method, e.g. ``"INNER JOIN"``
join_condition: join condition, e.g. ``"ON x = y"``
"""
self.join_type = join_type
self.table = table
self.join_condition = join_condition
[docs]def parser_add_from_tables(
parsed: ParseResults, join_info_list: List[JoinInfo], grammar: SqlGrammar
) -> ParseResults:
"""
Takes a parsed SQL statement of the form
.. code-block:: sql
SELECT a, b, c
FROM sometable
WHERE conditions;
and adds one or more join columns, e.g. ``JoinInfo("othertable", "INNER
JOIN", "ON table.key = othertable.key")``, to give
.. code-block:: sql
SELECT a, b, c
FROM sometable
INNER JOIN othertable ON table.key = othertable.key
WHERE conditions;
Presupposes that there at least one table already in the FROM clause.
Args:
parsed: a `pyparsing.ParseResults` result
join_info_list: list of :class:`JoinInfo` objects
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
Returns:
a `pyparsing.ParseResults` result
"""
# log.critical(parsed.dump())
existing_tables = parsed.join_source.from_tables.asList()
# log.critical(f"existing tables: {existing_tables}")
# log.critical(f"adding table: {table}")
for ji in join_info_list:
if ji.table in existing_tables: # already there
# log.critical("field already present")
continue
parsed_join = grammar.get_join_op().parseString(
ji.join_type, parseAll=True
)[
0
] # e.g. INNER JOIN
parsed_table = grammar.get_table_spec().parseString(
ji.table, parseAll=True
)[0]
extrabits = [parsed_join, parsed_table]
if ji.join_condition: # e.g. ON x = y
extrabits.append(
grammar.get_join_constraint().parseString(
ji.join_condition, parseAll=True
)[0]
)
parsed.join_source.extend(extrabits)
# log.critical(parsed.dump())
return parsed
[docs]def get_first_from_table(
parsed: ParseResults,
match_db: str = "",
match_schema: str = "",
match_table: str = "",
) -> TableId:
"""
Given a set of parsed results from a SELECT statement, returns the ``db,
schema, table`` tuple representing the first table in the FROM clause.
Optionally, the match may be constrained with the ``match*`` parameters.
Args:
parsed: a `pyparsing.ParseResults` result
match_db: optional database name to constrain the result to
match_schema: optional schema name to constrain the result to
match_table: optional table name to constrain the result to
Returns:
a :class:`TableId`, which will be empty in case of failure
"""
existing_tables = parsed.join_source.from_tables.asList()
for t in existing_tables:
if isinstance(t, list):
assert len(t) == 1
t = t[0]
table_id = split_db_schema_table(t)
if match_db and table_id.db != match_db:
continue
if match_schema and table_id.schema != match_schema:
continue
if match_table and table_id.table != match_table:
continue
return table_id
return TableId()
[docs]def set_distinct_within_parsed(p: ParseResults, action: str = "set") -> None:
"""
Modifies (in place) the DISTINCT status of a parsed SQL statement.
Args:
p: a `pyparsing.ParseResults` result
action: ``"set"`` to turn DISTINCT on; ``"clear"`` to turn it off;
or ``"toggle"`` to toggle it.
"""
ss = p.select_specifier # type: ParseResults
if action == "set":
if "DISTINCT" not in ss.asList():
ss.append("DISTINCT")
elif action == "clear":
if "DISTINCT" in ss.asList():
del ss[:]
elif action == "toggle":
if "DISTINCT" in ss.asList():
del ss[:]
else:
ss.append("DISTINCT")
else:
raise ValueError("action must be one of set/clear/toggle")
[docs]def set_distinct(
sql: str,
grammar: SqlGrammar,
action: str = "set",
formatted: bool = True,
debug: bool = False,
debug_verbose: bool = False,
) -> str:
"""
Takes an SQL statement (as a string) and modifies its DISTINCT status.
Args:
sql: SQL statment as text
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
action: one of ``"set"``, ``"clear"``, ``"toggle"``; see
:func:`set_distinct_within_parsed`
formatted: pretty-format the result?
debug: show debugging information to the Python log
debug_verbose: be verbose when debugging
Returns:
the modified SQL statment, as a string
"""
p = grammar.get_select_statement().parseString(sql, parseAll=True)
if debug:
log.info(f"START: {sql}")
if debug_verbose:
log.debug("start dump:\n" + p.dump())
set_distinct_within_parsed(p, action=action)
result = text_from_parsed(p, formatted=formatted)
if debug:
log.info(f"END: {result}")
if debug_verbose:
log.debug("end dump:\n" + p.dump())
return result
[docs]def toggle_distinct(
sql: str,
grammar: SqlGrammar,
formatted: bool = True,
debug: bool = False,
debug_verbose: bool = False,
) -> str:
"""
Takes an SQL statement and toggles its DISTINCT status.
Args:
sql: SQL statment as text
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
formatted: pretty-format the result?
debug: show debugging information to the Python log
debug_verbose: be verbose when debugging
Returns:
the modified SQL statment, as a string
"""
return set_distinct(
sql=sql,
grammar=grammar,
action="toggle",
formatted=formatted,
debug=debug,
debug_verbose=debug_verbose,
)
# =============================================================================
# SQLAlchemy reflection and DDL
# =============================================================================
_global_print_not_execute_sql = False
[docs]def set_print_not_execute(print_not_execute: bool) -> None:
"""
Sets a nasty global flag: should we print DDL, rather than executing it,
when we issue DDL commands from this module?
Args:
print_not_execute: print (not execute)?
"""
global _global_print_not_execute_sql
_global_print_not_execute_sql = print_not_execute
[docs]def execute(engine: Engine, sql: str) -> None:
"""
Executes SQL.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
Args:
engine: SQLAlchemy database Engine
sql: raw SQL to execute (or print)
"""
log.debug(sql)
if _global_print_not_execute_sql:
print(format_sql_for_print(sql) + "\n;")
# extra \n in case the SQL ends in a comment
else:
engine.execute(sql)
[docs]def add_columns(engine: Engine, table: Table, columns: List[Column]) -> None:
"""
Adds columns to a table.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
Args:
engine: SQLAlchemy database Engine
table: SQLAlchemy Table object
columns: SQLAlchemy Column objects to add to the table
Behaviour of different database systems:
- ANSI SQL: add one column at a time: ``ALTER TABLE ADD [COLUMN] coldef``
- i.e. "COLUMN" optional, one at a time, no parentheses
- https://www.contrib.andrew.cmu.edu/~shadow/sql/sql1992.txt
- MySQL: ``ALTER TABLE ADD [COLUMN] (a INT, b VARCHAR(32));``
- i.e. "COLUMN" optional, parentheses required for >1, multiple OK
- https://dev.mysql.com/doc/refman/5.7/en/alter-table.html
- MS SQL Server: ``ALTER TABLE ADD COLUMN a INT, B VARCHAR(32);``
- i.e. no "COLUMN", no parentheses, multiple OK
- https://msdn.microsoft.com/en-us/library/ms190238.aspx
- https://msdn.microsoft.com/en-us/library/ms190273.aspx
- https://stackoverflow.com/questions/2523676
This function therefore operates one at a time.
SQLAlchemy doesn't provide a shortcut for this.
"""
existing_column_names = get_column_names(
engine, tablename=table.name, to_lower=True
)
column_defs = [] # type: List[str]
for column in columns:
if column.name.lower() not in existing_column_names:
column_defs.append(column_creation_ddl(column, engine.dialect))
else:
log.debug(
f"Table {table.name!r}: column {column.name!r} "
f"already exists; not adding"
)
for column_def in column_defs:
log.info(f"Table {table.name!r}: adding column {column_def!r}")
sql = f"ALTER TABLE {table.name} ADD {column_def}"
execute(engine, sql)
[docs]def drop_columns(
engine: Engine, table: Table, column_names: Iterable[str]
) -> None:
"""
Drops columns from a table.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
Args:
engine: SQLAlchemy database Engine
table: SQLAlchemy Table object
column_names: names of columns to drop
Columns are dropped one by one.
"""
existing_column_names = get_column_names(
engine, tablename=table.name, to_lower=True
)
for name in column_names:
if name.lower() not in existing_column_names:
log.debug(
f"Table {table.name!r}: column {name!r} "
f"does not exist; not dropping"
)
else:
log.info(f"Table {table.name!r}: dropping column {name!r}")
sql = f"ALTER TABLE {table.name} DROP COLUMN {name}"
# SQL Server:
# http://www.techonthenet.com/sql_server/tables/alter_table.php
# MySQL:
# http://dev.mysql.com/doc/refman/5.7/en/alter-table.html
execute(engine, sql)
[docs]def add_indexes(
engine: Engine, table: Table, index_info_list: Iterable[IndexCreationInfo]
) -> None:
"""
Adds indexes to a table.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
Args:
engine:
SQLAlchemy database Engine
table:
SQLAlchemy Table object
index_info_list:
Index(es) to create: list of :class:`IndexCreationInfo` objects.
"""
existing_index_names = get_index_names(
engine, tablename=table.name, to_lower=True
)
for i in index_info_list:
index_name = i.index_name
column = i.column_names
if index_name.lower() not in existing_index_names:
log.info(
f"Table {table.name!r}: adding index {index_name!r} on "
f"column {column!r}"
)
execute(
engine,
f"""
CREATE{" UNIQUE" if i.unique else ""} INDEX {index_name}
ON {table.name} ({column})
""",
)
else:
log.debug(
f"Table {table.name!r}: index {index_name!r} "
f"already exists; not adding"
)
[docs]def drop_indexes(
engine: Engine, table: Table, index_names: Iterable[str]
) -> None:
"""
Drops indexes from a table.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
Args:
engine: SQLAlchemy database Engine
table: SQLAlchemy Table object
index_names: names of indexes to drop
"""
existing_index_names = get_index_names(
engine, tablename=table.name, to_lower=True
)
for index_name in index_names:
if index_name.lower() not in existing_index_names:
log.debug(
f"Table {table.name!r}: index {index_name!r} "
f"does not exist; not dropping"
)
else:
log.info(f"Table {table.name!r}: dropping index {index_name!r}")
if engine.dialect.name == "mysql":
sql = f"ALTER TABLE {table.name} DROP INDEX {index_name}"
elif engine.dialect.name == "mssql":
sql = f"DROP INDEX {table.name}.{index_name}"
else:
assert False, f"Unknown dialect: {engine.dialect.name}"
execute(engine, sql)
[docs]def get_table_names(
engine: Engine, to_lower: bool = False, sort: bool = False
) -> List[str]:
"""
Returns all table names for the database.
Args:
engine: SQLAlchemy database Engine
to_lower: convert table names to lower case?
sort: sort table names?
Returns:
list of table names
"""
inspector = inspect(engine)
table_names = inspector.get_table_names()
if to_lower:
table_names = [x.lower() for x in table_names]
if sort:
table_names = sorted(table_names, key=lambda x: x.lower())
return table_names
[docs]def get_view_names(
engine: Engine, to_lower: bool = False, sort: bool = False
) -> List[str]:
"""
Returns all view names for the database.
Args:
engine: SQLAlchemy database Engine
to_lower: convert view names to lower case?
sort: sort view names?
Returns:
list of view names
"""
inspector = inspect(engine)
view_names = inspector.get_view_names()
if to_lower:
view_names = [x.lower() for x in view_names]
if sort:
view_names = sorted(view_names, key=lambda x: x.lower())
return view_names
[docs]def get_column_names(
engine: Engine, tablename: str, to_lower: bool = False, sort: bool = False
) -> List[str]:
"""
Reads columns names afresh from the database, for a specific table (in case
metadata is out of date).
Args:
engine: SQLAlchemy database Engine
tablename: name of the table
to_lower: convert view names to lower case?
sort: sort view names?
Returns:
list of column names
"""
inspector = inspect(engine)
columns = inspector.get_columns(tablename)
column_names = [x["name"] for x in columns]
if to_lower:
column_names = [x.lower() for x in column_names]
if sort:
column_names = sorted(column_names, key=lambda x: x.lower())
return column_names
[docs]def get_index_names(
engine: Engine, tablename: str, to_lower: bool = False, sort: bool = False
) -> List[str]:
"""
Reads index names from the database, for a specific table.
Args:
engine: SQLAlchemy database Engine
tablename: name of the table
to_lower: convert index names to lower case?
sort: sort index names?
Returns:
list of index names
"""
# http://docs.sqlalchemy.org/en/latest/core/reflection.html
inspector = inspect(engine)
indexes = inspector.get_indexes(tablename)
index_names = [x["name"] for x in indexes if x["name"]]
# ... at least for SQL Server, there always seems to be a blank one
# with {'name': None, ...}.
if to_lower:
index_names = [x.lower() for x in index_names]
if sort:
index_names = sorted(index_names, key=lambda x: x.lower())
return index_names
[docs]def ensure_columns_present(
engine: Engine, tablename: str, column_names: Iterable[str]
) -> None:
"""
Ensure all these columns are present in a table, or raise an exception.
Operates in case-insensitive fashion.
Args:
engine: SQLAlchemy database Engine
tablename: name of the table
column_names: names of required columns
Raises:
:exc:`ValueError` if any are missing
"""
existing_column_names = get_column_names(
engine, tablename=tablename, to_lower=True
)
if not column_names:
return
for col in column_names:
if col.lower() not in existing_column_names:
raise ValueError(
f"Column {col!r} missing from table {tablename!r}, "
f"whose columns are {existing_column_names!r}"
)
[docs]def create_view(engine: Engine, viewname: str, select_sql: str) -> None:
"""
Creates a view.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
Args:
engine: SQLAlchemy database Engine
viewname: view name
select_sql: SQL SELECT statement for this view
"""
if engine.dialect.name == "mysql":
# MySQL has CREATE OR REPLACE VIEW.
sql = f"CREATE OR REPLACE VIEW {viewname} AS {select_sql}"
else:
# SQL Server doesn't: https://stackoverflow.com/questions/18534919
drop_view(engine, viewname, quiet=True)
sql = f"CREATE VIEW {viewname} AS {select_sql}"
log.info(f"Creating view: {viewname!r}")
execute(engine, sql)
[docs]def assert_view_has_same_num_rows(
engine: Engine, basetable: str, viewname: str
) -> None:
"""
Ensures that a view gives the same number of rows as a table. (For use in
situations where this should hold; views don't have to do this in general!)
Args:
engine: SQLAlchemy database Engine
basetable: name of the table that this view should have a 1:1
relationship to
viewname: view name
Raises:
:exc:`AssertionError` if they don't have the same number of rows
"""
# Note that this relies on the data, i.e. design failures MAY cause this
# assertion to fail, but won't necessarily (e.g. if the table is empty).
n_base = count_star(engine, basetable)
n_view = count_star(engine, viewname)
assert n_view == n_base, (
f"View bug: view {viewname} has {n_view} records but its base table "
f"{basetable} has {n_base}; they should be equal"
)
[docs]def drop_view(engine: Engine, viewname: str, quiet: bool = False) -> None:
"""
Drops a view.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
Args:
engine: SQLAlchemy database Engine
viewname: view name
quiet: don't announce this to the Python log
"""
# MySQL has DROP VIEW IF EXISTS, but SQL Server only has that from
# SQL Server 2016 onwards.
# - https://msdn.microsoft.com/en-us/library/ms173492.aspx
# - http://dev.mysql.com/doc/refman/5.7/en/drop-view.html
view_names = get_view_names(engine, to_lower=True)
if viewname.lower() not in view_names:
log.debug(f"View {viewname} does not exist; not dropping")
else:
if not quiet:
log.info(f"Dropping view: {viewname!r}")
sql = f"DROP VIEW {viewname}"
execute(engine, sql)
[docs]def get_column_fk_description(c: Column) -> str:
"""
Standardized description of a column's foreign keys.
Args:
c:
SQLAlchemy Column
"""
fkeys = sorted(
c.foreign_keys, key=lambda x: (x.column.table.name, x.column.name)
)
if not fkeys:
return ""
fk_strings = [f"{fk.column.table.name}.{fk.column.name}" for fk in fkeys]
return "FK to " + ", ".join(fk_strings)
[docs]@dataclass
class ReflectedColumnInfo:
"""
Provides information about a column reflected from a database, with
optional additional information from a CRATE data dictionary, +/- a
description of values in that column (for researcher reports).
"""
column: Column
override_comment: str = None # can override SQLAlchemy-level comment
crate_annotation: str = None
values_info: str = None
@property
def name(self) -> str:
return self.columnname
@property
def columnname(self) -> str:
return self.column.name
# Do not manipulate the case of SOURCE tables/columns.
# If you do, they can fail to match the SQLAlchemy
# introspection and cause a crash.
@property
def tablename(self) -> str:
return self.column.table.name
@property
def tablename_columname(self) -> str:
return f"{self.column.table.name}.{self.column.name}"
@property
def sqla_coltype(self) -> TypeEngine:
return self.column.type
@property
def sql_type(self) -> str:
try:
return str(self.column.type)
except CompileError:
log.critical(f"Column that failed was: {self.column!r}")
raise
@property
def datatype_sqltext(self) -> str:
return self.sql_type
@property
def pk(self) -> bool:
return self.column.primary_key
@property
def nullable(self) -> bool:
return self.column.nullable
@property
def comment(self) -> str:
"""
The database comment, if present, or another that has been supplied.
"""
db_comment = getattr(self.column, "comment", "")
# ... not all dialects support reflecting comments;
# https://docs.sqlalchemy.org/en/14/core/reflection.html
return self.override_comment or db_comment or ""
@property
def nullable_str(self) -> str:
return "✓" if self.nullable else "NOT NULL"
@property
def pk_str(self) -> str:
return "PK" if self.pk else ""
@property
def fk_str(self) -> str:
return get_column_fk_description(self.column)
[docs] def get_column_source_description(self, with_fk: bool = True) -> str:
"""
Returns a description of where the column is from, used as a suffix for
data dictionary comment generation.
Args:
with_fk:
Include foreign key descriptions (helpful because CRATE doesn't
reproduce FK relationships in the destination DDL).
"""
if with_fk:
fk_str = self.fk_str
if fk_str:
fk_str = "; " + fk_str
else:
fk_str = ""
return f" [from {self.tablename_columname}{fk_str}]"
@property
def crate_annotation_str(self) -> str:
"""
Human-oriented version for report.
"""
return self.crate_annotation or "?"
@property
def values_info_str(self) -> str:
"""
Human-oriented version for report.
"""
return self.values_info or "?"
# =============================================================================
# ViewMaker
# =============================================================================
[docs]class ViewMaker:
"""
View-building assistance class.
"""
[docs] def __init__(
self,
viewname: str,
engine: Engine,
basetable: str,
existing_to_lower: bool = False,
rename: Dict[str, str] = None,
userobj: Any = None,
enforce_same_n_rows_as_base: bool = True,
insert_basetable_columns: bool = True,
) -> None:
"""
Args:
viewname: name of the view
engine: SQLAlchemy database Engine
basetable: name of the single base table that this view draws from
existing_to_lower: translate column names to lower case in the
view?
rename: optional dictionary mapping ``from_name: to_name`` to
translate column names in the view
userobj: optional object (e.g. `argparse.Namespace`,
dictionary...), not used by this class, and purely to store
information for others' benefit
enforce_same_n_rows_as_base: ensure that the view produces the
same number of rows as its base table?
insert_basetable_columns: start drafting the view by including all
columns from the base table?
"""
rename = rename or {}
assert basetable, "ViewMaker: basetable missing!"
self.viewname = viewname
self.engine = engine
self.basetable = basetable
self.userobj = userobj # only for others' benefit
self.enforce_same_n_rows_as_base = enforce_same_n_rows_as_base
self.select_elements = [] # type: List[str]
self.from_elements = [basetable] # type: List[str]
self.where_elements = [] # type: List[str]
self.lookup_tables = [] # type: List[str]
self.index_requests = OrderedDict() # type: Dict[str, List[str]]
if insert_basetable_columns:
grammar = make_grammar(engine.dialect.name)
def q(identifier: str) -> str:
return grammar.quote_identifier_if_required(identifier)
for colname in get_column_names(
engine, tablename=basetable, to_lower=existing_to_lower
):
if colname in rename:
rename_to = rename[colname]
if not rename_to:
continue
as_clause = f" AS {q(rename_to)}"
else:
as_clause = ""
self.select_elements.append(
f"{q(basetable)}.{q(colname)}{as_clause}"
)
assert self.select_elements, (
"Must have some active SELECT " "elements from base table"
)
[docs] def add_select(self, element: str) -> None:
"""
Add an element to the SELECT clause of the the draft view's SQL
(meaning: add e.g. a result column).
"""
self.select_elements.append(element)
[docs] def add_from(self, element: str) -> None:
"""
Add an element to the FROM clause of the draft view's SQL statement.
"""
self.from_elements.append(element)
[docs] def add_where(self, element: str) -> None:
"""
Add an element to the WHERE clause of the draft view's SQL statement.
"""
self.where_elements.append(element)
[docs] def get_sql(self) -> str:
"""
Returns the view-creation SQL.
"""
assert self.select_elements, "ViewMaker: no SELECT elements!"
if self.where_elements:
where = "\n WHERE {}".format(
"\n AND ".join(self.where_elements)
)
else:
where = ""
return (
"\n SELECT {select_elements}"
"\n FROM {from_elements}{where}".format(
select_elements=",\n ".join(self.select_elements),
from_elements="\n ".join(self.from_elements),
where=where,
)
)
[docs] def create_view(self, engine: Engine) -> None:
"""
Creates the view.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
If ``enforce_same_n_rows_as_base`` is set, check the number of rows
returned matches the base table.
Args:
engine: SQLAlchemy database Engine
"""
create_view(engine, self.viewname, self.get_sql())
if self.enforce_same_n_rows_as_base:
assert_view_has_same_num_rows(
engine, self.basetable, self.viewname
)
[docs] def drop_view(self, engine: Engine) -> None:
"""
Drops the view.
Whether we act or just print is conditional on previous calls to
:func:`set_print_not_execute`.
Args:
engine: SQLAlchemy database Engine
"""
drop_view(engine, self.viewname)
[docs] def record_lookup_table(self, table: str) -> None:
"""
Keep a record of a lookup table. The framework may wish to suppress
these from a data dictionary later (e.g. create a view, suppress the
messier raw data). See :func:`get_lookup_tables`.
Args:
table: table name
"""
if table not in self.lookup_tables:
self.lookup_tables.append(table)
[docs] def get_lookup_tables(self) -> List[str]:
"""
Returns all lookup tables that we have recorded. See
:func:`record_lookup_table`.
"""
return self.lookup_tables
[docs] def request_index(self, table: str, column: str) -> None:
"""
Note a request that a specific column be indexed. The framework can use
the ViewMaker to keep a note of these requests, and then add index
hints to a data dictionary if it wishes. See
:func:`get_index_request_dict`.
Args:
table: table name
column: column name
"""
if table not in self.index_requests:
self.index_requests[table] = [] # type: List[str]
if column not in self.index_requests[table]:
self.index_requests[table].append(column)
[docs] def get_index_request_dict(self) -> Dict[str, List[str]]:
"""
Returns all our recorded index requests, as a dictionary mapping each
table name to a list of column names to be indexed. See
:func:`request_index`.
"""
return self.index_requests
[docs] def record_lookup_table_keyfield(
self, table: str, keyfield: Union[str, Iterable[str]]
) -> None:
"""
Makes a note that a table is a lookup table, and its key field(s)
should be indexed. See :func:`get_lookup_tables`,
:func:`get_index_request_dict`.
Args:
table: table name
keyfield: field name, or iterable (e.g. list) of them
"""
if isinstance(keyfield, str):
keyfield = [keyfield]
self.record_lookup_table(table)
for kf in keyfield:
self.request_index(table, kf)
[docs] def record_lookup_table_keyfields(
self,
table_keyfield_tuples: Iterable[Tuple[str, Union[str, Iterable[str]]]],
) -> None:
"""
Make a note of a whole set of lookup table / key field groups. See
:func:`record_lookup_table_keyfield`.
Args:
table_keyfield_tuples:
iterable (e.g. list) of tuples of the format ``tablename,
keyfield``. Each will be passed to
:func:`record_lookup_table_keyfield`.
"""
for t, k in table_keyfield_tuples:
self.record_lookup_table_keyfield(t, k)
# =============================================================================
# TransactionSizeLimiter
# =============================================================================
[docs]class TransactionSizeLimiter:
"""
Class to allow us to limit the size of database transactions.
"""
[docs] def __init__(
self,
session: Session,
max_rows_before_commit: int = None,
max_bytes_before_commit: int = None,
) -> None:
"""
Args:
session: SQLAlchemy database Session
max_rows_before_commit: how many rows should we insert before
triggering a COMMIT? ``None`` for no limit.
max_bytes_before_commit: how many bytes should we insert before
triggering a COMMIT? ``None`` for no limit.
"""
self._session = session
self._max_rows_before_commit = max_rows_before_commit
self._max_bytes_before_commit = max_bytes_before_commit
self._bytes_in_transaction = 0
self._rows_in_transaction = 0
[docs] def commit(self) -> None:
"""
Performs a database COMMIT and resets our counters.
(Measures some timing information, too.)
"""
with MultiTimerContext(timer, TIMING_COMMIT):
self._session.commit()
self._bytes_in_transaction = 0
self._rows_in_transaction = 0
[docs] def notify(
self, n_rows: int, n_bytes: int, force_commit: bool = False
) -> None:
"""
Use this function to notify the limiter of data that you've inserted
into the database. If the total number of rows or bytes exceeds a limit
that we've set, this will trigger a COMMIT.
Args:
n_rows: number of rows inserted
n_bytes: number of bytes inserted
force_commit: force a COMMIT?
"""
if force_commit:
self.commit()
return
self._bytes_in_transaction += n_bytes
self._rows_in_transaction += n_rows
if (
self._max_bytes_before_commit is not None
and self._bytes_in_transaction >= self._max_bytes_before_commit
):
log.debug(
f"Triggering early commit based on byte count "
f"(reached {sizeof_fmt(self._bytes_in_transaction)}, "
f"limit is {sizeof_fmt(self._max_bytes_before_commit)})"
)
self.commit()
elif (
self._max_rows_before_commit is not None
and self._rows_in_transaction >= self._max_rows_before_commit
):
log.debug(
f"Triggering early commit based on row count "
f"(reached {self._rows_in_transaction} rows, "
f"limit is {self._max_rows_before_commit})"
)
self.commit()
# =============================================================================
# Specification matching
# =============================================================================
def _matches_tabledef(table: str, tabledef: str) -> bool:
"""
Does the table name match the wildcard-based table definition?
Args:
table: tablename
tabledef: ``fnmatch``-style pattern (e.g.
``"patient_address_table_*"``)
"""
tr = get_spec_match_regex(tabledef)
return bool(tr.match(table))
[docs]def matches_tabledef(table: str, tabledef: Union[str, List[str]]) -> bool:
"""
Does the table name match the wildcard-based table definition?
Args:
table: table name
tabledef: ``fnmatch``-style pattern (e.g.
``"patient_address_table_*"``), or list of them
"""
if isinstance(tabledef, str):
return _matches_tabledef(table, tabledef)
elif not tabledef:
return False
else: # list
return any(_matches_tabledef(table, td) for td in tabledef)
def _matches_fielddef(table: str, field: str, fielddef: str) -> bool:
"""
Does the table/field name match the wildcard-based field definition?
Args:
table: tablename
field: fieldname
fielddef: ``fnmatch``-style pattern (e.g. ``"system_table.*"``,
``"*.nhs_number"``)
"""
column_id = split_db_schema_table_column(fielddef)
cr = get_spec_match_regex(column_id.column)
if not column_id.table:
# Table not specified in the wildcard.
# It's a match if the field matches.
return bool(cr.match(field))
# Table specified in the wildcard.
# Both the table and the field parts have to match.
tr = get_spec_match_regex(column_id.table)
return bool(tr.match(table)) and bool(cr.match(field))
[docs]def matches_fielddef(
table: str, field: str, fielddef: Union[str, List[str]]
) -> bool:
"""
Does the table/field name match the wildcard-based field definition?
Args:
table: table name
field: fieldname
fielddef: ``fnmatch``-style pattern (e.g. ``"system_table.*"`` or
``"*.nhs_number"``), or list of them
"""
if isinstance(fielddef, str):
return _matches_fielddef(table, field, fielddef)
elif not fielddef:
return False
else: # list
return any(_matches_fielddef(table, field, fd) for fd in fielddef)
# =============================================================================
# More SQL
# =============================================================================
[docs]def sql_fragment_cast_to_int(
expr: str,
big: bool = True,
dialect: Dialect = None,
viewmaker: ViewMaker = None,
) -> str:
"""
Takes an SQL expression and coerces it to an integer. For Microsoft SQL
Server.
Args:
expr: starting SQL expression
big: use BIGINT, not INTEGER?
dialect: optional :class:`sqlalchemy.engine.interfaces.Dialect`. If
``None`` and we have a ``viewmaker``, use the viewmaker's dialect.
Otherwise, assume SQL Server.
viewmaker: optional :class:`ViewMaker`
Returns:
modified SQL expression
*Notes*
Conversion to INT:
- https://stackoverflow.com/questions/2000045
- https://stackoverflow.com/questions/14719760 (this one in particular!)
- https://stackoverflow.com/questions/14692131
- see LIKE example.
- see ISNUMERIC();
https://msdn.microsoft.com/en-us/library/ms186272.aspx;
but that includes non-integer numerics
- https://msdn.microsoft.com/en-us/library/ms174214(v=sql.120).aspx;
relates to the SQL Server Management Studio "Find and Replace"
dialogue box, not to SQL itself!
- https://stackoverflow.com/questions/29206404/mssql-regular-expression
Note that the regex-like expression supported by LIKE is extremely limited.
- https://msdn.microsoft.com/en-us/library/ms179859.aspx
- The only things supported are:
.. code-block:: none
% any characters
_ any single character
[] single character in range or set, e.g. [a-f], [abcdef]
[^] single character NOT in range or set, e.g. [^a-f], [abcdef]
SQL Server does not support a REGEXP command directly.
So the best bet is to have the LIKE clause check for a non-integer:
.. code-block:: sql
CASE
WHEN something LIKE '%[^0-9]%' THEN NULL
ELSE CAST(something AS BIGINT)
END
... which doesn't deal with spaces properly, but there you go.
Could also strip whitespace left/right:
.. code-block:: sql
CASE
WHEN LTRIM(RTRIM(something)) LIKE '%[^0-9]%' THEN NULL
ELSE CAST(something AS BIGINT)
END
That only works for positive integers.
LTRIM/RTRIM are not ANSI SQL.
Nor are unusual LIKE clauses; see
https://stackoverflow.com/questions/712580/list-of-special-characters-for-sql-like-clause
The other, for SQL Server 2012 or higher, is TRY_CAST:
.. code-block:: sql
TRY_CAST(something AS BIGINT)
... which returns NULL upon failure; see
https://msdn.microsoft.com/en-us/library/hh974669.aspx
Therefore, our **method** is as follows:
- If the database supports TRY_CAST, use that.
- Otherwise if we're using SQL Server, use a CASE/CAST construct.
- Otherwise, raise :exc:`ValueError` as we don't know what to do.
"""
inttype = "BIGINT" if big else "INTEGER"
if dialect is None and viewmaker is not None:
dialect = viewmaker.engine.dialect
if dialect is None:
sql_server = True
supports_try_cast = False
else:
# noinspection PyUnresolvedReferences
sql_server = dialect.name == "mssql"
# noinspection PyUnresolvedReferences
supports_try_cast = (
sql_server and dialect.server_version_info >= MS_2012_VERSION
)
if supports_try_cast:
return f"TRY_CAST({expr} AS {inttype})"
elif sql_server:
return (
f"CASE WHEN LTRIM(RTRIM({expr})) LIKE '%[^0-9]%' "
f"THEN NULL ELSE CAST({expr} AS {inttype}) END"
)
# Doesn't support negative integers.
else:
# noinspection PyUnresolvedReferences
raise ValueError(
f"Code not yet written for convert-to-int for "
f"dialect {dialect.name}"
)
# =============================================================================
# Abstracted SQL WHERE condition
# =============================================================================
[docs]@register_for_json(method=METHOD_PROVIDES_INIT_KWARGS)
@functools.total_ordering
class WhereCondition:
"""
Ancillary class for building SQL WHERE expressions from our web forms.
The essence of it is ``WHERE column op value_or_values``.
"""
[docs] def __init__(
self,
column_id: ColumnId = None,
op: str = "",
datatype: str = "",
value_or_values: Any = None,
raw_sql: str = "",
from_table_for_raw_sql: TableId = None,
) -> None:
"""
Args:
column_id:
:class:`ColumnId` for the column
op:
operation (e.g. ``=``, ``<``, ``<=``, etc.)
datatype:
data type string that must match values in our
``querybuilder.js``; see source code. We use this to know how
to build SQL literal values. (Not terribly elegant, but it
works; SQL injection isn't a particular concern because we
let our users run any SQL they want and ensure the connection
is made read-only.)
value_or_values:
``None``, single value, or list of values. Which is appropriate
depends on the operation. For example, ``IS NULL`` takes no
value; ``=`` takes one; ``IN`` takes many.
raw_sql:
override any thinking we might wish to do, and just return this
raw SQL
from_table_for_raw_sql:
if we are using raw SQL, provide a :class:`TableId` for the
relevant table here
"""
self._column_id = column_id
self._op = op.upper()
self._datatype = datatype
self._value = value_or_values
self._no_value = False
self._multivalue = False
self._raw_sql = raw_sql
self._from_table_for_raw_sql = from_table_for_raw_sql
if not self._raw_sql:
if self._op in SQL_OPS_VALUE_UNNECESSARY:
self._no_value = True
assert value_or_values is None, "Superfluous value passed"
elif self._op in SQL_OPS_MULTIPLE_VALUES:
self._multivalue = True
assert isinstance(value_or_values, list), "Need list"
else:
assert not isinstance(
value_or_values, list
), "Need single value"
def init_kwargs(self) -> Dict:
return {
"column_id": self._column_id,
"op": self._op,
"datatype": self._datatype,
"value_or_values": self._value,
"raw_sql": self._raw_sql,
"from_table_for_raw_sql": self._from_table_for_raw_sql,
}
def __repr__(self) -> str:
return (
"{qualname}("
"column_id={column_id}, "
"op={op}, "
"datatype={datatype}, "
"value_or_values={value_or_values}, "
"raw_sql={raw_sql}, "
"from_table_for_raw_sql={from_table_for_raw_sql}"
")".format(
qualname=self.__class__.__qualname__,
column_id=repr(self._column_id),
op=repr(self._op),
datatype=repr(self._datatype),
value_or_values=repr(self._value),
raw_sql=repr(self._raw_sql),
from_table_for_raw_sql=repr(self._from_table_for_raw_sql),
)
)
def __eq__(self, other: "WhereCondition") -> bool:
return (
self._raw_sql == other._raw_sql
and self._column_id == other._column_id
and self._op == other._op
and self._value == other._value
)
def __lt__(self, other: "WhereCondition") -> bool:
return (self._raw_sql, self._column_id, self._op, self._value) < (
other._raw_sql,
other._column_id,
other._op,
other._value,
)
@property
def column_id(self) -> ColumnId:
"""
Returns the :class:`ColumnId` provided at creation.
"""
return self._column_id
@property
def table_id(self) -> TableId:
"""
Returns a :class:`TableId`:
- for raw SQL, our ``from_table_for_raw_sql`` attribute
- otherwise, the table ID extracted from our ``column_id`` attribute
"""
if self._raw_sql:
return self._from_table_for_raw_sql
return self.column_id.table_id
[docs] def table_str(self, grammar: SqlGrammar) -> str:
"""
Returns the table identifier in the specified SQL grammar.
Args:
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
"""
return self.table_id.identifier(grammar)
[docs] def sql(self, grammar: SqlGrammar) -> str:
"""
Returns the WHERE clause (without ``WHERE`` itself!) for our condition,
in the specified SQL grammar. Some examples might be:
- ``somecol = 3``
- ``othercol IN (6, 7, 8)``
- ``thirdcol IS NOT NULL``
- ``textcol LIKE '%paracetamol%'``
- ``MATCH (fulltextcol AGAINST 'paracetamol')`` (MySQL)
- ``CONTAINS(fulltextcol, 'paracetamol')`` (SQL Server)
Args:
grammar: :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar`
"""
if self._raw_sql:
return self._raw_sql
col = self._column_id.identifier(grammar)
op = self._op
if self._no_value:
return f"{col} {op}"
if self._datatype in QB_STRING_TYPES:
element_converter = sql_string_literal
elif self._datatype == QB_DATATYPE_DATE:
element_converter = sql_date_literal
elif self._datatype == QB_DATATYPE_INTEGER:
element_converter = str
elif self._datatype == QB_DATATYPE_FLOAT:
element_converter = str
else:
# Safe default
element_converter = sql_string_literal
if self._multivalue:
literal = "({})".format(
", ".join(element_converter(v) for v in self._value)
)
else:
literal = element_converter(self._value)
if self._op == "MATCH": # MySQL
return f"MATCH ({col}) AGAINST ({literal})"
elif self._op == "CONTAINS": # SQL Server
return f"CONTAINS({col}, {literal})"
else:
return f"{col} {op} {literal}"
# =============================================================================
# SQL formatting
# =============================================================================
# =============================================================================
# Plain SQL types
# =============================================================================
[docs]def is_sql_column_type_textual(column_type: str, min_length: int = 1) -> bool:
"""
Does an SQL column type look textual?
Args:
column_type: SQL column type as a string, e.g. ``"VARCHAR(50)"``
min_length: what's the minimum string length we'll say "yes" to?
Returns:
is it a textual column (of the minimum length or more)?
Note:
- For SQL Server's NVARCHAR(MAX),
:meth:`crate_anon.crateweb.research.research_db_info._schema_query_microsoft`
returns "NVARCHAR(-1)"
"""
if not column_type:
return False
column_type = column_type.upper().split()[0]
if column_type in SQLTYPES_TEXT:
# A text type without a specific length
return True
try:
m = COLTYPE_WITH_ONE_INTEGER_REGEX.match(column_type)
basetype = m.group(1)
length = int(m.group(2))
except (AttributeError, ValueError):
return False
return (length >= min_length or length < 0) and basetype in SQLTYPES_TEXT
[docs]def coltype_length_if_text(column_type: str, dialect: str) -> Optional[int]:
"""
Find the length of an sql text column type.
Args:
column_type: SQL column type as a string, e.g. ``"VARCHAR(50)"``
dialect: the sql dialect the column type is from
Returns:
length of the column or ``None`` if it's not a text column.
"""
column_type = column_type.upper()
if column_type in SQLTYPES_TEXT:
# No length specified - get the default
try:
if dialect == SqlaDialectName.MYSQL:
return MYSQL_COLTYPE_TO_LEN[column_type]
elif dialect == SqlaDialectName.MSSQL:
return MSSQL_COLTYPE_TO_LEN[column_type]
else:
raise ValueError(
f"{dialect} is not a valid SQL dialect. Must "
f"be one of: {SqlaDialectName.MYSQL!r}, "
f"{SqlaDialectName.MSSQL!r}"
)
except KeyError:
log.error(
f"SQL dialect {dialect} has no data type " f"{column_type}"
)
raise
else:
# Length specified - get it from the column type
try:
m = COLTYPE_WITH_ONE_INTEGER_REGEX.match(column_type)
basetype = m.group(1)
length = m.group(2)
if length == "MAX" or length == "-1":
if basetype == "VARCHAR":
return MSSQL_COLTYPE_TO_LEN["VARCHAR_MAX"]
elif basetype == "NVARCHAR":
return MSSQL_COLTYPE_TO_LEN["NVARCHAR_MAX"]
else:
return None
except AttributeError:
# Not the correct type of column
return None
try:
length = int(length)
except ValueError:
# Not the correct type of column
return None
return length
[docs]def escape_quote_in_literal(s: str) -> str:
r"""
Escape ``'``. We could use ``''`` or ``\'``.
Let's use ``\.`` for consistency with percent escaping.
"""
return s.replace("'", r"\'")
[docs]def escape_percent_in_literal(sql: str) -> str:
r"""
Escapes ``%`` by converting it to ``\%``.
Use this for LIKE clauses.
- https://dev.mysql.com/doc/refman/5.7/en/string-literals.html
"""
return sql.replace("%", r"\%")
[docs]def escape_percent_for_python_dbapi(sql: str) -> str:
"""
Escapes ``%`` by converting it to ``%%``.
Use this for SQL within Python where ``%`` characters are used for argument
placeholders.
"""
return sql.replace("%", "%%")
[docs]def escape_sql_string_literal(s: str) -> str:
"""
Escapes SQL string literal fragments against quotes and parameter
substitution.
"""
return escape_percent_in_literal(escape_quote_in_literal(s))
[docs]def make_string_literal(s: str) -> str:
"""
Converts a Python string into an SQL single-quoted (and escaped) string
literal.
"""
return f"'{escape_sql_string_literal(s)}'"
[docs]def escape_sql_string_or_int_literal(s: Union[str, int]) -> str:
"""
Converts an integer or a string into an SQL literal (with single quotes and
escaping in the case of a string).
"""
if isinstance(s, int):
return str(s)
else:
return make_string_literal(s)
[docs]def translate_sql_qmark_to_percent(sql: str) -> str:
"""
This function translates SQL using ``?`` placeholders to SQL using ``%s``
placeholders, without breaking literal ``'?'`` or ``'%'``, e.g. inside
string literals.
*Notes*
- MySQL likes ``?`` as a placeholder.
- https://dev.mysql.com/doc/refman/5.7/en/sql-syntax-prepared-statements.html
- Python DBAPI allows several: ``%s``, ``?``, ``:1``, ``:name``,
``%(name)s``.
- https://www.python.org/dev/peps/pep-0249/#paramstyle
- Django uses ``%s``.
- https://docs.djangoproject.com/en/1.8/topics/db/sql/
- Microsoft like ``?``, ``@paramname``, and ``:paramname``.
- https://msdn.microsoft.com/en-us/library/yy6y35y8(v=vs.110).aspx
- We need to parse SQL with argument placeholders.
- See :class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` classes,
particularly: ``bind_parameter``
I prefer ``?``, because ``%`` is used in LIKE clauses, and the databases
we're using like it.
So:
- We use ``%s`` when using ``cursor.execute()`` directly, via Django.
- We use ``?`` when talking to users, and
:class:`cardinal_pythonlib.sql.sql_grammar.SqlGrammar` objects, so that
the visual appearance matches what they expect from their database.
""" # noqa
# 1. Escape % characters
sql = escape_percent_for_python_dbapi(sql)
# 2. Replace ? characters that are not within quotes with %s.
newsql = ""
in_quotes = False
for c in sql:
if c == "'":
in_quotes = not in_quotes
if c == "?" and not in_quotes:
newsql += "%s"
else:
newsql += c
return newsql