diff --git a/CHANGELOG.md b/CHANGELOG.md index 32fb418..f44d7a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,9 @@ All notable changes to this project will be documented in this file. ### Added -- Added Trino connection capabilities in `connections` module. +- Added Trino connection capabilities in the `connections` module. +- Added MySQL connection capabilities in the `connections` module. +- Added SSH tunneling capabilities in the `connections` module. ## [0.2.0] - 2023-01-19 diff --git a/README.md b/README.md index 9160eb3..5d4b6f4 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,12 @@ Lolafect is a collection of Python bits that help us build our Prefect flows. You can find below examples of how to leverage `lolafect` in your flows. +**_Note: the code excerpts below are simplified for brevity and won't run +as-is. If you want to see perfect examples, you might want to check the tests +in this repository._** + +### Config + **Let the `LolaConfig` object do the boilerplate env stuff for you** ```python @@ -36,6 +42,8 @@ lolaconfig = build_lolaconfig( ) ``` +### Connections + **Connect to a Trino server** ```python @@ -43,7 +51,7 @@ from lolafect.connections import connect_to_trino, close_trino_connection with Flow(...) as flow: connection = connect_to_trino.run( - trino_credentials=my_trino_credentials #You can probably try to fetch this from lolaconfig.TRINO_CREDENTIALS + trino_credentials=my_trino_credentials # You can probably try to fetch this from lolaconfig.TRINO_CREDENTIALS ) task_result = some_trino_related_task(trino_connection=connection) close_trino_connection.run( @@ -52,6 +60,66 @@ with Flow(...) as flow: ) ``` +**Open an SSH tunnel** + +```python +from lolafect.connections import open_ssh_tunnel_with_s3_pkey, close_ssh_tunnel + +with Flow(...) as flow: + # You probably want to fetch these args from lolaconfig.SSH_CREDENTIALS and lolaconfig.DW_CREDENTIALS + tunnel = open_ssh_tunnel_with_s3_pkey( + s3_bucket_name="some-bucket", + ssh_tunnel_credentials={...}, + remote_target_host="some-host-probably-mysql", + remote_target_port=12345, + ) + # Tunnel is now alive. tunnel.is_active == True + close_ssh_tunnel(tunnel=tunnel) +``` + +**Connect to a MySQL instance** +```python +from lolafect.connections import connect_to_mysql, close_mysql_connection + +with Flow(...) as flow: + connection = connect_to_mysql.run( + mysql_credentials={...}, # You probably want to get this from TEST_LOLACONFIG.DW_CREDENTIALS + ) + connection.cursor().execute("SELECT 1") + close_mysql_connection.run(connection=connection) + +# Want to connect through an SSH tunnel? Open the tunnel normally and then +# override the host and port when connecting to MySQL. + +from lolafect.connections import ( + open_ssh_tunnel_with_s3_pkey, + get_local_bind_address_from_ssh_tunnel, + close_ssh_tunnel +) + +with Flow(...) as flow: + # You probably want to fetch these args from lolaconfig.SSH_CREDENTIALS and lolaconfig.DW_CREDENTIALS + tunnel = open_ssh_tunnel_with_s3_pkey( + s3_bucket_name="some-bucket", + ssh_tunnel_credentials={...}, + remote_target_host="the-mysql-host", + remote_target_port=3306, + ) + + connection = connect_to_mysql.run( + mysql_credentials={...}, # You probably want to get this from TEST_LOLACONFIG.DW_CREDENTIALS + overriding_host_and_port=get_local_bind_address_from_ssh_tunnel.run( + tunnel=tunnel # This will open the connection through the SSH tunnel instead of straight to MySQL + ), + ) + + connection.cursor().execute("SELECT 1") + + close_mysql_connection.run(connection=connection) + close_ssh_tunnel.run(tunnel=tunnel) +``` + +### Slack **Send a warning message to slack if your tasks fails** diff --git a/lolafect/__version__.py b/lolafect/__version__.py index 0590644..7e876b1 100644 --- a/lolafect/__version__.py +++ b/lolafect/__version__.py @@ -1 +1 @@ -__version__="0.2.0" \ No newline at end of file +__version__="dev" \ No newline at end of file diff --git a/lolafect/connections.py b/lolafect/connections.py index 40c450f..223c9f4 100644 --- a/lolafect/connections.py +++ b/lolafect/connections.py @@ -1,10 +1,16 @@ import datetime +import os +from typing import Tuple +from contextlib import contextmanager import prefect from prefect import task from prefect.triggers import all_finished from trino.auth import BasicAuthentication import trino +import pymysql +import boto3 +from sshtunnel import SSHTunnelForwarder from lolafect.defaults import DEFAULT_TRINO_HTTP_SCHEME @@ -18,7 +24,7 @@ def connect_to_trino( :param trino_credentials: a dict with the host, port, user and password. :param http_schema: which http schema to use in the connection. - :return: + :return: the connection to trino. """ logger = prefect.context.get("logger") logger.info( @@ -56,4 +62,232 @@ def close_trino_connection(trino_connection: trino.dbapi.Connection) -> None: trino_connection.close() logger.info("Trino connection closed successfully.") return - logger.info("No connection received.") + logger.warning( + f"Instead of a Trino connection, a {type(trino_connection)} was received." + ) + raise DeprecationWarning( + "This method will only accept the type 'trino.dbapi.Connection' in next major release.\n" + "Please, update your code accordingly." + ) + + +@contextmanager +def _temp_secret_file_from_s3( + s3_bucket_name: str, s3_file_key: str, local_temp_file_path: str +) -> str: + """ + Downloads a file from S3 and ensures that it will be deleted once the + context is exited from, even in the face of an exception. + + :param s3_bucket_name: the bucket where the file lives. + :param s3_file_key: the key of the file within the bucket. + :param local_temp_file_path: the path where the file should be stored + temporarily. + :return: the local file path. + """ + boto3.client("s3").download_file( + s3_bucket_name, + s3_file_key, + local_temp_file_path, + ) + try: + yield local_temp_file_path + except Exception as e: + raise e + finally: + # Regardless of what happens in the context manager, we always delete the temp + # copy of the secret file. + os.remove(local_temp_file_path) + + +@task() +def open_ssh_tunnel_with_s3_pkey( + s3_bucket_name: str, + ssh_tunnel_credentials: dict, + remote_target_host: str, + remote_target_port: int, + local_bind_host: str = "127.0.0.1", + local_bind_port: int = 12345, +) -> SSHTunnelForwarder: + """ + Temporarily fetch a ssh key from S3 and then proceed to open a SSH tunnel + using it. + + :param s3_bucket_name: the bucket where the file lives. + :param ssh_tunnel_credentials: the details of the jumpthost SSH + connection. + :param remote_target_host: the remote host to tunnel to. + :param remote_target_port: the remote port to tunnel to. + :param local_bind_host: the host for the local bind address. + :param local_bind_port: the port for the local bind address. + :return: the tunnel, already open. + """ + logger = prefect.context.get("logger") + logger.info("Going to open an SSH tunnel.") + + temp_file_path = "temp" + + with _temp_secret_file_from_s3( + s3_bucket_name=s3_bucket_name, + s3_file_key=ssh_tunnel_credentials["path_to_ssh_pkey"], + local_temp_file_path=temp_file_path, + ) as ssh_key_file: + tunnel = open_ssh_tunnel( + local_bind_host, + local_bind_port, + remote_target_host, + remote_target_port, + ssh_key_file, + ssh_tunnel_credentials, + ) + logger.info( + f"SSH tunnel is now open and listening at {local_bind_host}:{local_bind_port}.\n" + f"Tunnel forwards to {remote_target_host}:{remote_target_port}" + ) + + return tunnel + + +def open_ssh_tunnel( + local_bind_host: str, + local_bind_port: int, + remote_target_host: str, + remote_target_port: int, + ssh_key_file_path: str, + ssh_tunnel_credentials: dict, +) -> SSHTunnelForwarder: + """ + Configure and start an SSH tunnel. + + :param local_bind_host: the local host address to bind the tunnel to. + :param local_bind_port: the local port address to bind the tunnel to. + :param remote_target_host: the remote host to forward to. + :param remote_target_port: the remote port to forward to. + :param ssh_key_file_path: the path to the ssh key. + :param ssh_tunnel_credentials: the details of the jumpthost SSH + connection. + :return: the tunnel, already open. + """ + + tunnel = SSHTunnelForwarder( + ssh_address_or_host=( + ssh_tunnel_credentials["ssh_jumphost"], + ssh_tunnel_credentials["ssh_port"], + ), + ssh_username=ssh_tunnel_credentials["ssh_username"], + ssh_pkey=ssh_key_file_path, + remote_bind_address=( + remote_target_host, + remote_target_port, + ), + local_bind_address=(local_bind_host, local_bind_port), + ssh_private_key_password=ssh_tunnel_credentials["ssh_pkey_password"], + ) + tunnel.start() + return tunnel + + +@task(trigger=all_finished) +def close_ssh_tunnel(tunnel: SSHTunnelForwarder) -> None: + """ + Close a SSH tunnel, or do nothing if something different is passed. + + :param tunnel: a SSH tunnel. + :return: + """ + logger = prefect.context.get("logger") + + if isinstance(tunnel, SSHTunnelForwarder): + tunnel.stop() + logger.info("SSH tunnel closed successfully.") + return + logger.warning(f"Instead of a SSH tunnel, a {type(tunnel)} was received.") + raise DeprecationWarning( + "This method will only accept the type 'SSHTunnelForwarder' in next major release.\n" + "Please, update your code accordingly." + ) + + +@task() +def get_local_bind_address_from_ssh_tunnel( + tunnel: SSHTunnelForwarder, +) -> Tuple[str, int]: + """ + A silly wrapper to be able to unpack the local bind address of a tunnel + within a Prefect flow. + + :param tunnel: an SSH tunnel. + :return: the local bind address of the SSH tunnel, as a tuple with host + and port. + """ + return tunnel.local_bind_address + + +@task() +def connect_to_mysql( + mysql_credentials: dict, overriding_host_and_port: Tuple[str, int] = None +) -> pymysql.Connection: + """ + Create a connection to a MySQL server, optionally using a host and port + different from the ones where the MySQL server is located. + + :param mysql_credentials: a dict with the connection details to the MySQL + instance. + :param overriding_host_and_port: an optional tuple containing a different + host and port. Useful to route through an SSH tunnel. + :return: the connection to the MySQL server. + """ + logger = prefect.context.get("logger") + logger.info( + f"Connecting to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}." + ) + + mysql_host = mysql_credentials["host"] + mysql_port = mysql_credentials["port"] + + if overriding_host_and_port: + # Since there is a tunnel, we actually want to connect to the local + # address of the tunnel, and not straight into the MySQL server. + mysql_host, mysql_port = overriding_host_and_port + logger.info( + f"Overriding the passed MySQL host and port with {mysql_host}:{mysql_port}." + ) + + db_connection = pymysql.connect( + host=mysql_host, + port=mysql_port, + user=mysql_credentials["user"], + password=mysql_credentials["password"], + database="dw_xl", + ) + + # Customizing this attributes to retrieve them later in the flow + db_connection.raw_user = mysql_credentials["user"] + db_connection.raw_password = mysql_credentials["password"] + + logger.info( + f"Connected to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}." + ) + + return db_connection + + +@task(trigger=all_finished) +def close_mysql_connection(connection: pymysql.Connection) -> None: + """ + Close a MySQL connection, or do nothing if something different is passed. + + :param connection: a MySQL connection. + :return: None + """ + logger = prefect.context.get("logger") + + if isinstance(connection, pymysql.Connection): + connection.close() + logger.info("MySQL connection closed successfully.") + return + logger.warning(f"Instead of a MySQL connection, a {type(connection)} was received.") + raise DeprecationWarning( + "This method will only accept the type 'pymysql.Connection' in next major release.\n" + "Please, update your code accordingly." + ) diff --git a/requirements-dev.txt b/requirements-dev.txt index f92f5bc..2adb75e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,4 +3,6 @@ requests==2.28.1 boto3==1.26.40 pytest==7.2.0 httpretty==1.1.4 -trino==0.321.0 \ No newline at end of file +trino==0.321.0 +sshtunnel==0.4.0 +PyMySQL==1.0.2 \ No newline at end of file diff --git a/setup.py b/setup.py index ec38921..3ac4e7f 100644 --- a/setup.py +++ b/setup.py @@ -23,5 +23,12 @@ setup( package_dir={"lolafect": "lolafect"}, include_package_data=True, python_requires=">=3.7", - install_requires=["prefect==1.2.2", "requests==2.28.1", "boto3==1.26.40", "trino==0.321.0"], + install_requires=[ + "prefect==1.2.2", + "requests==2.28.1", + "boto3==1.26.40", + "trino==0.321.0", + "sshtunnel==0.4.0", + "PyMySQL==1.0.2" + ], ) diff --git a/tests/test_integration/test_connections.py b/tests/test_integration/test_connections.py index 74944e0..2d03039 100644 --- a/tests/test_integration/test_connections.py +++ b/tests/test_integration/test_connections.py @@ -1,5 +1,18 @@ +import pathlib + +from prefect.tasks.core.function import FunctionTask + from lolafect.lolaconfig import build_lolaconfig -from lolafect.connections import connect_to_trino, close_trino_connection +from lolafect.connections import ( + connect_to_trino, + close_trino_connection, + _temp_secret_file_from_s3, + open_ssh_tunnel_with_s3_pkey, + get_local_bind_address_from_ssh_tunnel, + close_ssh_tunnel, + connect_to_mysql, + close_mysql_connection, +) # __ __ _____ _ _ _____ _ _ _____ _ # \ \ / /\ | __ \| \ | |_ _| \ | |/ ____| | @@ -25,3 +38,115 @@ def test_that_trino_connect_and_disconnect_works_properly(): connection.cursor().execute("SELECT 1") close_trino_connection.run(trino_connection=connection) + + +def test_temporal_download_of_secret_file_works_properly_in_happy_path(): + + temp_file_name = "test_temp_file" + + with _temp_secret_file_from_s3( + TEST_LOLACONFIG.S3_BUCKET_NAME, + s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test, + local_temp_file_path=temp_file_name, + ) as temp: + temp_file_found_when_in_context_manager = pathlib.Path(temp).exists() + + temp_file_missing_when_outside_context_manager = not pathlib.Path( + temp_file_name + ).exists() + + assert ( + temp_file_found_when_in_context_manager + and temp_file_missing_when_outside_context_manager + ) + + +def test_temporal_download_of_secret_file_works_properly_even_with_exception(): + temp_file_name = "test_temp_file" + + try: + with _temp_secret_file_from_s3( + TEST_LOLACONFIG.S3_BUCKET_NAME, + s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test, + local_temp_file_path=temp_file_name, + ) as temp: + temp_file_found_when_in_context_manager = pathlib.Path(temp).exists() + raise Exception # Something nasty happens within the context manager + except: + pass # We go with the test, ignoring the forced exception + + temp_file_missing_when_outside_context_manager = not pathlib.Path( + temp_file_name + ).exists() + + assert ( + temp_file_found_when_in_context_manager + and temp_file_missing_when_outside_context_manager + ) + + +def test_opening_and_closing_ssh_tunnel_works_properly(): + + test_local_bind_host = "127.0.0.1" + test_local_bind_port = 12345 + + tunnel = open_ssh_tunnel_with_s3_pkey.run( + s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME, + ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS, + remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"], + remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"], + local_bind_host=test_local_bind_host, + local_bind_port=test_local_bind_port, + ) + tunnel_was_active = tunnel.is_active + + local_bind_host_matches = ( + get_local_bind_address_from_ssh_tunnel.run(tunnel)[0] == test_local_bind_host + ) + local_bind_port_matches = ( + get_local_bind_address_from_ssh_tunnel.run(tunnel)[1] == test_local_bind_port + ) + + close_ssh_tunnel.run(tunnel) + + tunnel_is_no_longer_active = not tunnel.is_active + + assert ( + tunnel_was_active + and tunnel_is_no_longer_active + and local_bind_host_matches + and local_bind_port_matches + ) + + +def test_connect_query_and_disconnect_from_mysql_with_tunnel(): + + test_local_bind_host = "127.0.0.1" + test_local_bind_port = 12345 + + tunnel = open_ssh_tunnel_with_s3_pkey.run( + s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME, + ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS, + remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"], + remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"], + local_bind_host=test_local_bind_host, + local_bind_port=test_local_bind_port, + ) + + connection = connect_to_mysql.run( + mysql_credentials=TEST_LOLACONFIG.DW_CREDENTIALS, + overriding_host_and_port=get_local_bind_address_from_ssh_tunnel.run( + tunnel=tunnel + ), + ) + + connection_was_open = connection.open + + connection.cursor().execute("SELECT 1") + + close_mysql_connection.run(connection=connection) + close_ssh_tunnel.run(tunnel=tunnel) + + connection_is_closed = not connection.open + + assert connection_was_open and connection_is_closed