Merge pull request #8 from lolamarket/release/0.3.0

Release/0.3.0
This commit is contained in:
pablolola 2023-01-27 15:30:41 +01:00 committed by GitHub
commit 74184c9a44
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 575 additions and 9 deletions

View file

@ -2,6 +2,15 @@
All notable changes to this project will be documented in this file. All notable changes to this project will be documented in this file.
## [0.3.0] - 2023-01-27
### Added
- Added Trino connection capabilities in the `connections` module.
- Added MySQL connection capabilities in the `connections` module.
- Added SSH tunneling capabilities in the `connections` module.
## [0.2.0] - 2023-01-19 ## [0.2.0] - 2023-01-19
### Added ### Added

108
README.md
View file

@ -6,6 +6,12 @@ Lolafect is a collection of Python bits that help us build our Prefect flows.
You can find below examples of how to leverage `lolafect` in your flows. You can find below examples of how to leverage `lolafect` in your flows.
**_Note: the code excerpts below are simplified for brevity and won't run
as-is. If you want to see perfect examples, you might want to check the tests
in this repository._**
### Config
**Let the `LolaConfig` object do the boilerplate env stuff for you** **Let the `LolaConfig` object do the boilerplate env stuff for you**
```python ```python
@ -36,6 +42,86 @@ lolaconfig = build_lolaconfig(
) )
``` ```
### Connections
**Connect to a Trino server**
```python
from lolafect.connections import connect_to_trino, close_trino_connection
with Flow(...) as flow:
connection = connect_to_trino.run(
trino_credentials=my_trino_credentials # You can probably try to fetch this from lolaconfig.TRINO_CREDENTIALS
)
task_result = some_trino_related_task(trino_connection=connection)
close_trino_connection.run(
trino_connection=connection,
upstream_tasks=[task_result]
)
```
**Open an SSH tunnel**
```python
from lolafect.connections import open_ssh_tunnel_with_s3_pkey, close_ssh_tunnel
with Flow(...) as flow:
# You probably want to fetch these args from lolaconfig.SSH_CREDENTIALS and lolaconfig.DW_CREDENTIALS
tunnel = open_ssh_tunnel_with_s3_pkey(
s3_bucket_name="some-bucket",
ssh_tunnel_credentials={...},
remote_target_host="some-host-probably-mysql",
remote_target_port=12345,
)
task_result = some_task_that_needs_the_tunnel(tunnel=tunnel)
# Tunnel is now alive. tunnel.is_active == True
close_ssh_tunnel(tunnel=tunnel, upstream_tasks=[task_result])
```
**Connect to a MySQL instance**
```python
from lolafect.connections import connect_to_mysql, close_mysql_connection
with Flow(...) as flow:
connection = connect_to_mysql(
mysql_credentials={...}, # You probably want to get this from TEST_LOLACONFIG.DW_CREDENTIALS
)
task_result = some_task_that_needs_mysql(connection=connection)
close_mysql_connection(connection=connection, upstream_tasks=[task_result])
# Want to connect through an SSH tunnel? Open the tunnel normally and then
# override the host and port when connecting to MySQL.
from lolafect.connections import (
open_ssh_tunnel_with_s3_pkey,
get_local_bind_address_from_ssh_tunnel,
close_ssh_tunnel
)
with Flow(...) as flow:
# You probably want to fetch these args from lolaconfig.SSH_CREDENTIALS and lolaconfig.DW_CREDENTIALS
tunnel = open_ssh_tunnel_with_s3_pkey(
s3_bucket_name="some-bucket",
ssh_tunnel_credentials={...},
remote_target_host="the-mysql-host",
remote_target_port=3306,
)
connection = connect_to_mysql(
mysql_credentials={...}, # You probably want to get this from TEST_LOLACONFIG.DW_CREDENTIALS
overriding_host_and_port=get_local_bind_address_from_ssh_tunnel.run(
tunnel=tunnel # This will open the connection through the SSH tunnel instead of straight to MySQL
),
)
task_result = some_task_that_needs_mysql(connection=connection)
mysql_closed = close_mysql_connection(connection=connection, upstream_tasks=[task_result])
close_ssh_tunnel.run(tunnel=tunnel, upstream_tasks=[mysql_closed])
```
### Slack
**Send a warning message to slack if your tasks fails** **Send a warning message to slack if your tasks fails**
```python ```python
@ -57,12 +143,30 @@ with Flow(...) as flow:
## How to test ## How to test
There are two test suites: unit tests and integration tests. Integration tests are prepared to plug to some of our
AWS resources, hence they are not fully reliable since they require specific credentials and permissions. The
recommended policy is:
- Use the unit tests in any CI process you want.
- Use the unit tests frequently as you code.
- Do not use the integration tests in CI processes.
- Use the integration tests as milestone checks when finishing feature branches.
- Make sure to ensure integration tests are working before making a new release.
When building new tests, please keep this philosophy in mind.
IDE-agnostic: IDE-agnostic:
1. Set up a virtual environment which contains both `lolafect` and the dependencies listed in `requirements-dev.txt`. 1. Set up a virtual environment which contains both `lolafect` and the dependencies listed in `requirements-dev.txt`.
2. Run: `pytests tests` 2. Run:
- For all tests: `pytests tests`
- Only unit tests: `pytest tests/test_unit`
- Only integration tests: `pytest tests/test_integration`
In Pycharm: In Pycharm:
- If you configure `pytest` as the project test runner, Pycharm will most probably autodetect the test - If you configure `pytest` as the project test runner, Pycharm will most probably autodetect the test
folder and allow you to run the test suite within the IDE. folder and allow you to run the test suite within the IDE. However, Pycharm has troubles running the integration
tests since the shell it runs from does not have the AWS credentials. Hence, for now we recommend you to only use
the Pycharm integrated test runner for the unit tests. You can easily set up a Run Configuration for that.

View file

@ -1 +1 @@
__version__="0.2.0" __version__ = "0.3.0"

288
lolafect/connections.py Normal file
View file

@ -0,0 +1,288 @@
import datetime
import os
from typing import Tuple
from contextlib import contextmanager
import prefect
from prefect import task
from prefect.triggers import all_finished
from trino.auth import BasicAuthentication
import trino
import pymysql
import boto3
from sshtunnel import SSHTunnelForwarder
from lolafect.defaults import DEFAULT_TRINO_HTTP_SCHEME
@task(log_stdout=True, max_retries=3, retry_delay=datetime.timedelta(minutes=10))
def connect_to_trino(
trino_credentials: dict, http_schema: str = DEFAULT_TRINO_HTTP_SCHEME
) -> trino.dbapi.Connection:
"""
Open a connection to the specified trino instance and return it.
:param trino_credentials: a dict with the host, port, user and password.
:param http_schema: which http schema to use in the connection.
:return: the connection to trino.
"""
logger = prefect.context.get("logger")
logger.info(
f"Connecting to Trino at {trino_credentials['host']}:{trino_credentials['port']}."
)
connection = trino.dbapi.connect(
host=trino_credentials["host"],
port=trino_credentials["port"],
user=trino_credentials["user"],
http_scheme=http_schema,
auth=BasicAuthentication(
trino_credentials["user"],
trino_credentials["password"],
),
)
logger.info(
f"Connected to Trino at {trino_credentials['host']}:{trino_credentials['port']}."
)
return connection
@task(trigger=all_finished)
def close_trino_connection(trino_connection: trino.dbapi.Connection) -> None:
"""
Close a Trino connection, or do nothing if what has been passed is not a
Trino connection.
:param trino_connection: a trino connection.
:return: None
"""
logger = prefect.context.get("logger")
if isinstance(trino_connection, trino.dbapi.Connection):
trino_connection.close()
logger.info("Trino connection closed successfully.")
return
logger.warning(
f"Instead of a Trino connection, a {type(trino_connection)} was received."
)
raise DeprecationWarning(
"This method will only accept the type 'trino.dbapi.Connection' in next major release.\n"
"Please, update your code accordingly."
)
@contextmanager
def _temp_secret_file_from_s3(
s3_bucket_name: str, s3_file_key: str, local_temp_file_path: str
) -> str:
"""
Downloads a file from S3 and ensures that it will be deleted once the
context is exited from, even in the face of an exception.
:param s3_bucket_name: the bucket where the file lives.
:param s3_file_key: the key of the file within the bucket.
:param local_temp_file_path: the path where the file should be stored
temporarily.
:return: the local file path.
"""
boto3.client("s3").download_file(
s3_bucket_name,
s3_file_key,
local_temp_file_path,
)
try:
yield local_temp_file_path
except Exception as e:
raise e
finally:
# Regardless of what happens in the context manager, we always delete the temp
# copy of the secret file.
os.remove(local_temp_file_path)
@task()
def open_ssh_tunnel_with_s3_pkey(
s3_bucket_name: str,
ssh_tunnel_credentials: dict,
remote_target_host: str,
remote_target_port: int,
local_bind_host: str = "127.0.0.1",
local_bind_port: int = 12345,
) -> SSHTunnelForwarder:
"""
Temporarily fetch a ssh key from S3 and then proceed to open a SSH tunnel
using it.
:param s3_bucket_name: the bucket where the file lives.
:param ssh_tunnel_credentials: the details of the jumpthost SSH
connection.
:param remote_target_host: the remote host to tunnel to.
:param remote_target_port: the remote port to tunnel to.
:param local_bind_host: the host for the local bind address.
:param local_bind_port: the port for the local bind address.
:return: the tunnel, already open.
"""
logger = prefect.context.get("logger")
logger.info("Going to open an SSH tunnel.")
temp_file_path = "temp"
with _temp_secret_file_from_s3(
s3_bucket_name=s3_bucket_name,
s3_file_key=ssh_tunnel_credentials["path_to_ssh_pkey"],
local_temp_file_path=temp_file_path,
) as ssh_key_file:
tunnel = open_ssh_tunnel(
local_bind_host,
local_bind_port,
remote_target_host,
remote_target_port,
ssh_key_file,
ssh_tunnel_credentials,
)
logger.info(
f"SSH tunnel is now open and listening at {local_bind_host}:{local_bind_port}.\n"
f"Tunnel forwards to {remote_target_host}:{remote_target_port}"
)
return tunnel
def open_ssh_tunnel(
local_bind_host: str,
local_bind_port: int,
remote_target_host: str,
remote_target_port: int,
ssh_key_file_path: str,
ssh_tunnel_credentials: dict,
) -> SSHTunnelForwarder:
"""
Configure and start an SSH tunnel.
:param local_bind_host: the local host address to bind the tunnel to.
:param local_bind_port: the local port address to bind the tunnel to.
:param remote_target_host: the remote host to forward to.
:param remote_target_port: the remote port to forward to.
:param ssh_key_file_path: the path to the ssh key.
:param ssh_tunnel_credentials: the details of the jumpthost SSH
connection.
:return: the tunnel, already open.
"""
tunnel = SSHTunnelForwarder(
ssh_address_or_host=(
ssh_tunnel_credentials["ssh_jumphost"],
ssh_tunnel_credentials["ssh_port"],
),
ssh_username=ssh_tunnel_credentials["ssh_username"],
ssh_pkey=ssh_key_file_path,
remote_bind_address=(
remote_target_host,
remote_target_port,
),
local_bind_address=(local_bind_host, local_bind_port),
ssh_private_key_password=ssh_tunnel_credentials["ssh_pkey_password"],
)
tunnel.start()
return tunnel
@task(trigger=all_finished)
def close_ssh_tunnel(tunnel: SSHTunnelForwarder) -> None:
"""
Close a SSH tunnel, or do nothing if something different is passed.
:param tunnel: a SSH tunnel.
:return:
"""
logger = prefect.context.get("logger")
if isinstance(tunnel, SSHTunnelForwarder):
tunnel.stop()
logger.info("SSH tunnel closed successfully.")
return
logger.warning(f"Instead of a SSH tunnel, a {type(tunnel)} was received.")
raise DeprecationWarning(
"This method will only accept the type 'SSHTunnelForwarder' in next major release.\n"
"Please, update your code accordingly."
)
@task()
def get_local_bind_address_from_ssh_tunnel(
tunnel: SSHTunnelForwarder,
) -> Tuple[str, int]:
"""
A silly wrapper to be able to unpack the local bind address of a tunnel
within a Prefect flow.
:param tunnel: an SSH tunnel.
:return: the local bind address of the SSH tunnel, as a tuple with host
and port.
"""
return tunnel.local_bind_address
@task()
def connect_to_mysql(
mysql_credentials: dict, overriding_host_and_port: Tuple[str, int] = None
) -> pymysql.Connection:
"""
Create a connection to a MySQL server, optionally using a host and port
different from the ones where the MySQL server is located.
:param mysql_credentials: a dict with the connection details to the MySQL
instance.
:param overriding_host_and_port: an optional tuple containing a different
host and port. Useful to route through an SSH tunnel.
:return: the connection to the MySQL server.
"""
logger = prefect.context.get("logger")
logger.info(
f"Connecting to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}."
)
mysql_host = mysql_credentials["host"]
mysql_port = mysql_credentials["port"]
if overriding_host_and_port:
# Since there is a tunnel, we actually want to connect to the local
# address of the tunnel, and not straight into the MySQL server.
mysql_host, mysql_port = overriding_host_and_port
logger.info(
f"Overriding the passed MySQL host and port with {mysql_host}:{mysql_port}."
)
db_connection = pymysql.connect(
host=mysql_host,
port=mysql_port,
user=mysql_credentials["user"],
password=mysql_credentials["password"],
)
logger.info(
f"Connected to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}."
)
return db_connection
@task(trigger=all_finished)
def close_mysql_connection(connection: pymysql.Connection) -> None:
"""
Close a MySQL connection, or do nothing if something different is passed.
:param connection: a MySQL connection.
:return: None
"""
logger = prefect.context.get("logger")
if isinstance(connection, pymysql.Connection):
connection.close()
logger.info("MySQL connection closed successfully.")
return
logger.warning(f"Instead of a MySQL connection, a {type(connection)} was received.")
raise DeprecationWarning(
"This method will only accept the type 'pymysql.Connection' in next major release.\n"
"Please, update your code accordingly."
)

View file

@ -1,6 +1,9 @@
DEFAULT_ENV_S3_BUCKET="pdo-prefect-flows" DEFAULT_ENV_S3_BUCKET = "pdo-prefect-flows"
DEFAULT_ENV_FILE_PATH="env/env_prd.json" DEFAULT_ENV_FILE_PATH = "env/env_prd.json"
DEFAULT_PATH_TO_SLACK_WEBHOOKS_FILE = "env/slack_webhooks.json" DEFAULT_PATH_TO_SLACK_WEBHOOKS_FILE = "env/slack_webhooks.json"
DEFAULT_KUBERNETES_IMAGE = "373245262072.dkr.ecr.eu-central-1.amazonaws.com/pdo-data-prefect:production" DEFAULT_KUBERNETES_IMAGE = (
"373245262072.dkr.ecr.eu-central-1.amazonaws.com/pdo-data-prefect:production"
)
DEFAULT_KUBERNETES_LABELS = ["k8s"] DEFAULT_KUBERNETES_LABELS = ["k8s"]
DEFAULT_FLOWS_PATH_IN_BUCKET = "flows/" DEFAULT_FLOWS_PATH_IN_BUCKET = "flows/"
DEFAULT_TRINO_HTTP_SCHEME = "https"

View file

@ -3,3 +3,6 @@ requests==2.28.1
boto3==1.26.40 boto3==1.26.40
pytest==7.2.0 pytest==7.2.0
httpretty==1.1.4 httpretty==1.1.4
trino==0.321.0
sshtunnel==0.4.0
PyMySQL==1.0.2

View file

@ -23,5 +23,12 @@ setup(
package_dir={"lolafect": "lolafect"}, package_dir={"lolafect": "lolafect"},
include_package_data=True, include_package_data=True,
python_requires=">=3.7", python_requires=">=3.7",
install_requires=["prefect==1.2.2", "requests==2.28.1", "boto3==1.26.40"], install_requires=[
"prefect==1.2.2",
"requests==2.28.1",
"boto3==1.26.40",
"trino==0.321.0",
"sshtunnel==0.4.0",
"PyMySQL==1.0.2"
],
) )

View file

@ -0,0 +1,152 @@
import pathlib
from prefect.tasks.core.function import FunctionTask
from lolafect.lolaconfig import build_lolaconfig
from lolafect.connections import (
connect_to_trino,
close_trino_connection,
_temp_secret_file_from_s3,
open_ssh_tunnel_with_s3_pkey,
get_local_bind_address_from_ssh_tunnel,
close_ssh_tunnel,
connect_to_mysql,
close_mysql_connection,
)
# __ __ _____ _ _ _____ _ _ _____ _
# \ \ / /\ | __ \| \ | |_ _| \ | |/ ____| |
# \ \ /\ / / \ | |__) | \| | | | | \| | | __| |
# \ \/ \/ / /\ \ | _ /| . ` | | | | . ` | | |_ | |
# \ /\ / ____ \| | \ \| |\ |_| |_| |\ | |__| |_|
# \/ \/_/ \_\_| \_\_| \_|_____|_| \_|\_____(_)
# This testing suite requires:
# - The calling shell to have permission in AWS
# - The calling shell to be within the Mercadão network
# - Do not use this tests as part of CI/CD pipelines since they are not idempotent and
# rely external resources. Instead, use them manually to check yourself that things
# are working properly.
TEST_LOLACONFIG = build_lolaconfig(flow_name="testing-suite")
def test_that_trino_connect_and_disconnect_works_properly():
connection = connect_to_trino.run(
trino_credentials=TEST_LOLACONFIG.TRINO_CREDENTIALS
)
connection.cursor().execute("SELECT 1")
close_trino_connection.run(trino_connection=connection)
def test_temporal_download_of_secret_file_works_properly_in_happy_path():
temp_file_name = "test_temp_file"
with _temp_secret_file_from_s3(
TEST_LOLACONFIG.S3_BUCKET_NAME,
s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test,
local_temp_file_path=temp_file_name,
) as temp:
temp_file_found_when_in_context_manager = pathlib.Path(temp).exists()
temp_file_missing_when_outside_context_manager = not pathlib.Path(
temp_file_name
).exists()
assert (
temp_file_found_when_in_context_manager
and temp_file_missing_when_outside_context_manager
)
def test_temporal_download_of_secret_file_works_properly_even_with_exception():
temp_file_name = "test_temp_file"
try:
with _temp_secret_file_from_s3(
TEST_LOLACONFIG.S3_BUCKET_NAME,
s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test,
local_temp_file_path=temp_file_name,
) as temp:
temp_file_found_when_in_context_manager = pathlib.Path(temp).exists()
raise Exception # Something nasty happens within the context manager
except:
pass # We go with the test, ignoring the forced exception
temp_file_missing_when_outside_context_manager = not pathlib.Path(
temp_file_name
).exists()
assert (
temp_file_found_when_in_context_manager
and temp_file_missing_when_outside_context_manager
)
def test_opening_and_closing_ssh_tunnel_works_properly():
test_local_bind_host = "127.0.0.1"
test_local_bind_port = 12345
tunnel = open_ssh_tunnel_with_s3_pkey.run(
s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME,
ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS,
remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"],
remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"],
local_bind_host=test_local_bind_host,
local_bind_port=test_local_bind_port,
)
tunnel_was_active = tunnel.is_active
local_bind_host_matches = (
get_local_bind_address_from_ssh_tunnel.run(tunnel)[0] == test_local_bind_host
)
local_bind_port_matches = (
get_local_bind_address_from_ssh_tunnel.run(tunnel)[1] == test_local_bind_port
)
close_ssh_tunnel.run(tunnel)
tunnel_is_no_longer_active = not tunnel.is_active
assert (
tunnel_was_active
and tunnel_is_no_longer_active
and local_bind_host_matches
and local_bind_port_matches
)
def test_connect_query_and_disconnect_from_mysql_with_tunnel():
test_local_bind_host = "127.0.0.1"
test_local_bind_port = 12345
tunnel = open_ssh_tunnel_with_s3_pkey.run(
s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME,
ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS,
remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"],
remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"],
local_bind_host=test_local_bind_host,
local_bind_port=test_local_bind_port,
)
connection = connect_to_mysql.run(
mysql_credentials=TEST_LOLACONFIG.DW_CREDENTIALS,
overriding_host_and_port=get_local_bind_address_from_ssh_tunnel.run(
tunnel=tunnel
),
)
connection_was_open = connection.open
connection.cursor().execute("SELECT 1")
close_mysql_connection.run(connection=connection)
close_ssh_tunnel.run(tunnel=tunnel)
connection_is_closed = not connection.open
assert connection_was_open and connection_is_closed