Docstrings, typing, a bit of renaming and minor stuff.

This commit is contained in:
Pablo Martin 2022-07-21 12:24:35 +02:00
parent 347d3a969d
commit 7dba389eb5
3 changed files with 125 additions and 63 deletions

View file

@ -1,17 +1,23 @@
import time
import traceback
from typing import Union, Callable
from typing import Union
import mysql.connector.connection
import trino.dbapi
from trino.dbapi import connect
from trino.auth import BasicAuthentication
import mysql.connector
from connections import get_connection
def main(config: dict) -> None:
def run_measuring_session(config: dict) -> None:
"""
Complete session flow. Connect, test all queries.
:param config: the full config for the measuring session.
:return: None
"""
print("Starting the measuring session.")
connection = get_connection(config)
connection = get_connection(config["connection_details"])
for query_config in config["queries_to_measure"]:
try:
@ -27,67 +33,30 @@ def main(config: dict) -> None:
class TestableQuery:
"""
Simple object to hold the details of a query that will be measured.
"""
def __init__(self, name: str, query_string: str):
self.name = name
self.query_string = query_string
def measure_query_runtime(connection: trino.dbapi.Connection, query: TestableQuery):
def measure_query_runtime(
connection: Union[trino.dbapi.Connection, mysql.connector.MySQLConnection],
query_to_measure: TestableQuery,
) -> None:
"""
Execute a query against the given connection and print the time it took.
:param connection: a connection object capable of generating cursors.
:param query_to_measure: the query that will be measured.
:return: None
"""
start_time = time.time()
cur = connection.cursor()
cur.execute(query.query_string)
cur.execute(query_to_measure.query_string)
rows = cur.fetchall()
print(f"Query '{query.name}' took {int(time.time() - start_time)} seconds to run.")
def get_connection(config: dict) -> Union[trino.dbapi.Connection]:
connection_builder = pick_connection_builder(config["connection_details"]["engine"])
connection = connection_builder(config)
return connection
def get_possible_connection_builders() -> dict:
return {
"trino": get_connection_to_trino,
"mysql": get_connection_to_mysql,
}
def pick_connection_builder(connection_engine_name: str) -> Callable:
possible_connection_builders = get_possible_connection_builders()
try:
connection_builder = possible_connection_builders[connection_engine_name]
except KeyError:
raise ValueError(
f"Connection type {connection_engine_name} is unknown. Please review config."
)
return connection_builder
def get_connection_to_trino(config):
return connect(
host=config["connection_details"]["host"],
port=config["connection_details"]["port"],
user=config["connection_details"]["user"],
auth=BasicAuthentication(
config["connection_details"]["user"],
config["connection_details"]["password"],
),
http_scheme=config["connection_details"]["http_scheme"],
catalog=config["connection_details"]["catalog"],
schema=config["connection_details"]["schema"],
print(
f"Query '{query_to_measure.name}' took {int(time.time() - start_time)} seconds to run."
)
def get_connection_to_mysql(config) -> mysql.connector.connection.MySQLConnection:
connection = mysql.connector.connect(
host=config["connection_details"]["host"],
port=config["connection_details"]["port"],
user=config["connection_details"]["user"],
password=config["connection_details"]["password"],
database=config["connection_details"]["schema"],
)
return connection