Adapting references
This commit is contained in:
parent
9feda795e9
commit
e4e1f42309
1 changed files with 43 additions and 34 deletions
|
|
@ -1,4 +1,6 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
import os
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import prefect
|
import prefect
|
||||||
from prefect import task
|
from prefect import task
|
||||||
|
|
@ -6,7 +8,8 @@ from prefect.triggers import all_finished
|
||||||
from trino.auth import BasicAuthentication
|
from trino.auth import BasicAuthentication
|
||||||
import trino
|
import trino
|
||||||
import pymysql
|
import pymysql
|
||||||
import sshtunnel
|
import boto3
|
||||||
|
from sshtunnel import SSHTunnelForwarder
|
||||||
|
|
||||||
from lolafect.defaults import DEFAULT_TRINO_HTTP_SCHEME
|
from lolafect.defaults import DEFAULT_TRINO_HTTP_SCHEME
|
||||||
|
|
||||||
|
|
@ -62,38 +65,42 @@ def close_trino_connection(trino_connection: trino.dbapi.Connection) -> None:
|
||||||
|
|
||||||
|
|
||||||
@task(log_stdout=True, nout=2)
|
@task(log_stdout=True, nout=2)
|
||||||
def connect_to_dw(use_ssh_tunnel):
|
def connect_to_dw(
|
||||||
print("Connecting to DW")
|
mysql_credentials: dict,
|
||||||
import pymysql
|
use_ssh_tunnel: bool,
|
||||||
|
s3_bucket_name,
|
||||||
|
ssh_tunnel_credentials,
|
||||||
|
) -> Tuple[pymysql.Connection, SSHTunnelForwarder]:
|
||||||
|
logger = prefect.context.get("logger")
|
||||||
|
logger.info(
|
||||||
|
f"Connecting to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}."
|
||||||
|
)
|
||||||
|
|
||||||
mysql_host = LOLACONFIG.DW_CREDENTIALS["host"]
|
mysql_host = mysql_credentials["host"]
|
||||||
tunnel = None
|
tunnel = None
|
||||||
if use_ssh_tunnel:
|
if use_ssh_tunnel:
|
||||||
print("Going to open an SSH tunnel.")
|
logger.info("Going to open an SSH tunnel.")
|
||||||
from sshtunnel import SSHTunnelForwarder
|
|
||||||
|
|
||||||
temp_file_path = "temp"
|
temp_file_path = "temp"
|
||||||
try:
|
try:
|
||||||
boto3.client("s3").download_file(
|
boto3.client("s3").download_file(
|
||||||
LOLACONFIG.S3_BUCKET_NAME,
|
s3_bucket_name,
|
||||||
LOLACONFIG.SSH_TUNNEL_CREDENTIALS["path_to_ssh_pkey"],
|
ssh_tunnel_credentials["path_to_ssh_pkey"],
|
||||||
temp_file_path,
|
temp_file_path,
|
||||||
)
|
)
|
||||||
tunnel = SSHTunnelForwarder(
|
tunnel = SSHTunnelForwarder(
|
||||||
ssh_host=(
|
ssh_host=(
|
||||||
LOLACONFIG.SSH_TUNNEL_CREDENTIALS["ssh_jumphost"],
|
ssh_tunnel_credentials["ssh_jumphost"],
|
||||||
LOLACONFIG.SSH_TUNNEL_CREDENTIALS["ssh_port"],
|
ssh_tunnel_credentials["ssh_port"],
|
||||||
),
|
),
|
||||||
ssh_username=LOLACONFIG.SSH_TUNNEL_CREDENTIALS["ssh_username"],
|
ssh_username=ssh_tunnel_credentials["ssh_username"],
|
||||||
ssh_pkey=temp_file_path,
|
ssh_pkey=temp_file_path,
|
||||||
remote_bind_address=(
|
remote_bind_address=(
|
||||||
LOLACONFIG.DW_CREDENTIALS["host"],
|
mysql_credentials["host"],
|
||||||
LOLACONFIG.DW_CREDENTIALS["port"],
|
mysql_credentials["port"],
|
||||||
),
|
),
|
||||||
local_bind_address=("127.0.0.1", LOLACONFIG.DW_CREDENTIALS["port"]),
|
local_bind_address=("127.0.0.1", mysql_credentials["port"]),
|
||||||
ssh_private_key_password=LOLACONFIG.SSH_TUNNEL_CREDENTIALS[
|
ssh_private_key_password=mysql_credentials["ssh_pkey_password"],
|
||||||
"ssh_pkey_password"
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
@ -102,43 +109,45 @@ def connect_to_dw(use_ssh_tunnel):
|
||||||
os.remove(temp_file_path)
|
os.remove(temp_file_path)
|
||||||
|
|
||||||
tunnel.start()
|
tunnel.start()
|
||||||
print("SSH tunnel is now open.")
|
logger.info(
|
||||||
|
f"SSH tunnel is now open and listening at{mysql_credentials['host']}:{mysql_credentials['port']}."
|
||||||
|
)
|
||||||
mysql_host = "127.0.0.1"
|
mysql_host = "127.0.0.1"
|
||||||
|
|
||||||
db_connection = pymysql.connect(
|
db_connection = pymysql.connect(
|
||||||
host=mysql_host,
|
host=mysql_host,
|
||||||
port=LOLACONFIG.DW_CREDENTIALS["port"],
|
port=mysql_credentials["port"],
|
||||||
user=LOLACONFIG.DW_CREDENTIALS["user"],
|
user=mysql_credentials["user"],
|
||||||
password=LOLACONFIG.DW_CREDENTIALS["password"],
|
password=mysql_credentials["password"],
|
||||||
database="dw_xl",
|
database="dw_xl",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Customizing this attributes to retrieve them later in the flow
|
# Customizing this attributes to retrieve them later in the flow
|
||||||
db_connection.raw_user = LOLACONFIG.DW_CREDENTIALS["user"]
|
db_connection.raw_user = mysql_credentials["user"]
|
||||||
db_connection.raw_password = LOLACONFIG.DW_CREDENTIALS["password"]
|
db_connection.raw_password = mysql_credentials["password"]
|
||||||
|
|
||||||
print("Connected to DW.")
|
logger.info(f"Connected to MySQL at {mysql_credentials['host']}:{mysql_credentials['port']}.")
|
||||||
|
|
||||||
return db_connection, tunnel
|
return db_connection, tunnel
|
||||||
|
|
||||||
|
|
||||||
@task(log_stdout=True, trigger=all_finished)
|
@task(trigger=all_finished)
|
||||||
def close_dw_connection(dw_connection):
|
def close_dw_connection(dw_connection: pymysql.Connection) -> None:
|
||||||
import pymysql
|
logger = prefect.context.get("logger")
|
||||||
|
|
||||||
if isinstance(dw_connection, pymysql.Connection):
|
if isinstance(dw_connection, pymysql.Connection):
|
||||||
dw_connection.close()
|
dw_connection.close()
|
||||||
print("DW connection closed successfully.")
|
logger.info("DW connection closed successfully.")
|
||||||
return
|
return
|
||||||
print("No connection received.")
|
logger.info("No connection received.")
|
||||||
|
|
||||||
|
|
||||||
@task(log_stdout=True, trigger=all_finished)
|
@task(log_stdout=True, trigger=all_finished)
|
||||||
def close_ssh_tunnel(ssh_tunnel):
|
def close_ssh_tunnel(ssh_tunnel: SSHTunnelForwarder) -> None:
|
||||||
from sshtunnel import SSHTunnelForwarder
|
logger = prefect.context.get("logger")
|
||||||
|
|
||||||
if isinstance(ssh_tunnel, SSHTunnelForwarder):
|
if isinstance(ssh_tunnel, SSHTunnelForwarder):
|
||||||
ssh_tunnel.stop()
|
ssh_tunnel.stop()
|
||||||
print("SSH tunnel closed successfully.")
|
logger.info("SSH tunnel closed successfully.")
|
||||||
return
|
return
|
||||||
print("No connection received.")
|
logger.info("No connection received.")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue