Create context manager for temp download of ssh key file. Tests.

This commit is contained in:
Pablo Martin 2023-01-23 17:54:17 +01:00
parent e4e1f42309
commit f75b832903
2 changed files with 69 additions and 3 deletions

View file

@ -1,6 +1,7 @@
import datetime import datetime
import os import os
from typing import Tuple from typing import Tuple
from contextlib import contextmanager
import prefect import prefect
from prefect import task from prefect import task
@ -64,8 +65,37 @@ def close_trino_connection(trino_connection: trino.dbapi.Connection) -> None:
logger.info("No connection received.") logger.info("No connection received.")
@task(log_stdout=True, nout=2) @contextmanager
def connect_to_dw( def _temp_secret_file_from_s3(
s3_bucket_name: str, s3_file_key: str, local_temp_file_path: str
) -> str:
"""
Downloads a file from S3 and ensures that it will be deleted once the
context is exited from, even in the face of an exception.
:param s3_bucket_name: the bucket where the file lives.
:param s3_file_key: the key of the file within the bucket.
:param local_temp_file_path: the path where the file should be stored
temporarily.
:return: the local file path.
"""
boto3.client("s3").download_file(
s3_bucket_name,
s3_file_key,
local_temp_file_path,
)
try:
yield local_temp_file_path
except Exception as e:
raise e
finally:
# Regardless of what happens in the context manager, we always delete the temp
# copy of the private key.
os.remove(local_temp_file_path)
@task(nout=2)
def connect_to_mysql(
mysql_credentials: dict, mysql_credentials: dict,
use_ssh_tunnel: bool, use_ssh_tunnel: bool,
s3_bucket_name, s3_bucket_name,

View file

@ -1,5 +1,7 @@
import pathlib
from lolafect.lolaconfig import build_lolaconfig from lolafect.lolaconfig import build_lolaconfig
from lolafect.connections import connect_to_trino, close_trino_connection from lolafect.connections import connect_to_trino, close_trino_connection, _temp_secret_file_from_s3
# __ __ _____ _ _ _____ _ _ _____ _ # __ __ _____ _ _ _____ _ _ _____ _
# \ \ / /\ | __ \| \ | |_ _| \ | |/ ____| | # \ \ / /\ | __ \| \ | |_ _| \ | |/ ____| |
@ -25,3 +27,37 @@ def test_that_trino_connect_and_disconnect_works_properly():
connection.cursor().execute("SELECT 1") connection.cursor().execute("SELECT 1")
close_trino_connection.run(trino_connection=connection) close_trino_connection.run(trino_connection=connection)
def test_temporal_download_of_secret_file_works_properly_in_happy_path():
temp_file_name = "test_temp_file"
with _temp_secret_file_from_s3(
TEST_LOLACONFIG.S3_BUCKET_NAME,
s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test,
local_temp_file_path=temp_file_name
) as temp:
temp_file_found_when_in_context_manager = pathlib.Path(temp).exists()
temp_file_missing_when_outside_context_manager = not pathlib.Path(temp_file_name).exists()
assert temp_file_found_when_in_context_manager and temp_file_missing_when_outside_context_manager
def test_temporal_download_of_secret_file_works_properly_even_with_exception():
temp_file_name = "test_temp_file"
try:
with _temp_secret_file_from_s3(
TEST_LOLACONFIG.S3_BUCKET_NAME,
s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test,
local_temp_file_path=temp_file_name
) as temp:
temp_file_found_when_in_context_manager = pathlib.Path(temp).exists()
raise Exception # Something nasty happens within the context manager
except:
pass # We go with the test, ignoring the forced exception
temp_file_missing_when_outside_context_manager = not pathlib.Path(temp_file_name).exists()
assert temp_file_found_when_in_context_manager and temp_file_missing_when_outside_context_manager