tests passing
This commit is contained in:
parent
de8828a9cb
commit
515a4527af
4 changed files with 79 additions and 69 deletions
|
|
@ -78,50 +78,50 @@ def test_get_rates_dry_run_always_returns_42_as_rates():
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
|
|
||||||
with runner.isolated_filesystem():
|
# with runner.isolated_filesystem():
|
||||||
run_result = runner.invoke(
|
run_result = runner.invoke(
|
||||||
get_rates,
|
get_rates,
|
||||||
[
|
[
|
||||||
"--start-date",
|
"--start-date",
|
||||||
some_random_date.strftime("%Y-%m-%d"),
|
some_random_date.strftime("%Y-%m-%d"),
|
||||||
"--end-date",
|
"--end-date",
|
||||||
(some_random_date + datetime.timedelta(days=3)).strftime("%Y-%m-%d"),
|
(some_random_date + datetime.timedelta(days=3)).strftime("%Y-%m-%d"),
|
||||||
"--currencies",
|
"--currencies",
|
||||||
",".join(some_random_currencies),
|
",".join(some_random_currencies),
|
||||||
"--output",
|
"--output",
|
||||||
"test_output.csv",
|
"test_output.csv",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert run_result.exit_code == 0
|
assert run_result.exit_code == 0
|
||||||
|
|
||||||
with open("test_output.csv", newline="") as csv_file:
|
with open("test_output.csv", newline="") as csv_file:
|
||||||
reader = csv.DictReader(csv_file)
|
reader = csv.DictReader(csv_file)
|
||||||
rows = list(reader)
|
rows = list(reader)
|
||||||
|
|
||||||
|
# Ensure that the output contains the correct number of rows
|
||||||
|
expected_num_rows = 36
|
||||||
|
assert (
|
||||||
|
len(rows) == expected_num_rows
|
||||||
|
), f"Expected {expected_num_rows} rows, but got {len(rows)}"
|
||||||
|
|
||||||
|
# Check that all rows have the expected rate of 42, 1/42 or 1 and the correct dates
|
||||||
|
for row in rows:
|
||||||
|
assert row["rate"] in (
|
||||||
|
"42",
|
||||||
|
"0.024",
|
||||||
|
"0.02",
|
||||||
|
"0",
|
||||||
|
"1",
|
||||||
|
), f"Expected rate to be 42, 1/42 or 1, but got {row['rate']}"
|
||||||
|
assert row["rate_date"] in [
|
||||||
|
(some_random_date + datetime.timedelta(days=i)).strftime("%Y-%m-%d")
|
||||||
|
for i in range(4)
|
||||||
|
], f"Unexpected rate_date {row['rate_date']}"
|
||||||
|
|
||||||
# Ensure that the output contains the correct number of rows
|
|
||||||
expected_num_rows = 36
|
|
||||||
assert (
|
assert (
|
||||||
len(rows) == expected_num_rows
|
row["from_currency"] in some_random_currencies
|
||||||
), f"Expected {expected_num_rows} rows, but got {len(rows)}"
|
), f"Unexpected from_currency {row['from_currency']}"
|
||||||
|
assert (
|
||||||
# Check that all rows have the expected rate of 42, 1/42 or 1 and the correct dates
|
row["to_currency"] in some_random_currencies
|
||||||
for row in rows:
|
), f"Unexpected to_currency {row['to_currency']}"
|
||||||
assert row["rate"] in (
|
|
||||||
"42",
|
|
||||||
"0.024",
|
|
||||||
"0.02",
|
|
||||||
"0",
|
|
||||||
"1",
|
|
||||||
), f"Expected rate to be 42, 1/42 or 1, but got {row['rate']}"
|
|
||||||
assert row["rate_date"] in [
|
|
||||||
(some_random_date + datetime.timedelta(days=i)).strftime("%Y-%m-%d")
|
|
||||||
for i in range(4)
|
|
||||||
], f"Unexpected rate_date {row['rate_date']}"
|
|
||||||
|
|
||||||
assert (
|
|
||||||
row["from_currency"] in some_random_currencies
|
|
||||||
), f"Unexpected from_currency {row['from_currency']}"
|
|
||||||
assert (
|
|
||||||
row["to_currency"] in some_random_currencies
|
|
||||||
), f"Unexpected to_currency {row['to_currency']}"
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from typing import List
|
||||||
from money.currency import Currency
|
from money.currency import Currency
|
||||||
from xecd_rates_client import XecdClient
|
from xecd_rates_client import XecdClient
|
||||||
|
|
||||||
|
from xexe.constants import RATES_SOURCES
|
||||||
from xexe.exchange_rates import ExchangeRates, add_equal_rates, add_inverse_rates
|
from xexe.exchange_rates import ExchangeRates, add_equal_rates, add_inverse_rates
|
||||||
from xexe.rate_fetching import (
|
from xexe.rate_fetching import (
|
||||||
MockRateFetcher,
|
MockRateFetcher,
|
||||||
|
|
@ -13,7 +14,7 @@ from xexe.rate_fetching import (
|
||||||
XERateFetcher,
|
XERateFetcher,
|
||||||
build_rate_fetcher,
|
build_rate_fetcher,
|
||||||
)
|
)
|
||||||
from xexe.rate_writing import CSVRateWriter, RateWriter
|
from xexe.rate_writing import CSVRateWriter, RateWriter, build_rate_writer
|
||||||
from xexe.utils import DateRange, generate_currency_and_dates_combinations
|
from xexe.utils import DateRange, generate_currency_and_dates_combinations
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
@ -67,27 +68,32 @@ def run_get_rates(
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.info("Getting rates")
|
logger.info("Getting rates")
|
||||||
|
|
||||||
process_state = GetRatesProcessState(output=output, ignore_warnings=ignore_warnings)
|
process_state = GetRatesProcessState(ignore_warnings=ignore_warnings)
|
||||||
|
|
||||||
rates = obtain_rates_from_source(
|
rates = obtain_rates_from_source(
|
||||||
process_state,
|
|
||||||
rates_source=rates_source,
|
rates_source=rates_source,
|
||||||
date_range=date_range,
|
date_range=date_range,
|
||||||
currencies=currencies,
|
currencies=currencies,
|
||||||
|
ignore_warnings=ignore_warnings,
|
||||||
)
|
)
|
||||||
logger.info("Rates obtained.")
|
logger.info("Rates obtained.")
|
||||||
|
|
||||||
if dry_run:
|
if dry_run:
|
||||||
logger.info("Dry run mode active. Not writing rates to output.")
|
logger.info("Dry run mode active. Not writing rates to output.")
|
||||||
return
|
return
|
||||||
write_rates_to_output(process_state, rates)
|
write_rates_to_output(rates, output)
|
||||||
logger.info("Rates written to output.")
|
logger.info("Rates written to output.")
|
||||||
|
|
||||||
|
|
||||||
def obtain_rates_from_source(
|
def obtain_rates_from_source(
|
||||||
process_state, rates_source: str, date_range: DateRange, currencies: List[Currency]
|
rates_source: str,
|
||||||
|
date_range: DateRange,
|
||||||
|
currencies: List[Currency],
|
||||||
|
ignore_warnings: bool,
|
||||||
) -> ExchangeRates:
|
) -> ExchangeRates:
|
||||||
rates_fetcher = build_rate_fetcher(rates_source)
|
rates_fetcher = build_rate_fetcher(
|
||||||
|
rates_source=rates_source, rate_sources_mapping=RATES_SOURCES
|
||||||
|
)
|
||||||
|
|
||||||
currency_and_date_combinations = generate_currency_and_dates_combinations(
|
currency_and_date_combinations = generate_currency_and_dates_combinations(
|
||||||
date_range=date_range, currencies=currencies
|
date_range=date_range, currencies=currencies
|
||||||
|
|
@ -96,7 +102,7 @@ def obtain_rates_from_source(
|
||||||
large_api_call_planned = (
|
large_api_call_planned = (
|
||||||
rates_fetcher.is_production_grade and len(currency_and_date_combinations) > 100
|
rates_fetcher.is_production_grade and len(currency_and_date_combinations) > 100
|
||||||
)
|
)
|
||||||
if large_api_call_planned and not process_state.ignore_warnings:
|
if large_api_call_planned and not ignore_warnings:
|
||||||
user_confirmation_string = "i understand"
|
user_confirmation_string = "i understand"
|
||||||
user_response = input(
|
user_response = input(
|
||||||
f"WARNING: you are about to execute a large call {len(currency_and_date_combinations)} to a metered API. Type '{user_confirmation_string}' to move forward: "
|
f"WARNING: you are about to execute a large call {len(currency_and_date_combinations)} to a metered API. Type '{user_confirmation_string}' to move forward: "
|
||||||
|
|
@ -130,25 +136,12 @@ def obtain_rates_from_source(
|
||||||
return rates
|
return rates
|
||||||
|
|
||||||
|
|
||||||
def write_rates_to_output(process_state, rates):
|
def write_rates_to_output(rates, output) -> None:
|
||||||
rates_writer = process_state.get_writer()
|
rates_writer = build_rate_writer(output)
|
||||||
logger.info("Attempting writing rates to output.")
|
logger.info("Attempting writing rates to output.")
|
||||||
rates_writer.write_rates(rates)
|
rates_writer.write_rates(rates)
|
||||||
|
|
||||||
|
|
||||||
class GetRatesProcessState:
|
class GetRatesProcessState:
|
||||||
def __init__(self, output: str, ignore_warnings: bool) -> None:
|
def __init__(self, ignore_warnings: bool) -> None:
|
||||||
self.writer = self._select_writer(output)
|
|
||||||
self.ignore_warnings = ignore_warnings
|
self.ignore_warnings = ignore_warnings
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _select_writer(output: str) -> CSVRateWriter:
|
|
||||||
output_is_csv_file_path = bool(pathlib.Path(output).suffix == ".csv")
|
|
||||||
|
|
||||||
if output_is_csv_file_path:
|
|
||||||
return CSVRateWriter(output_file_path=output)
|
|
||||||
|
|
||||||
raise ValueError(f"Don't know how to handle passed output: {output}")
|
|
||||||
|
|
||||||
def get_writer(self) -> RateWriter:
|
|
||||||
return self.writer
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ from money.currency import Currency, CurrencyHelper
|
||||||
from money.money import Money
|
from money.money import Money
|
||||||
from xecd_rates_client import XecdClient
|
from xecd_rates_client import XecdClient
|
||||||
|
|
||||||
from xexe.constants import RATES_SOURCES
|
|
||||||
from xexe.exchange_rates import ExchangeRate
|
from xexe.exchange_rates import ExchangeRate
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -82,5 +81,5 @@ class XERateFetcher(RateFetcher):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_rate_fetcher(rates_source: str):
|
def build_rate_fetcher(rates_source: str, rate_sources_mapping):
|
||||||
return RATES_SOURCES[rates_source]()
|
return rate_sources_mapping[rates_source]()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import csv
|
import csv
|
||||||
import datetime
|
import datetime
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
@ -10,6 +11,8 @@ from psycopg2 import sql
|
||||||
from xexe.constants import DWH_SCHEMA, DWH_TABLE
|
from xexe.constants import DWH_SCHEMA, DWH_TABLE
|
||||||
from xexe.exchange_rates import ExchangeRates
|
from xexe.exchange_rates import ExchangeRates
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
class RateWriter(ABC):
|
class RateWriter(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
@ -130,3 +133,18 @@ class DWHRateWriter(RateWriter):
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
self.connection.autocommit = True # Reset autocommit to its default state
|
self.connection.autocommit = True # Reset autocommit to its default state
|
||||||
|
|
||||||
|
|
||||||
|
def build_rate_writer(output: str) -> RateWriter:
|
||||||
|
output_is_csv_file_path = bool(pathlib.Path(output).suffix == ".csv")
|
||||||
|
|
||||||
|
if output_is_csv_file_path:
|
||||||
|
logger.info("Creating CSV Rate Writer.")
|
||||||
|
return CSVRateWriter(output_file_path=output)
|
||||||
|
|
||||||
|
output_is_dwh = bool(output == "dwh")
|
||||||
|
if output_is_dwh:
|
||||||
|
logger.info("Creating DWH Rate Writer.")
|
||||||
|
return DWHRateWriter()
|
||||||
|
|
||||||
|
raise ValueError(f"Don't know how to handle passed output: {output}")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue