Merge pull request #7 from lolamarket/feature/mysql-connection

Feature/mysql connection
This commit is contained in:
pablolola 2023-01-27 12:58:25 +01:00 committed by GitHub
commit 0b648064dd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 446 additions and 8 deletions

View file

@ -6,7 +6,9 @@ All notable changes to this project will be documented in this file.
### Added
- Added Trino connection capabilities in `connections` module.
- Added Trino connection capabilities in the `connections` module.
- Added MySQL connection capabilities in the `connections` module.
- Added SSH tunneling capabilities in the `connections` module.
## [0.2.0] - 2023-01-19

View file

@ -6,6 +6,12 @@ Lolafect is a collection of Python bits that help us build our Prefect flows.
You can find below examples of how to leverage `lolafect` in your flows.
**_Note: the code excerpts below are simplified for brevity and won't run
as-is. If you want to see perfect examples, you might want to check the tests
in this repository._**
### Config
**Let the `LolaConfig` object do the boilerplate env stuff for you**
```python
@ -36,6 +42,8 @@ lolaconfig = build_lolaconfig(
)
```
### Connections
**Connect to a Trino server**
```python
@ -52,6 +60,66 @@ with Flow(...) as flow:
)
```
**Open an SSH tunnel**
```python
from lolafect.connections import open_ssh_tunnel_with_s3_pkey, close_ssh_tunnel
with Flow(...) as flow:
# You probably want to fetch these args from lolaconfig.SSH_CREDENTIALS and lolaconfig.DW_CREDENTIALS
tunnel = open_ssh_tunnel_with_s3_pkey(
s3_bucket_name="some-bucket",
ssh_tunnel_credentials={...},
remote_target_host="some-host-probably-mysql",
remote_target_port=12345,
)
# Tunnel is now alive. tunnel.is_active == True
close_ssh_tunnel(tunnel=tunnel)
```
**Connect to a MySQL instance**
```python
from lolafect.connections import connect_to_mysql, close_mysql_connection
with Flow(...) as flow:
connection = connect_to_mysql.run(
mysql_credentials={...}, # You probably want to get this from TEST_LOLACONFIG.DW_CREDENTIALS
)
connection.cursor().execute("SELECT 1")
close_mysql_connection.run(connection=connection)
# Want to connect through an SSH tunnel? Open the tunnel normally and then
# override the host and port when connecting to MySQL.
from lolafect.connections import (
open_ssh_tunnel_with_s3_pkey,
get_local_bind_address_from_ssh_tunnel,
close_ssh_tunnel
)
with Flow(...) as flow:
# You probably want to fetch these args from lolaconfig.SSH_CREDENTIALS and lolaconfig.DW_CREDENTIALS
tunnel = open_ssh_tunnel_with_s3_pkey(
s3_bucket_name="some-bucket",
ssh_tunnel_credentials={...},
remote_target_host="the-mysql-host",
remote_target_port=3306,
)
connection = connect_to_mysql.run(
mysql_credentials={...}, # You probably want to get this from TEST_LOLACONFIG.DW_CREDENTIALS
overriding_host_and_port=get_local_bind_address_from_ssh_tunnel.run(
tunnel=tunnel # This will open the connection through the SSH tunnel instead of straight to MySQL
),
)
connection.cursor().execute("SELECT 1")
close_mysql_connection.run(connection=connection)
close_ssh_tunnel.run(tunnel=tunnel)
```
### Slack
**Send a warning message to slack if your tasks fails**

View file

@ -1 +1 @@
__version__="0.2.0"
__version__="dev"

View file

@ -1,10 +1,16 @@
import datetime
import os
from typing import Tuple
from contextlib import contextmanager
import prefect
from prefect import task
from prefect.triggers import all_finished
from trino.auth import BasicAuthentication
import trino
import pymysql
import boto3
from sshtunnel import SSHTunnelForwarder
from lolafect.defaults import DEFAULT_TRINO_HTTP_SCHEME
@ -18,7 +24,7 @@ def connect_to_trino(
:param trino_credentials: a dict with the host, port, user and password.
:param http_schema: which http schema to use in the connection.
:return:
:return: the connection to trino.
"""
logger = prefect.context.get("logger")
logger.info(
@ -56,4 +62,232 @@ def close_trino_connection(trino_connection: trino.dbapi.Connection) -> None:
trino_connection.close()
logger.info("Trino connection closed successfully.")
return
logger.info("No connection received.")
logger.warning(
f"Instead of a Trino connection, a {type(trino_connection)} was received."
)
raise DeprecationWarning(
"This method will only accept the type 'trino.dbapi.Connection' in next major release.\n"
"Please, update your code accordingly."
)
@contextmanager
def _temp_secret_file_from_s3(
s3_bucket_name: str, s3_file_key: str, local_temp_file_path: str
) -> str:
"""
Downloads a file from S3 and ensures that it will be deleted once the
context is exited from, even in the face of an exception.
:param s3_bucket_name: the bucket where the file lives.
:param s3_file_key: the key of the file within the bucket.
:param local_temp_file_path: the path where the file should be stored
temporarily.
:return: the local file path.
"""
boto3.client("s3").download_file(
s3_bucket_name,
s3_file_key,
local_temp_file_path,
)
try:
yield local_temp_file_path
except Exception as e:
raise e
finally:
# Regardless of what happens in the context manager, we always delete the temp
# copy of the secret file.
os.remove(local_temp_file_path)
@task()
def open_ssh_tunnel_with_s3_pkey(
s3_bucket_name: str,
ssh_tunnel_credentials: dict,
remote_target_host: str,
remote_target_port: int,
local_bind_host: str = "127.0.0.1",
local_bind_port: int = 12345,
) -> SSHTunnelForwarder:
"""
Temporarily fetch a ssh key from S3 and then proceed to open a SSH tunnel
using it.
:param s3_bucket_name: the bucket where the file lives.
:param ssh_tunnel_credentials: the details of the jumpthost SSH
connection.
:param remote_target_host: the remote host to tunnel to.
:param remote_target_port: the remote port to tunnel to.
:param local_bind_host: the host for the local bind address.
:param local_bind_port: the port for the local bind address.
:return: the tunnel, already open.
"""
logger = prefect.context.get("logger")
logger.info("Going to open an SSH tunnel.")
temp_file_path = "temp"
with _temp_secret_file_from_s3(
s3_bucket_name=s3_bucket_name,
s3_file_key=ssh_tunnel_credentials["path_to_ssh_pkey"],
local_temp_file_path=temp_file_path,
) as ssh_key_file:
tunnel = open_ssh_tunnel(
local_bind_host,
local_bind_port,
remote_target_host,
remote_target_port,
ssh_key_file,
ssh_tunnel_credentials,
)
logger.info(
f"SSH tunnel is now open and listening at {local_bind_host}:{local_bind_port}.\n"
f"Tunnel forwards to {remote_target_host}:{remote_target_port}"
)
return tunnel
def open_ssh_tunnel(
local_bind_host: str,
local_bind_port: int,
remote_target_host: str,
remote_target_port: int,
ssh_key_file_path: str,
ssh_tunnel_credentials: dict,
) -> SSHTunnelForwarder:
"""
Configure and start an SSH tunnel.
:param local_bind_host: the local host address to bind the tunnel to.
:param local_bind_port: the local port address to bind the tunnel to.
:param remote_target_host: the remote host to forward to.
:param remote_target_port: the remote port to forward to.
:param ssh_key_file_path: the path to the ssh key.
:param ssh_tunnel_credentials: the details of the jumpthost SSH
connection.
:return: the tunnel, already open.
"""
tunnel = SSHTunnelForwarder(
ssh_address_or_host=(
ssh_tunnel_credentials["ssh_jumphost"],
ssh_tunnel_credentials["ssh_port"],
),
ssh_username=ssh_tunnel_credentials["ssh_username"],
ssh_pkey=ssh_key_file_path,
remote_bind_address=(
remote_target_host,
remote_target_port,
),
local_bind_address=(local_bind_host, local_bind_port),
ssh_private_key_password=ssh_tunnel_credentials["ssh_pkey_password"],
)
tunnel.start()
return tunnel
@task(trigger=all_finished)
def close_ssh_tunnel(tunnel: SSHTunnelForwarder) -> None:
"""
Close a SSH tunnel, or do nothing if something different is passed.
:param tunnel: a SSH tunnel.
:return:
"""
logger = prefect.context.get("logger")
if isinstance(tunnel, SSHTunnelForwarder):
tunnel.stop()
logger.info("SSH tunnel closed successfully.")
return
logger.warning(f"Instead of a SSH tunnel, a {type(tunnel)} was received.")
raise DeprecationWarning(
"This method will only accept the type 'SSHTunnelForwarder' in next major release.\n"
"Please, update your code accordingly."
)
@task()
def get_local_bind_address_from_ssh_tunnel(
tunnel: SSHTunnelForwarder,
) -> Tuple[str, int]:
"""
A silly wrapper to be able to unpack the local bind address of a tunnel
within a Prefect flow.
:param tunnel: an SSH tunnel.
:return: the local bind address of the SSH tunnel, as a tuple with host
and port.
"""
return tunnel.local_bind_address
@task()
def connect_to_mysql(
mysql_credentials: dict, overriding_host_and_port: Tuple[str, int] = None
) -> pymysql.Connection:
"""
Create a connection to a MySQL server, optionally using a host and port
different from the ones where the MySQL server is located.
:param mysql_credentials: a dict with the connection details to the MySQL
instance.
:param overriding_host_and_port: an optional tuple containing a different
host and port. Useful to route through an SSH tunnel.
:return: the connection to the MySQL server.
"""
logger = prefect.context.get("logger")
logger.info(
f"Connecting to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}."
)
mysql_host = mysql_credentials["host"]
mysql_port = mysql_credentials["port"]
if overriding_host_and_port:
# Since there is a tunnel, we actually want to connect to the local
# address of the tunnel, and not straight into the MySQL server.
mysql_host, mysql_port = overriding_host_and_port
logger.info(
f"Overriding the passed MySQL host and port with {mysql_host}:{mysql_port}."
)
db_connection = pymysql.connect(
host=mysql_host,
port=mysql_port,
user=mysql_credentials["user"],
password=mysql_credentials["password"],
database="dw_xl",
)
# Customizing this attributes to retrieve them later in the flow
db_connection.raw_user = mysql_credentials["user"]
db_connection.raw_password = mysql_credentials["password"]
logger.info(
f"Connected to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}."
)
return db_connection
@task(trigger=all_finished)
def close_mysql_connection(connection: pymysql.Connection) -> None:
"""
Close a MySQL connection, or do nothing if something different is passed.
:param connection: a MySQL connection.
:return: None
"""
logger = prefect.context.get("logger")
if isinstance(connection, pymysql.Connection):
connection.close()
logger.info("MySQL connection closed successfully.")
return
logger.warning(f"Instead of a MySQL connection, a {type(connection)} was received.")
raise DeprecationWarning(
"This method will only accept the type 'pymysql.Connection' in next major release.\n"
"Please, update your code accordingly."
)

View file

@ -4,3 +4,5 @@ boto3==1.26.40
pytest==7.2.0
httpretty==1.1.4
trino==0.321.0
sshtunnel==0.4.0
PyMySQL==1.0.2

View file

@ -23,5 +23,12 @@ setup(
package_dir={"lolafect": "lolafect"},
include_package_data=True,
python_requires=">=3.7",
install_requires=["prefect==1.2.2", "requests==2.28.1", "boto3==1.26.40", "trino==0.321.0"],
install_requires=[
"prefect==1.2.2",
"requests==2.28.1",
"boto3==1.26.40",
"trino==0.321.0",
"sshtunnel==0.4.0",
"PyMySQL==1.0.2"
],
)

View file

@ -1,5 +1,18 @@
import pathlib
from prefect.tasks.core.function import FunctionTask
from lolafect.lolaconfig import build_lolaconfig
from lolafect.connections import connect_to_trino, close_trino_connection
from lolafect.connections import (
connect_to_trino,
close_trino_connection,
_temp_secret_file_from_s3,
open_ssh_tunnel_with_s3_pkey,
get_local_bind_address_from_ssh_tunnel,
close_ssh_tunnel,
connect_to_mysql,
close_mysql_connection,
)
# __ __ _____ _ _ _____ _ _ _____ _
# \ \ / /\ | __ \| \ | |_ _| \ | |/ ____| |
@ -25,3 +38,115 @@ def test_that_trino_connect_and_disconnect_works_properly():
connection.cursor().execute("SELECT 1")
close_trino_connection.run(trino_connection=connection)
def test_temporal_download_of_secret_file_works_properly_in_happy_path():
temp_file_name = "test_temp_file"
with _temp_secret_file_from_s3(
TEST_LOLACONFIG.S3_BUCKET_NAME,
s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test,
local_temp_file_path=temp_file_name,
) as temp:
temp_file_found_when_in_context_manager = pathlib.Path(temp).exists()
temp_file_missing_when_outside_context_manager = not pathlib.Path(
temp_file_name
).exists()
assert (
temp_file_found_when_in_context_manager
and temp_file_missing_when_outside_context_manager
)
def test_temporal_download_of_secret_file_works_properly_even_with_exception():
temp_file_name = "test_temp_file"
try:
with _temp_secret_file_from_s3(
TEST_LOLACONFIG.S3_BUCKET_NAME,
s3_file_key="env/env_prd.json", # Not a secret file, but then again, this is a test,
local_temp_file_path=temp_file_name,
) as temp:
temp_file_found_when_in_context_manager = pathlib.Path(temp).exists()
raise Exception # Something nasty happens within the context manager
except:
pass # We go with the test, ignoring the forced exception
temp_file_missing_when_outside_context_manager = not pathlib.Path(
temp_file_name
).exists()
assert (
temp_file_found_when_in_context_manager
and temp_file_missing_when_outside_context_manager
)
def test_opening_and_closing_ssh_tunnel_works_properly():
test_local_bind_host = "127.0.0.1"
test_local_bind_port = 12345
tunnel = open_ssh_tunnel_with_s3_pkey.run(
s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME,
ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS,
remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"],
remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"],
local_bind_host=test_local_bind_host,
local_bind_port=test_local_bind_port,
)
tunnel_was_active = tunnel.is_active
local_bind_host_matches = (
get_local_bind_address_from_ssh_tunnel.run(tunnel)[0] == test_local_bind_host
)
local_bind_port_matches = (
get_local_bind_address_from_ssh_tunnel.run(tunnel)[1] == test_local_bind_port
)
close_ssh_tunnel.run(tunnel)
tunnel_is_no_longer_active = not tunnel.is_active
assert (
tunnel_was_active
and tunnel_is_no_longer_active
and local_bind_host_matches
and local_bind_port_matches
)
def test_connect_query_and_disconnect_from_mysql_with_tunnel():
test_local_bind_host = "127.0.0.1"
test_local_bind_port = 12345
tunnel = open_ssh_tunnel_with_s3_pkey.run(
s3_bucket_name=TEST_LOLACONFIG.S3_BUCKET_NAME,
ssh_tunnel_credentials=TEST_LOLACONFIG.SSH_TUNNEL_CREDENTIALS,
remote_target_host=TEST_LOLACONFIG.DW_CREDENTIALS["host"],
remote_target_port=TEST_LOLACONFIG.DW_CREDENTIALS["port"],
local_bind_host=test_local_bind_host,
local_bind_port=test_local_bind_port,
)
connection = connect_to_mysql.run(
mysql_credentials=TEST_LOLACONFIG.DW_CREDENTIALS,
overriding_host_and_port=get_local_bind_address_from_ssh_tunnel.run(
tunnel=tunnel
),
)
connection_was_open = connection.open
connection.cursor().execute("SELECT 1")
close_mysql_connection.run(connection=connection)
close_ssh_tunnel.run(tunnel=tunnel)
connection_is_closed = not connection.open
assert connection_was_open and connection_is_closed