Splitting MySQL and SSH tunnel bits. SSH tunnel is now tested.
This commit is contained in:
parent
e2e7f8fb93
commit
3564ae99f7
2 changed files with 65 additions and 45 deletions
|
|
@ -100,6 +100,8 @@ def open_ssh_tunnel_with_s3_pkey(
|
||||||
ssh_tunnel_credentials: dict,
|
ssh_tunnel_credentials: dict,
|
||||||
remote_target_host: str,
|
remote_target_host: str,
|
||||||
remote_target_port: int,
|
remote_target_port: int,
|
||||||
|
local_bind_host: str = "127.0.0.1",
|
||||||
|
local_bind_port: int = 12345
|
||||||
) -> SSHTunnelForwarder:
|
) -> SSHTunnelForwarder:
|
||||||
"""
|
"""
|
||||||
Temporarily fetch a ssh key from S3 and then proceed to open a SSH tunnel
|
Temporarily fetch a ssh key from S3 and then proceed to open a SSH tunnel
|
||||||
|
|
@ -110,14 +112,14 @@ def open_ssh_tunnel_with_s3_pkey(
|
||||||
connection.
|
connection.
|
||||||
:param remote_target_host: the remote host to tunnel to.
|
:param remote_target_host: the remote host to tunnel to.
|
||||||
:param remote_target_port: the remote port to tunnel to.
|
:param remote_target_port: the remote port to tunnel to.
|
||||||
|
:param local_bind_host: the host for the local bind address.
|
||||||
|
:param local_bind_port: the port for the local bind address.
|
||||||
:return: the tunnel, already open.
|
:return: the tunnel, already open.
|
||||||
"""
|
"""
|
||||||
logger = prefect.context.get("logger")
|
logger = prefect.context.get("logger")
|
||||||
logger.info("Going to open an SSH tunnel.")
|
logger.info("Going to open an SSH tunnel.")
|
||||||
|
|
||||||
temp_file_path = "temp"
|
temp_file_path = "temp"
|
||||||
local_bind_host = "127.0.0.1"
|
|
||||||
local_bind_port = 12345
|
|
||||||
|
|
||||||
with _temp_secret_file_from_s3(
|
with _temp_secret_file_from_s3(
|
||||||
s3_bucket_name=s3_bucket_name,
|
s3_bucket_name=s3_bucket_name,
|
||||||
|
|
@ -179,59 +181,54 @@ def open_ssh_tunnel(
|
||||||
return tunnel
|
return tunnel
|
||||||
|
|
||||||
|
|
||||||
@task(nout=2)
|
@task()
|
||||||
|
def get_local_bind_address_from_ssh_tunnel(
|
||||||
|
tunnel: SSHTunnelForwarder,
|
||||||
|
) -> Tuple[str, int]:
|
||||||
|
"""
|
||||||
|
A silly wrapper to be able to unpack the local bind address of a tunnel
|
||||||
|
within a Prefect flow.
|
||||||
|
|
||||||
|
:param tunnel: an SSH tunnel.
|
||||||
|
:return: the local bind address of the SSH tunnel, as a tuple with host
|
||||||
|
and port.
|
||||||
|
"""
|
||||||
|
return tunnel.local_bind_address
|
||||||
|
|
||||||
|
|
||||||
|
@task()
|
||||||
def connect_to_mysql(
|
def connect_to_mysql(
|
||||||
mysql_credentials: dict,
|
mysql_credentials: dict, overriding_host_and_port: Tuple[str, int] = None
|
||||||
use_ssh_tunnel: bool,
|
) -> pymysql.Connection:
|
||||||
s3_bucket_name,
|
"""
|
||||||
ssh_tunnel_credentials,
|
Create a connection to a MySQL server, optionally using a host and port
|
||||||
) -> Tuple[pymysql.Connection, SSHTunnelForwarder]:
|
different from the ones where the MySQL server is located.
|
||||||
|
|
||||||
|
:param mysql_credentials: a dict with the connection details to the MySQL
|
||||||
|
instance.
|
||||||
|
:param overriding_host_and_port: an optional tuple containing a different
|
||||||
|
host and port. Useful to route through an SSH tunnel.
|
||||||
|
:return: the connection to the MySQL server.
|
||||||
|
"""
|
||||||
logger = prefect.context.get("logger")
|
logger = prefect.context.get("logger")
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Connecting to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}."
|
f"Connecting to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}."
|
||||||
)
|
)
|
||||||
|
|
||||||
mysql_host = mysql_credentials["host"]
|
mysql_host = mysql_credentials["host"]
|
||||||
tunnel = None
|
mysql_port = mysql_credentials["port"]
|
||||||
if use_ssh_tunnel:
|
|
||||||
logger.info("Going to open an SSH tunnel.")
|
|
||||||
|
|
||||||
temp_file_path = "temp"
|
if overriding_host_and_port:
|
||||||
try:
|
# Since there is a tunnel, we actually want to connect to the local
|
||||||
boto3.client("s3").download_file(
|
# address of the tunnel, and not straight into the MySQL server.
|
||||||
s3_bucket_name,
|
mysql_host, mysql_port = overriding_host_and_port
|
||||||
ssh_tunnel_credentials["path_to_ssh_pkey"],
|
|
||||||
temp_file_path,
|
|
||||||
)
|
|
||||||
tunnel = SSHTunnelForwarder(
|
|
||||||
ssh_host=(
|
|
||||||
ssh_tunnel_credentials["ssh_jumphost"],
|
|
||||||
ssh_tunnel_credentials["ssh_port"],
|
|
||||||
),
|
|
||||||
ssh_username=ssh_tunnel_credentials["ssh_username"],
|
|
||||||
ssh_pkey=temp_file_path,
|
|
||||||
remote_bind_address=(
|
|
||||||
mysql_credentials["host"],
|
|
||||||
mysql_credentials["port"],
|
|
||||||
),
|
|
||||||
local_bind_address=("127.0.0.1", mysql_credentials["port"]),
|
|
||||||
ssh_private_key_password=ssh_tunnel_credentials["ssh_pkey_password"],
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
finally:
|
|
||||||
# No matter what happens above, we always must delete the temp copy of the key
|
|
||||||
os.remove(temp_file_path)
|
|
||||||
|
|
||||||
tunnel.start()
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"SSH tunnel is now open and listening at{mysql_credentials['host']}:{mysql_credentials['port']}."
|
f"Overriding the passed MySQL host and port with {mysql_host}:{mysql_port}."
|
||||||
)
|
)
|
||||||
mysql_host = "127.0.0.1"
|
|
||||||
|
|
||||||
db_connection = pymysql.connect(
|
db_connection = pymysql.connect(
|
||||||
host=mysql_host,
|
host=mysql_host,
|
||||||
port=mysql_credentials["port"],
|
port=mysql_port,
|
||||||
user=mysql_credentials["user"],
|
user=mysql_credentials["user"],
|
||||||
password=mysql_credentials["password"],
|
password=mysql_credentials["password"],
|
||||||
database="dw_xl",
|
database="dw_xl",
|
||||||
|
|
@ -245,7 +242,7 @@ def connect_to_mysql(
|
||||||
f"Connected to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}."
|
f"Connected to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}."
|
||||||
)
|
)
|
||||||
|
|
||||||
return db_connection, tunnel
|
return db_connection
|
||||||
|
|
||||||
|
|
||||||
@task(trigger=all_finished)
|
@task(trigger=all_finished)
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,15 @@
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
|
from prefect.tasks.core.function import FunctionTask
|
||||||
|
|
||||||
from lolafect.lolaconfig import build_lolaconfig
|
from lolafect.lolaconfig import build_lolaconfig
|
||||||
from lolafect.connections import (
|
from lolafect.connections import (
|
||||||
connect_to_trino,
|
connect_to_trino,
|
||||||
close_trino_connection,
|
close_trino_connection,
|
||||||
_temp_secret_file_from_s3,
|
_temp_secret_file_from_s3,
|
||||||
open_ssh_tunnel_with_s3_pkey,
|
open_ssh_tunnel_with_s3_pkey,
|
||||||
|
get_local_bind_address_from_ssh_tunnel,
|
||||||
|
close_ssh_tunnel,
|
||||||
)
|
)
|
||||||
|
|
||||||
# __ __ _____ _ _ _____ _ _ _____ _
|
# __ __ _____ _ _ _____ _ _ _____ _
|
||||||
|
|
@ -80,15 +84,34 @@ def test_temporal_download_of_secret_file_works_properly_even_with_exception():
|
||||||
|
|
||||||
|
|
||||||
def test_opening_and_closing_ssh_tunnel_works_properly():
|
def test_opening_and_closing_ssh_tunnel_works_properly():
|
||||||
|
|
||||||
|
test_local_bind_host = "127.0.0.1"
|
||||||
|
test_local_bind_port = 12345
|
||||||
|
|
||||||
tunnel = open_ssh_tunnel_with_s3_pkey.run(
|
tunnel = open_ssh_tunnel_with_s3_pkey.run(
|
||||||
s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME,
|
s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME,
|
||||||
ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS,
|
ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS,
|
||||||
remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"],
|
remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"],
|
||||||
remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"],
|
remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"],
|
||||||
|
local_bind_host=test_local_bind_host,
|
||||||
|
local_bind_port=test_local_bind_port,
|
||||||
)
|
)
|
||||||
tunnel_was_active = tunnel.is_active
|
tunnel_was_active = tunnel.is_active
|
||||||
tunnel.close()
|
|
||||||
|
local_bind_host_matches = (
|
||||||
|
get_local_bind_address_from_ssh_tunnel.run(tunnel)[0] == test_local_bind_host
|
||||||
|
)
|
||||||
|
local_bind_port_matches = (
|
||||||
|
get_local_bind_address_from_ssh_tunnel.run(tunnel)[1] == test_local_bind_port
|
||||||
|
)
|
||||||
|
|
||||||
|
close_ssh_tunnel.run(tunnel)
|
||||||
|
|
||||||
tunnel_is_no_longer_active = not tunnel.is_active
|
tunnel_is_no_longer_active = not tunnel.is_active
|
||||||
|
|
||||||
assert tunnel_was_active and tunnel_is_no_longer_active
|
assert (
|
||||||
|
tunnel_was_active
|
||||||
|
and tunnel_is_no_longer_active
|
||||||
|
and local_bind_host_matches
|
||||||
|
and local_bind_port_matches
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue