288 lines
9.1 KiB
Python
288 lines
9.1 KiB
Python
import datetime
|
|
import os
|
|
from typing import Tuple
|
|
from contextlib import contextmanager
|
|
|
|
import prefect
|
|
from prefect import task
|
|
from prefect.triggers import all_finished
|
|
from trino.auth import BasicAuthentication
|
|
import trino
|
|
import pymysql
|
|
import boto3
|
|
from sshtunnel import SSHTunnelForwarder
|
|
|
|
from lolafect.defaults import DEFAULT_TRINO_HTTP_SCHEME
|
|
|
|
|
|
@task(log_stdout=True, max_retries=3, retry_delay=datetime.timedelta(minutes=10))
|
|
def connect_to_trino(
|
|
trino_credentials: dict, http_schema: str = DEFAULT_TRINO_HTTP_SCHEME
|
|
) -> trino.dbapi.Connection:
|
|
"""
|
|
Open a connection to the specified trino instance and return it.
|
|
|
|
:param trino_credentials: a dict with the host, port, user and password.
|
|
:param http_schema: which http schema to use in the connection.
|
|
:return: the connection to trino.
|
|
"""
|
|
logger = prefect.context.get("logger")
|
|
logger.info(
|
|
f"Connecting to Trino at {trino_credentials['host']}:{trino_credentials['port']}."
|
|
)
|
|
|
|
connection = trino.dbapi.connect(
|
|
host=trino_credentials["host"],
|
|
port=trino_credentials["port"],
|
|
user=trino_credentials["user"],
|
|
http_scheme=http_schema,
|
|
auth=BasicAuthentication(
|
|
trino_credentials["user"],
|
|
trino_credentials["password"],
|
|
),
|
|
)
|
|
logger.info(
|
|
f"Connected to Trino at {trino_credentials['host']}:{trino_credentials['port']}."
|
|
)
|
|
|
|
return connection
|
|
|
|
|
|
@task(trigger=all_finished)
|
|
def close_trino_connection(trino_connection: trino.dbapi.Connection) -> None:
|
|
"""
|
|
Close a Trino connection, or do nothing if what has been passed is not a
|
|
Trino connection.
|
|
|
|
:param trino_connection: a trino connection.
|
|
:return: None
|
|
"""
|
|
logger = prefect.context.get("logger")
|
|
if isinstance(trino_connection, trino.dbapi.Connection):
|
|
trino_connection.close()
|
|
logger.info("Trino connection closed successfully.")
|
|
return
|
|
logger.warning(
|
|
f"Instead of a Trino connection, a {type(trino_connection)} was received."
|
|
)
|
|
raise DeprecationWarning(
|
|
"This method will only accept the type 'trino.dbapi.Connection' in next major release.\n"
|
|
"Please, update your code accordingly."
|
|
)
|
|
|
|
|
|
@contextmanager
|
|
def _temp_secret_file_from_s3(
|
|
s3_bucket_name: str, s3_file_key: str, local_temp_file_path: str
|
|
) -> str:
|
|
"""
|
|
Downloads a file from S3 and ensures that it will be deleted once the
|
|
context is exited from, even in the face of an exception.
|
|
|
|
:param s3_bucket_name: the bucket where the file lives.
|
|
:param s3_file_key: the key of the file within the bucket.
|
|
:param local_temp_file_path: the path where the file should be stored
|
|
temporarily.
|
|
:return: the local file path.
|
|
"""
|
|
boto3.client("s3").download_file(
|
|
s3_bucket_name,
|
|
s3_file_key,
|
|
local_temp_file_path,
|
|
)
|
|
try:
|
|
yield local_temp_file_path
|
|
except Exception as e:
|
|
raise e
|
|
finally:
|
|
# Regardless of what happens in the context manager, we always delete the temp
|
|
# copy of the secret file.
|
|
os.remove(local_temp_file_path)
|
|
|
|
|
|
@task()
|
|
def open_ssh_tunnel_with_s3_pkey(
|
|
s3_bucket_name: str,
|
|
ssh_tunnel_credentials: dict,
|
|
remote_target_host: str,
|
|
remote_target_port: int,
|
|
local_bind_host: str = "127.0.0.1",
|
|
local_bind_port: int = 12345,
|
|
) -> SSHTunnelForwarder:
|
|
"""
|
|
Temporarily fetch a ssh key from S3 and then proceed to open a SSH tunnel
|
|
using it.
|
|
|
|
:param s3_bucket_name: the bucket where the file lives.
|
|
:param ssh_tunnel_credentials: the details of the jumpthost SSH
|
|
connection.
|
|
:param remote_target_host: the remote host to tunnel to.
|
|
:param remote_target_port: the remote port to tunnel to.
|
|
:param local_bind_host: the host for the local bind address.
|
|
:param local_bind_port: the port for the local bind address.
|
|
:return: the tunnel, already open.
|
|
"""
|
|
logger = prefect.context.get("logger")
|
|
logger.info("Going to open an SSH tunnel.")
|
|
|
|
temp_file_path = "temp"
|
|
|
|
with _temp_secret_file_from_s3(
|
|
s3_bucket_name=s3_bucket_name,
|
|
s3_file_key=ssh_tunnel_credentials["path_to_ssh_pkey"],
|
|
local_temp_file_path=temp_file_path,
|
|
) as ssh_key_file:
|
|
tunnel = open_ssh_tunnel(
|
|
local_bind_host,
|
|
local_bind_port,
|
|
remote_target_host,
|
|
remote_target_port,
|
|
ssh_key_file,
|
|
ssh_tunnel_credentials,
|
|
)
|
|
logger.info(
|
|
f"SSH tunnel is now open and listening at {local_bind_host}:{local_bind_port}.\n"
|
|
f"Tunnel forwards to {remote_target_host}:{remote_target_port}"
|
|
)
|
|
|
|
return tunnel
|
|
|
|
|
|
def open_ssh_tunnel(
|
|
local_bind_host: str,
|
|
local_bind_port: int,
|
|
remote_target_host: str,
|
|
remote_target_port: int,
|
|
ssh_key_file_path: str,
|
|
ssh_tunnel_credentials: dict,
|
|
) -> SSHTunnelForwarder:
|
|
"""
|
|
Configure and start an SSH tunnel.
|
|
|
|
:param local_bind_host: the local host address to bind the tunnel to.
|
|
:param local_bind_port: the local port address to bind the tunnel to.
|
|
:param remote_target_host: the remote host to forward to.
|
|
:param remote_target_port: the remote port to forward to.
|
|
:param ssh_key_file_path: the path to the ssh key.
|
|
:param ssh_tunnel_credentials: the details of the jumpthost SSH
|
|
connection.
|
|
:return: the tunnel, already open.
|
|
"""
|
|
|
|
tunnel = SSHTunnelForwarder(
|
|
ssh_address_or_host=(
|
|
ssh_tunnel_credentials["ssh_jumphost"],
|
|
ssh_tunnel_credentials["ssh_port"],
|
|
),
|
|
ssh_username=ssh_tunnel_credentials["ssh_username"],
|
|
ssh_pkey=ssh_key_file_path,
|
|
remote_bind_address=(
|
|
remote_target_host,
|
|
remote_target_port,
|
|
),
|
|
local_bind_address=(local_bind_host, local_bind_port),
|
|
ssh_private_key_password=ssh_tunnel_credentials["ssh_pkey_password"],
|
|
)
|
|
tunnel.start()
|
|
return tunnel
|
|
|
|
|
|
@task(trigger=all_finished)
|
|
def close_ssh_tunnel(tunnel: SSHTunnelForwarder) -> None:
|
|
"""
|
|
Close a SSH tunnel, or do nothing if something different is passed.
|
|
|
|
:param tunnel: a SSH tunnel.
|
|
:return:
|
|
"""
|
|
logger = prefect.context.get("logger")
|
|
|
|
if isinstance(tunnel, SSHTunnelForwarder):
|
|
tunnel.stop()
|
|
logger.info("SSH tunnel closed successfully.")
|
|
return
|
|
logger.warning(f"Instead of a SSH tunnel, a {type(tunnel)} was received.")
|
|
raise DeprecationWarning(
|
|
"This method will only accept the type 'SSHTunnelForwarder' in next major release.\n"
|
|
"Please, update your code accordingly."
|
|
)
|
|
|
|
|
|
@task()
|
|
def get_local_bind_address_from_ssh_tunnel(
|
|
tunnel: SSHTunnelForwarder,
|
|
) -> Tuple[str, int]:
|
|
"""
|
|
A silly wrapper to be able to unpack the local bind address of a tunnel
|
|
within a Prefect flow.
|
|
|
|
:param tunnel: an SSH tunnel.
|
|
:return: the local bind address of the SSH tunnel, as a tuple with host
|
|
and port.
|
|
"""
|
|
return tunnel.local_bind_address
|
|
|
|
|
|
@task()
|
|
def connect_to_mysql(
|
|
mysql_credentials: dict, overriding_host_and_port: Tuple[str, int] = None
|
|
) -> pymysql.Connection:
|
|
"""
|
|
Create a connection to a MySQL server, optionally using a host and port
|
|
different from the ones where the MySQL server is located.
|
|
|
|
:param mysql_credentials: a dict with the connection details to the MySQL
|
|
instance.
|
|
:param overriding_host_and_port: an optional tuple containing a different
|
|
host and port. Useful to route through an SSH tunnel.
|
|
:return: the connection to the MySQL server.
|
|
"""
|
|
logger = prefect.context.get("logger")
|
|
logger.info(
|
|
f"Connecting to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}."
|
|
)
|
|
|
|
mysql_host = mysql_credentials["host"]
|
|
mysql_port = mysql_credentials["port"]
|
|
|
|
if overriding_host_and_port:
|
|
# Since there is a tunnel, we actually want to connect to the local
|
|
# address of the tunnel, and not straight into the MySQL server.
|
|
mysql_host, mysql_port = overriding_host_and_port
|
|
logger.info(
|
|
f"Overriding the passed MySQL host and port with {mysql_host}:{mysql_port}."
|
|
)
|
|
|
|
db_connection = pymysql.connect(
|
|
host=mysql_host,
|
|
port=mysql_port,
|
|
user=mysql_credentials["user"],
|
|
password=mysql_credentials["password"],
|
|
)
|
|
|
|
logger.info(
|
|
f"Connected to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}."
|
|
)
|
|
|
|
return db_connection
|
|
|
|
|
|
@task(trigger=all_finished)
|
|
def close_mysql_connection(connection: pymysql.Connection) -> None:
|
|
"""
|
|
Close a MySQL connection, or do nothing if something different is passed.
|
|
|
|
:param connection: a MySQL connection.
|
|
:return: None
|
|
"""
|
|
logger = prefect.context.get("logger")
|
|
|
|
if isinstance(connection, pymysql.Connection):
|
|
connection.close()
|
|
logger.info("MySQL connection closed successfully.")
|
|
return
|
|
logger.warning(f"Instead of a MySQL connection, a {type(connection)} was received.")
|
|
raise DeprecationWarning(
|
|
"This method will only accept the type 'pymysql.Connection' in next major release.\n"
|
|
"Please, update your code accordingly."
|
|
)
|