Post code-cleanup

This commit is contained in:
lwark
2025-01-21 15:18:37 -06:00
parent bf711369c6
commit bc4af61f5f
26 changed files with 546 additions and 377 deletions

View File

@@ -10,6 +10,8 @@ from json import JSONDecodeError
import logging, re, yaml, sys, os, stat, platform, getpass, json, numpy as np, pandas as pd
from threading import Thread
from inspect import getmembers, isfunction, stack
from types import GeneratorType
from dateutil.easter import easter
from jinja2 import Environment, FileSystemLoader
from logging import handlers
@@ -18,7 +20,7 @@ from sqlalchemy.orm import Session
from sqlalchemy import create_engine, text, MetaData
from pydantic import field_validator, BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Any, Tuple, Literal, List
from typing import Any, Tuple, Literal, List, Generator
from __init__ import project_path
from configparser import ConfigParser
from tkinter import Tk # NOTE: This is for choosing database path before app is created.
@@ -38,7 +40,7 @@ if platform.system() == "Windows":
logger.info(f"Got platform Windows, config_dir: {os_config_dir}")
else:
os_config_dir = ".config"
logger.info(f"Got platform other, config_dir: {os_config_dir}")
logger.info(f"Got platform {platform.system()}, config_dir: {os_config_dir}")
main_aux_dir = Path.home().joinpath(f"{os_config_dir}/submissions")
@@ -58,7 +60,7 @@ main_form_style = '''
page_size = 250
def divide_chunks(input_list: list, chunk_count: int):
def divide_chunks(input_list: list, chunk_count: int) -> Generator[Any, Any, None]:
"""
Divides a list into {chunk_count} equal parts
@@ -179,7 +181,7 @@ def check_not_nan(cell_contents) -> bool:
return False
def convert_nans_to_nones(input_str) -> str | None:
def convert_nans_to_nones(input_str:str) -> str | None:
"""
Get rid of various "nan", "NAN", "NaN", etc/
@@ -289,12 +291,10 @@ class Settings(BaseSettings, extra="allow"):
@classmethod
def set_schema(cls, value):
if value is None:
# print("No value for dir path")
if check_if_app():
alembic_path = Path(sys._MEIPASS).joinpath("files", "alembic.ini")
else:
alembic_path = project_path.joinpath("alembic.ini")
# print(f"Getting alembic path: {alembic_path}")
value = cls.get_alembic_db_path(alembic_path=alembic_path, mode='schema')
if value is None:
value = "sqlite"
@@ -321,14 +321,11 @@ class Settings(BaseSettings, extra="allow"):
if value is None:
match values.data['database_schema']:
case "sqlite":
# print("No value for dir path")
if check_if_app():
alembic_path = Path(sys._MEIPASS).joinpath("files", "alembic.ini")
else:
alembic_path = project_path.joinpath("alembic.ini")
# print(f"Getting alembic path: {alembic_path}")
value = cls.get_alembic_db_path(alembic_path=alembic_path, mode='path').parent
# print(f"Using {value}")
case _:
Tk().withdraw() # we don't want a full GUI, so keep the root window from appearing
value = Path(askdirectory(
@@ -340,9 +337,7 @@ class Settings(BaseSettings, extra="allow"):
except AttributeError:
check = False
if not check:
# print(f"No directory found, using Documents/submissions")
value.mkdir(exist_ok=True)
# print(f"Final return of directory_path: {value}")
return value
@field_validator('database_path', mode="before")
@@ -360,7 +355,6 @@ class Settings(BaseSettings, extra="allow"):
alembic_path = Path(sys._MEIPASS).joinpath("files", "alembic.ini")
else:
alembic_path = project_path.joinpath("alembic.ini")
# print(f"Getting alembic path: {alembic_path}")
value = cls.get_alembic_db_path(alembic_path=alembic_path, mode='path').parent
return value
@@ -372,7 +366,6 @@ class Settings(BaseSettings, extra="allow"):
alembic_path = Path(sys._MEIPASS).joinpath("files", "alembic.ini")
else:
alembic_path = project_path.joinpath("alembic.ini")
# print(f"Getting alembic path: {alembic_path}")
value = cls.get_alembic_db_path(alembic_path=alembic_path, mode='path').stem
return value
@@ -384,9 +377,7 @@ class Settings(BaseSettings, extra="allow"):
alembic_path = Path(sys._MEIPASS).joinpath("files", "alembic.ini")
else:
alembic_path = project_path.joinpath("alembic.ini")
# print(f"Getting alembic path: {alembic_path}")
value = cls.get_alembic_db_path(alembic_path=alembic_path, mode='user')
# print(f"Got {value} for user")
return value
@field_validator("database_password", mode='before')
@@ -397,9 +388,7 @@ class Settings(BaseSettings, extra="allow"):
alembic_path = Path(sys._MEIPASS).joinpath("files", "alembic.ini")
else:
alembic_path = project_path.joinpath("alembic.ini")
# print(f"Getting alembic path: {alembic_path}")
value = cls.get_alembic_db_path(alembic_path=alembic_path, mode='pass')
# print(f"Got {value} for pass")
return value
@field_validator('database_session', mode="before")
@@ -421,7 +410,6 @@ class Settings(BaseSettings, extra="allow"):
"{{ values['database_schema'] }}://{{ value }}/{{ db_name }}?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes&Trusted_Connection=yes"
)
case _:
# print(pprint.pprint(values.data))
tmp = jinja_template_loading().from_string(
"{% if values['database_user'] %}{{ values['database_user'] }}{% if values['database_password'] %}:{{ values['database_password'] }}{% endif %}{% endif %}@{{ values['database_path'] }}")
value = tmp.render(values=values.data)
@@ -444,7 +432,6 @@ class Settings(BaseSettings, extra="allow"):
super().__init__(*args, **kwargs)
self.set_from_db()
self.set_scripts()
# pprint(f"User settings:\n{self.__dict__}")
def set_from_db(self):
if 'pytest' in sys.modules:
@@ -453,11 +440,8 @@ class Settings(BaseSettings, extra="allow"):
teardown_scripts=dict(goodbye=None)
)
else:
# print(f"Hello from database settings getter.")
# print(self.__dict__)
session = self.database_session
metadata = MetaData()
# print(self.database_session.get_bind())
try:
metadata.reflect(bind=session.get_bind())
except AttributeError as e:
@@ -467,7 +451,6 @@ class Settings(BaseSettings, extra="allow"):
print(f"Couldn't find _configitems in {metadata.tables.keys()}.")
return
config_items = session.execute(text("SELECT * FROM _configitem")).all()
# print(f"Config: {pprint.pprint(config_items)}")
output = {}
for item in config_items:
try:
@@ -488,6 +471,7 @@ class Settings(BaseSettings, extra="allow"):
p = Path(__file__).parents[2].joinpath("scripts").absolute()
if p.__str__() not in sys.path:
sys.path.append(p.__str__())
# NOTE: Get all .py files that don't have __ in them.
modules = p.glob("[!__]*.py")
for module in modules:
mod = importlib.import_module(module.stem)
@@ -495,6 +479,7 @@ class Settings(BaseSettings, extra="allow"):
name = function[0]
func = function[1]
# NOTE: assign function based on its name being in config: startup/teardown
# NOTE: scripts must be registered using {name: Null} in the database
if name in self.startup_scripts.keys():
self.startup_scripts[name] = func
if name in self.teardown_scripts.keys():
@@ -543,14 +528,12 @@ class Settings(BaseSettings, extra="allow"):
try:
return url[:url.index("@")].split(":")[0]
except (IndexError, ValueError) as e:
# print(f"Error on user: {e}")
return None
case "pass":
url = re.sub(r"^.*//", "", url)
try:
return url[:url.index("@")].split(":")[1]
except (IndexError, ValueError) as e:
# print(f"Error on user: {e}")
return None
def save(self, settings_path: Path):
@@ -592,7 +575,6 @@ def get_config(settings_path: Path | str | None = None) -> Settings:
def join(loader, node):
seq = loader.construct_sequence(node)
return ''.join([str(i) for i in seq])
# NOTE: register the tag handler
yaml.add_constructor('!join', join)
# NOTE: make directories
@@ -624,7 +606,6 @@ def get_config(settings_path: Path | str | None = None) -> Settings:
# NOTE: copy settings to config directory
settings = Settings(**default_settings)
settings.save(settings_path=CONFIGDIR.joinpath("config.yml"))
# print(f"Default settings: {pprint.pprint(settings.__dict__)}")
return settings
else:
# NOTE: check if user defined path is directory
@@ -829,10 +810,23 @@ def setup_lookup(func):
elif v is not None:
sanitized_kwargs[k] = v
return func(*args, **sanitized_kwargs)
return wrapper
def get_application_from_parent(widget):
try:
return widget.app
except AttributeError:
logger.info("Using recursion to get application object.")
from frontend.widgets.app import App
while not isinstance(widget, App):
try:
widget = widget.parent()
except AttributeError:
return widget
return widget
class Result(BaseModel, arbitrary_types_allowed=True):
owner: str = Field(default="", validate_default=True)
code: int = Field(default=0)
@@ -937,20 +931,20 @@ def rreplace(s: str, old: str, new: str) -> str:
return (s[::-1].replace(old[::-1], new[::-1], 1))[::-1]
def remove_key_from_list_of_dicts(input: list, key: str) -> list:
def remove_key_from_list_of_dicts(input_list: list, key: str) -> list:
"""
Removes a key from all dictionaries in a list of dictionaries
Args:
input (list): Input list of dicts
input_list (list): Input list of dicts
key (str): Name of key to remove.
Returns:
list: List of updated dictionaries
"""
for item in input:
for item in input_list:
del item[key]
return input
return input_list
def yaml_regex_creator(loader, node):
@@ -963,6 +957,7 @@ def yaml_regex_creator(loader, node):
def super_splitter(ins_str: str, substring: str, idx: int) -> str:
"""
Splits string on substring at index
Args:
ins_str (str): input string
@@ -978,6 +973,20 @@ def super_splitter(ins_str: str, substring: str, idx: int) -> str:
return ins_str
def is_developer() -> bool:
"""
Checks if user is in list of super users
Returns:
bool: True if yes, False if no.
"""
try:
check = getpass.getuser() in ctx.super_users
except:
check = False
return check
def is_power_user() -> bool:
"""
Checks if user is in list of power users
@@ -1000,21 +1009,49 @@ def check_authorization(func):
func (function): Function to be used.
"""
@wraps(func)
@report_result
def wrapper(*args, **kwargs):
logger.info(f"Checking authorization")
if is_power_user():
error_msg = f"User {getpass.getuser()} is not authorized for this function."
auth_func = is_power_user
if auth_func():
return func(*args, **kwargs)
else:
logger.error(f"User {getpass.getuser()} is not authorized for this function.")
logger.error(error_msg)
report = Report()
report.add_result(
Result(owner=func.__str__(), code=1, msg="This user does not have permission for this function.",
status="warning"))
Result(owner=func.__str__(), code=1, msg=error_msg, status="warning"))
return report
return wrapper
def under_development(func):
"""
Decorator to check if user is authorized to access function
Args:
func (function): Function to be used.
"""
@wraps(func)
@report_result
def wrapper(*args, **kwargs):
logger.warning(f"This feature is under development")
if is_developer():
return func(*args, **kwargs)
else:
error_msg = f"User {getpass.getuser()} is not authorized for this function."
logger.error(error_msg)
report = Report()
report.add_result(
Result(owner=func.__str__(), code=1, msg=error_msg,
status="warning"))
return report
return wrapper
def report_result(func):
"""
Decorator to display any reports returned from a function.
@@ -1036,14 +1073,9 @@ def report_result(func):
case Report():
report = output
case tuple():
# try:
report = next((item for item in output if isinstance(item, Report)), None)
# except IndexError:
# report = None
case _:
report = Report()
# return report
# logger.info(f"Got report: {report}")
try:
results = report.results
except AttributeError:
@@ -1058,13 +1090,11 @@ def report_result(func):
logger.error(result.msg)
if output:
true_output = tuple(item for item in output if not isinstance(item, Report))
# logger.debug(f"True output: {true_output}")
if len(true_output) == 1:
true_output = true_output[0]
else:
true_output = None
return true_output
return wrapper
@@ -1084,20 +1114,19 @@ def create_holidays_for_year(year: int | None = None) -> List[date]:
offset = -d.weekday() # weekday == 0 means Monday
output = d + timedelta(offset)
return output.date()
if not year:
year = date.today().year
# Includes New Year's day for next year.
# NOTE: Includes New Year's day for next year.
holidays = [date(year, 1, 1), date(year, 7, 1), date(year, 9, 30),
date(year, 11, 11), date(year, 12, 25), date(year, 12, 26),
date(year + 1, 1, 1)]
# Labour Day
# NOTE: Labour Day
holidays.append(find_nth_monday(year, 9))
# Thanksgiving
# NOTE: Thanksgiving
holidays.append(find_nth_monday(year, 10, occurence=2))
# Victoria Day
# NOTE: Victoria Day
holidays.append(find_nth_monday(year, 5, day=25))
# Easter, etc
# NOTE: Easter, etc
holidays.append(easter(year) - timedelta(days=2))
holidays.append(easter(year) + timedelta(days=1))
return sorted(holidays)
@@ -1107,8 +1136,7 @@ class classproperty(property):
def __get__(self, owner_self, owner_cls):
return self.fget(owner_cls)
# NOTE: Monkey patching... hooray!
builtins.classproperty = classproperty
ctx = get_config(None)