From 8f1f3d75e1b0772d1ef5f0b7ea5b570d21f37754 Mon Sep 17 00:00:00 2001 From: Pablo Martin Date: Fri, 21 Apr 2023 12:22:32 +0200 Subject: [PATCH] Implemented the tasks. --- lolafect/utils.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/lolafect/utils.py b/lolafect/utils.py index 621e59d..608f27b 100644 --- a/lolafect/utils.py +++ b/lolafect/utils.py @@ -1,5 +1,8 @@ import json +from typing import Any +import prefect +from prefect import task class S3FileReader: """ @@ -22,3 +25,38 @@ class S3FileReader: .read() .decode("utf-8") ) + +@task() +def begin_sql_transaction(connection: Any) -> None: + """ + Start a SQL transaction in the passed connection. The task is agnostic to + the SQL engine being used. As long as it implements a begin() method, this + will work. + + :param connection: the connection to some database. + :return: None + """ + logger = prefect.context.get("logger") + logger.info(f"Starting SQL transaction with connection: {connection}.") + connection.begin() + + +@task() +def end_sql_transaction(connection: Any, dry_run: bool = False) -> None: + """ + Finish a SQL transaction, either by rolling it back or by committing it. + + :param connection: the connection to some database. + :param dry_run: a flag indicating if persistence is desired. If dry_run + is True, changes will be rollbacked. + :return: None + """ + logger = prefect.context.get("logger") + logger.info(f"Using connection: {connection}.") + + if dry_run: + connection.rollback() + logger.info("Dry-run mode activated. Rolling back the transaction.") + else: + logger.info("Committing the transaction.") + connection.commit()