tests passing
This commit is contained in:
parent
de8828a9cb
commit
515a4527af
4 changed files with 79 additions and 69 deletions
|
|
@ -78,50 +78,50 @@ def test_get_rates_dry_run_always_returns_42_as_rates():
|
|||
|
||||
runner = CliRunner()
|
||||
|
||||
with runner.isolated_filesystem():
|
||||
run_result = runner.invoke(
|
||||
get_rates,
|
||||
[
|
||||
"--start-date",
|
||||
some_random_date.strftime("%Y-%m-%d"),
|
||||
"--end-date",
|
||||
(some_random_date + datetime.timedelta(days=3)).strftime("%Y-%m-%d"),
|
||||
"--currencies",
|
||||
",".join(some_random_currencies),
|
||||
"--output",
|
||||
"test_output.csv",
|
||||
],
|
||||
)
|
||||
# with runner.isolated_filesystem():
|
||||
run_result = runner.invoke(
|
||||
get_rates,
|
||||
[
|
||||
"--start-date",
|
||||
some_random_date.strftime("%Y-%m-%d"),
|
||||
"--end-date",
|
||||
(some_random_date + datetime.timedelta(days=3)).strftime("%Y-%m-%d"),
|
||||
"--currencies",
|
||||
",".join(some_random_currencies),
|
||||
"--output",
|
||||
"test_output.csv",
|
||||
],
|
||||
)
|
||||
|
||||
assert run_result.exit_code == 0
|
||||
assert run_result.exit_code == 0
|
||||
|
||||
with open("test_output.csv", newline="") as csv_file:
|
||||
reader = csv.DictReader(csv_file)
|
||||
rows = list(reader)
|
||||
with open("test_output.csv", newline="") as csv_file:
|
||||
reader = csv.DictReader(csv_file)
|
||||
rows = list(reader)
|
||||
|
||||
# Ensure that the output contains the correct number of rows
|
||||
expected_num_rows = 36
|
||||
assert (
|
||||
len(rows) == expected_num_rows
|
||||
), f"Expected {expected_num_rows} rows, but got {len(rows)}"
|
||||
|
||||
# Check that all rows have the expected rate of 42, 1/42 or 1 and the correct dates
|
||||
for row in rows:
|
||||
assert row["rate"] in (
|
||||
"42",
|
||||
"0.024",
|
||||
"0.02",
|
||||
"0",
|
||||
"1",
|
||||
), f"Expected rate to be 42, 1/42 or 1, but got {row['rate']}"
|
||||
assert row["rate_date"] in [
|
||||
(some_random_date + datetime.timedelta(days=i)).strftime("%Y-%m-%d")
|
||||
for i in range(4)
|
||||
], f"Unexpected rate_date {row['rate_date']}"
|
||||
|
||||
# Ensure that the output contains the correct number of rows
|
||||
expected_num_rows = 36
|
||||
assert (
|
||||
len(rows) == expected_num_rows
|
||||
), f"Expected {expected_num_rows} rows, but got {len(rows)}"
|
||||
|
||||
# Check that all rows have the expected rate of 42, 1/42 or 1 and the correct dates
|
||||
for row in rows:
|
||||
assert row["rate"] in (
|
||||
"42",
|
||||
"0.024",
|
||||
"0.02",
|
||||
"0",
|
||||
"1",
|
||||
), f"Expected rate to be 42, 1/42 or 1, but got {row['rate']}"
|
||||
assert row["rate_date"] in [
|
||||
(some_random_date + datetime.timedelta(days=i)).strftime("%Y-%m-%d")
|
||||
for i in range(4)
|
||||
], f"Unexpected rate_date {row['rate_date']}"
|
||||
|
||||
assert (
|
||||
row["from_currency"] in some_random_currencies
|
||||
), f"Unexpected from_currency {row['from_currency']}"
|
||||
assert (
|
||||
row["to_currency"] in some_random_currencies
|
||||
), f"Unexpected to_currency {row['to_currency']}"
|
||||
row["from_currency"] in some_random_currencies
|
||||
), f"Unexpected from_currency {row['from_currency']}"
|
||||
assert (
|
||||
row["to_currency"] in some_random_currencies
|
||||
), f"Unexpected to_currency {row['to_currency']}"
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import List
|
|||
from money.currency import Currency
|
||||
from xecd_rates_client import XecdClient
|
||||
|
||||
from xexe.constants import RATES_SOURCES
|
||||
from xexe.exchange_rates import ExchangeRates, add_equal_rates, add_inverse_rates
|
||||
from xexe.rate_fetching import (
|
||||
MockRateFetcher,
|
||||
|
|
@ -13,7 +14,7 @@ from xexe.rate_fetching import (
|
|||
XERateFetcher,
|
||||
build_rate_fetcher,
|
||||
)
|
||||
from xexe.rate_writing import CSVRateWriter, RateWriter
|
||||
from xexe.rate_writing import CSVRateWriter, RateWriter, build_rate_writer
|
||||
from xexe.utils import DateRange, generate_currency_and_dates_combinations
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
|
@ -67,27 +68,32 @@ def run_get_rates(
|
|||
) -> None:
|
||||
logger.info("Getting rates")
|
||||
|
||||
process_state = GetRatesProcessState(output=output, ignore_warnings=ignore_warnings)
|
||||
process_state = GetRatesProcessState(ignore_warnings=ignore_warnings)
|
||||
|
||||
rates = obtain_rates_from_source(
|
||||
process_state,
|
||||
rates_source=rates_source,
|
||||
date_range=date_range,
|
||||
currencies=currencies,
|
||||
ignore_warnings=ignore_warnings,
|
||||
)
|
||||
logger.info("Rates obtained.")
|
||||
|
||||
if dry_run:
|
||||
logger.info("Dry run mode active. Not writing rates to output.")
|
||||
return
|
||||
write_rates_to_output(process_state, rates)
|
||||
write_rates_to_output(rates, output)
|
||||
logger.info("Rates written to output.")
|
||||
|
||||
|
||||
def obtain_rates_from_source(
|
||||
process_state, rates_source: str, date_range: DateRange, currencies: List[Currency]
|
||||
rates_source: str,
|
||||
date_range: DateRange,
|
||||
currencies: List[Currency],
|
||||
ignore_warnings: bool,
|
||||
) -> ExchangeRates:
|
||||
rates_fetcher = build_rate_fetcher(rates_source)
|
||||
rates_fetcher = build_rate_fetcher(
|
||||
rates_source=rates_source, rate_sources_mapping=RATES_SOURCES
|
||||
)
|
||||
|
||||
currency_and_date_combinations = generate_currency_and_dates_combinations(
|
||||
date_range=date_range, currencies=currencies
|
||||
|
|
@ -96,7 +102,7 @@ def obtain_rates_from_source(
|
|||
large_api_call_planned = (
|
||||
rates_fetcher.is_production_grade and len(currency_and_date_combinations) > 100
|
||||
)
|
||||
if large_api_call_planned and not process_state.ignore_warnings:
|
||||
if large_api_call_planned and not ignore_warnings:
|
||||
user_confirmation_string = "i understand"
|
||||
user_response = input(
|
||||
f"WARNING: you are about to execute a large call {len(currency_and_date_combinations)} to a metered API. Type '{user_confirmation_string}' to move forward: "
|
||||
|
|
@ -130,25 +136,12 @@ def obtain_rates_from_source(
|
|||
return rates
|
||||
|
||||
|
||||
def write_rates_to_output(process_state, rates):
|
||||
rates_writer = process_state.get_writer()
|
||||
def write_rates_to_output(rates, output) -> None:
|
||||
rates_writer = build_rate_writer(output)
|
||||
logger.info("Attempting writing rates to output.")
|
||||
rates_writer.write_rates(rates)
|
||||
|
||||
|
||||
class GetRatesProcessState:
|
||||
def __init__(self, output: str, ignore_warnings: bool) -> None:
|
||||
self.writer = self._select_writer(output)
|
||||
def __init__(self, ignore_warnings: bool) -> None:
|
||||
self.ignore_warnings = ignore_warnings
|
||||
|
||||
@staticmethod
|
||||
def _select_writer(output: str) -> CSVRateWriter:
|
||||
output_is_csv_file_path = bool(pathlib.Path(output).suffix == ".csv")
|
||||
|
||||
if output_is_csv_file_path:
|
||||
return CSVRateWriter(output_file_path=output)
|
||||
|
||||
raise ValueError(f"Don't know how to handle passed output: {output}")
|
||||
|
||||
def get_writer(self) -> RateWriter:
|
||||
return self.writer
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from money.currency import Currency, CurrencyHelper
|
|||
from money.money import Money
|
||||
from xecd_rates_client import XecdClient
|
||||
|
||||
from xexe.constants import RATES_SOURCES
|
||||
from xexe.exchange_rates import ExchangeRate
|
||||
|
||||
|
||||
|
|
@ -82,5 +81,5 @@ class XERateFetcher(RateFetcher):
|
|||
)
|
||||
|
||||
|
||||
def build_rate_fetcher(rates_source: str):
|
||||
return RATES_SOURCES[rates_source]()
|
||||
def build_rate_fetcher(rates_source: str, rate_sources_mapping):
|
||||
return rate_sources_mapping[rates_source]()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import csv
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
from abc import ABC, abstractmethod
|
||||
|
|
@ -10,6 +11,8 @@ from psycopg2 import sql
|
|||
from xexe.constants import DWH_SCHEMA, DWH_TABLE
|
||||
from xexe.exchange_rates import ExchangeRates
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class RateWriter(ABC):
|
||||
@abstractmethod
|
||||
|
|
@ -130,3 +133,18 @@ class DWHRateWriter(RateWriter):
|
|||
finally:
|
||||
cursor.close()
|
||||
self.connection.autocommit = True # Reset autocommit to its default state
|
||||
|
||||
|
||||
def build_rate_writer(output: str) -> RateWriter:
|
||||
output_is_csv_file_path = bool(pathlib.Path(output).suffix == ".csv")
|
||||
|
||||
if output_is_csv_file_path:
|
||||
logger.info("Creating CSV Rate Writer.")
|
||||
return CSVRateWriter(output_file_path=output)
|
||||
|
||||
output_is_dwh = bool(output == "dwh")
|
||||
if output_is_dwh:
|
||||
logger.info("Creating DWH Rate Writer.")
|
||||
return DWHRateWriter()
|
||||
|
||||
raise ValueError(f"Don't know how to handle passed output: {output}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue