From f4f231d15d712c47286efc49aa6e57d3ecbbb7be Mon Sep 17 00:00:00 2001 From: Pablo Martin Date: Tue, 24 Jan 2023 11:00:39 +0100 Subject: [PATCH] Test for connecting to MySQL --- lolafect/connections.py | 16 +++++----- tests/test_integration/test_connections.py | 35 ++++++++++++++++++++++ 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/lolafect/connections.py b/lolafect/connections.py index 3d85f17..450c4e4 100644 --- a/lolafect/connections.py +++ b/lolafect/connections.py @@ -101,7 +101,7 @@ def open_ssh_tunnel_with_s3_pkey( remote_target_host: str, remote_target_port: int, local_bind_host: str = "127.0.0.1", - local_bind_port: int = 12345 + local_bind_port: int = 12345, ) -> SSHTunnelForwarder: """ Temporarily fetch a ssh key from S3 and then proceed to open a SSH tunnel @@ -164,7 +164,7 @@ def open_ssh_tunnel( """ tunnel = SSHTunnelForwarder( - ssh_host=( + ssh_address_or_host=( ssh_tunnel_credentials["ssh_jumphost"], ssh_tunnel_credentials["ssh_port"], ), @@ -246,22 +246,22 @@ def connect_to_mysql( @task(trigger=all_finished) -def close_mysql_connection(dw_connection: pymysql.Connection) -> None: +def close_mysql_connection(connection: pymysql.Connection) -> None: logger = prefect.context.get("logger") - if isinstance(dw_connection, pymysql.Connection): - dw_connection.close() + if isinstance(connection, pymysql.Connection): + connection.close() logger.info("DW connection closed successfully.") return logger.info("No connection received.") @task(trigger=all_finished) -def close_ssh_tunnel(ssh_tunnel: SSHTunnelForwarder) -> None: +def close_ssh_tunnel(tunnel: SSHTunnelForwarder) -> None: logger = prefect.context.get("logger") - if isinstance(ssh_tunnel, SSHTunnelForwarder): - ssh_tunnel.stop() + if isinstance(tunnel, SSHTunnelForwarder): + tunnel.stop() logger.info("SSH tunnel closed successfully.") return logger.info("No connection received.") diff --git a/tests/test_integration/test_connections.py b/tests/test_integration/test_connections.py index 7f647c1..2d03039 100644 --- a/tests/test_integration/test_connections.py +++ b/tests/test_integration/test_connections.py @@ -10,6 +10,8 @@ from lolafect.connections import ( open_ssh_tunnel_with_s3_pkey, get_local_bind_address_from_ssh_tunnel, close_ssh_tunnel, + connect_to_mysql, + close_mysql_connection, ) # __ __ _____ _ _ _____ _ _ _____ _ @@ -115,3 +117,36 @@ def test_opening_and_closing_ssh_tunnel_works_properly(): and local_bind_host_matches and local_bind_port_matches ) + + +def test_connect_query_and_disconnect_from_mysql_with_tunnel(): + + test_local_bind_host = "127.0.0.1" + test_local_bind_port = 12345 + + tunnel = open_ssh_tunnel_with_s3_pkey.run( + s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME, + ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS, + remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"], + remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"], + local_bind_host=test_local_bind_host, + local_bind_port=test_local_bind_port, + ) + + connection = connect_to_mysql.run( + mysql_credentials=TEST_LOLACONFIG.DW_CREDENTIALS, + overriding_host_and_port=get_local_bind_address_from_ssh_tunnel.run( + tunnel=tunnel + ), + ) + + connection_was_open = connection.open + + connection.cursor().execute("SELECT 1") + + close_mysql_connection.run(connection=connection) + close_ssh_tunnel.run(tunnel=tunnel) + + connection_is_closed = not connection.open + + assert connection_was_open and connection_is_closed