diff --git a/CHANGELOG.md b/CHANGELOG.md index 525db12..30b1c5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,13 @@ All notable changes to this project will be documented in this file. +## [0.5.0] - 2023-04-26 + +### Added + +- Added tasks `begin_sql_transaction` and `end_sql_transaction` to the `utils`module. These enable the management of SQL + transactions in flows. It also allows for dry running SQL statements. + ## [0.4.0] - 2023-02-08 ### Added diff --git a/README.md b/README.md index 134b28d..e3abcf9 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,31 @@ with Flow(...) as flow: close_ssh_tunnel.run(tunnel=tunnel, upstream_tasks=[mysql_closed]) ``` +**Use SQL transactions and dry running** + +```python +from lolafect.connections import connect_to_mysql, close_mysql_connection +from lolafect.utils import begin_sql_transaction, end_sql_transaction + +with Flow(...) as flow: + connection = connect_to_mysql( + mysql_credentials={...}, # You probably want to get this from TEST_LOLACONFIG.DW_CREDENTIALS + ) + transaction_started = begin_sql_transaction(connection) + task_result = some_task_that_needs_mysql( + connection=connection, + upstream_task=[transaction_started] + ) + transaction_finished = end_sql_transaction( + connection, + dry_run=False, # True means rollback, False means commit changes + upstream_tasks=[task_result] + ) + + close_mysql_connection(connection=connection, upstream_tasks=[transaction_finished]) + +``` + ### Use Great Expectations **Run a Great Expectations validation on a MySQL query** @@ -143,28 +168,6 @@ with Flow(...) as flow: print("The data is bad!!!") ``` -**Run a Great Expectations validation on a Trino query** - -```python -from lolafect.data_testing import run_data_test_on_trino - -with Flow(...) as flow: - - my_query = """SELECT something FROM somewhere""" - my_expectations = {...} # A bunch of things you want to validate on the result of the query - - validation_results = run_data_test_on_trino( - name="my-cool-validation", - trino_credentials={...}, - query=my_query, - expectations=my_expectations - ) - - if not validation_results["success"]: - print("The data is bad!!!") -``` - - ### Slack **Send a warning message to slack if your tasks fails** diff --git a/lolafect/__version__.py b/lolafect/__version__.py index 6a9beea..3d18726 100644 --- a/lolafect/__version__.py +++ b/lolafect/__version__.py @@ -1 +1 @@ -__version__ = "0.4.0" +__version__ = "0.5.0" diff --git a/lolafect/data_testing.py b/lolafect/data_testing.py index 9e3e777..4216d4b 100644 --- a/lolafect/data_testing.py +++ b/lolafect/data_testing.py @@ -60,51 +60,6 @@ def run_data_test_on_mysql( return results -@task() -def run_data_test_on_trino( - name: str, - trino_credentials: dict, - query: str, - expectation_configurations: List[ExpectationConfiguration], - great_expectations_s3_bucket: str = DEFAULT_GREAT_EXPECTATIONS_S3_BUCKET, -) -> dict: - """ - Validate a query and an expectation suite against a given Trino server. - - :param name: a unique name for the data test. - :param trino_credentials: credentials for the Trino cluster. - :param query: the query to test against. - :param expectation_configurations: the expectations on the dataset. - :param great_expectations_s3_bucket: the bucket where Great Expectations - files live. - :return: the result of the data test. - """ - logger = prefect.context.get("logger") - - logger.info("Creating data context.") - data_context = _create_in_memory_data_context_for_trino( - trino_credentials, great_expectations_s3_bucket - ) - logger.info("Data context created.") - logger.info("Creating expectation suite.") - data_context = _create_expectation_suite( - data_context, name, expectation_configurations - ) - logger.info("Expectation suite created.") - logger.info("Creating checkpoint.") - data_context = _create_checkpoint( - data_context, - f"{trino_credentials['host']}:{trino_credentials['port']}", - query, - name, - ) - logger.info("Checkpoint created.") - logger.info("Running checkpoint.") - results = data_context.run_checkpoint(f"{name}_checkpoint") - logger.info("Checkpoint finished.") - logger.info(f"Validation result: {results['success']}") - - return results def _create_in_memory_data_context_for_mysql( mysql_credentials: dict, @@ -156,58 +111,6 @@ def _create_in_memory_data_context_for_mysql( return data_context -def _create_in_memory_data_context_for_trino( - trino_credentials: dict, - great_expectations_s3_bucket: str, -) -> AbstractDataContext: - """ - Create a DataContext without a YAML config file and specify a Trino - datasource. - - :param trino_credentials: the creds to the mysql where the query will be - executed. - :param great_expectations_s3_bucket: the name of the bucket where Great - Exepctations files while be stored. - :return: the data context. - """ - - data_context = BaseDataContext( - project_config=DataContextConfig( - datasources={ - f"{trino_credentials['host']}:{trino_credentials['port']}": DatasourceConfig( - class_name="Datasource", - execution_engine={ - "class_name": "SqlAlchemyExecutionEngine", - "connection_string": f"trino://%s:%s@%s:%s/%s/%s" - % ( - trino_credentials["user"], - urlquote(trino_credentials["password"]), - trino_credentials["host"], - trino_credentials["port"], - trino_credentials["catalog"], - trino_credentials["schema"], - ), - }, - data_connectors={ - "default_runtime_data_connector_name": { - "class_name": "RuntimeDataConnector", - "batch_identifiers": ["default_identifier_name"], - }, - "default_inferred_data_connector_name": { - "class_name": "InferredAssetSqlDataConnector", - "name": "whole_table", - }, - }, - ) - }, - store_backend_defaults=S3StoreBackendDefaults( - default_bucket_name=great_expectations_s3_bucket - ), - ) - ) - - return data_context - def _create_expectation_suite( data_context: AbstractDataContext, diff --git a/lolafect/utils.py b/lolafect/utils.py index 621e59d..2b7e460 100644 --- a/lolafect/utils.py +++ b/lolafect/utils.py @@ -1,5 +1,8 @@ import json +from typing import Any +import prefect +from prefect import task class S3FileReader: """ @@ -22,3 +25,41 @@ class S3FileReader: .read() .decode("utf-8") ) + +@task() +def begin_sql_transaction(connection: Any) -> None: + """ + Start a SQL transaction in the passed connection. The task is agnostic to + the SQL engine being used. As long as the connection object implements a + begin() method, this task will work. + + :param connection: the connection to some database. + :return: None + """ + logger = prefect.context.get("logger") + logger.info(f"Starting SQL transaction with connection: {connection}.") + connection.begin() + + +@task() +def end_sql_transaction(connection: Any, dry_run: bool = False) -> None: + """ + Finish a SQL transaction, either by rolling it back or by committing it. + The task is agnostic to the SQL engine being used. As long as the + connection object implements a `commit` and a `rollback` method, this task + will work. + + :param connection: the connection to some database. + :param dry_run: a flag indicating if persistence is desired. If dry_run + is True, changes will be rolledback. Otherwise, they will be committed. + :return: None + """ + logger = prefect.context.get("logger") + logger.info(f"Using connection: {connection}.") + + if dry_run: + connection.rollback() + logger.info("Dry-run mode activated. Rolling back the transaction.") + else: + logger.info("Committing the transaction.") + connection.commit() diff --git a/tests/test_integration/test_data_testing.py b/tests/test_integration/test_data_testing.py index a6c564a..8b456f2 100644 --- a/tests/test_integration/test_data_testing.py +++ b/tests/test_integration/test_data_testing.py @@ -1,7 +1,7 @@ from great_expectations.core.expectation_configuration import ExpectationConfiguration from lolafect.lolaconfig import build_lolaconfig -from lolafect.data_testing import run_data_test_on_mysql, run_data_test_on_trino +from lolafect.data_testing import run_data_test_on_mysql from lolafect.connections import open_ssh_tunnel_with_s3_pkey, close_ssh_tunnel # __ __ _____ _ _ _____ _ _ _____ _ @@ -18,139 +18,100 @@ from lolafect.connections import open_ssh_tunnel_with_s3_pkey, close_ssh_tunnel # are working properly. TEST_LOLACONFIG = build_lolaconfig(flow_name="testing-suite") -#1 AS a_one, -# 'lol' AS a_string, -# NULL AS a_null -TEST_QUERY = """ -SELECT * -from app_lm_mysql_pl.comprea.market -where id = 1 -""" -TEST_EXPECTATIONS_THAT_FIT_DATA = [ - ExpectationConfiguration( - expectation_type="expect_column_values_to_be_between", - kwargs={"column": "a_one", "min_value": 1, "max_value": 1}, - ), - ExpectationConfiguration( - expectation_type="expect_column_values_to_match_like_pattern", - kwargs={"column": "a_string", "like_pattern": "%lol%"}, - ), - ExpectationConfiguration( - expectation_type="expect_column_values_to_be_null", - kwargs={"column": "a_null"}, - ), -] -TEST_EXPECTATIONS_THAT_DONT_FIT_DATA = [ - ExpectationConfiguration( - expectation_type="expect_column_values_to_be_between", - kwargs={"column": "a_one", "min_value": 2, "max_value": 2}, - ), - ExpectationConfiguration( - expectation_type="expect_column_values_to_match_like_pattern", - kwargs={"column": "a_string", "like_pattern": "%xD%"}, - ), - ExpectationConfiguration( - expectation_type="expect_column_values_to_not_be_null", - kwargs={"column": "a_null"}, - ), -] +def test_validation_on_mysql_succeeds(): -# -# def test_validation_on_mysql_succeeds(): -# ssh_tunnel = open_ssh_tunnel_with_s3_pkey.run( -# s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME, -# ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS, -# remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"], -# remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"], -# ) -# -# validation_result = run_data_test_on_mysql.run( -# name="lolafect-testing-test_validation_on_mysql_succeeds", -# mysql_credentials={ -# "host": ssh_tunnel.local_bind_address[0], -# "port": ssh_tunnel.local_bind_address[1], -# "user": TEST_LOLACONFIG.DW_CREDENTIALS["user"], -# "password": TEST_LOLACONFIG.DW_CREDENTIALS["password"], -# "db": TEST_LOLACONFIG.DW_CREDENTIALS["default_db"], -# }, -# query=TEST_QUERY, -# expectation_configurations=TEST_EXPECTATIONS_THAT_FIT_DATA, -# ) -# -# closed_tunnel = close_ssh_tunnel.run(ssh_tunnel) -# -# data_test_passed = validation_result["success"] == True -# -# assert data_test_passed -# -# -# def test_validation_on_mysql_fails(): -# ssh_tunnel = open_ssh_tunnel_with_s3_pkey.run( -# s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME, -# ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS, -# remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"], -# remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"], -# ) -# -# validation_result = run_data_test_on_mysql.run( -# name="lolafect-testing-test_validation_on_mysql_fails", -# mysql_credentials={ -# "host": ssh_tunnel.local_bind_address[0], -# "port": ssh_tunnel.local_bind_address[1], -# "user": TEST_LOLACONFIG.DW_CREDENTIALS["user"], -# "password": TEST_LOLACONFIG.DW_CREDENTIALS["password"], -# "db": TEST_LOLACONFIG.DW_CREDENTIALS["default_db"], -# }, -# query=TEST_QUERY, -# expectation_configurations=TEST_EXPECTATIONS_THAT_DONT_FIT_DATA, -# ) -# -# closed_tunnel = close_ssh_tunnel.run(ssh_tunnel) -# -# data_test_failed = validation_result["success"] == False -# -# assert data_test_failed -# + test_query = """ + SELECT 1 AS a_one, + "lol" AS a_string, + NULL AS a_null + """ + test_expectations = [ + ExpectationConfiguration( + expectation_type="expect_column_values_to_be_between", + kwargs={"column": "a_one", "min_value": 1, "max_value": 1}, + ), + ExpectationConfiguration( + expectation_type="expect_column_values_to_match_like_pattern", + kwargs={"column": "a_string", "like_pattern": "%lol%"}, + ), + ExpectationConfiguration( + expectation_type="expect_column_values_to_be_null", + kwargs={"column": "a_null"}, + ), + ] -def test_validation_on_trino_succeeds(): - validation_result = run_data_test_on_trino.run( - name="lolafect-testing-test_validation_on_mysql_fails", - trino_credentials={ - "host": TEST_LOLACONFIG.TRINO_CREDENTIALS["host"], - "port": TEST_LOLACONFIG.TRINO_CREDENTIALS["port"], - "user": TEST_LOLACONFIG.TRINO_CREDENTIALS["user"], - "password": TEST_LOLACONFIG.TRINO_CREDENTIALS["password"], - "catalog": "data_dw", - "schema": "sandbox" - }, - query=TEST_QUERY, - expectation_configurations=TEST_EXPECTATIONS_THAT_FIT_DATA, + ssh_tunnel = open_ssh_tunnel_with_s3_pkey.run( + s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME, + ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS, + remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"], + remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"], ) - print("###############\n" * 20) - print(validation_result) + + validation_result = run_data_test_on_mysql.run( + name="lolafect-testing-test_validation_on_mysql_succeeds", + mysql_credentials={ + "host": ssh_tunnel.local_bind_address[0], + "port": ssh_tunnel.local_bind_address[1], + "user": TEST_LOLACONFIG.DW_CREDENTIALS["user"], + "password": TEST_LOLACONFIG.DW_CREDENTIALS["password"], + "db": TEST_LOLACONFIG.DW_CREDENTIALS["default_db"], + }, + query=test_query, + expectation_configurations=test_expectations, + ) + + closed_tunnel = close_ssh_tunnel.run(ssh_tunnel) data_test_passed = validation_result["success"] == True assert data_test_passed -def test_validation_on_trino_fails(): - validation_result = run_data_test_on_trino.run( - name="lolafect-testing-test_validation_on_mysql_fails", - trino_credentials={ - "host": TEST_LOLACONFIG.TRINO_CREDENTIALS["host"], - "port": TEST_LOLACONFIG.TRINO_CREDENTIALS["port"], - "user": TEST_LOLACONFIG.TRINO_CREDENTIALS["user"], - "password": TEST_LOLACONFIG.TRINO_CREDENTIALS["password"], - "catalog": "data_dw", - "schema": "sandbox" - }, - query=TEST_QUERY, - expectation_configurations=TEST_EXPECTATIONS_THAT_DONT_FIT_DATA, +def test_validation_on_mysql_fails(): + test_query = """ + SELECT 1 AS a_one, + "lol" AS a_string, + NULL AS a_null + """ + test_expectations = [ + ExpectationConfiguration( + expectation_type="expect_column_values_to_be_between", + kwargs={"column": "a_one", "min_value": 2, "max_value": 2}, + ), + ExpectationConfiguration( + expectation_type="expect_column_values_to_match_like_pattern", + kwargs={"column": "a_string", "like_pattern": "%xD%"}, + ), + ExpectationConfiguration( + expectation_type="expect_column_values_to_not_be_null", + kwargs={"column": "a_null"}, + ), + ] + + ssh_tunnel = open_ssh_tunnel_with_s3_pkey.run( + s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME, + ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS, + remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"], + remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"], ) + validation_result = run_data_test_on_mysql.run( + name="lolafect-testing-test_validation_on_mysql_fails", + mysql_credentials={ + "host": ssh_tunnel.local_bind_address[0], + "port": ssh_tunnel.local_bind_address[1], + "user": TEST_LOLACONFIG.DW_CREDENTIALS["user"], + "password": TEST_LOLACONFIG.DW_CREDENTIALS["password"], + "db": TEST_LOLACONFIG.DW_CREDENTIALS["default_db"], + }, + query=test_query, + expectation_configurations=test_expectations, + ) + + closed_tunnel = close_ssh_tunnel.run(ssh_tunnel) + data_test_failed = validation_result["success"] == False assert data_test_failed diff --git a/tests/test_integration/test_utils.py b/tests/test_integration/test_utils.py new file mode 100644 index 0000000..1ce5f85 --- /dev/null +++ b/tests/test_integration/test_utils.py @@ -0,0 +1,123 @@ +import pytest + +from lolafect.lolaconfig import build_lolaconfig +from lolafect.connections import ( + open_ssh_tunnel_with_s3_pkey, + get_local_bind_address_from_ssh_tunnel, + close_ssh_tunnel, + connect_to_mysql, + close_mysql_connection, +) +from lolafect.utils import begin_sql_transaction, end_sql_transaction + +# __ __ _____ _ _ _____ _ _ _____ _ +# \ \ / /\ | __ \| \ | |_ _| \ | |/ ____| | +# \ \ /\ / / \ | |__) | \| | | | | \| | | __| | +# \ \/ \/ / /\ \ | _ /| . ` | | | | . ` | | |_ | | +# \ /\ / ____ \| | \ \| |\ |_| |_| |\ | |__| |_| +# \/ \/_/ \_\_| \_\_| \_|_____|_| \_|\_____(_) +# This testing suite requires: +# - The calling shell to have permission in AWS +# - The calling shell to be within the Mercadão network +# - Do not use this tests as part of CI/CD pipelines since they are not idempotent and +# rely external resources. Instead, use them manually to check yourself that things +# are working properly. + + +TEST_LOLACONFIG = build_lolaconfig(flow_name="testing-suite") + + +@pytest.fixture +def connection_with_test_table(): + """ + Connects to DW, creates a test table in the sandbox env, and yields the + connection to the test. + + After the test, the table is dropped and the connection is closed. + """ + test_local_bind_host = "127.0.0.1" + test_local_bind_port = 12345 + + tunnel = open_ssh_tunnel_with_s3_pkey.run( + s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME, + ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS, + remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"], + remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"], + local_bind_host=test_local_bind_host, + local_bind_port=test_local_bind_port, + ) + + connection = connect_to_mysql.run( + mysql_credentials=TEST_LOLACONFIG.DW_CREDENTIALS, + overriding_host_and_port=get_local_bind_address_from_ssh_tunnel.run( + tunnel=tunnel + ), + ) + cursor = connection.cursor() + cursor.execute(""" + CREATE TABLE sandbox.lolafect_transaction_test_table + ( + a_test_column INT + ) + """) + + # Connection and table ready for tests + yield connection # Test happens now + # Test finished, time to remove stuff and close connection + + cursor.execute(""" + DROP TABLE sandbox.lolafect_transaction_test_table + """ + ) + close_mysql_connection.run(connection=connection) + close_ssh_tunnel.run(tunnel=tunnel) + + +def test_sql_transaction_persists_changes_properly(connection_with_test_table): + cursor = connection_with_test_table.cursor() + + cursor.execute(""" + SELECT a_test_column + FROM sandbox.lolafect_transaction_test_table + """) + table_is_empty_at_first = not bool(cursor.fetchall()) # An empty tuple yields False + + begin_sql_transaction.run(connection=connection_with_test_table) + cursor.execute(""" + INSERT INTO sandbox.lolafect_transaction_test_table (a_test_column) + VALUES (1) + """) + end_sql_transaction.run(connection=connection_with_test_table, dry_run=False) + + cursor.execute(""" + SELECT a_test_column + FROM sandbox.lolafect_transaction_test_table + """) + table_has_a_record_after_commit = bool(cursor.fetchall()) # A non-empty tuple yields True + + assert table_is_empty_at_first and table_has_a_record_after_commit + + +def test_sql_transaction_rollbacks_changes_properly(connection_with_test_table): + cursor = connection_with_test_table.cursor() + + cursor.execute(""" + SELECT a_test_column + FROM sandbox.lolafect_transaction_test_table + """) + table_is_empty_at_first = not bool(cursor.fetchall()) # An empty tuple yields False + + begin_sql_transaction.run(connection=connection_with_test_table) + cursor.execute(""" + INSERT INTO sandbox.lolafect_transaction_test_table (a_test_column) + VALUES (1) + """) + end_sql_transaction.run(connection=connection_with_test_table, dry_run=True) + + cursor.execute(""" + SELECT a_test_column + FROM sandbox.lolafect_transaction_test_table + """) + table_is_still_empty_after_rollback = not bool(cursor.fetchall()) # A tuple yields False + + assert table_is_empty_at_first and table_is_still_empty_after_rollback