Create context manager for temp download of ssh key file. Tests.
This commit is contained in:
parent
e4e1f42309
commit
f75b832903
2 changed files with 69 additions and 3 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import datetime
|
||||
import os
|
||||
from typing import Tuple
|
||||
from contextlib import contextmanager
|
||||
|
||||
import prefect
|
||||
from prefect import task
|
||||
|
|
@ -64,8 +65,37 @@ def close_trino_connection(trino_connection: trino.dbapi.Connection) -> None:
|
|||
logger.info("No connection received.")
|
||||
|
||||
|
||||
@task(log_stdout=True, nout=2)
|
||||
def connect_to_dw(
|
||||
@contextmanager
|
||||
def _temp_secret_file_from_s3(
|
||||
s3_bucket_name: str, s3_file_key: str, local_temp_file_path: str
|
||||
) -> str:
|
||||
"""
|
||||
Downloads a file from S3 and ensures that it will be deleted once the
|
||||
context is exited from, even in the face of an exception.
|
||||
|
||||
:param s3_bucket_name: the bucket where the file lives.
|
||||
:param s3_file_key: the key of the file within the bucket.
|
||||
:param local_temp_file_path: the path where the file should be stored
|
||||
temporarily.
|
||||
:return: the local file path.
|
||||
"""
|
||||
boto3.client("s3").download_file(
|
||||
s3_bucket_name,
|
||||
s3_file_key,
|
||||
local_temp_file_path,
|
||||
)
|
||||
try:
|
||||
yield local_temp_file_path
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
# Regardless of what happens in the context manager, we always delete the temp
|
||||
# copy of the private key.
|
||||
os.remove(local_temp_file_path)
|
||||
|
||||
|
||||
@task(nout=2)
|
||||
def connect_to_mysql(
|
||||
mysql_credentials: dict,
|
||||
use_ssh_tunnel: bool,
|
||||
s3_bucket_name,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import pathlib
|
||||
|
||||
from lolafect.lolaconfig import build_lolaconfig
|
||||
from lolafect.connections import connect_to_trino, close_trino_connection
|
||||
from lolafect.connections import connect_to_trino, close_trino_connection, _temp_secret_file_from_s3
|
||||
|
||||
# __ __ _____ _ _ _____ _ _ _____ _
|
||||
# \ \ / /\ | __ \| \ | |_ _| \ | |/ ____| |
|
||||
|
|
@ -25,3 +27,37 @@ def test_that_trino_connect_and_disconnect_works_properly():
|
|||
connection.cursor().execute("SELECT 1")
|
||||
|
||||
close_trino_connection.run(trino_connection=connection)
|
||||
|
||||
|
||||
def test_temporal_download_of_secret_file_works_properly_in_happy_path():
|
||||
|
||||
temp_file_name = "test_temp_file"
|
||||
|
||||
with _temp_secret_file_from_s3(
|
||||
TEST_LOLACONFIG.S3_BUCKET_NAME,
|
||||
s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test,
|
||||
local_temp_file_path=temp_file_name
|
||||
) as temp:
|
||||
temp_file_found_when_in_context_manager = pathlib.Path(temp).exists()
|
||||
|
||||
temp_file_missing_when_outside_context_manager = not pathlib.Path(temp_file_name).exists()
|
||||
|
||||
assert temp_file_found_when_in_context_manager and temp_file_missing_when_outside_context_manager
|
||||
|
||||
def test_temporal_download_of_secret_file_works_properly_even_with_exception():
|
||||
temp_file_name = "test_temp_file"
|
||||
|
||||
try:
|
||||
with _temp_secret_file_from_s3(
|
||||
TEST_LOLACONFIG.S3_BUCKET_NAME,
|
||||
s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test,
|
||||
local_temp_file_path=temp_file_name
|
||||
) as temp:
|
||||
temp_file_found_when_in_context_manager = pathlib.Path(temp).exists()
|
||||
raise Exception # Something nasty happens within the context manager
|
||||
except:
|
||||
pass # We go with the test, ignoring the forced exception
|
||||
|
||||
temp_file_missing_when_outside_context_manager = not pathlib.Path(temp_file_name).exists()
|
||||
|
||||
assert temp_file_found_when_in_context_manager and temp_file_missing_when_outside_context_manager
|
||||
Loading…
Add table
Add a link
Reference in a new issue