diff --git a/CHANGELOG.md b/CHANGELOG.md index 78124e9..008da48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,15 @@ All notable changes to this project will be documented in this file. +## [0.3.0] - 2023-01-27 + +### Added + +- Added Trino connection capabilities in the `connections` module. +- Added MySQL connection capabilities in the `connections` module. +- Added SSH tunneling capabilities in the `connections` module. + + ## [0.2.0] - 2023-01-19 ### Added diff --git a/README.md b/README.md index e7b7994..8258e82 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,12 @@ Lolafect is a collection of Python bits that help us build our Prefect flows. You can find below examples of how to leverage `lolafect` in your flows. +**_Note: the code excerpts below are simplified for brevity and won't run +as-is. If you want to see perfect examples, you might want to check the tests +in this repository._** + +### Config + **Let the `LolaConfig` object do the boilerplate env stuff for you** ```python @@ -36,6 +42,86 @@ lolaconfig = build_lolaconfig( ) ``` +### Connections + +**Connect to a Trino server** + +```python +from lolafect.connections import connect_to_trino, close_trino_connection + +with Flow(...) as flow: + connection = connect_to_trino.run( + trino_credentials=my_trino_credentials # You can probably try to fetch this from lolaconfig.TRINO_CREDENTIALS + ) + task_result = some_trino_related_task(trino_connection=connection) + close_trino_connection.run( + trino_connection=connection, + upstream_tasks=[task_result] + ) +``` + +**Open an SSH tunnel** + +```python +from lolafect.connections import open_ssh_tunnel_with_s3_pkey, close_ssh_tunnel + +with Flow(...) as flow: + # You probably want to fetch these args from lolaconfig.SSH_CREDENTIALS and lolaconfig.DW_CREDENTIALS + tunnel = open_ssh_tunnel_with_s3_pkey( + s3_bucket_name="some-bucket", + ssh_tunnel_credentials={...}, + remote_target_host="some-host-probably-mysql", + remote_target_port=12345, + ) + task_result = some_task_that_needs_the_tunnel(tunnel=tunnel) + # Tunnel is now alive. tunnel.is_active == True + close_ssh_tunnel(tunnel=tunnel, upstream_tasks=[task_result]) +``` + +**Connect to a MySQL instance** +```python +from lolafect.connections import connect_to_mysql, close_mysql_connection + +with Flow(...) as flow: + connection = connect_to_mysql( + mysql_credentials={...}, # You probably want to get this from TEST_LOLACONFIG.DW_CREDENTIALS + ) + task_result = some_task_that_needs_mysql(connection=connection) + close_mysql_connection(connection=connection, upstream_tasks=[task_result]) + +# Want to connect through an SSH tunnel? Open the tunnel normally and then +# override the host and port when connecting to MySQL. + +from lolafect.connections import ( + open_ssh_tunnel_with_s3_pkey, + get_local_bind_address_from_ssh_tunnel, + close_ssh_tunnel +) + +with Flow(...) as flow: + # You probably want to fetch these args from lolaconfig.SSH_CREDENTIALS and lolaconfig.DW_CREDENTIALS + tunnel = open_ssh_tunnel_with_s3_pkey( + s3_bucket_name="some-bucket", + ssh_tunnel_credentials={...}, + remote_target_host="the-mysql-host", + remote_target_port=3306, + ) + + connection = connect_to_mysql( + mysql_credentials={...}, # You probably want to get this from TEST_LOLACONFIG.DW_CREDENTIALS + overriding_host_and_port=get_local_bind_address_from_ssh_tunnel.run( + tunnel=tunnel # This will open the connection through the SSH tunnel instead of straight to MySQL + ), + ) + + task_result = some_task_that_needs_mysql(connection=connection) + + mysql_closed = close_mysql_connection(connection=connection, upstream_tasks=[task_result]) + close_ssh_tunnel.run(tunnel=tunnel, upstream_tasks=[mysql_closed]) +``` + +### Slack + **Send a warning message to slack if your tasks fails** ```python @@ -57,12 +143,30 @@ with Flow(...) as flow: ## How to test +There are two test suites: unit tests and integration tests. Integration tests are prepared to plug to some of our +AWS resources, hence they are not fully reliable since they require specific credentials and permissions. The +recommended policy is: + +- Use the unit tests in any CI process you want. +- Use the unit tests frequently as you code. +- Do not use the integration tests in CI processes. +- Use the integration tests as milestone checks when finishing feature branches. +- Make sure to ensure integration tests are working before making a new release. + +When building new tests, please keep this philosophy in mind. + + IDE-agnostic: 1. Set up a virtual environment which contains both `lolafect` and the dependencies listed in `requirements-dev.txt`. -2. Run: `pytests tests` +2. Run: + - For all tests: `pytests tests` + - Only unit tests: `pytest tests/test_unit` + - Only integration tests: `pytest tests/test_integration` In Pycharm: - If you configure `pytest` as the project test runner, Pycharm will most probably autodetect the test - folder and allow you to run the test suite within the IDE. \ No newline at end of file + folder and allow you to run the test suite within the IDE. However, Pycharm has troubles running the integration + tests since the shell it runs from does not have the AWS credentials. Hence, for now we recommend you to only use + the Pycharm integrated test runner for the unit tests. You can easily set up a Run Configuration for that. \ No newline at end of file diff --git a/lolafect/__version__.py b/lolafect/__version__.py index 0590644..493f741 100644 --- a/lolafect/__version__.py +++ b/lolafect/__version__.py @@ -1 +1 @@ -__version__="0.2.0" \ No newline at end of file +__version__ = "0.3.0" diff --git a/lolafect/connections.py b/lolafect/connections.py new file mode 100644 index 0000000..33e2e5f --- /dev/null +++ b/lolafect/connections.py @@ -0,0 +1,288 @@ +import datetime +import os +from typing import Tuple +from contextlib import contextmanager + +import prefect +from prefect import task +from prefect.triggers import all_finished +from trino.auth import BasicAuthentication +import trino +import pymysql +import boto3 +from sshtunnel import SSHTunnelForwarder + +from lolafect.defaults import DEFAULT_TRINO_HTTP_SCHEME + + +@task(log_stdout=True, max_retries=3, retry_delay=datetime.timedelta(minutes=10)) +def connect_to_trino( + trino_credentials: dict, http_schema: str = DEFAULT_TRINO_HTTP_SCHEME +) -> trino.dbapi.Connection: + """ + Open a connection to the specified trino instance and return it. + + :param trino_credentials: a dict with the host, port, user and password. + :param http_schema: which http schema to use in the connection. + :return: the connection to trino. + """ + logger = prefect.context.get("logger") + logger.info( + f"Connecting to Trino at {trino_credentials['host']}:{trino_credentials['port']}." + ) + + connection = trino.dbapi.connect( + host=trino_credentials["host"], + port=trino_credentials["port"], + user=trino_credentials["user"], + http_scheme=http_schema, + auth=BasicAuthentication( + trino_credentials["user"], + trino_credentials["password"], + ), + ) + logger.info( + f"Connected to Trino at {trino_credentials['host']}:{trino_credentials['port']}." + ) + + return connection + + +@task(trigger=all_finished) +def close_trino_connection(trino_connection: trino.dbapi.Connection) -> None: + """ + Close a Trino connection, or do nothing if what has been passed is not a + Trino connection. + + :param trino_connection: a trino connection. + :return: None + """ + logger = prefect.context.get("logger") + if isinstance(trino_connection, trino.dbapi.Connection): + trino_connection.close() + logger.info("Trino connection closed successfully.") + return + logger.warning( + f"Instead of a Trino connection, a {type(trino_connection)} was received." + ) + raise DeprecationWarning( + "This method will only accept the type 'trino.dbapi.Connection' in next major release.\n" + "Please, update your code accordingly." + ) + + +@contextmanager +def _temp_secret_file_from_s3( + s3_bucket_name: str, s3_file_key: str, local_temp_file_path: str +) -> str: + """ + Downloads a file from S3 and ensures that it will be deleted once the + context is exited from, even in the face of an exception. + + :param s3_bucket_name: the bucket where the file lives. + :param s3_file_key: the key of the file within the bucket. + :param local_temp_file_path: the path where the file should be stored + temporarily. + :return: the local file path. + """ + boto3.client("s3").download_file( + s3_bucket_name, + s3_file_key, + local_temp_file_path, + ) + try: + yield local_temp_file_path + except Exception as e: + raise e + finally: + # Regardless of what happens in the context manager, we always delete the temp + # copy of the secret file. + os.remove(local_temp_file_path) + + +@task() +def open_ssh_tunnel_with_s3_pkey( + s3_bucket_name: str, + ssh_tunnel_credentials: dict, + remote_target_host: str, + remote_target_port: int, + local_bind_host: str = "127.0.0.1", + local_bind_port: int = 12345, +) -> SSHTunnelForwarder: + """ + Temporarily fetch a ssh key from S3 and then proceed to open a SSH tunnel + using it. + + :param s3_bucket_name: the bucket where the file lives. + :param ssh_tunnel_credentials: the details of the jumpthost SSH + connection. + :param remote_target_host: the remote host to tunnel to. + :param remote_target_port: the remote port to tunnel to. + :param local_bind_host: the host for the local bind address. + :param local_bind_port: the port for the local bind address. + :return: the tunnel, already open. + """ + logger = prefect.context.get("logger") + logger.info("Going to open an SSH tunnel.") + + temp_file_path = "temp" + + with _temp_secret_file_from_s3( + s3_bucket_name=s3_bucket_name, + s3_file_key=ssh_tunnel_credentials["path_to_ssh_pkey"], + local_temp_file_path=temp_file_path, + ) as ssh_key_file: + tunnel = open_ssh_tunnel( + local_bind_host, + local_bind_port, + remote_target_host, + remote_target_port, + ssh_key_file, + ssh_tunnel_credentials, + ) + logger.info( + f"SSH tunnel is now open and listening at {local_bind_host}:{local_bind_port}.\n" + f"Tunnel forwards to {remote_target_host}:{remote_target_port}" + ) + + return tunnel + + +def open_ssh_tunnel( + local_bind_host: str, + local_bind_port: int, + remote_target_host: str, + remote_target_port: int, + ssh_key_file_path: str, + ssh_tunnel_credentials: dict, +) -> SSHTunnelForwarder: + """ + Configure and start an SSH tunnel. + + :param local_bind_host: the local host address to bind the tunnel to. + :param local_bind_port: the local port address to bind the tunnel to. + :param remote_target_host: the remote host to forward to. + :param remote_target_port: the remote port to forward to. + :param ssh_key_file_path: the path to the ssh key. + :param ssh_tunnel_credentials: the details of the jumpthost SSH + connection. + :return: the tunnel, already open. + """ + + tunnel = SSHTunnelForwarder( + ssh_address_or_host=( + ssh_tunnel_credentials["ssh_jumphost"], + ssh_tunnel_credentials["ssh_port"], + ), + ssh_username=ssh_tunnel_credentials["ssh_username"], + ssh_pkey=ssh_key_file_path, + remote_bind_address=( + remote_target_host, + remote_target_port, + ), + local_bind_address=(local_bind_host, local_bind_port), + ssh_private_key_password=ssh_tunnel_credentials["ssh_pkey_password"], + ) + tunnel.start() + return tunnel + + +@task(trigger=all_finished) +def close_ssh_tunnel(tunnel: SSHTunnelForwarder) -> None: + """ + Close a SSH tunnel, or do nothing if something different is passed. + + :param tunnel: a SSH tunnel. + :return: + """ + logger = prefect.context.get("logger") + + if isinstance(tunnel, SSHTunnelForwarder): + tunnel.stop() + logger.info("SSH tunnel closed successfully.") + return + logger.warning(f"Instead of a SSH tunnel, a {type(tunnel)} was received.") + raise DeprecationWarning( + "This method will only accept the type 'SSHTunnelForwarder' in next major release.\n" + "Please, update your code accordingly." + ) + + +@task() +def get_local_bind_address_from_ssh_tunnel( + tunnel: SSHTunnelForwarder, +) -> Tuple[str, int]: + """ + A silly wrapper to be able to unpack the local bind address of a tunnel + within a Prefect flow. + + :param tunnel: an SSH tunnel. + :return: the local bind address of the SSH tunnel, as a tuple with host + and port. + """ + return tunnel.local_bind_address + + +@task() +def connect_to_mysql( + mysql_credentials: dict, overriding_host_and_port: Tuple[str, int] = None +) -> pymysql.Connection: + """ + Create a connection to a MySQL server, optionally using a host and port + different from the ones where the MySQL server is located. + + :param mysql_credentials: a dict with the connection details to the MySQL + instance. + :param overriding_host_and_port: an optional tuple containing a different + host and port. Useful to route through an SSH tunnel. + :return: the connection to the MySQL server. + """ + logger = prefect.context.get("logger") + logger.info( + f"Connecting to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}." + ) + + mysql_host = mysql_credentials["host"] + mysql_port = mysql_credentials["port"] + + if overriding_host_and_port: + # Since there is a tunnel, we actually want to connect to the local + # address of the tunnel, and not straight into the MySQL server. + mysql_host, mysql_port = overriding_host_and_port + logger.info( + f"Overriding the passed MySQL host and port with {mysql_host}:{mysql_port}." + ) + + db_connection = pymysql.connect( + host=mysql_host, + port=mysql_port, + user=mysql_credentials["user"], + password=mysql_credentials["password"], + ) + + logger.info( + f"Connected to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}." + ) + + return db_connection + + +@task(trigger=all_finished) +def close_mysql_connection(connection: pymysql.Connection) -> None: + """ + Close a MySQL connection, or do nothing if something different is passed. + + :param connection: a MySQL connection. + :return: None + """ + logger = prefect.context.get("logger") + + if isinstance(connection, pymysql.Connection): + connection.close() + logger.info("MySQL connection closed successfully.") + return + logger.warning(f"Instead of a MySQL connection, a {type(connection)} was received.") + raise DeprecationWarning( + "This method will only accept the type 'pymysql.Connection' in next major release.\n" + "Please, update your code accordingly." + ) diff --git a/lolafect/defaults.py b/lolafect/defaults.py index feb66e4..5c6c15f 100644 --- a/lolafect/defaults.py +++ b/lolafect/defaults.py @@ -1,6 +1,9 @@ -DEFAULT_ENV_S3_BUCKET="pdo-prefect-flows" -DEFAULT_ENV_FILE_PATH="env/env_prd.json" +DEFAULT_ENV_S3_BUCKET = "pdo-prefect-flows" +DEFAULT_ENV_FILE_PATH = "env/env_prd.json" DEFAULT_PATH_TO_SLACK_WEBHOOKS_FILE = "env/slack_webhooks.json" -DEFAULT_KUBERNETES_IMAGE = "373245262072.dkr.ecr.eu-central-1.amazonaws.com/pdo-data-prefect:production" -DEFAULT_KUBERNETES_LABELS = ["k8s"] +DEFAULT_KUBERNETES_IMAGE = ( + "373245262072.dkr.ecr.eu-central-1.amazonaws.com/pdo-data-prefect:production" +) +DEFAULT_KUBERNETES_LABELS = ["k8s"] DEFAULT_FLOWS_PATH_IN_BUCKET = "flows/" +DEFAULT_TRINO_HTTP_SCHEME = "https" diff --git a/requirements-dev.txt b/requirements-dev.txt index 40f6a2a..2adb75e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,4 +2,7 @@ prefect==1.2.2 requests==2.28.1 boto3==1.26.40 pytest==7.2.0 -httpretty==1.1.4 \ No newline at end of file +httpretty==1.1.4 +trino==0.321.0 +sshtunnel==0.4.0 +PyMySQL==1.0.2 \ No newline at end of file diff --git a/setup.py b/setup.py index 1b6c89e..3ac4e7f 100644 --- a/setup.py +++ b/setup.py @@ -23,5 +23,12 @@ setup( package_dir={"lolafect": "lolafect"}, include_package_data=True, python_requires=">=3.7", - install_requires=["prefect==1.2.2", "requests==2.28.1", "boto3==1.26.40"], + install_requires=[ + "prefect==1.2.2", + "requests==2.28.1", + "boto3==1.26.40", + "trino==0.321.0", + "sshtunnel==0.4.0", + "PyMySQL==1.0.2" + ], ) diff --git a/tests/test_integration/test_connections.py b/tests/test_integration/test_connections.py new file mode 100644 index 0000000..2d03039 --- /dev/null +++ b/tests/test_integration/test_connections.py @@ -0,0 +1,152 @@ +import pathlib + +from prefect.tasks.core.function import FunctionTask + +from lolafect.lolaconfig import build_lolaconfig +from lolafect.connections import ( + connect_to_trino, + close_trino_connection, + _temp_secret_file_from_s3, + open_ssh_tunnel_with_s3_pkey, + get_local_bind_address_from_ssh_tunnel, + close_ssh_tunnel, + connect_to_mysql, + close_mysql_connection, +) + +# __ __ _____ _ _ _____ _ _ _____ _ +# \ \ / /\ | __ \| \ | |_ _| \ | |/ ____| | +# \ \ /\ / / \ | |__) | \| | | | | \| | | __| | +# \ \/ \/ / /\ \ | _ /| . ` | | | | . ` | | |_ | | +# \ /\ / ____ \| | \ \| |\ |_| |_| |\ | |__| |_| +# \/ \/_/ \_\_| \_\_| \_|_____|_| \_|\_____(_) +# This testing suite requires: +# - The calling shell to have permission in AWS +# - The calling shell to be within the Mercadão network +# - Do not use this tests as part of CI/CD pipelines since they are not idempotent and +# rely external resources. Instead, use them manually to check yourself that things +# are working properly. +TEST_LOLACONFIG = build_lolaconfig(flow_name="testing-suite") + + +def test_that_trino_connect_and_disconnect_works_properly(): + + connection = connect_to_trino.run( + trino_credentials=TEST_LOLACONFIG.TRINO_CREDENTIALS + ) + + connection.cursor().execute("SELECT 1") + + close_trino_connection.run(trino_connection=connection) + + +def test_temporal_download_of_secret_file_works_properly_in_happy_path(): + + temp_file_name = "test_temp_file" + + with _temp_secret_file_from_s3( + TEST_LOLACONFIG.S3_BUCKET_NAME, + s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test, + local_temp_file_path=temp_file_name, + ) as temp: + temp_file_found_when_in_context_manager = pathlib.Path(temp).exists() + + temp_file_missing_when_outside_context_manager = not pathlib.Path( + temp_file_name + ).exists() + + assert ( + temp_file_found_when_in_context_manager + and temp_file_missing_when_outside_context_manager + ) + + +def test_temporal_download_of_secret_file_works_properly_even_with_exception(): + temp_file_name = "test_temp_file" + + try: + with _temp_secret_file_from_s3( + TEST_LOLACONFIG.S3_BUCKET_NAME, + s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test, + local_temp_file_path=temp_file_name, + ) as temp: + temp_file_found_when_in_context_manager = pathlib.Path(temp).exists() + raise Exception # Something nasty happens within the context manager + except: + pass # We go with the test, ignoring the forced exception + + temp_file_missing_when_outside_context_manager = not pathlib.Path( + temp_file_name + ).exists() + + assert ( + temp_file_found_when_in_context_manager + and temp_file_missing_when_outside_context_manager + ) + + +def test_opening_and_closing_ssh_tunnel_works_properly(): + + test_local_bind_host = "127.0.0.1" + test_local_bind_port = 12345 + + tunnel = open_ssh_tunnel_with_s3_pkey.run( + s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME, + ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS, + remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"], + remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"], + local_bind_host=test_local_bind_host, + local_bind_port=test_local_bind_port, + ) + tunnel_was_active = tunnel.is_active + + local_bind_host_matches = ( + get_local_bind_address_from_ssh_tunnel.run(tunnel)[0] == test_local_bind_host + ) + local_bind_port_matches = ( + get_local_bind_address_from_ssh_tunnel.run(tunnel)[1] == test_local_bind_port + ) + + close_ssh_tunnel.run(tunnel) + + tunnel_is_no_longer_active = not tunnel.is_active + + assert ( + tunnel_was_active + and tunnel_is_no_longer_active + and local_bind_host_matches + and local_bind_port_matches + ) + + +def test_connect_query_and_disconnect_from_mysql_with_tunnel(): + + test_local_bind_host = "127.0.0.1" + test_local_bind_port = 12345 + + tunnel = open_ssh_tunnel_with_s3_pkey.run( + s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME, + ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS, + remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"], + remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"], + local_bind_host=test_local_bind_host, + local_bind_port=test_local_bind_port, + ) + + connection = connect_to_mysql.run( + mysql_credentials=TEST_LOLACONFIG.DW_CREDENTIALS, + overriding_host_and_port=get_local_bind_address_from_ssh_tunnel.run( + tunnel=tunnel + ), + ) + + connection_was_open = connection.open + + connection.cursor().execute("SELECT 1") + + close_mysql_connection.run(connection=connection) + close_ssh_tunnel.run(tunnel=tunnel) + + connection_is_closed = not connection.open + + assert connection_was_open and connection_is_closed diff --git a/tests/test_lolaconfig.py b/tests/test_unit/test_lolaconfig.py similarity index 100% rename from tests/test_lolaconfig.py rename to tests/test_unit/test_lolaconfig.py diff --git a/tests/test_slack.py b/tests/test_unit/test_slack.py similarity index 100% rename from tests/test_slack.py rename to tests/test_unit/test_slack.py