Test for connecting to MySQL
This commit is contained in:
parent
3564ae99f7
commit
f4f231d15d
2 changed files with 43 additions and 8 deletions
|
|
@ -101,7 +101,7 @@ def open_ssh_tunnel_with_s3_pkey(
|
||||||
remote_target_host: str,
|
remote_target_host: str,
|
||||||
remote_target_port: int,
|
remote_target_port: int,
|
||||||
local_bind_host: str = "127.0.0.1",
|
local_bind_host: str = "127.0.0.1",
|
||||||
local_bind_port: int = 12345
|
local_bind_port: int = 12345,
|
||||||
) -> SSHTunnelForwarder:
|
) -> SSHTunnelForwarder:
|
||||||
"""
|
"""
|
||||||
Temporarily fetch a ssh key from S3 and then proceed to open a SSH tunnel
|
Temporarily fetch a ssh key from S3 and then proceed to open a SSH tunnel
|
||||||
|
|
@ -164,7 +164,7 @@ def open_ssh_tunnel(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tunnel = SSHTunnelForwarder(
|
tunnel = SSHTunnelForwarder(
|
||||||
ssh_host=(
|
ssh_address_or_host=(
|
||||||
ssh_tunnel_credentials["ssh_jumphost"],
|
ssh_tunnel_credentials["ssh_jumphost"],
|
||||||
ssh_tunnel_credentials["ssh_port"],
|
ssh_tunnel_credentials["ssh_port"],
|
||||||
),
|
),
|
||||||
|
|
@ -246,22 +246,22 @@ def connect_to_mysql(
|
||||||
|
|
||||||
|
|
||||||
@task(trigger=all_finished)
|
@task(trigger=all_finished)
|
||||||
def close_mysql_connection(dw_connection: pymysql.Connection) -> None:
|
def close_mysql_connection(connection: pymysql.Connection) -> None:
|
||||||
logger = prefect.context.get("logger")
|
logger = prefect.context.get("logger")
|
||||||
|
|
||||||
if isinstance(dw_connection, pymysql.Connection):
|
if isinstance(connection, pymysql.Connection):
|
||||||
dw_connection.close()
|
connection.close()
|
||||||
logger.info("DW connection closed successfully.")
|
logger.info("DW connection closed successfully.")
|
||||||
return
|
return
|
||||||
logger.info("No connection received.")
|
logger.info("No connection received.")
|
||||||
|
|
||||||
|
|
||||||
@task(trigger=all_finished)
|
@task(trigger=all_finished)
|
||||||
def close_ssh_tunnel(ssh_tunnel: SSHTunnelForwarder) -> None:
|
def close_ssh_tunnel(tunnel: SSHTunnelForwarder) -> None:
|
||||||
logger = prefect.context.get("logger")
|
logger = prefect.context.get("logger")
|
||||||
|
|
||||||
if isinstance(ssh_tunnel, SSHTunnelForwarder):
|
if isinstance(tunnel, SSHTunnelForwarder):
|
||||||
ssh_tunnel.stop()
|
tunnel.stop()
|
||||||
logger.info("SSH tunnel closed successfully.")
|
logger.info("SSH tunnel closed successfully.")
|
||||||
return
|
return
|
||||||
logger.info("No connection received.")
|
logger.info("No connection received.")
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ from lolafect.connections import (
|
||||||
open_ssh_tunnel_with_s3_pkey,
|
open_ssh_tunnel_with_s3_pkey,
|
||||||
get_local_bind_address_from_ssh_tunnel,
|
get_local_bind_address_from_ssh_tunnel,
|
||||||
close_ssh_tunnel,
|
close_ssh_tunnel,
|
||||||
|
connect_to_mysql,
|
||||||
|
close_mysql_connection,
|
||||||
)
|
)
|
||||||
|
|
||||||
# __ __ _____ _ _ _____ _ _ _____ _
|
# __ __ _____ _ _ _____ _ _ _____ _
|
||||||
|
|
@ -115,3 +117,36 @@ def test_opening_and_closing_ssh_tunnel_works_properly():
|
||||||
and local_bind_host_matches
|
and local_bind_host_matches
|
||||||
and local_bind_port_matches
|
and local_bind_port_matches
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_query_and_disconnect_from_mysql_with_tunnel():
|
||||||
|
|
||||||
|
test_local_bind_host = "127.0.0.1"
|
||||||
|
test_local_bind_port = 12345
|
||||||
|
|
||||||
|
tunnel = open_ssh_tunnel_with_s3_pkey.run(
|
||||||
|
s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME,
|
||||||
|
ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS,
|
||||||
|
remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"],
|
||||||
|
remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"],
|
||||||
|
local_bind_host=test_local_bind_host,
|
||||||
|
local_bind_port=test_local_bind_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
connection = connect_to_mysql.run(
|
||||||
|
mysql_credentials=TEST_LOLACONFIG.DW_CREDENTIALS,
|
||||||
|
overriding_host_and_port=get_local_bind_address_from_ssh_tunnel.run(
|
||||||
|
tunnel=tunnel
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
connection_was_open = connection.open
|
||||||
|
|
||||||
|
connection.cursor().execute("SELECT 1")
|
||||||
|
|
||||||
|
close_mysql_connection.run(connection=connection)
|
||||||
|
close_ssh_tunnel.run(tunnel=tunnel)
|
||||||
|
|
||||||
|
connection_is_closed = not connection.open
|
||||||
|
|
||||||
|
assert connection_was_open and connection_is_closed
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue