From f75b832903fa30475de6ff3c7c57852f376c15dc Mon Sep 17 00:00:00 2001 From: Pablo Martin Date: Mon, 23 Jan 2023 17:54:17 +0100 Subject: [PATCH] Create context manager for temp download of ssh key file. Tests. --- lolafect/connections.py | 34 +++++++++++++++++-- tests/test_integration/test_connections.py | 38 +++++++++++++++++++++- 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/lolafect/connections.py b/lolafect/connections.py index 34cc490..d823db9 100644 --- a/lolafect/connections.py +++ b/lolafect/connections.py @@ -1,6 +1,7 @@ import datetime import os from typing import Tuple +from contextlib import contextmanager import prefect from prefect import task @@ -64,8 +65,37 @@ def close_trino_connection(trino_connection: trino.dbapi.Connection) -> None: logger.info("No connection received.") -@task(log_stdout=True, nout=2) -def connect_to_dw( +@contextmanager +def _temp_secret_file_from_s3( + s3_bucket_name: str, s3_file_key: str, local_temp_file_path: str +) -> str: + """ + Downloads a file from S3 and ensures that it will be deleted once the + context is exited from, even in the face of an exception. + + :param s3_bucket_name: the bucket where the file lives. + :param s3_file_key: the key of the file within the bucket. + :param local_temp_file_path: the path where the file should be stored + temporarily. + :return: the local file path. + """ + boto3.client("s3").download_file( + s3_bucket_name, + s3_file_key, + local_temp_file_path, + ) + try: + yield local_temp_file_path + except Exception as e: + raise e + finally: + # Regardless of what happens in the context manager, we always delete the temp + # copy of the private key. + os.remove(local_temp_file_path) + + +@task(nout=2) +def connect_to_mysql( mysql_credentials: dict, use_ssh_tunnel: bool, s3_bucket_name, diff --git a/tests/test_integration/test_connections.py b/tests/test_integration/test_connections.py index 74944e0..a3349a4 100644 --- a/tests/test_integration/test_connections.py +++ b/tests/test_integration/test_connections.py @@ -1,5 +1,7 @@ +import pathlib + from lolafect.lolaconfig import build_lolaconfig -from lolafect.connections import connect_to_trino, close_trino_connection +from lolafect.connections import connect_to_trino, close_trino_connection, _temp_secret_file_from_s3 # __ __ _____ _ _ _____ _ _ _____ _ # \ \ / /\ | __ \| \ | |_ _| \ | |/ ____| | @@ -25,3 +27,37 @@ def test_that_trino_connect_and_disconnect_works_properly(): connection.cursor().execute("SELECT 1") close_trino_connection.run(trino_connection=connection) + + +def test_temporal_download_of_secret_file_works_properly_in_happy_path(): + + temp_file_name = "test_temp_file" + + with _temp_secret_file_from_s3( + TEST_LOLACONFIG.S3_BUCKET_NAME, + s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test, + local_temp_file_path=temp_file_name + ) as temp: + temp_file_found_when_in_context_manager = pathlib.Path(temp).exists() + + temp_file_missing_when_outside_context_manager = not pathlib.Path(temp_file_name).exists() + + assert temp_file_found_when_in_context_manager and temp_file_missing_when_outside_context_manager + +def test_temporal_download_of_secret_file_works_properly_even_with_exception(): + temp_file_name = "test_temp_file" + + try: + with _temp_secret_file_from_s3( + TEST_LOLACONFIG.S3_BUCKET_NAME, + s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test, + local_temp_file_path=temp_file_name + ) as temp: + temp_file_found_when_in_context_manager = pathlib.Path(temp).exists() + raise Exception # Something nasty happens within the context manager + except: + pass # We go with the test, ignoring the forced exception + + temp_file_missing_when_outside_context_manager = not pathlib.Path(temp_file_name).exists() + + assert temp_file_found_when_in_context_manager and temp_file_missing_when_outside_context_manager \ No newline at end of file