Simple copy from project functions
This commit is contained in:
parent
fae5888c52
commit
9feda795e9
1 changed files with 85 additions and 0 deletions
|
|
@ -5,6 +5,8 @@ from prefect import task
|
|||
from prefect.triggers import all_finished
|
||||
from trino.auth import BasicAuthentication
|
||||
import trino
|
||||
import pymysql
|
||||
import sshtunnel
|
||||
|
||||
from lolafect.defaults import DEFAULT_TRINO_HTTP_SCHEME
|
||||
|
||||
|
|
@ -57,3 +59,86 @@ def close_trino_connection(trino_connection: trino.dbapi.Connection) -> None:
|
|||
logger.info("Trino connection closed successfully.")
|
||||
return
|
||||
logger.info("No connection received.")
|
||||
|
||||
|
||||
@task(log_stdout=True, nout=2)
|
||||
def connect_to_dw(use_ssh_tunnel):
|
||||
print("Connecting to DW")
|
||||
import pymysql
|
||||
|
||||
mysql_host = LOLACONFIG.DW_CREDENTIALS["host"]
|
||||
tunnel = None
|
||||
if use_ssh_tunnel:
|
||||
print("Going to open an SSH tunnel.")
|
||||
from sshtunnel import SSHTunnelForwarder
|
||||
|
||||
temp_file_path = "temp"
|
||||
try:
|
||||
boto3.client("s3").download_file(
|
||||
LOLACONFIG.S3_BUCKET_NAME,
|
||||
LOLACONFIG.SSH_TUNNEL_CREDENTIALS["path_to_ssh_pkey"],
|
||||
temp_file_path,
|
||||
)
|
||||
tunnel = SSHTunnelForwarder(
|
||||
ssh_host=(
|
||||
LOLACONFIG.SSH_TUNNEL_CREDENTIALS["ssh_jumphost"],
|
||||
LOLACONFIG.SSH_TUNNEL_CREDENTIALS["ssh_port"],
|
||||
),
|
||||
ssh_username=LOLACONFIG.SSH_TUNNEL_CREDENTIALS["ssh_username"],
|
||||
ssh_pkey=temp_file_path,
|
||||
remote_bind_address=(
|
||||
LOLACONFIG.DW_CREDENTIALS["host"],
|
||||
LOLACONFIG.DW_CREDENTIALS["port"],
|
||||
),
|
||||
local_bind_address=("127.0.0.1", LOLACONFIG.DW_CREDENTIALS["port"]),
|
||||
ssh_private_key_password=LOLACONFIG.SSH_TUNNEL_CREDENTIALS[
|
||||
"ssh_pkey_password"
|
||||
],
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
# No matter what happens above, we always must delete the temp copy of the key
|
||||
os.remove(temp_file_path)
|
||||
|
||||
tunnel.start()
|
||||
print("SSH tunnel is now open.")
|
||||
mysql_host = "127.0.0.1"
|
||||
|
||||
db_connection = pymysql.connect(
|
||||
host=mysql_host,
|
||||
port=LOLACONFIG.DW_CREDENTIALS["port"],
|
||||
user=LOLACONFIG.DW_CREDENTIALS["user"],
|
||||
password=LOLACONFIG.DW_CREDENTIALS["password"],
|
||||
database="dw_xl",
|
||||
)
|
||||
|
||||
# Customizing this attributes to retrieve them later in the flow
|
||||
db_connection.raw_user = LOLACONFIG.DW_CREDENTIALS["user"]
|
||||
db_connection.raw_password = LOLACONFIG.DW_CREDENTIALS["password"]
|
||||
|
||||
print("Connected to DW.")
|
||||
|
||||
return db_connection, tunnel
|
||||
|
||||
|
||||
@task(log_stdout=True, trigger=all_finished)
|
||||
def close_dw_connection(dw_connection):
|
||||
import pymysql
|
||||
|
||||
if isinstance(dw_connection, pymysql.Connection):
|
||||
dw_connection.close()
|
||||
print("DW connection closed successfully.")
|
||||
return
|
||||
print("No connection received.")
|
||||
|
||||
|
||||
@task(log_stdout=True, trigger=all_finished)
|
||||
def close_ssh_tunnel(ssh_tunnel):
|
||||
from sshtunnel import SSHTunnelForwarder
|
||||
|
||||
if isinstance(ssh_tunnel, SSHTunnelForwarder):
|
||||
ssh_tunnel.stop()
|
||||
print("SSH tunnel closed successfully.")
|
||||
return
|
||||
print("No connection received.")
|
||||
Loading…
Add table
Add a link
Reference in a new issue