Add tests and fix issues
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -3,4 +3,5 @@
|
|||||||
.venv
|
.venv
|
||||||
**/*.db
|
**/*.db
|
||||||
**/__pycache__/**
|
**/__pycache__/**
|
||||||
.streamlit/secrets.toml
|
.streamlit/secrets.toml
|
||||||
|
/testdb.sqlite
|
||||||
|
|||||||
@@ -24,10 +24,7 @@ RUN poetry install --only=main --no-interaction --no-ansi
|
|||||||
COPY . /app
|
COPY . /app
|
||||||
VOLUME /app/data
|
VOLUME /app/data
|
||||||
|
|
||||||
RUN touch .streamlit/secrets.toml \
|
EXPOSE 8501
|
||||||
&& toml add_section --toml-path='.streamlit/secrets.toml' 'connections.sqlite' \
|
|
||||||
&& toml set --toml-path='.streamlit/secrets.toml' 'connections.sqlite.type' 'queries' \
|
|
||||||
&& toml set --toml-path='.streamlit/secrets.toml' 'connections.sqlite.url' 'sqlite:///data/daily-counter.db'
|
|
||||||
|
|
||||||
HEALTHCHECK --interval=60s --retries=5 CMD wget -qO- http://127.0.0.1:8501/_stcore/health || exit 1
|
HEALTHCHECK --interval=60s --retries=5 CMD wget -qO- http://127.0.0.1:8501/_stcore/health || exit 1
|
||||||
ENTRYPOINT ["/sbin/tini", "--"]
|
ENTRYPOINT ["/sbin/tini", "--"]
|
||||||
|
|||||||
@@ -1,32 +1,35 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
|
from streamlit import dialog
|
||||||
from queries import crud, daily_stats, weekly_stats, monthly_stats, yearly_stats
|
from queries import crud, daily_stats, weekly_stats, monthly_stats, yearly_stats
|
||||||
from enums import CounterType
|
from enums import CounterType
|
||||||
|
|
||||||
@st.dialog("Add New Counter", icon=":material/add_box:")
|
@dialog("Add New Counter", icon=":material/add_box:")
|
||||||
def _add_counter():
|
def _add_counter():
|
||||||
colors = crud.get_colors(1)
|
colors = crud.get_colors(1)
|
||||||
with st.form(key="add_counter", border=False, clear_on_submit=True):
|
with st.form(key="add_counter", border=False, clear_on_submit=True):
|
||||||
title = st.text_input("Title:")
|
title = st.text_input("Title:", key="new_counter_title")
|
||||||
counter_type_name = st.selectbox("Type", options=[e.name for e in CounterType])
|
counter_type_name = st.selectbox("Type", options=[e.name for e in CounterType], key="new_counter_type")
|
||||||
color = st.radio("Color",
|
selected_color = st.radio("Color",
|
||||||
key="color-selector",
|
key="new_counter_color_selector",
|
||||||
width="stretch",
|
width="stretch",
|
||||||
options=[colors[key][0] for key in colors],
|
options=[colors[key][0] for key in colors],
|
||||||
format_func=lambda c: f"#{c}")
|
format_func=lambda c: f"#{c}")
|
||||||
with st.container(horizontal=True, width="stretch", horizontal_alignment="center"):
|
with st.container(horizontal=True, width="stretch", horizontal_alignment="center"):
|
||||||
if st.form_submit_button(label="Create", icon=":material/save:"):
|
if st.form_submit_button(label="Create", icon=":material/save:", key="create_counter_submit_btn"):
|
||||||
crud.create_counter(title, CounterType[counter_type_name], color)
|
if not title:
|
||||||
|
raise ValueError("Title cannot be empty")
|
||||||
|
crud.create_counter(title, CounterType[counter_type_name], selected_color)
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
|
|
||||||
@st.dialog("Remove Counter", icon=":material/delete:")
|
@dialog("Remove Counter", icon=":material/delete:")
|
||||||
def _remove_counter(counter_id:int):
|
def _remove_counter(remove_counter_id:int):
|
||||||
with st.form(key="remove_counter", border=False, clear_on_submit=True):
|
with st.form(key="remove_counter", border=False, clear_on_submit=True):
|
||||||
st.subheader("Are you sure?")
|
st.subheader("Are you sure?")
|
||||||
with st.container(horizontal=True, width="stretch", horizontal_alignment="center"):
|
with st.container(horizontal=True, width="stretch", horizontal_alignment="center"):
|
||||||
if st.form_submit_button("Confirm", icon=":material/delete:"):
|
if st.form_submit_button("Confirm", icon=":material/delete:", key="remove_counter_submit_btn"):
|
||||||
crud.remove_counter(counter_id)
|
crud.remove_counter(remove_counter_id)
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
df = crud.get_counters()
|
df = crud.get_counters()
|
||||||
@@ -89,7 +92,7 @@ with st.container(key="counter-table"):
|
|||||||
</style>
|
</style>
|
||||||
""")
|
""")
|
||||||
|
|
||||||
if st.button("Add Counter", width="stretch", icon=":material/add_box:"):
|
if st.button("Add Counter", width="stretch", icon=":material/add_box:", key="new_counter_button"):
|
||||||
_add_counter()
|
_add_counter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
|
from os import getenv
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from sqlalchemy.sql import text
|
from sqlalchemy.sql import text
|
||||||
from streamlit.connections import BaseConnection
|
from streamlit.connections import BaseConnection
|
||||||
|
|
||||||
connection: BaseConnection = st.connection("sqlite")
|
def connection() -> BaseConnection:
|
||||||
|
_connection = st.connection("sql", url=getenv('DATABASE_URL'))
|
||||||
|
with _connection.session as configured_session:
|
||||||
|
configured_session.execute(text('PRAGMA foreign_keys=ON'))
|
||||||
|
return _connection
|
||||||
|
|
||||||
|
|
||||||
with connection.session as configure_session:
|
|
||||||
configure_session.execute(text('PRAGMA foreign_keys=ON'))
|
|
||||||
@@ -8,7 +8,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
def create_counter(title:str, counter_type:CounterType, counter_color) -> None:
|
def create_counter(title:str, counter_type:CounterType, counter_color) -> None:
|
||||||
logger.info("Adding counter %s", counter_type)
|
logger.info("Adding counter %s", counter_type)
|
||||||
with connection.session as session:
|
with connection().session as session:
|
||||||
try:
|
try:
|
||||||
query = text('INSERT INTO counters (name, type, color) VALUES (:title, :type, :color)')
|
query = text('INSERT INTO counters (name, type, color) VALUES (:title, :type, :color)')
|
||||||
session.execute(query, {'title': title, 'type': counter_type, 'color': counter_color})
|
session.execute(query, {'title': title, 'type': counter_type, 'color': counter_color})
|
||||||
@@ -19,14 +19,14 @@ def create_counter(title:str, counter_type:CounterType, counter_color) -> None:
|
|||||||
|
|
||||||
def get_counters():
|
def get_counters():
|
||||||
try:
|
try:
|
||||||
return connection.query('SELECT id, name, type, color FROM counters', ttl=0)
|
return connection().query('SELECT id, name, type, color FROM counters', ttl=0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
return st.dataframe()
|
return st.dataframe()
|
||||||
|
|
||||||
def increment_counter(counter_id:int) -> None:
|
def increment_counter(counter_id:int) -> None:
|
||||||
logger.info("Incrementing counter %s", counter_id)
|
logger.info("Incrementing counter %s", counter_id)
|
||||||
with connection.session as session:
|
with connection().session as session:
|
||||||
try:
|
try:
|
||||||
query = text('INSERT INTO entries (counter_id) VALUES (:id)')
|
query = text('INSERT INTO entries (counter_id) VALUES (:id)')
|
||||||
session.execute(query, {'id': counter_id})
|
session.execute(query, {'id': counter_id})
|
||||||
@@ -35,10 +35,9 @@ def increment_counter(counter_id:int) -> None:
|
|||||||
logger.error(e)
|
logger.error(e)
|
||||||
session.rollback()
|
session.rollback()
|
||||||
|
|
||||||
|
|
||||||
def remove_counter(counter_id:int) -> None:
|
def remove_counter(counter_id:int) -> None:
|
||||||
logger.info("Removing counter %s", counter_id)
|
logger.info("Removing counter %s", counter_id)
|
||||||
with connection.session as session:
|
with connection().session as session:
|
||||||
try:
|
try:
|
||||||
query = text('DELETE FROM counters WHERE id = :id')
|
query = text('DELETE FROM counters WHERE id = :id')
|
||||||
session.execute(query, {'id': counter_id})
|
session.execute(query, {'id': counter_id})
|
||||||
@@ -47,10 +46,9 @@ def remove_counter(counter_id:int) -> None:
|
|||||||
logger.error(e)
|
logger.error(e)
|
||||||
session.rollback()
|
session.rollback()
|
||||||
|
|
||||||
|
|
||||||
def get_counter(counter_id:int):
|
def get_counter(counter_id:int):
|
||||||
try:
|
try:
|
||||||
return connection.query('SELECT * FROM counters WHERE id = :id', params={'id': counter_id}, ttl=0).iloc[0]
|
return connection().query('SELECT * FROM counters WHERE id = :id', params={'id': counter_id}, ttl=0).iloc[0]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
return None
|
return None
|
||||||
@@ -58,7 +56,7 @@ def get_counter(counter_id:int):
|
|||||||
|
|
||||||
def get_colors(palette_id:int):
|
def get_colors(palette_id:int):
|
||||||
try:
|
try:
|
||||||
return connection.query('''SELECT color1,color2,color3,color4,color5 FROM color_palettes WHERE id = :id''', params={'id': palette_id})
|
return connection().query('''SELECT color1,color2,color3,color4,color5 FROM color_palettes WHERE id = :id''', params={'id': palette_id})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
def get_all_daily_analytics(end_date:str = 'now'):
|
def get_all_daily_analytics(end_date:str = 'now'):
|
||||||
try:
|
try:
|
||||||
return connection.query('''
|
return connection().query('''
|
||||||
WITH RECURSIVE timeseries(d) AS (
|
WITH RECURSIVE timeseries(d) AS (
|
||||||
VALUES(date(:end_date))
|
VALUES(date(:end_date))
|
||||||
UNION ALL
|
UNION ALL
|
||||||
@@ -39,7 +39,7 @@ def get_all_daily_analytics(end_date:str = 'now'):
|
|||||||
|
|
||||||
def get_daily_analytics(counter_id:int, end_date:str = 'now'):
|
def get_daily_analytics(counter_id:int, end_date:str = 'now'):
|
||||||
try:
|
try:
|
||||||
return connection.query('''
|
return connection().query('''
|
||||||
WITH RECURSIVE timeseries(d) AS (
|
WITH RECURSIVE timeseries(d) AS (
|
||||||
VALUES(date(:end_date))
|
VALUES(date(:end_date))
|
||||||
UNION ALL
|
UNION ALL
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
def get_all_monthly_analytics(end_date:str = 'now'):
|
def get_all_monthly_analytics(end_date:str = 'now'):
|
||||||
try:
|
try:
|
||||||
return connection.query('''
|
return connection().query('''
|
||||||
WITH RECURSIVE timeseries(d) AS (
|
WITH RECURSIVE timeseries(d) AS (
|
||||||
VALUES(date(:end_date,'start of year'))
|
VALUES(date(:end_date,'start of year'))
|
||||||
UNION ALL
|
UNION ALL
|
||||||
@@ -45,7 +45,7 @@ def get_all_monthly_analytics(end_date:str = 'now'):
|
|||||||
|
|
||||||
def get_monthly_analytics(counter_id:int, end_date:str = 'now'):
|
def get_monthly_analytics(counter_id:int, end_date:str = 'now'):
|
||||||
try:
|
try:
|
||||||
return connection.query('''
|
return connection().query('''
|
||||||
WITH RECURSIVE timeseries(d) AS (
|
WITH RECURSIVE timeseries(d) AS (
|
||||||
VALUES( date(:end_date, 'start of year'))
|
VALUES( date(:end_date, 'start of year'))
|
||||||
UNION ALL
|
UNION ALL
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
def get_all_weekly_analytics(end_date:str = 'now'):
|
def get_all_weekly_analytics(end_date:str = 'now'):
|
||||||
try:
|
try:
|
||||||
return connection.query('''
|
return connection().query('''
|
||||||
WITH RECURSIVE timeseries(d) AS (
|
WITH RECURSIVE timeseries(d) AS (
|
||||||
VALUES(date(:end_date, 'weekday 0'))
|
VALUES(date(:end_date, 'weekday 0'))
|
||||||
UNION ALL
|
UNION ALL
|
||||||
@@ -42,7 +42,7 @@ def get_all_weekly_analytics(end_date:str = 'now'):
|
|||||||
|
|
||||||
def get_weekly_analytics(counter_id:int, end_date:str = 'now'):
|
def get_weekly_analytics(counter_id:int, end_date:str = 'now'):
|
||||||
try:
|
try:
|
||||||
return connection.query('''
|
return connection().query('''
|
||||||
WITH RECURSIVE timeseries(d) AS (
|
WITH RECURSIVE timeseries(d) AS (
|
||||||
VALUES(date(:end_date, 'weekday 0'))
|
VALUES(date(:end_date, 'weekday 0'))
|
||||||
UNION ALL
|
UNION ALL
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
def get_all_yearly_analytics(end_date:str = 'now'):
|
def get_all_yearly_analytics(end_date:str = 'now'):
|
||||||
try:
|
try:
|
||||||
return connection.query('''
|
return connection().query('''
|
||||||
WITH RECURSIVE timeseries(d) AS (
|
WITH RECURSIVE timeseries(d) AS (
|
||||||
VALUES(date(:end_date,'start of year', '-4 years'))
|
VALUES(date(:end_date,'start of year', '-4 years'))
|
||||||
UNION ALL
|
UNION ALL
|
||||||
@@ -42,7 +42,7 @@ def get_all_yearly_analytics(end_date:str = 'now'):
|
|||||||
|
|
||||||
def get_yearly_analytics(counter_id:int, end_date:str = 'now'):
|
def get_yearly_analytics(counter_id:int, end_date:str = 'now'):
|
||||||
try:
|
try:
|
||||||
return connection.query('''
|
return connection().query('''
|
||||||
WITH RECURSIVE timeseries(d) AS (
|
WITH RECURSIVE timeseries(d) AS (
|
||||||
VALUES( date(:end_date, 'start of year', '-4 years'))
|
VALUES( date(:end_date, 'start of year', '-4 years'))
|
||||||
UNION ALL
|
UNION ALL
|
||||||
|
|||||||
@@ -1,18 +1,19 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
import logging
|
||||||
from logger import init_logger
|
from logger import init_logger
|
||||||
from styles import init_styles
|
from styles import init_styles
|
||||||
|
|
||||||
init_logger()
|
init_logger()
|
||||||
init_styles()
|
init_styles()
|
||||||
|
|
||||||
if st.user and not st.user.is_logged_in:
|
if hasattr(st, 'user') and hasattr(st.user, 'is_logged_in'):
|
||||||
with st.container(width="stretch", height="stretch", horizontal_alignment="center"):
|
if not st.user.is_logged_in:
|
||||||
st.title("Daily Counter", width="stretch", text_alignment="center")
|
with st.container(width="stretch", height="stretch", horizontal_alignment="center"):
|
||||||
st.text("Please log in to use this app", width="stretch", text_alignment="center")
|
st.title("Daily Counter", width="stretch", text_alignment="center")
|
||||||
st.space()
|
st.text("Please log in to use this app", width="stretch", text_alignment="center")
|
||||||
if st.button("Log in"):
|
st.space()
|
||||||
st.login()
|
if st.button("Log in"):
|
||||||
|
st.login()
|
||||||
else:
|
else:
|
||||||
counters = st.Page("pages/counters.py", title="Counters", icon=":material/update:")
|
counters = st.Page("pages/counters.py", title="Counters", icon=":material/update:")
|
||||||
stats = st.Page("pages/stats.py", title="Statistics", icon=":material/chart_data:")
|
stats = st.Page("pages/stats.py", title="Statistics", icon=":material/chart_data:")
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ def _load_color_selector_styles():
|
|||||||
css_color = '#' + colors[c][0]
|
css_color = '#' + colors[c][0]
|
||||||
st.html(f"""
|
st.html(f"""
|
||||||
<style>
|
<style>
|
||||||
.st-key-color-selector label:has(> input[value='{idx}']) {{
|
.st-key-new_counter_color_selector label:has(> input[value='{idx}']) {{
|
||||||
background-color: {css_color};
|
background-color: {css_color};
|
||||||
}}
|
}}
|
||||||
</style>
|
</style>
|
||||||
|
|||||||
@@ -37,20 +37,20 @@
|
|||||||
background-color: whitesmoke;
|
background-color: whitesmoke;
|
||||||
}
|
}
|
||||||
|
|
||||||
.st-key-color-selector div[role = "radiogroup"] {
|
.st-key-new_counter_color_selector div[role = "radiogroup"] {
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: row;
|
flex-direction: row;
|
||||||
}
|
}
|
||||||
.st-key-color-selector div[role = "radiogroup"] > label {
|
.st-key-new_counter_color_selector div[role = "radiogroup"] > label {
|
||||||
flex: 1
|
flex: 1
|
||||||
}
|
}
|
||||||
.st-key-color-selector div[role = "radiogroup"] > label > div:first-child {
|
.st-key-new_counter_color_selector div[role = "radiogroup"] > label > div:first-child {
|
||||||
display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
.st-key-color-selector div[role = "radiogroup"] > label:has(> input[tabindex="0"]) {
|
.st-key-new_counter_color_selector div[role = "radiogroup"] > label:has(> input[tabindex="0"]) {
|
||||||
outline: 3px solid blue;
|
outline: 3px solid blue;
|
||||||
}
|
}
|
||||||
.st-key-color-selector div[role = "radiogroup"] p {
|
.st-key-new_counter_color_selector div[role = "radiogroup"] p {
|
||||||
visibility: hidden;
|
visibility: hidden;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,15 @@
|
|||||||
#!/usr/bin/env sh
|
#!/usr/bin/env sh
|
||||||
|
|
||||||
STREAMLIT_SECRETS_LOCATION=".streamlit/secrets.toml"
|
STREAMLIT_SECRETS_LOCATION=".streamlit/secrets.toml"
|
||||||
|
touch STREAMLIT_SECRETS_LOCATION
|
||||||
|
|
||||||
|
SQLITE_DATABASE="/data/daily-counter.db"
|
||||||
|
SQLITE_DATABASE_URL="sqlite://$SQLITE_DATABASE"
|
||||||
|
export DATABASE_URL="$SQLITE_DATABASE_URL"
|
||||||
|
echo "INFO [entrypoint] Using SQLite database at $SQLITE_DATABASE"
|
||||||
|
|
||||||
if [ "$OIDC_ENABLED" = "true" ]; then
|
if [ "$OIDC_ENABLED" = "true" ]; then
|
||||||
echo "INFO [entrypoint] OIDC configuration detected. Configuring app..."
|
echo "INFO [entrypoint] OIDC configuration detected. Configuring authentication..."
|
||||||
toml add_section --toml-path=$STREAMLIT_SECRETS_LOCATION 'auth'
|
toml add_section --toml-path=$STREAMLIT_SECRETS_LOCATION 'auth'
|
||||||
toml set --toml-path=$STREAMLIT_SECRETS_LOCATION 'auth.redirect_uri' "$OIDC_PUBLIC_URL/oauth2callback"
|
toml set --toml-path=$STREAMLIT_SECRETS_LOCATION 'auth.redirect_uri' "$OIDC_PUBLIC_URL/oauth2callback"
|
||||||
toml set --toml-path=$STREAMLIT_SECRETS_LOCATION 'auth.cookie_secret' "$OIDC_COOKIE_SECRET"
|
toml set --toml-path=$STREAMLIT_SECRETS_LOCATION 'auth.cookie_secret' "$OIDC_COOKIE_SECRET"
|
||||||
|
|||||||
1245
poetry.lock
generated
1245
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -3,12 +3,15 @@ name = "daily-counter"
|
|||||||
description = "A daily counter for any habbit tracking"
|
description = "A daily counter for any habbit tracking"
|
||||||
version = "0.1"
|
version = "0.1"
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
requires-python = ">= 3.10"
|
requires-python = ">=3.10,<4"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"alembic (==1.18.4)",
|
"alembic (==1.18.4)",
|
||||||
"streamlit (==1.56.0)",
|
"streamlit (==1.56.0)",
|
||||||
"toml-cli (==0.8.2)",
|
"toml-cli (==0.8.2)",
|
||||||
"authlib (==1.6.9)"
|
"authlib (==1.6.9)",
|
||||||
|
"sqlalchemy (>=2.0.49,<3.0.0)",
|
||||||
|
"pytest-alembic (>=0.12.1,<0.13.0)",
|
||||||
|
"pytest-env (>=1.6.0,<2.0.0)"
|
||||||
]
|
]
|
||||||
|
|
||||||
[virtualenvs]
|
[virtualenvs]
|
||||||
@@ -23,4 +26,14 @@ prepend_sys_path = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
package-mode = false
|
package-mode = false
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
pytest = ">=9.0"
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = ">=3.10,<4"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core>=1.0.0"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
34
tests/conftest.py
Normal file
34
tests/conftest.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from alembic import config
|
||||||
|
from pytest_alembic.config import Config
|
||||||
|
from streamlit.testing.v1 import AppTest
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_database(alembic_runner):
|
||||||
|
logger.info("Running database migrations")
|
||||||
|
alembic_runner.migrate_up_to('heads')
|
||||||
|
yield
|
||||||
|
logger.info("Resetting database")
|
||||||
|
alembic_runner.migrate_down_to('base')
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def alembic_config() -> Config:
|
||||||
|
logging.info("Setting up alembic config")
|
||||||
|
alembic_cfg = config.Config(toml_file="pyproject.toml")
|
||||||
|
alembic_cfg.set_main_option("sqlalchemy.url", os.getenv("DATABASE_URL", ""))
|
||||||
|
return Config(alembic_config=alembic_cfg)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app() -> AppTest:
|
||||||
|
return AppTest.from_file("app/streamlit_app.py")
|
||||||
|
|
||||||
|
def delete_database():
|
||||||
|
file = os.getenv("DATABASE_FILE")
|
||||||
|
if file and os.path.isfile(file):
|
||||||
|
logger.info(f"Deleting database file {file}")
|
||||||
|
os.remove(file)
|
||||||
9
tests/pytest.ini
Normal file
9
tests/pytest.ini
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
[pytest]
|
||||||
|
log_cli = 1
|
||||||
|
log_cli_level = INFO
|
||||||
|
log_cli_format = %(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)
|
||||||
|
log_cli_date_format=%Y-%m-%d %H:%M:%S
|
||||||
|
env =
|
||||||
|
DATABASE_FILE=testdb.sqlite
|
||||||
|
DATABASE_URL=sqlite:///testdb.sqlite?cache=shared
|
||||||
|
|
||||||
59
tests/ui/counters_test.py
Normal file
59
tests/ui/counters_test.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import queries.crud
|
||||||
|
from enums import CounterType
|
||||||
|
|
||||||
|
def test_initial_state(app):
|
||||||
|
app.run()
|
||||||
|
assert not app.exception
|
||||||
|
assert not app.error
|
||||||
|
assert len(app.header) == 0 # No counter currently present
|
||||||
|
|
||||||
|
def test_add_counter(app):
|
||||||
|
app.run()
|
||||||
|
|
||||||
|
# Open new counter dialog
|
||||||
|
app.button(key="new_counter_button").click().run()
|
||||||
|
|
||||||
|
# Fill in details and submit
|
||||||
|
app.text_input(key="new_counter_title").set_value("Walk")
|
||||||
|
app.selectbox(key='new_counter_type').select(CounterType.DAILY.name)
|
||||||
|
app.radio(key='new_counter_color_selector').set_value("020122")
|
||||||
|
app.button(key="create_counter_submit_btn").click()
|
||||||
|
app.run()
|
||||||
|
|
||||||
|
assert not app.exception
|
||||||
|
assert not app.error
|
||||||
|
assert len(app.text_input) == 0 # dialog closed, back in the main screen
|
||||||
|
|
||||||
|
# Simulate button listener due to bug https://github.com/streamlit/streamlit/issues/9786
|
||||||
|
queries.crud.create_counter("Walk", CounterType.DAILY, "020122")
|
||||||
|
app.run()
|
||||||
|
|
||||||
|
assert len(app.header) == 1 # A new counter was added
|
||||||
|
assert app.header[0].value == ":material/calendar_clock: Walk"
|
||||||
|
|
||||||
|
def test_remove_counter(app):
|
||||||
|
|
||||||
|
# Create a counter to remove
|
||||||
|
queries.crud.create_counter("Remove me", CounterType.SIMPLE, "020122")
|
||||||
|
|
||||||
|
app.run()
|
||||||
|
|
||||||
|
assert not app.exception
|
||||||
|
assert not app.error
|
||||||
|
assert len(app.header) == 1 # One counter exists
|
||||||
|
|
||||||
|
# Remove the counter
|
||||||
|
app.button("remove_counter_1").click().run()
|
||||||
|
|
||||||
|
# Confirmation
|
||||||
|
assert app.subheader[0].value == 'Are you sure?'
|
||||||
|
app.button(key="remove_counter_submit_btn").click()
|
||||||
|
app.run()
|
||||||
|
|
||||||
|
assert len(app.text_input) == 0 # dialog closed, back in the main screen
|
||||||
|
|
||||||
|
# Simulate button listener due to bug https://github.com/streamlit/streamlit/issues/9786
|
||||||
|
queries.crud.remove_counter(1)
|
||||||
|
app.run()
|
||||||
|
|
||||||
|
assert len(app.header) == 0 # No counter exists
|
||||||
Reference in New Issue
Block a user