import csv import datetime import logging import os import pathlib from abc import ABC, abstractmethod import psycopg2 from psycopg2 import sql from xexe.constants import DWH_SCHEMA, DWH_TABLE from xexe.exchange_rates import ExchangeRates logger = logging.getLogger() class RateWriter(ABC): @abstractmethod def write_rates(self, rates: ExchangeRates) -> None: pass class CSVRateWriter(RateWriter): def __init__(self, output_file_path: pathlib.Path) -> None: super().__init__() self.output_file_path = output_file_path def write_rates(self, rates: ExchangeRates) -> None: with open(self.output_file_path, mode="w") as csv_file: csv_writer = csv.writer(csv_file) csv_writer.writerow( ["from_currency", "to_currency", "rate", "rate_date", "exported_at"] ) exported_at = datetime.datetime.now() for rate in rates._rate_index.values(): csv_writer.writerow( [ rate.from_currency.value, rate.to_currency.value, rate.rate.amount, rate.rate_date.strftime("%Y-%m-%d"), exported_at.isoformat(timespec="seconds"), ] ) class DWHRateWriter(RateWriter): def __init__(self) -> None: super().__init__() self.connection = self._create_connection() self._verify_prerequisites() @staticmethod def _create_connection(): host = os.environ["DWH_HOST"] port = os.environ["DWH_PORT"] database = os.environ["DWH_DB"] user = os.environ["DWH_USER"] password = os.environ["DWH_PASSWORD"] connection = psycopg2.connect( host=host, port=port, database=database, user=user, password=password ) return connection def _verify_prerequisites(self): cursor = self.connection.cursor() schema = DWH_SCHEMA table = DWH_TABLE self.connection.autocommit = False try: cursor.execute( sql.SQL( "SELECT schema_name FROM information_schema.schemata WHERE schema_name = %s;" ), [schema], ) schema_exists = cursor.fetchone() is not None if not schema_exists: raise Exception(f"Schema '{schema}' does not exist.") cursor.execute( sql.SQL( "SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = %s;" ), [schema, table], ) table_exists = cursor.fetchone() is not None if table_exists: # Check if we can write to the table try: cursor.execute( sql.SQL("SELECT 1 FROM {}.{} LIMIT 1 FOR UPDATE;").format( sql.Identifier(schema), sql.Identifier(table) ) ) except Exception as e: raise Exception( f"Cannot write to the existing table '{table}' in schema '{schema}': {e}" ) else: # Check if we can create a new table in the schema try: cursor.execute( sql.SQL("CREATE TABLE {}.{} (id SERIAL PRIMARY KEY);").format( sql.Identifier(schema), sql.Identifier(table) ) ) cursor.execute( sql.SQL("DROP TABLE {}.{};").format( sql.Identifier(schema), sql.Identifier(table) ) ) except Exception as e: raise Exception( f"Cannot create a new table '{table}' in schema '{schema}': {e}" ) # Roll back the transaction to ensure no changes are persisted self.connection.rollback() except Exception as e: # Roll back any changes if there was an exception self.connection.rollback() raise e finally: cursor.close() self.connection.autocommit = True # Reset autocommit to its default state def _create_rates_table_if_not_exists(self): cursor = self.connection.cursor() create_table_query = sql.SQL( """ CREATE TABLE IF NOT EXISTS {}.{} ( from_currency CHAR(3) NOT NULL, to_currency CHAR(3) NOT NULL, rate DECIMAL(23, 8) NOT NULL, rate_date_utc DATE NOT NULL, exported_at_utc TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (from_currency, to_currency, rate_date_utc) ); """ ).format(sql.Identifier(DWH_SCHEMA), sql.Identifier(DWH_TABLE)) try: cursor.execute(create_table_query) self.connection.commit() except Exception as e: self.connection.rollback() raise Exception( f"Failed to create table '{DWH_TABLE}' in schema '{DWH_SCHEMA}': {e}" ) finally: cursor.close() def write_rates(self, rates: ExchangeRates) -> None: self._create_rates_table_if_not_exists() cursor = self.connection.cursor() insert_query = sql.SQL( """ INSERT INTO {}.{} (from_currency, to_currency, rate, rate_date_utc, exported_at_utc) VALUES (%s, %s, %s, %s, CURRENT_TIMESTAMP) ON CONFLICT (from_currency, to_currency, rate_date_utc) DO UPDATE SET rate = EXCLUDED.rate, exported_at_utc = EXCLUDED.exported_at_utc; """ ).format(sql.Identifier(DWH_SCHEMA), sql.Identifier(DWH_TABLE)) self.connection.autocommit = False try: for rate in rates: cursor.execute( insert_query, ( rate.from_currency.value, rate.to_currency.value, rate.amount, rate.rate_date, ), ) self.connection.commit() except Exception as e: self.connection.rollback() raise Exception( f"Failed to write rates to table '{DWH_TABLE}' in schema '{DWH_SCHEMA}': {e}" ) finally: cursor.close() self.connection.autocommit = True def build_rate_writer(output: str) -> RateWriter: output_is_csv_file_path = bool(pathlib.Path(output).suffix == ".csv") if output_is_csv_file_path: logger.info("Creating CSV Rate Writer.") return CSVRateWriter(output_file_path=output) output_is_dwh = bool(output == "dwh") if output_is_dwh: logger.info("Creating DWH Rate Writer.") return DWHRateWriter() raise ValueError(f"Don't know how to handle passed output: {output}")