Add ssh opening functions and test.
This commit is contained in:
parent
dd07cd7959
commit
8c66a843b9
2 changed files with 106 additions and 1 deletions
|
|
@ -94,6 +94,91 @@ def _temp_secret_file_from_s3(
|
||||||
os.remove(local_temp_file_path)
|
os.remove(local_temp_file_path)
|
||||||
|
|
||||||
|
|
||||||
|
@task()
|
||||||
|
def open_ssh_tunnel_with_s3_pkey(
|
||||||
|
s3_bucket_name: str,
|
||||||
|
ssh_tunnel_credentials: dict,
|
||||||
|
remote_target_host: str,
|
||||||
|
remote_target_port: int,
|
||||||
|
) -> SSHTunnelForwarder:
|
||||||
|
"""
|
||||||
|
Temporarily fetch a ssh key from S3 and then proceed to open a SSH tunnel
|
||||||
|
using it.
|
||||||
|
|
||||||
|
:param s3_bucket_name: the bucket where the file lives.
|
||||||
|
:param ssh_tunnel_credentials: the details of the jumpthost SSH
|
||||||
|
connection.
|
||||||
|
:param remote_target_host: the remote host to tunnel to.
|
||||||
|
:param remote_target_port: the remote port to tunnel to.
|
||||||
|
:return: the tunnel, already open.
|
||||||
|
"""
|
||||||
|
logger = prefect.context.get("logger")
|
||||||
|
logger.info("Going to open an SSH tunnel.")
|
||||||
|
|
||||||
|
temp_file_path = "temp"
|
||||||
|
local_bind_host = "127.0.0.1"
|
||||||
|
local_bind_port = 12345
|
||||||
|
|
||||||
|
with _temp_secret_file_from_s3(
|
||||||
|
s3_bucket_name=s3_bucket_name,
|
||||||
|
s3_file_key=ssh_tunnel_credentials["path_to_ssh_pkey"],
|
||||||
|
local_temp_file_path=temp_file_path,
|
||||||
|
) as ssh_key_file:
|
||||||
|
tunnel = open_ssh_tunnel(
|
||||||
|
local_bind_host,
|
||||||
|
local_bind_port,
|
||||||
|
remote_target_host,
|
||||||
|
remote_target_port,
|
||||||
|
ssh_key_file,
|
||||||
|
ssh_tunnel_credentials,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"SSH tunnel is now open and listening at {local_bind_host}:{local_bind_port}.\n"
|
||||||
|
f"Tunnel forwards to {remote_target_host}:{remote_target_port}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return tunnel
|
||||||
|
|
||||||
|
|
||||||
|
def open_ssh_tunnel(
|
||||||
|
local_bind_host: str,
|
||||||
|
local_bind_port: int,
|
||||||
|
remote_target_host: str,
|
||||||
|
remote_target_port: int,
|
||||||
|
ssh_key_file_path: str,
|
||||||
|
ssh_tunnel_credentials: dict,
|
||||||
|
) -> SSHTunnelForwarder:
|
||||||
|
"""
|
||||||
|
Configure and start an SSH tunnel.
|
||||||
|
|
||||||
|
:param local_bind_host: the local host address to bind the tunnel to.
|
||||||
|
:param local_bind_port: the local port address to bind the tunnel to.
|
||||||
|
:param remote_target_host: the remote host to forward to.
|
||||||
|
:param remote_target_port: the remote port to forward to.
|
||||||
|
:param ssh_key_file_path: the path to the ssh key.
|
||||||
|
:param ssh_tunnel_credentials: the details of the jumpthost SSH
|
||||||
|
connection.
|
||||||
|
:return: the tunnel, already open.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tunnel = SSHTunnelForwarder(
|
||||||
|
ssh_host=(
|
||||||
|
ssh_tunnel_credentials["ssh_jumphost"],
|
||||||
|
ssh_tunnel_credentials["ssh_port"],
|
||||||
|
),
|
||||||
|
ssh_username=ssh_tunnel_credentials["ssh_username"],
|
||||||
|
ssh_pkey=ssh_key_file_path,
|
||||||
|
remote_bind_address=(
|
||||||
|
remote_target_host,
|
||||||
|
remote_target_port,
|
||||||
|
),
|
||||||
|
local_bind_address=(local_bind_host, local_bind_port),
|
||||||
|
ssh_private_key_password=ssh_tunnel_credentials["ssh_pkey_password"],
|
||||||
|
)
|
||||||
|
tunnel.start()
|
||||||
|
return tunnel
|
||||||
|
|
||||||
|
|
||||||
@task(nout=2)
|
@task(nout=2)
|
||||||
def connect_to_mysql(
|
def connect_to_mysql(
|
||||||
mysql_credentials: dict,
|
mysql_credentials: dict,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,12 @@
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
from lolafect.lolaconfig import build_lolaconfig
|
from lolafect.lolaconfig import build_lolaconfig
|
||||||
from lolafect.connections import connect_to_trino, close_trino_connection, _temp_secret_file_from_s3
|
from lolafect.connections import (
|
||||||
|
connect_to_trino,
|
||||||
|
close_trino_connection,
|
||||||
|
_temp_secret_file_from_s3,
|
||||||
|
open_ssh_tunnel_with_s3_pkey,
|
||||||
|
)
|
||||||
|
|
||||||
# __ __ _____ _ _ _____ _ _ _____ _
|
# __ __ _____ _ _ _____ _ _ _____ _
|
||||||
# \ \ / /\ | __ \| \ | |_ _| \ | |/ ____| |
|
# \ \ / /\ | __ \| \ | |_ _| \ | |/ ____| |
|
||||||
|
|
@ -72,3 +77,18 @@ def test_temporal_download_of_secret_file_works_properly_even_with_exception():
|
||||||
temp_file_found_when_in_context_manager
|
temp_file_found_when_in_context_manager
|
||||||
and temp_file_missing_when_outside_context_manager
|
and temp_file_missing_when_outside_context_manager
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_opening_and_closing_ssh_tunnel_works_properly():
|
||||||
|
tunnel = open_ssh_tunnel_with_s3_pkey.run(
|
||||||
|
s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME,
|
||||||
|
ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS,
|
||||||
|
remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"],
|
||||||
|
remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"],
|
||||||
|
)
|
||||||
|
tunnel_was_active = tunnel.is_active
|
||||||
|
tunnel.close()
|
||||||
|
|
||||||
|
tunnel_is_no_longer_active = not tunnel.is_active
|
||||||
|
|
||||||
|
assert tunnel_was_active and tunnel_is_no_longer_active
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue