Create context manager for temp download of ssh key file. Tests.

This commit is contained in:
Pablo Martin 2023-01-23 17:54:17 +01:00
parent e4e1f42309
commit f75b832903
2 changed files with 69 additions and 3 deletions

View file

@ -1,6 +1,7 @@
import datetime
import os
from typing import Tuple
from contextlib import contextmanager
import prefect
from prefect import task
@ -64,8 +65,37 @@ def close_trino_connection(trino_connection: trino.dbapi.Connection) -> None:
logger.info("No connection received.")
@task(log_stdout=True, nout=2)
def connect_to_dw(
@contextmanager
def _temp_secret_file_from_s3(
s3_bucket_name: str, s3_file_key: str, local_temp_file_path: str
) -> str:
"""
Downloads a file from S3 and ensures that it will be deleted once the
context is exited from, even in the face of an exception.
:param s3_bucket_name: the bucket where the file lives.
:param s3_file_key: the key of the file within the bucket.
:param local_temp_file_path: the path where the file should be stored
temporarily.
:return: the local file path.
"""
boto3.client("s3").download_file(
s3_bucket_name,
s3_file_key,
local_temp_file_path,
)
try:
yield local_temp_file_path
except Exception as e:
raise e
finally:
# Regardless of what happens in the context manager, we always delete the temp
# copy of the private key.
os.remove(local_temp_file_path)
@task(nout=2)
def connect_to_mysql(
mysql_credentials: dict,
use_ssh_tunnel: bool,
s3_bucket_name,

View file

@ -1,5 +1,7 @@
import pathlib
from lolafect.lolaconfig import build_lolaconfig
from lolafect.connections import connect_to_trino, close_trino_connection
from lolafect.connections import connect_to_trino, close_trino_connection, _temp_secret_file_from_s3
# __ __ _____ _ _ _____ _ _ _____ _
# \ \ / /\ | __ \| \ | |_ _| \ | |/ ____| |
@ -25,3 +27,37 @@ def test_that_trino_connect_and_disconnect_works_properly():
connection.cursor().execute("SELECT 1")
close_trino_connection.run(trino_connection=connection)
def test_temporal_download_of_secret_file_works_properly_in_happy_path():
temp_file_name = "test_temp_file"
with _temp_secret_file_from_s3(
TEST_LOLACONFIG.S3_BUCKET_NAME,
s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test,
local_temp_file_path=temp_file_name
) as temp:
temp_file_found_when_in_context_manager = pathlib.Path(temp).exists()
temp_file_missing_when_outside_context_manager = not pathlib.Path(temp_file_name).exists()
assert temp_file_found_when_in_context_manager and temp_file_missing_when_outside_context_manager
def test_temporal_download_of_secret_file_works_properly_even_with_exception():
temp_file_name = "test_temp_file"
try:
with _temp_secret_file_from_s3(
TEST_LOLACONFIG.S3_BUCKET_NAME,
s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test,
local_temp_file_path=temp_file_name
) as temp:
temp_file_found_when_in_context_manager = pathlib.Path(temp).exists()
raise Exception # Something nasty happens within the context manager
except:
pass # We go with the test, ignoring the forced exception
temp_file_missing_when_outside_context_manager = not pathlib.Path(temp_file_name).exists()
assert temp_file_found_when_in_context_manager and temp_file_missing_when_outside_context_manager