From 8c66a843b9c51ba3c440e6fd027d2f715ac7b00a Mon Sep 17 00:00:00 2001 From: Pablo Martin Date: Mon, 23 Jan 2023 18:38:12 +0100 Subject: [PATCH] Add ssh opening functions and test. --- lolafect/connections.py | 85 ++++++++++++++++++++++ tests/test_integration/test_connections.py | 22 +++++- 2 files changed, 106 insertions(+), 1 deletion(-) diff --git a/lolafect/connections.py b/lolafect/connections.py index d823db9..a1fa4a1 100644 --- a/lolafect/connections.py +++ b/lolafect/connections.py @@ -94,6 +94,91 @@ def _temp_secret_file_from_s3( os.remove(local_temp_file_path) +@task() +def open_ssh_tunnel_with_s3_pkey( + s3_bucket_name: str, + ssh_tunnel_credentials: dict, + remote_target_host: str, + remote_target_port: int, +) -> SSHTunnelForwarder: + """ + Temporarily fetch a ssh key from S3 and then proceed to open a SSH tunnel + using it. + + :param s3_bucket_name: the bucket where the file lives. + :param ssh_tunnel_credentials: the details of the jumpthost SSH + connection. + :param remote_target_host: the remote host to tunnel to. + :param remote_target_port: the remote port to tunnel to. + :return: the tunnel, already open. + """ + logger = prefect.context.get("logger") + logger.info("Going to open an SSH tunnel.") + + temp_file_path = "temp" + local_bind_host = "127.0.0.1" + local_bind_port = 12345 + + with _temp_secret_file_from_s3( + s3_bucket_name=s3_bucket_name, + s3_file_key=ssh_tunnel_credentials["path_to_ssh_pkey"], + local_temp_file_path=temp_file_path, + ) as ssh_key_file: + tunnel = open_ssh_tunnel( + local_bind_host, + local_bind_port, + remote_target_host, + remote_target_port, + ssh_key_file, + ssh_tunnel_credentials, + ) + logger.info( + f"SSH tunnel is now open and listening at {local_bind_host}:{local_bind_port}.\n" + f"Tunnel forwards to {remote_target_host}:{remote_target_port}" + ) + + return tunnel + + +def open_ssh_tunnel( + local_bind_host: str, + local_bind_port: int, + remote_target_host: str, + remote_target_port: int, + ssh_key_file_path: str, + ssh_tunnel_credentials: dict, +) -> SSHTunnelForwarder: + """ + Configure and start an SSH tunnel. + + :param local_bind_host: the local host address to bind the tunnel to. + :param local_bind_port: the local port address to bind the tunnel to. + :param remote_target_host: the remote host to forward to. + :param remote_target_port: the remote port to forward to. + :param ssh_key_file_path: the path to the ssh key. + :param ssh_tunnel_credentials: the details of the jumpthost SSH + connection. + :return: the tunnel, already open. + """ + + tunnel = SSHTunnelForwarder( + ssh_host=( + ssh_tunnel_credentials["ssh_jumphost"], + ssh_tunnel_credentials["ssh_port"], + ), + ssh_username=ssh_tunnel_credentials["ssh_username"], + ssh_pkey=ssh_key_file_path, + remote_bind_address=( + remote_target_host, + remote_target_port, + ), + local_bind_address=(local_bind_host, local_bind_port), + ssh_private_key_password=ssh_tunnel_credentials["ssh_pkey_password"], + ) + tunnel.start() + return tunnel + + @task(nout=2) def connect_to_mysql( mysql_credentials: dict, diff --git a/tests/test_integration/test_connections.py b/tests/test_integration/test_connections.py index 771baa3..52c471a 100644 --- a/tests/test_integration/test_connections.py +++ b/tests/test_integration/test_connections.py @@ -1,7 +1,12 @@ import pathlib from lolafect.lolaconfig import build_lolaconfig -from lolafect.connections import connect_to_trino, close_trino_connection, _temp_secret_file_from_s3 +from lolafect.connections import ( + connect_to_trino, + close_trino_connection, + _temp_secret_file_from_s3, + open_ssh_tunnel_with_s3_pkey, +) # __ __ _____ _ _ _____ _ _ _____ _ # \ \ / /\ | __ \| \ | |_ _| \ | |/ ____| | @@ -72,3 +77,18 @@ def test_temporal_download_of_secret_file_works_properly_even_with_exception(): temp_file_found_when_in_context_manager and temp_file_missing_when_outside_context_manager ) + + +def test_opening_and_closing_ssh_tunnel_works_properly(): + tunnel = open_ssh_tunnel_with_s3_pkey.run( + s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME, + ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS, + remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"], + remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"], + ) + tunnel_was_active = tunnel.is_active + tunnel.close() + + tunnel_is_no_longer_active = not tunnel.is_active + + assert tunnel_was_active and tunnel_is_no_longer_active