Post code-cleanup
This commit is contained in:
@@ -10,6 +10,8 @@ from json import JSONDecodeError
|
||||
import logging, re, yaml, sys, os, stat, platform, getpass, json, numpy as np, pandas as pd
|
||||
from threading import Thread
|
||||
from inspect import getmembers, isfunction, stack
|
||||
from types import GeneratorType
|
||||
|
||||
from dateutil.easter import easter
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from logging import handlers
|
||||
@@ -18,7 +20,7 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy import create_engine, text, MetaData
|
||||
from pydantic import field_validator, BaseModel, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from typing import Any, Tuple, Literal, List
|
||||
from typing import Any, Tuple, Literal, List, Generator
|
||||
from __init__ import project_path
|
||||
from configparser import ConfigParser
|
||||
from tkinter import Tk # NOTE: This is for choosing database path before app is created.
|
||||
@@ -38,7 +40,7 @@ if platform.system() == "Windows":
|
||||
logger.info(f"Got platform Windows, config_dir: {os_config_dir}")
|
||||
else:
|
||||
os_config_dir = ".config"
|
||||
logger.info(f"Got platform other, config_dir: {os_config_dir}")
|
||||
logger.info(f"Got platform {platform.system()}, config_dir: {os_config_dir}")
|
||||
|
||||
main_aux_dir = Path.home().joinpath(f"{os_config_dir}/submissions")
|
||||
|
||||
@@ -58,7 +60,7 @@ main_form_style = '''
|
||||
page_size = 250
|
||||
|
||||
|
||||
def divide_chunks(input_list: list, chunk_count: int):
|
||||
def divide_chunks(input_list: list, chunk_count: int) -> Generator[Any, Any, None]:
|
||||
"""
|
||||
Divides a list into {chunk_count} equal parts
|
||||
|
||||
@@ -179,7 +181,7 @@ def check_not_nan(cell_contents) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def convert_nans_to_nones(input_str) -> str | None:
|
||||
def convert_nans_to_nones(input_str:str) -> str | None:
|
||||
"""
|
||||
Get rid of various "nan", "NAN", "NaN", etc/
|
||||
|
||||
@@ -289,12 +291,10 @@ class Settings(BaseSettings, extra="allow"):
|
||||
@classmethod
|
||||
def set_schema(cls, value):
|
||||
if value is None:
|
||||
# print("No value for dir path")
|
||||
if check_if_app():
|
||||
alembic_path = Path(sys._MEIPASS).joinpath("files", "alembic.ini")
|
||||
else:
|
||||
alembic_path = project_path.joinpath("alembic.ini")
|
||||
# print(f"Getting alembic path: {alembic_path}")
|
||||
value = cls.get_alembic_db_path(alembic_path=alembic_path, mode='schema')
|
||||
if value is None:
|
||||
value = "sqlite"
|
||||
@@ -321,14 +321,11 @@ class Settings(BaseSettings, extra="allow"):
|
||||
if value is None:
|
||||
match values.data['database_schema']:
|
||||
case "sqlite":
|
||||
# print("No value for dir path")
|
||||
if check_if_app():
|
||||
alembic_path = Path(sys._MEIPASS).joinpath("files", "alembic.ini")
|
||||
else:
|
||||
alembic_path = project_path.joinpath("alembic.ini")
|
||||
# print(f"Getting alembic path: {alembic_path}")
|
||||
value = cls.get_alembic_db_path(alembic_path=alembic_path, mode='path').parent
|
||||
# print(f"Using {value}")
|
||||
case _:
|
||||
Tk().withdraw() # we don't want a full GUI, so keep the root window from appearing
|
||||
value = Path(askdirectory(
|
||||
@@ -340,9 +337,7 @@ class Settings(BaseSettings, extra="allow"):
|
||||
except AttributeError:
|
||||
check = False
|
||||
if not check:
|
||||
# print(f"No directory found, using Documents/submissions")
|
||||
value.mkdir(exist_ok=True)
|
||||
# print(f"Final return of directory_path: {value}")
|
||||
return value
|
||||
|
||||
@field_validator('database_path', mode="before")
|
||||
@@ -360,7 +355,6 @@ class Settings(BaseSettings, extra="allow"):
|
||||
alembic_path = Path(sys._MEIPASS).joinpath("files", "alembic.ini")
|
||||
else:
|
||||
alembic_path = project_path.joinpath("alembic.ini")
|
||||
# print(f"Getting alembic path: {alembic_path}")
|
||||
value = cls.get_alembic_db_path(alembic_path=alembic_path, mode='path').parent
|
||||
return value
|
||||
|
||||
@@ -372,7 +366,6 @@ class Settings(BaseSettings, extra="allow"):
|
||||
alembic_path = Path(sys._MEIPASS).joinpath("files", "alembic.ini")
|
||||
else:
|
||||
alembic_path = project_path.joinpath("alembic.ini")
|
||||
# print(f"Getting alembic path: {alembic_path}")
|
||||
value = cls.get_alembic_db_path(alembic_path=alembic_path, mode='path').stem
|
||||
return value
|
||||
|
||||
@@ -384,9 +377,7 @@ class Settings(BaseSettings, extra="allow"):
|
||||
alembic_path = Path(sys._MEIPASS).joinpath("files", "alembic.ini")
|
||||
else:
|
||||
alembic_path = project_path.joinpath("alembic.ini")
|
||||
# print(f"Getting alembic path: {alembic_path}")
|
||||
value = cls.get_alembic_db_path(alembic_path=alembic_path, mode='user')
|
||||
# print(f"Got {value} for user")
|
||||
return value
|
||||
|
||||
@field_validator("database_password", mode='before')
|
||||
@@ -397,9 +388,7 @@ class Settings(BaseSettings, extra="allow"):
|
||||
alembic_path = Path(sys._MEIPASS).joinpath("files", "alembic.ini")
|
||||
else:
|
||||
alembic_path = project_path.joinpath("alembic.ini")
|
||||
# print(f"Getting alembic path: {alembic_path}")
|
||||
value = cls.get_alembic_db_path(alembic_path=alembic_path, mode='pass')
|
||||
# print(f"Got {value} for pass")
|
||||
return value
|
||||
|
||||
@field_validator('database_session', mode="before")
|
||||
@@ -421,7 +410,6 @@ class Settings(BaseSettings, extra="allow"):
|
||||
"{{ values['database_schema'] }}://{{ value }}/{{ db_name }}?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes&Trusted_Connection=yes"
|
||||
)
|
||||
case _:
|
||||
# print(pprint.pprint(values.data))
|
||||
tmp = jinja_template_loading().from_string(
|
||||
"{% if values['database_user'] %}{{ values['database_user'] }}{% if values['database_password'] %}:{{ values['database_password'] }}{% endif %}{% endif %}@{{ values['database_path'] }}")
|
||||
value = tmp.render(values=values.data)
|
||||
@@ -444,7 +432,6 @@ class Settings(BaseSettings, extra="allow"):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.set_from_db()
|
||||
self.set_scripts()
|
||||
# pprint(f"User settings:\n{self.__dict__}")
|
||||
|
||||
def set_from_db(self):
|
||||
if 'pytest' in sys.modules:
|
||||
@@ -453,11 +440,8 @@ class Settings(BaseSettings, extra="allow"):
|
||||
teardown_scripts=dict(goodbye=None)
|
||||
)
|
||||
else:
|
||||
# print(f"Hello from database settings getter.")
|
||||
# print(self.__dict__)
|
||||
session = self.database_session
|
||||
metadata = MetaData()
|
||||
# print(self.database_session.get_bind())
|
||||
try:
|
||||
metadata.reflect(bind=session.get_bind())
|
||||
except AttributeError as e:
|
||||
@@ -467,7 +451,6 @@ class Settings(BaseSettings, extra="allow"):
|
||||
print(f"Couldn't find _configitems in {metadata.tables.keys()}.")
|
||||
return
|
||||
config_items = session.execute(text("SELECT * FROM _configitem")).all()
|
||||
# print(f"Config: {pprint.pprint(config_items)}")
|
||||
output = {}
|
||||
for item in config_items:
|
||||
try:
|
||||
@@ -488,6 +471,7 @@ class Settings(BaseSettings, extra="allow"):
|
||||
p = Path(__file__).parents[2].joinpath("scripts").absolute()
|
||||
if p.__str__() not in sys.path:
|
||||
sys.path.append(p.__str__())
|
||||
# NOTE: Get all .py files that don't have __ in them.
|
||||
modules = p.glob("[!__]*.py")
|
||||
for module in modules:
|
||||
mod = importlib.import_module(module.stem)
|
||||
@@ -495,6 +479,7 @@ class Settings(BaseSettings, extra="allow"):
|
||||
name = function[0]
|
||||
func = function[1]
|
||||
# NOTE: assign function based on its name being in config: startup/teardown
|
||||
# NOTE: scripts must be registered using {name: Null} in the database
|
||||
if name in self.startup_scripts.keys():
|
||||
self.startup_scripts[name] = func
|
||||
if name in self.teardown_scripts.keys():
|
||||
@@ -543,14 +528,12 @@ class Settings(BaseSettings, extra="allow"):
|
||||
try:
|
||||
return url[:url.index("@")].split(":")[0]
|
||||
except (IndexError, ValueError) as e:
|
||||
# print(f"Error on user: {e}")
|
||||
return None
|
||||
case "pass":
|
||||
url = re.sub(r"^.*//", "", url)
|
||||
try:
|
||||
return url[:url.index("@")].split(":")[1]
|
||||
except (IndexError, ValueError) as e:
|
||||
# print(f"Error on user: {e}")
|
||||
return None
|
||||
|
||||
def save(self, settings_path: Path):
|
||||
@@ -592,7 +575,6 @@ def get_config(settings_path: Path | str | None = None) -> Settings:
|
||||
def join(loader, node):
|
||||
seq = loader.construct_sequence(node)
|
||||
return ''.join([str(i) for i in seq])
|
||||
|
||||
# NOTE: register the tag handler
|
||||
yaml.add_constructor('!join', join)
|
||||
# NOTE: make directories
|
||||
@@ -624,7 +606,6 @@ def get_config(settings_path: Path | str | None = None) -> Settings:
|
||||
# NOTE: copy settings to config directory
|
||||
settings = Settings(**default_settings)
|
||||
settings.save(settings_path=CONFIGDIR.joinpath("config.yml"))
|
||||
# print(f"Default settings: {pprint.pprint(settings.__dict__)}")
|
||||
return settings
|
||||
else:
|
||||
# NOTE: check if user defined path is directory
|
||||
@@ -829,10 +810,23 @@ def setup_lookup(func):
|
||||
elif v is not None:
|
||||
sanitized_kwargs[k] = v
|
||||
return func(*args, **sanitized_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_application_from_parent(widget):
|
||||
try:
|
||||
return widget.app
|
||||
except AttributeError:
|
||||
logger.info("Using recursion to get application object.")
|
||||
from frontend.widgets.app import App
|
||||
while not isinstance(widget, App):
|
||||
try:
|
||||
widget = widget.parent()
|
||||
except AttributeError:
|
||||
return widget
|
||||
return widget
|
||||
|
||||
|
||||
class Result(BaseModel, arbitrary_types_allowed=True):
|
||||
owner: str = Field(default="", validate_default=True)
|
||||
code: int = Field(default=0)
|
||||
@@ -937,20 +931,20 @@ def rreplace(s: str, old: str, new: str) -> str:
|
||||
return (s[::-1].replace(old[::-1], new[::-1], 1))[::-1]
|
||||
|
||||
|
||||
def remove_key_from_list_of_dicts(input: list, key: str) -> list:
|
||||
def remove_key_from_list_of_dicts(input_list: list, key: str) -> list:
|
||||
"""
|
||||
Removes a key from all dictionaries in a list of dictionaries
|
||||
|
||||
Args:
|
||||
input (list): Input list of dicts
|
||||
input_list (list): Input list of dicts
|
||||
key (str): Name of key to remove.
|
||||
|
||||
Returns:
|
||||
list: List of updated dictionaries
|
||||
"""
|
||||
for item in input:
|
||||
for item in input_list:
|
||||
del item[key]
|
||||
return input
|
||||
return input_list
|
||||
|
||||
|
||||
def yaml_regex_creator(loader, node):
|
||||
@@ -963,6 +957,7 @@ def yaml_regex_creator(loader, node):
|
||||
|
||||
def super_splitter(ins_str: str, substring: str, idx: int) -> str:
|
||||
"""
|
||||
Splits string on substring at index
|
||||
|
||||
Args:
|
||||
ins_str (str): input string
|
||||
@@ -978,6 +973,20 @@ def super_splitter(ins_str: str, substring: str, idx: int) -> str:
|
||||
return ins_str
|
||||
|
||||
|
||||
def is_developer() -> bool:
|
||||
"""
|
||||
Checks if user is in list of super users
|
||||
|
||||
Returns:
|
||||
bool: True if yes, False if no.
|
||||
"""
|
||||
try:
|
||||
check = getpass.getuser() in ctx.super_users
|
||||
except:
|
||||
check = False
|
||||
return check
|
||||
|
||||
|
||||
def is_power_user() -> bool:
|
||||
"""
|
||||
Checks if user is in list of power users
|
||||
@@ -1000,21 +1009,49 @@ def check_authorization(func):
|
||||
func (function): Function to be used.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
@report_result
|
||||
def wrapper(*args, **kwargs):
|
||||
logger.info(f"Checking authorization")
|
||||
if is_power_user():
|
||||
error_msg = f"User {getpass.getuser()} is not authorized for this function."
|
||||
auth_func = is_power_user
|
||||
if auth_func():
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
logger.error(f"User {getpass.getuser()} is not authorized for this function.")
|
||||
logger.error(error_msg)
|
||||
report = Report()
|
||||
report.add_result(
|
||||
Result(owner=func.__str__(), code=1, msg="This user does not have permission for this function.",
|
||||
status="warning"))
|
||||
Result(owner=func.__str__(), code=1, msg=error_msg, status="warning"))
|
||||
return report
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def under_development(func):
|
||||
"""
|
||||
Decorator to check if user is authorized to access function
|
||||
|
||||
Args:
|
||||
func (function): Function to be used.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
@report_result
|
||||
def wrapper(*args, **kwargs):
|
||||
logger.warning(f"This feature is under development")
|
||||
if is_developer():
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
error_msg = f"User {getpass.getuser()} is not authorized for this function."
|
||||
logger.error(error_msg)
|
||||
report = Report()
|
||||
report.add_result(
|
||||
Result(owner=func.__str__(), code=1, msg=error_msg,
|
||||
status="warning"))
|
||||
return report
|
||||
return wrapper
|
||||
|
||||
|
||||
def report_result(func):
|
||||
"""
|
||||
Decorator to display any reports returned from a function.
|
||||
@@ -1036,14 +1073,9 @@ def report_result(func):
|
||||
case Report():
|
||||
report = output
|
||||
case tuple():
|
||||
# try:
|
||||
report = next((item for item in output if isinstance(item, Report)), None)
|
||||
# except IndexError:
|
||||
# report = None
|
||||
case _:
|
||||
report = Report()
|
||||
# return report
|
||||
# logger.info(f"Got report: {report}")
|
||||
try:
|
||||
results = report.results
|
||||
except AttributeError:
|
||||
@@ -1058,13 +1090,11 @@ def report_result(func):
|
||||
logger.error(result.msg)
|
||||
if output:
|
||||
true_output = tuple(item for item in output if not isinstance(item, Report))
|
||||
# logger.debug(f"True output: {true_output}")
|
||||
if len(true_output) == 1:
|
||||
true_output = true_output[0]
|
||||
else:
|
||||
true_output = None
|
||||
return true_output
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -1084,20 +1114,19 @@ def create_holidays_for_year(year: int | None = None) -> List[date]:
|
||||
offset = -d.weekday() # weekday == 0 means Monday
|
||||
output = d + timedelta(offset)
|
||||
return output.date()
|
||||
|
||||
if not year:
|
||||
year = date.today().year
|
||||
# Includes New Year's day for next year.
|
||||
# NOTE: Includes New Year's day for next year.
|
||||
holidays = [date(year, 1, 1), date(year, 7, 1), date(year, 9, 30),
|
||||
date(year, 11, 11), date(year, 12, 25), date(year, 12, 26),
|
||||
date(year + 1, 1, 1)]
|
||||
# Labour Day
|
||||
# NOTE: Labour Day
|
||||
holidays.append(find_nth_monday(year, 9))
|
||||
# Thanksgiving
|
||||
# NOTE: Thanksgiving
|
||||
holidays.append(find_nth_monday(year, 10, occurence=2))
|
||||
# Victoria Day
|
||||
# NOTE: Victoria Day
|
||||
holidays.append(find_nth_monday(year, 5, day=25))
|
||||
# Easter, etc
|
||||
# NOTE: Easter, etc
|
||||
holidays.append(easter(year) - timedelta(days=2))
|
||||
holidays.append(easter(year) + timedelta(days=1))
|
||||
return sorted(holidays)
|
||||
@@ -1107,8 +1136,7 @@ class classproperty(property):
|
||||
def __get__(self, owner_self, owner_cls):
|
||||
return self.fget(owner_cls)
|
||||
|
||||
|
||||
# NOTE: Monkey patching... hooray!
|
||||
builtins.classproperty = classproperty
|
||||
|
||||
|
||||
ctx = get_config(None)
|
||||
|
||||
Reference in New Issue
Block a user