diff --git a/CHANGELOG.md b/CHANGELOG.md index 525db12..4dc06d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,13 @@ All notable changes to this project will be documented in this file. +## [Unreleased] + +### Added + +- Added tasks `begin_sql_transaction` and `end_sql_transaction` to the `utils`module. These enable the management of SQL + transactions in flows. It also allows for dry running SQL statements. + ## [0.4.0] - 2023-02-08 ### Added diff --git a/README.md b/README.md index d9cbbf1..e3abcf9 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,31 @@ with Flow(...) as flow: close_ssh_tunnel.run(tunnel=tunnel, upstream_tasks=[mysql_closed]) ``` +**Use SQL transactions and dry running** + +```python +from lolafect.connections import connect_to_mysql, close_mysql_connection +from lolafect.utils import begin_sql_transaction, end_sql_transaction + +with Flow(...) as flow: + connection = connect_to_mysql( + mysql_credentials={...}, # You probably want to get this from TEST_LOLACONFIG.DW_CREDENTIALS + ) + transaction_started = begin_sql_transaction(connection) + task_result = some_task_that_needs_mysql( + connection=connection, + upstream_task=[transaction_started] + ) + transaction_finished = end_sql_transaction( + connection, + dry_run=False, # True means rollback, False means commit changes + upstream_tasks=[task_result] + ) + + close_mysql_connection(connection=connection, upstream_tasks=[transaction_finished]) + +``` + ### Use Great Expectations **Run a Great Expectations validation on a MySQL query** diff --git a/lolafect/utils.py b/lolafect/utils.py index 621e59d..2b7e460 100644 --- a/lolafect/utils.py +++ b/lolafect/utils.py @@ -1,5 +1,8 @@ import json +from typing import Any +import prefect +from prefect import task class S3FileReader: """ @@ -22,3 +25,41 @@ class S3FileReader: .read() .decode("utf-8") ) + +@task() +def begin_sql_transaction(connection: Any) -> None: + """ + Start a SQL transaction in the passed connection. The task is agnostic to + the SQL engine being used. As long as the connection object implements a + begin() method, this task will work. + + :param connection: the connection to some database. + :return: None + """ + logger = prefect.context.get("logger") + logger.info(f"Starting SQL transaction with connection: {connection}.") + connection.begin() + + +@task() +def end_sql_transaction(connection: Any, dry_run: bool = False) -> None: + """ + Finish a SQL transaction, either by rolling it back or by committing it. + The task is agnostic to the SQL engine being used. As long as the + connection object implements a `commit` and a `rollback` method, this task + will work. + + :param connection: the connection to some database. + :param dry_run: a flag indicating if persistence is desired. If dry_run + is True, changes will be rolledback. Otherwise, they will be committed. + :return: None + """ + logger = prefect.context.get("logger") + logger.info(f"Using connection: {connection}.") + + if dry_run: + connection.rollback() + logger.info("Dry-run mode activated. Rolling back the transaction.") + else: + logger.info("Committing the transaction.") + connection.commit() diff --git a/tests/test_integration/test_utils.py b/tests/test_integration/test_utils.py new file mode 100644 index 0000000..1ce5f85 --- /dev/null +++ b/tests/test_integration/test_utils.py @@ -0,0 +1,123 @@ +import pytest + +from lolafect.lolaconfig import build_lolaconfig +from lolafect.connections import ( + open_ssh_tunnel_with_s3_pkey, + get_local_bind_address_from_ssh_tunnel, + close_ssh_tunnel, + connect_to_mysql, + close_mysql_connection, +) +from lolafect.utils import begin_sql_transaction, end_sql_transaction + +# __ __ _____ _ _ _____ _ _ _____ _ +# \ \ / /\ | __ \| \ | |_ _| \ | |/ ____| | +# \ \ /\ / / \ | |__) | \| | | | | \| | | __| | +# \ \/ \/ / /\ \ | _ /| . ` | | | | . ` | | |_ | | +# \ /\ / ____ \| | \ \| |\ |_| |_| |\ | |__| |_| +# \/ \/_/ \_\_| \_\_| \_|_____|_| \_|\_____(_) +# This testing suite requires: +# - The calling shell to have permission in AWS +# - The calling shell to be within the Mercadão network +# - Do not use this tests as part of CI/CD pipelines since they are not idempotent and +# rely external resources. Instead, use them manually to check yourself that things +# are working properly. + + +TEST_LOLACONFIG = build_lolaconfig(flow_name="testing-suite") + + +@pytest.fixture +def connection_with_test_table(): + """ + Connects to DW, creates a test table in the sandbox env, and yields the + connection to the test. + + After the test, the table is dropped and the connection is closed. + """ + test_local_bind_host = "127.0.0.1" + test_local_bind_port = 12345 + + tunnel = open_ssh_tunnel_with_s3_pkey.run( + s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME, + ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS, + remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"], + remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"], + local_bind_host=test_local_bind_host, + local_bind_port=test_local_bind_port, + ) + + connection = connect_to_mysql.run( + mysql_credentials=TEST_LOLACONFIG.DW_CREDENTIALS, + overriding_host_and_port=get_local_bind_address_from_ssh_tunnel.run( + tunnel=tunnel + ), + ) + cursor = connection.cursor() + cursor.execute(""" + CREATE TABLE sandbox.lolafect_transaction_test_table + ( + a_test_column INT + ) + """) + + # Connection and table ready for tests + yield connection # Test happens now + # Test finished, time to remove stuff and close connection + + cursor.execute(""" + DROP TABLE sandbox.lolafect_transaction_test_table + """ + ) + close_mysql_connection.run(connection=connection) + close_ssh_tunnel.run(tunnel=tunnel) + + +def test_sql_transaction_persists_changes_properly(connection_with_test_table): + cursor = connection_with_test_table.cursor() + + cursor.execute(""" + SELECT a_test_column + FROM sandbox.lolafect_transaction_test_table + """) + table_is_empty_at_first = not bool(cursor.fetchall()) # An empty tuple yields False + + begin_sql_transaction.run(connection=connection_with_test_table) + cursor.execute(""" + INSERT INTO sandbox.lolafect_transaction_test_table (a_test_column) + VALUES (1) + """) + end_sql_transaction.run(connection=connection_with_test_table, dry_run=False) + + cursor.execute(""" + SELECT a_test_column + FROM sandbox.lolafect_transaction_test_table + """) + table_has_a_record_after_commit = bool(cursor.fetchall()) # A non-empty tuple yields True + + assert table_is_empty_at_first and table_has_a_record_after_commit + + +def test_sql_transaction_rollbacks_changes_properly(connection_with_test_table): + cursor = connection_with_test_table.cursor() + + cursor.execute(""" + SELECT a_test_column + FROM sandbox.lolafect_transaction_test_table + """) + table_is_empty_at_first = not bool(cursor.fetchall()) # An empty tuple yields False + + begin_sql_transaction.run(connection=connection_with_test_table) + cursor.execute(""" + INSERT INTO sandbox.lolafect_transaction_test_table (a_test_column) + VALUES (1) + """) + end_sql_transaction.run(connection=connection_with_test_table, dry_run=True) + + cursor.execute(""" + SELECT a_test_column + FROM sandbox.lolafect_transaction_test_table + """) + table_is_still_empty_after_rollback = not bool(cursor.fetchall()) # A tuple yields False + + assert table_is_empty_at_first and table_is_still_empty_after_rollback