"""
crate_anon/crateweb/research/tests/sql_writer_tests.py
===============================================================================
Copyright (C) 2015, University of Cambridge, Department of Psychiatry.
Created by Rudolf Cardinal (rnc1001@cam.ac.uk).
This file is part of CRATE.
CRATE is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
CRATE is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with CRATE. If not, see <https://www.gnu.org/licenses/>.
===============================================================================
Test sql_writer.py.
"""
import logging
from unittest import TestCase
from cardinal_pythonlib.sql.sql_grammar_factory import make_grammar
from cardinal_pythonlib.sqlalchemy.dialect import SqlaDialectName
from crate_anon.common.sql import ColumnId, WhereCondition
from crate_anon.crateweb.research.errors import DatabaseStructureNotUnderstood
from crate_anon.crateweb.research.sql_writer import (
add_to_select,
SelectElement,
)
log = logging.getLogger(__name__)
[docs]class AddToSelectTests(TestCase):
[docs] def setUp(self) -> None:
super().setUp()
self.grammar = make_grammar(SqlaDialectName.MYSQL)
def assert_query_equal(self, actual: str, expected: str) -> None:
# Test a query string matches the expected value, ignoring
# whitespace differences
actual = actual.replace(" ,", ",")
actual = " ".join(actual.split())
self.assertEqual(actual, expected)
def test_second_table_joined(self) -> None:
sql = add_to_select(
"SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5",
grammar=self.grammar,
select_elements=[
SelectElement(column_id=ColumnId(table="t2", column="c"))
],
# magic_join requires DB knowledge hence Django
magic_join=False,
)
self.assert_query_equal(
sql,
(
"SELECT t1.a, t1.b, t2.c FROM t1 NATURAL JOIN t2 "
"WHERE t1.col1 > 5"
),
)
def test_another_column_added(self) -> None:
sql = add_to_select(
"SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5",
grammar=self.grammar,
select_elements=[
SelectElement(column_id=ColumnId(table="t1", column="a"))
],
)
self.assert_query_equal(
sql, "SELECT t1.a, t1.b, t1.a FROM t1 WHERE t1.col1 > 5"
)
def test_select_element_added_to_nothing(self) -> None:
sql = add_to_select(
"",
grammar=self.grammar,
select_elements=[
SelectElement(column_id=ColumnId(table="t2", column="c"))
],
)
self.assert_query_equal(sql, "SELECT t2.c FROM t2")
def test_first_where_condition_added(self) -> None:
sql = add_to_select(
"SELECT t1.a, t1.b FROM t1",
grammar=self.grammar,
where_conditions=[WhereCondition(raw_sql="t1.col1 > 5")],
)
self.assert_query_equal(
sql, "SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5"
)
def test_second_where_condition_added(self) -> None:
sql = add_to_select(
"SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5",
grammar=self.grammar,
where_conditions=[WhereCondition(raw_sql="t1.col2 < 3")],
)
self.assert_query_equal(
sql, "SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5 AND t1.col2 < 3"
)
def test_third_where_condition_added(self) -> None:
sql = add_to_select(
"SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 5 AND t3.col99 = 100",
grammar=self.grammar,
where_conditions=[WhereCondition(raw_sql="t1.col2 < 3")],
)
self.assert_query_equal(
sql,
(
"SELECT t1.a, t1.b FROM t1 "
"WHERE t1.col1 > 5 AND t3.col99 = 100 AND t1.col2 < 3"
),
)
def test_multiple_wheres_added_to_none(self) -> None:
sql = add_to_select(
"SELECT t1.a, t1.b FROM t1",
grammar=self.grammar,
where_conditions=[
WhereCondition(raw_sql="t1.col1 > 99"),
WhereCondition(raw_sql="t1.col2 < 999"),
],
)
self.assert_query_equal(
sql,
"SELECT t1.a, t1.b FROM t1 WHERE t1.col1 > 99 AND t1.col2 < 999",
)
def test_raises_when_table_does_not_exist(self) -> None:
column_id = ColumnId(
schema="research", table="blobdoc", column="_src_hash"
)
with self.assertRaises(DatabaseStructureNotUnderstood):
add_to_select(
"SELECT foo from bar",
grammar=self.grammar,
select_elements=[SelectElement(column_id=column_id)],
)