data-xexe/xexe/rate_writing.py
2024-06-13 16:26:59 +02:00

212 lines
7.1 KiB
Python

import csv
import datetime
import logging
import os
import pathlib
from abc import ABC, abstractmethod
import psycopg2
from psycopg2 import sql
from xexe.constants import DWH_SCHEMA, DWH_TABLE
from xexe.exchange_rates import ExchangeRates
logger = logging.getLogger()
class RateWriter(ABC):
@abstractmethod
def write_rates(self, rates: ExchangeRates) -> None:
pass
class CSVRateWriter(RateWriter):
def __init__(self, output_file_path: pathlib.Path) -> None:
super().__init__()
self.output_file_path = output_file_path
def write_rates(self, rates: ExchangeRates) -> None:
with open(self.output_file_path, mode="w") as csv_file:
csv_writer = csv.writer(csv_file)
csv_writer.writerow(
["from_currency", "to_currency", "rate", "rate_date", "exported_at"]
)
exported_at = datetime.datetime.now()
for rate in rates._rate_index.values():
csv_writer.writerow(
[
rate.from_currency.value,
rate.to_currency.value,
rate.rate.amount,
rate.rate_date.strftime("%Y-%m-%d"),
exported_at.isoformat(timespec="seconds"),
]
)
class DWHRateWriter(RateWriter):
def __init__(self) -> None:
super().__init__()
self.connection = self._create_connection()
self._verify_prerequisites()
@staticmethod
def _create_connection():
host = os.environ["DWH_HOST"]
port = os.environ["DWH_PORT"]
database = os.environ["DWH_DB"]
user = os.environ["DWH_USER"]
password = os.environ["DWH_PASSWORD"]
connection = psycopg2.connect(
host=host, port=port, database=database, user=user, password=password
)
return connection
def _verify_prerequisites(self):
cursor = self.connection.cursor()
schema = DWH_SCHEMA
table = DWH_TABLE
self.connection.autocommit = False
try:
cursor.execute(
sql.SQL(
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = %s;"
),
[schema],
)
schema_exists = cursor.fetchone() is not None
if not schema_exists:
raise Exception(f"Schema '{schema}' does not exist.")
cursor.execute(
sql.SQL(
"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = %s;"
),
[schema, table],
)
table_exists = cursor.fetchone() is not None
if table_exists:
# Check if we can write to the table
try:
cursor.execute(
sql.SQL("SELECT 1 FROM {}.{} LIMIT 1 FOR UPDATE;").format(
sql.Identifier(schema), sql.Identifier(table)
)
)
except Exception as e:
raise Exception(
f"Cannot write to the existing table '{table}' in schema '{schema}': {e}"
)
else:
# Check if we can create a new table in the schema
try:
cursor.execute(
sql.SQL("CREATE TABLE {}.{} (id SERIAL PRIMARY KEY);").format(
sql.Identifier(schema), sql.Identifier(table)
)
)
cursor.execute(
sql.SQL("DROP TABLE {}.{};").format(
sql.Identifier(schema), sql.Identifier(table)
)
)
except Exception as e:
raise Exception(
f"Cannot create a new table '{table}' in schema '{schema}': {e}"
)
# Roll back the transaction to ensure no changes are persisted
self.connection.rollback()
except Exception as e:
# Roll back any changes if there was an exception
self.connection.rollback()
raise e
finally:
cursor.close()
self.connection.autocommit = True # Reset autocommit to its default state
def _create_rates_table_if_not_exists(self):
cursor = self.connection.cursor()
create_table_query = sql.SQL(
"""
CREATE TABLE IF NOT EXISTS {}.{} (
from_currency CHAR(3) NOT NULL,
to_currency CHAR(3) NOT NULL,
rate DECIMAL(19, 4) NOT NULL,
rate_date_utc DATE NOT NULL,
exported_at_utc TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (from_currency, to_currency, rate_date_utc)
);
"""
).format(sql.Identifier(DWH_SCHEMA), sql.Identifier(DWH_TABLE))
try:
cursor.execute(create_table_query)
self.connection.commit()
except Exception as e:
self.connection.rollback()
raise Exception(
f"Failed to create table '{DWH_TABLE}' in schema '{DWH_SCHEMA}': {e}"
)
finally:
cursor.close()
def write_rates(self, rates: ExchangeRates) -> None:
self._create_rates_table_if_not_exists()
cursor = self.connection.cursor()
insert_query = sql.SQL(
"""
INSERT INTO {}.{} (from_currency, to_currency, rate, rate_date_utc, exported_at_utc)
VALUES (%s, %s, %s, %s, CURRENT_TIMESTAMP)
ON CONFLICT (from_currency, to_currency, rate_date_utc)
DO UPDATE SET rate = EXCLUDED.rate, exported_at_utc = EXCLUDED.exported_at_utc;
"""
).format(sql.Identifier(DWH_SCHEMA), sql.Identifier(DWH_TABLE))
self.connection.autocommit = False
try:
for rate in rates:
cursor.execute(
insert_query,
(
rate.from_currency.value,
rate.to_currency.value,
rate.amount,
rate.rate_date,
),
)
self.connection.commit()
except Exception as e:
self.connection.rollback()
raise Exception(
f"Failed to write rates to table '{DWH_TABLE}' in schema '{DWH_SCHEMA}': {e}"
)
finally:
cursor.close()
self.connection.autocommit = True
def build_rate_writer(output: str) -> RateWriter:
output_is_csv_file_path = bool(pathlib.Path(output).suffix == ".csv")
if output_is_csv_file_path:
logger.info("Creating CSV Rate Writer.")
return CSVRateWriter(output_file_path=output)
output_is_dwh = bool(output == "dwh")
if output_is_dwh:
logger.info("Creating DWH Rate Writer.")
return DWHRateWriter()
raise ValueError(f"Don't know how to handle passed output: {output}")