Create context manager for temp download of ssh key file. Tests.
This commit is contained in:
parent
e4e1f42309
commit
f75b832903
2 changed files with 69 additions and 3 deletions
|
|
@ -1,6 +1,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import prefect
|
import prefect
|
||||||
from prefect import task
|
from prefect import task
|
||||||
|
|
@ -64,8 +65,37 @@ def close_trino_connection(trino_connection: trino.dbapi.Connection) -> None:
|
||||||
logger.info("No connection received.")
|
logger.info("No connection received.")
|
||||||
|
|
||||||
|
|
||||||
@task(log_stdout=True, nout=2)
|
@contextmanager
|
||||||
def connect_to_dw(
|
def _temp_secret_file_from_s3(
|
||||||
|
s3_bucket_name: str, s3_file_key: str, local_temp_file_path: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Downloads a file from S3 and ensures that it will be deleted once the
|
||||||
|
context is exited from, even in the face of an exception.
|
||||||
|
|
||||||
|
:param s3_bucket_name: the bucket where the file lives.
|
||||||
|
:param s3_file_key: the key of the file within the bucket.
|
||||||
|
:param local_temp_file_path: the path where the file should be stored
|
||||||
|
temporarily.
|
||||||
|
:return: the local file path.
|
||||||
|
"""
|
||||||
|
boto3.client("s3").download_file(
|
||||||
|
s3_bucket_name,
|
||||||
|
s3_file_key,
|
||||||
|
local_temp_file_path,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
yield local_temp_file_path
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
# Regardless of what happens in the context manager, we always delete the temp
|
||||||
|
# copy of the private key.
|
||||||
|
os.remove(local_temp_file_path)
|
||||||
|
|
||||||
|
|
||||||
|
@task(nout=2)
|
||||||
|
def connect_to_mysql(
|
||||||
mysql_credentials: dict,
|
mysql_credentials: dict,
|
||||||
use_ssh_tunnel: bool,
|
use_ssh_tunnel: bool,
|
||||||
s3_bucket_name,
|
s3_bucket_name,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
|
import pathlib
|
||||||
|
|
||||||
from lolafect.lolaconfig import build_lolaconfig
|
from lolafect.lolaconfig import build_lolaconfig
|
||||||
from lolafect.connections import connect_to_trino, close_trino_connection
|
from lolafect.connections import connect_to_trino, close_trino_connection, _temp_secret_file_from_s3
|
||||||
|
|
||||||
# __ __ _____ _ _ _____ _ _ _____ _
|
# __ __ _____ _ _ _____ _ _ _____ _
|
||||||
# \ \ / /\ | __ \| \ | |_ _| \ | |/ ____| |
|
# \ \ / /\ | __ \| \ | |_ _| \ | |/ ____| |
|
||||||
|
|
@ -25,3 +27,37 @@ def test_that_trino_connect_and_disconnect_works_properly():
|
||||||
connection.cursor().execute("SELECT 1")
|
connection.cursor().execute("SELECT 1")
|
||||||
|
|
||||||
close_trino_connection.run(trino_connection=connection)
|
close_trino_connection.run(trino_connection=connection)
|
||||||
|
|
||||||
|
|
||||||
|
def test_temporal_download_of_secret_file_works_properly_in_happy_path():
|
||||||
|
|
||||||
|
temp_file_name = "test_temp_file"
|
||||||
|
|
||||||
|
with _temp_secret_file_from_s3(
|
||||||
|
TEST_LOLACONFIG.S3_BUCKET_NAME,
|
||||||
|
s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test,
|
||||||
|
local_temp_file_path=temp_file_name
|
||||||
|
) as temp:
|
||||||
|
temp_file_found_when_in_context_manager = pathlib.Path(temp).exists()
|
||||||
|
|
||||||
|
temp_file_missing_when_outside_context_manager = not pathlib.Path(temp_file_name).exists()
|
||||||
|
|
||||||
|
assert temp_file_found_when_in_context_manager and temp_file_missing_when_outside_context_manager
|
||||||
|
|
||||||
|
def test_temporal_download_of_secret_file_works_properly_even_with_exception():
|
||||||
|
temp_file_name = "test_temp_file"
|
||||||
|
|
||||||
|
try:
|
||||||
|
with _temp_secret_file_from_s3(
|
||||||
|
TEST_LOLACONFIG.S3_BUCKET_NAME,
|
||||||
|
s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test,
|
||||||
|
local_temp_file_path=temp_file_name
|
||||||
|
) as temp:
|
||||||
|
temp_file_found_when_in_context_manager = pathlib.Path(temp).exists()
|
||||||
|
raise Exception # Something nasty happens within the context manager
|
||||||
|
except:
|
||||||
|
pass # We go with the test, ignoring the forced exception
|
||||||
|
|
||||||
|
temp_file_missing_when_outside_context_manager = not pathlib.Path(temp_file_name).exists()
|
||||||
|
|
||||||
|
assert temp_file_found_when_in_context_manager and temp_file_missing_when_outside_context_manager
|
||||||
Loading…
Add table
Add a link
Reference in a new issue