From 3564ae99f785462d853fd10bf84f033393ff5746 Mon Sep 17 00:00:00 2001 From: Pablo Martin Date: Tue, 24 Jan 2023 09:57:54 +0100 Subject: [PATCH] Splitting MySQL and SSH tunnel bits. SSH tunnel is now tested. --- lolafect/connections.py | 83 +++++++++++----------- tests/test_integration/test_connections.py | 27 ++++++- 2 files changed, 65 insertions(+), 45 deletions(-) diff --git a/lolafect/connections.py b/lolafect/connections.py index 3e48f96..3d85f17 100644 --- a/lolafect/connections.py +++ b/lolafect/connections.py @@ -100,6 +100,8 @@ def open_ssh_tunnel_with_s3_pkey( ssh_tunnel_credentials: dict, remote_target_host: str, remote_target_port: int, + local_bind_host: str = "127.0.0.1", + local_bind_port: int = 12345 ) -> SSHTunnelForwarder: """ Temporarily fetch a ssh key from S3 and then proceed to open a SSH tunnel @@ -110,14 +112,14 @@ def open_ssh_tunnel_with_s3_pkey( connection. :param remote_target_host: the remote host to tunnel to. :param remote_target_port: the remote port to tunnel to. + :param local_bind_host: the host for the local bind address. + :param local_bind_port: the port for the local bind address. :return: the tunnel, already open. """ logger = prefect.context.get("logger") logger.info("Going to open an SSH tunnel.") temp_file_path = "temp" - local_bind_host = "127.0.0.1" - local_bind_port = 12345 with _temp_secret_file_from_s3( s3_bucket_name=s3_bucket_name, @@ -179,59 +181,54 @@ def open_ssh_tunnel( return tunnel -@task(nout=2) +@task() +def get_local_bind_address_from_ssh_tunnel( + tunnel: SSHTunnelForwarder, +) -> Tuple[str, int]: + """ + A silly wrapper to be able to unpack the local bind address of a tunnel + within a Prefect flow. + + :param tunnel: an SSH tunnel. + :return: the local bind address of the SSH tunnel, as a tuple with host + and port. + """ + return tunnel.local_bind_address + + +@task() def connect_to_mysql( - mysql_credentials: dict, - use_ssh_tunnel: bool, - s3_bucket_name, - ssh_tunnel_credentials, -) -> Tuple[pymysql.Connection, SSHTunnelForwarder]: + mysql_credentials: dict, overriding_host_and_port: Tuple[str, int] = None +) -> pymysql.Connection: + """ + Create a connection to a MySQL server, optionally using a host and port + different from the ones where the MySQL server is located. + + :param mysql_credentials: a dict with the connection details to the MySQL + instance. + :param overriding_host_and_port: an optional tuple containing a different + host and port. Useful to route through an SSH tunnel. + :return: the connection to the MySQL server. + """ logger = prefect.context.get("logger") logger.info( f"Connecting to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}." ) mysql_host = mysql_credentials["host"] - tunnel = None - if use_ssh_tunnel: - logger.info("Going to open an SSH tunnel.") + mysql_port = mysql_credentials["port"] - temp_file_path = "temp" - try: - boto3.client("s3").download_file( - s3_bucket_name, - ssh_tunnel_credentials["path_to_ssh_pkey"], - temp_file_path, - ) - tunnel = SSHTunnelForwarder( - ssh_host=( - ssh_tunnel_credentials["ssh_jumphost"], - ssh_tunnel_credentials["ssh_port"], - ), - ssh_username=ssh_tunnel_credentials["ssh_username"], - ssh_pkey=temp_file_path, - remote_bind_address=( - mysql_credentials["host"], - mysql_credentials["port"], - ), - local_bind_address=("127.0.0.1", mysql_credentials["port"]), - ssh_private_key_password=ssh_tunnel_credentials["ssh_pkey_password"], - ) - except Exception as e: - raise e - finally: - # No matter what happens above, we always must delete the temp copy of the key - os.remove(temp_file_path) - - tunnel.start() + if overriding_host_and_port: + # Since there is a tunnel, we actually want to connect to the local + # address of the tunnel, and not straight into the MySQL server. + mysql_host, mysql_port = overriding_host_and_port logger.info( - f"SSH tunnel is now open and listening at{mysql_credentials['host']}:{mysql_credentials['port']}." + f"Overriding the passed MySQL host and port with {mysql_host}:{mysql_port}." ) - mysql_host = "127.0.0.1" db_connection = pymysql.connect( host=mysql_host, - port=mysql_credentials["port"], + port=mysql_port, user=mysql_credentials["user"], password=mysql_credentials["password"], database="dw_xl", @@ -245,7 +242,7 @@ def connect_to_mysql( f"Connected to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}." ) - return db_connection, tunnel + return db_connection @task(trigger=all_finished) diff --git a/tests/test_integration/test_connections.py b/tests/test_integration/test_connections.py index 52c471a..7f647c1 100644 --- a/tests/test_integration/test_connections.py +++ b/tests/test_integration/test_connections.py @@ -1,11 +1,15 @@ import pathlib +from prefect.tasks.core.function import FunctionTask + from lolafect.lolaconfig import build_lolaconfig from lolafect.connections import ( connect_to_trino, close_trino_connection, _temp_secret_file_from_s3, open_ssh_tunnel_with_s3_pkey, + get_local_bind_address_from_ssh_tunnel, + close_ssh_tunnel, ) # __ __ _____ _ _ _____ _ _ _____ _ @@ -80,15 +84,34 @@ def test_temporal_download_of_secret_file_works_properly_even_with_exception(): def test_opening_and_closing_ssh_tunnel_works_properly(): + + test_local_bind_host = "127.0.0.1" + test_local_bind_port = 12345 + tunnel = open_ssh_tunnel_with_s3_pkey.run( s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME, ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS, remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"], remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"], + local_bind_host=test_local_bind_host, + local_bind_port=test_local_bind_port, ) tunnel_was_active = tunnel.is_active - tunnel.close() + + local_bind_host_matches = ( + get_local_bind_address_from_ssh_tunnel.run(tunnel)[0] == test_local_bind_host + ) + local_bind_port_matches = ( + get_local_bind_address_from_ssh_tunnel.run(tunnel)[1] == test_local_bind_port + ) + + close_ssh_tunnel.run(tunnel) tunnel_is_no_longer_active = not tunnel.is_active - assert tunnel_was_active and tunnel_is_no_longer_active + assert ( + tunnel_was_active + and tunnel_is_no_longer_active + and local_bind_host_matches + and local_bind_port_matches + )