62 lines
1.8 KiB
Python
62 lines
1.8 KiB
Python
import json
|
|
from typing import Any
|
|
|
|
import prefect
|
|
from prefect import task
|
|
|
|
class S3FileReader:
|
|
"""
|
|
An S3 client along with a few reading utils.
|
|
"""
|
|
|
|
def __init__(self, s3_client):
|
|
self.s3_client = s3_client
|
|
|
|
def read_json_from_s3_file(self, bucket: str, key: str) -> dict:
|
|
"""
|
|
Read a JSON file from an S3 location and return contents as a dict.
|
|
|
|
:param bucket: the name of the bucket where the file is stored.
|
|
:param key: the path to the file within the bucket.
|
|
:return: the file contents.
|
|
"""
|
|
return json.loads(
|
|
self.s3_client.get_object(Bucket=bucket, Key=key)["Body"]
|
|
.read()
|
|
.decode("utf-8")
|
|
)
|
|
|
|
@task()
|
|
def begin_sql_transaction(connection: Any) -> None:
|
|
"""
|
|
Start a SQL transaction in the passed connection. The task is agnostic to
|
|
the SQL engine being used. As long as it implements a begin() method, this
|
|
will work.
|
|
|
|
:param connection: the connection to some database.
|
|
:return: None
|
|
"""
|
|
logger = prefect.context.get("logger")
|
|
logger.info(f"Starting SQL transaction with connection: {connection}.")
|
|
connection.begin()
|
|
|
|
|
|
@task()
|
|
def end_sql_transaction(connection: Any, dry_run: bool = False) -> None:
|
|
"""
|
|
Finish a SQL transaction, either by rolling it back or by committing it.
|
|
|
|
:param connection: the connection to some database.
|
|
:param dry_run: a flag indicating if persistence is desired. If dry_run
|
|
is True, changes will be rollbacked.
|
|
:return: None
|
|
"""
|
|
logger = prefect.context.get("logger")
|
|
logger.info(f"Using connection: {connection}.")
|
|
|
|
if dry_run:
|
|
connection.rollback()
|
|
logger.info("Dry-run mode activated. Rolling back the transaction.")
|
|
else:
|
|
logger.info("Committing the transaction.")
|
|
connection.commit()
|