From bd9ff7191a282c1a1f1468726252c9d54c7e7645 Mon Sep 17 00:00:00 2001 From: John Ahlroos Date: Tue, 28 Apr 2026 21:04:52 +0200 Subject: [PATCH] Add user specific profiles --- app/pages/stats.py | 5 +- app/queries/connection.py | 2 +- app/queries/crud.py | 53 +++++++++++---- app/queries/daily_stats.py | 9 ++- app/queries/monthly_stats.py | 9 ++- app/queries/user.py | 67 +++++++++++++++++++ app/queries/weekly_stats.py | 9 ++- app/queries/yearly_stats.py | 9 ++- app/streamlit_app.py | 38 ++++++++--- css/theme.css | 18 ++++- migrations/versions/20260427172417_user_id.py | 56 ++++++++++++++++ tests/database/user_db_test.py | 55 +++++++++++++++ 12 files changed, 294 insertions(+), 36 deletions(-) create mode 100644 app/queries/user.py create mode 100644 migrations/versions/20260427172417_user_id.py create mode 100644 tests/database/user_db_test.py diff --git a/app/pages/stats.py b/app/pages/stats.py index 19ff1ce..5534824 100644 --- a/app/pages/stats.py +++ b/app/pages/stats.py @@ -29,7 +29,10 @@ if "counter_id" in st.query_params.keys(): with st.container(horizontal_alignment="right", vertical_alignment="bottom", horizontal=True): st.header('Counter: ' + df['name']) - selection = st.segmented_control("Time Range", options, selection_mode="single", required=True, default=counter_type.name, label_visibility="hidden") + selected = counter_type.name + if selected == CounterType.SIMPLE.name: + selected = CounterType.DAILY.name + selection = st.segmented_control("Time Range", options, selection_mode="single", required=True, default=selected, label_visibility="hidden") match getattr(CounterType, selection): case CounterType.DAILY: diff --git a/app/queries/connection.py b/app/queries/connection.py index 5315ef2..9870a4a 100644 --- a/app/queries/connection.py +++ b/app/queries/connection.py @@ -5,7 +5,7 @@ from sqlalchemy.sql import text from streamlit.connections import BaseConnection def connection() -> BaseConnection: - _connection = st.connection("sql", url=getenv('DATABASE_URL')) + _connection = st.connection("sql", url=getenv('DATABASE_URL'), ttl=0, autocommit=True) with _connection.session as configured_session: configured_session.execute(text('PRAGMA foreign_keys=ON')) return _connection diff --git a/app/queries/crud.py b/app/queries/crud.py index 535b75f..68d934c 100644 --- a/app/queries/crud.py +++ b/app/queries/crud.py @@ -7,48 +7,73 @@ from enums import CounterType logger = logging.getLogger(__name__) def create_counter(title:str, counter_type:CounterType, counter_color) -> None: - logger.info("Adding counter %s", counter_type) + user_id = int(st.session_state.user_id) + logger.info("Adding counter %s for user %d", counter_type, user_id) with connection().session as session: try: - query = text('INSERT INTO counters (name, type, color) VALUES (:title, :type, :color)') - session.execute(query, {'title': title, 'type': counter_type, 'color': counter_color}) - session.commit() + query = text('INSERT INTO counters (user_id, name, type, color) VALUES (:user, :title, :type, :color)') + session.execute(query, { + 'user': user_id, + 'title': title, + 'type': counter_type, + 'color': counter_color + }) except Exception as e: logger.error(e) session.rollback() def get_counters(): + user_id = int(st.session_state.user_id) try: - return connection().query('SELECT id, name, type, color FROM counters', ttl=0) + return connection().query(""" + SELECT id, name, type, color + FROM counters + WHERE user_id = :user + """, params={'user': user_id }) except Exception as e: logger.error(e) return st.dataframe() def increment_counter(counter_id:int) -> None: - logger.info("Incrementing counter %s", counter_id) + user_id = int(st.session_state.user_id) + logger.info("Incrementing counter %d for user %d", counter_id, user_id) with connection().session as session: try: - query = text('INSERT INTO entries (counter_id) VALUES (:id)') - session.execute(query, {'id': counter_id}) - session.commit() + query = text('INSERT INTO entries (counter_id, user_id) VALUES (:id, :user)') + session.execute(query, { + 'id': counter_id, + 'user': user_id + }) except Exception as e: logger.error(e) session.rollback() def remove_counter(counter_id:int) -> None: - logger.info("Removing counter %s", counter_id) + user_id = int(st.session_state.user_id) + logger.info("Removing counter %d from user %d", counter_id, user_id) with connection().session as session: try: - query = text('DELETE FROM counters WHERE id = :id') - session.execute(query, {'id': counter_id}) - session.commit() + query = text('DELETE FROM counters WHERE id = :id AND user_id = :user') + session.execute(query, { + 'id': counter_id, + 'user': user_id + }) except Exception as e: logger.error(e) session.rollback() def get_counter(counter_id:int): + user_id = int(st.session_state.user_id) try: - return connection().query('SELECT * FROM counters WHERE id = :id', params={'id': counter_id}, ttl=0).iloc[0] + counters = connection().query(""" + SELECT * FROM counters + WHERE id = :id AND user_id = :user + """, params={ 'id': counter_id, 'user': user_id} + ) + if counters.empty: + return None + + return counters.iloc[0] except Exception as e: logger.error(e) return None diff --git a/app/queries/daily_stats.py b/app/queries/daily_stats.py index fb164a3..cc0517d 100644 --- a/app/queries/daily_stats.py +++ b/app/queries/daily_stats.py @@ -1,9 +1,11 @@ import logging from queries.connection import connection +import streamlit as st logger = logging.getLogger(__name__) def get_all_daily_analytics(end_date:str = 'now'): + user_id = int(st.session_state.user_id) try: return connection().query(''' WITH RECURSIVE timeseries(d) AS ( @@ -19,6 +21,7 @@ def get_all_daily_analytics(end_date:str = 'now'): counter_id, sum(increment) as count FROM entries + WHERE user_id = :user_id group by counter_id, date(timestamp) ) select @@ -31,13 +34,14 @@ def get_all_daily_analytics(end_date:str = 'now'): left outer join stats t on s.d = t.d left join counters c on t.counter_id = c.id GROUP by s.d - ''', params={"end_date": end_date}, ttl=0) + ''', params={"end_date": end_date, "user_id": user_id }) except Exception as e: logger.error(e) return None def get_daily_analytics(counter_id:int, end_date:str = 'now'): + user_id = int(st.session_state.user_id) try: return connection().query(''' WITH RECURSIVE timeseries(d) AS ( @@ -53,6 +57,7 @@ def get_daily_analytics(counter_id:int, end_date:str = 'now'): sum(increment) as count FROM entries where counter_id = :id + and user_id = :user_id group by date(timestamp) ) SELECT @@ -60,7 +65,7 @@ def get_daily_analytics(counter_id:int, end_date:str = 'now'): coalesce(s.count, 0) as count FROM timeseries as t LEFT JOIN stats as s on s.d = t.d - ''', params={'id': counter_id, "end_date": end_date}, ttl=0) + ''', params={'id': counter_id, "end_date": end_date, "user_id": user_id}) except Exception as e: logger.error(e) return None \ No newline at end of file diff --git a/app/queries/monthly_stats.py b/app/queries/monthly_stats.py index d14c110..f342d84 100644 --- a/app/queries/monthly_stats.py +++ b/app/queries/monthly_stats.py @@ -1,9 +1,11 @@ import logging from queries.connection import connection +import streamlit as st logger = logging.getLogger(__name__) def get_all_monthly_analytics(end_date:str = 'now'): + user_id = int(st.session_state.user_id) try: return connection().query(''' WITH RECURSIVE timeseries(d) AS ( @@ -26,6 +28,7 @@ def get_all_monthly_analytics(end_date:str = 'now'): counter_id, sum(increment) as count FROM entries + WHERE user_id = :user_id group by counter_id, strftime('%m', timestamp), strftime('%Y', timestamp) ) select @@ -38,12 +41,13 @@ def get_all_monthly_analytics(end_date:str = 'now'): left outer join stats t on m.m = t.m and m.y = t.y left join counters c on t.counter_id = c.id GROUP by m.m, m.y - ''', params={"end_date": end_date}, ttl=0) + ''', params={"end_date": end_date, "user_id": user_id}) except Exception as e: logger.error(e) return None def get_monthly_analytics(counter_id:int, end_date:str = 'now'): + user_id = int(st.session_state.user_id) try: return connection().query(''' WITH RECURSIVE timeseries(d) AS ( @@ -66,6 +70,7 @@ def get_monthly_analytics(counter_id:int, end_date:str = 'now'): sum(increment) as count FROM entries where counter_id = :id + and user_id = :user_id group by strftime('%m', timestamp), strftime('%Y', timestamp) ) SELECT @@ -73,7 +78,7 @@ def get_monthly_analytics(counter_id:int, end_date:str = 'now'): coalesce(s.count, 0) as count FROM months as m LEFT JOIN stats as s on s.m = m.m and s.y = m.y - ''', params={'id': counter_id, "end_date": end_date}, ttl=0) + ''', params={'id': counter_id, "end_date": end_date, "user_id": user_id}) except Exception as e: logger.error(e) return None \ No newline at end of file diff --git a/app/queries/user.py b/app/queries/user.py new file mode 100644 index 0000000..8da0d1d --- /dev/null +++ b/app/queries/user.py @@ -0,0 +1,67 @@ +import logging + +import streamlit as st +from sqlalchemy.sql import text +from streamlit.user_info import UserInfoProxy + +from queries.connection import connection + +logger = logging.getLogger(__name__) + +def find_user_by_oidc_id(oidc_user_id): + return connection().query('SELECT * FROM users WHERE oidc_user_id = :id', params={'id': oidc_user_id}) + + +def find_user_by_email(email): + return connection().query('SELECT * FROM users WHERE email = :email', params={'email': email}) + + +def find_default_user(): + return find_user_by_email('default') + + +def update_default_user(email, name, oidc_user_id): + with connection().session as session: + try: + query = text("UPDATE users SET email = :email, name = :name, oidc_user_id = :user_id WHERE email = 'default'") + session.execute(query, {'email': email, 'name': name, 'user_id': oidc_user_id}) + except Exception as e: + session.rollback() + raise e + + +def create_user(email, name, oidc_user_id): + with connection().session as session: + try: + logger.info("Creating new user %s", email) + query = text('INSERT INTO users (email, name, oidc_user_id) VALUES (:email, :name, :user_id)') + session.execute(query, {'email': email, 'name': name, 'user_id': oidc_user_id}) + return connection().query('SELECT * FROM users WHERE oidc_user_id = :id', params={'id': oidc_user_id}) + except Exception as e: + session.rollback() + raise e + + +def set_user_in_session(user: UserInfoProxy): + email = user.email + user_id = user.sub + if hasattr(user, 'name'): + name = user.name + else: + name = None + + user_entity = find_user_by_oidc_id(user_id) + if user_entity.empty: + user_entity = find_user_by_email(email) + if user_entity.empty: + user_entity = find_default_user() + if user_entity.empty: + user_entity = create_user(email, name, user_id) + else: + update_default_user(email, name, user_id) + user_entity = find_user_by_oidc_id(user_id) + + st.session_state.user_id = user_entity["id"][0] + st.session_state.user_name = user_entity["name"][0] + st.session_state.user_email = user_entity["email"][0] + st.session_state.user_external_id = user_entity["oidc_user_id"][0] \ No newline at end of file diff --git a/app/queries/weekly_stats.py b/app/queries/weekly_stats.py index 52b494a..2eea973 100644 --- a/app/queries/weekly_stats.py +++ b/app/queries/weekly_stats.py @@ -1,9 +1,11 @@ import logging +import streamlit as st from queries.connection import connection logger = logging.getLogger(__name__) def get_all_weekly_analytics(end_date:str = 'now'): + user_id = int(st.session_state.user_id) try: return connection().query(''' WITH RECURSIVE timeseries(d) AS ( @@ -23,6 +25,7 @@ def get_all_weekly_analytics(end_date:str = 'now'): counter_id, sum(increment) as count FROM entries + WHERE user_id = :user_id group by counter_id, strftime('%W', timestamp) ) select @@ -35,12 +38,13 @@ def get_all_weekly_analytics(end_date:str = 'now'): left outer join stats t on s.w = t.w left join counters c on t.counter_id = c.id GROUP by s.w - ''', params={"end_date": end_date}, ttl=0) + ''', params={"end_date": end_date, "user_id": user_id}) except Exception as e: logger.error(e) return None def get_weekly_analytics(counter_id:int, end_date:str = 'now'): + user_id = int(st.session_state.user_id) try: return connection().query(''' WITH RECURSIVE timeseries(d) AS ( @@ -60,6 +64,7 @@ def get_weekly_analytics(counter_id:int, end_date:str = 'now'): sum(increment) as count FROM entries where counter_id = :id + and user_id = :user_id group by strftime('%W', timestamp) ) SELECT @@ -67,7 +72,7 @@ def get_weekly_analytics(counter_id:int, end_date:str = 'now'): coalesce(s.count, 0) as count FROM weeks as w LEFT JOIN stats as s on s.w = w.w - ''', params={'id': counter_id, "end_date": end_date}, ttl=0) + ''', params={'id': counter_id, "end_date": end_date, "user_id": user_id}) except Exception as e: logger.error(e) return None \ No newline at end of file diff --git a/app/queries/yearly_stats.py b/app/queries/yearly_stats.py index abb7df1..8421025 100644 --- a/app/queries/yearly_stats.py +++ b/app/queries/yearly_stats.py @@ -1,9 +1,11 @@ import logging +import streamlit as st from queries.connection import connection logger = logging.getLogger(__name__) def get_all_yearly_analytics(end_date:str = 'now'): + user_id = int(st.session_state.user_id) try: return connection().query(''' WITH RECURSIVE timeseries(d) AS ( @@ -23,6 +25,7 @@ def get_all_yearly_analytics(end_date:str = 'now'): counter_id, sum(increment) as count FROM entries + WHERE user_id = :user_id group by counter_id, strftime('%Y', timestamp) ) select @@ -35,12 +38,13 @@ def get_all_yearly_analytics(end_date:str = 'now'): left outer join stats t on y.y = t.y left join counters c on t.counter_id = c.id GROUP by y.y - ''', params={"end_date": end_date}, ttl=0) + ''', params={"end_date": end_date, "user_id": user_id}) except Exception as e: logger.error(e) return None def get_yearly_analytics(counter_id:int, end_date:str = 'now'): + user_id = int(st.session_state.user_id) try: return connection().query(''' WITH RECURSIVE timeseries(d) AS ( @@ -60,6 +64,7 @@ def get_yearly_analytics(counter_id:int, end_date:str = 'now'): sum(increment) as count FROM entries where counter_id = :id + and user_id = :user_id group by strftime('%Y', timestamp) ) SELECT @@ -67,7 +72,7 @@ def get_yearly_analytics(counter_id:int, end_date:str = 'now'): coalesce(s.count, 0) as count FROM years as m LEFT JOIN stats as s on s.y = m.y - ''', params={'id': counter_id, "end_date": end_date}, ttl=0) + ''', params={'id': counter_id, "end_date": end_date, "user_id":user_id}) except Exception as e: logger.error(e) return None \ No newline at end of file diff --git a/app/streamlit_app.py b/app/streamlit_app.py index 12f59d3..6876d1d 100644 --- a/app/streamlit_app.py +++ b/app/streamlit_app.py @@ -1,21 +1,37 @@ import streamlit as st -import logging +from streamlit import dialog +import queries.user as user_queries + from logger import init_logger from styles import init_styles init_logger() init_styles() -if hasattr(st, 'user') and hasattr(st.user, 'is_logged_in'): - if not st.user.is_logged_in: - with st.container(width="stretch", height="stretch", horizontal_alignment="center"): - st.title("Daily Counter", width="stretch", text_alignment="center") - st.text("Please log in to use this app", width="stretch", text_alignment="center") - st.space() - if st.button("Log in"): - st.login() +is_login_enabled = hasattr(st, 'user') +is_logged_in = is_login_enabled and hasattr(st.user, 'is_logged_in') and st.user.is_logged_in + +if is_logged_in: + user_queries.set_user_in_session(st.user) +else: + st.session_state.user_id = 1 # default user + +if is_login_enabled and not is_logged_in: + with st.container(width="stretch", height="stretch", horizontal_alignment="center"): + st.title("Daily Counter", width="stretch", text_alignment="center") + st.text("Please log in to use this app", width="stretch", text_alignment="center") + st.space() + if st.button("Log in"): + st.login() + else: counters = st.Page("pages/counters.py", title="Counters", icon=":material/update:") stats = st.Page("pages/stats.py", title="Statistics", icon=":material/chart_data:") - pg = st.navigation(position="top", pages=[counters, stats]) - pg.run() \ No newline at end of file + logoutPage = st.Page(st.logout, title="Logout", icon=":material/logout:") + + pages = [counters, stats] + if is_login_enabled: + pages = pages + [logoutPage] + + pg = st.navigation(position="top", pages=pages) + pg.run() diff --git a/css/theme.css b/css/theme.css index de05c2f..cf3e648 100644 --- a/css/theme.css +++ b/css/theme.css @@ -1,5 +1,5 @@ #MainMenu { - #display: none; + display: none; } .stApp { min-width: 360px; @@ -54,3 +54,19 @@ visibility: hidden; } +div:has(> .stToolbarActions) { + display: none; +} +.rc-overflow > .rc-overflow-item { + flex: 1; +} +.rc-overflow > .rc-overflow-item:nth-child(3) > div { + margin-left: auto; + width: fit-content; +} +.rc-overflow > .rc-overflow-item:nth-child(3) > div > a { + padding: 0; +} +.rc-overflow > .rc-overflow-item-rest { + display: none; +} diff --git a/migrations/versions/20260427172417_user_id.py b/migrations/versions/20260427172417_user_id.py new file mode 100644 index 0000000..deafd93 --- /dev/null +++ b/migrations/versions/20260427172417_user_id.py @@ -0,0 +1,56 @@ +"""add user id + +Revision ID: d9faf8fb8642 +Revises: 4ee21f978e6c +Create Date: 2026-04-27 17:24:17.892586 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = 'd9faf8fb8642' +down_revision: Union[str, Sequence[str], None] = '4ee21f978e6c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + users = op.create_table( + "users", + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("email", sa.String, nullable=False), + sa.Column("name", sa.String), + sa.Column("oidc_user_id", sa.Integer) + ) + + op.bulk_insert(users, [ { "email": "default" } ]) + + with op.batch_alter_table("counters") as batch_op: + batch_op.add_column(sa.Column("user_id", sa.Integer, insert_default=1)) + batch_op.create_foreign_key("fk_counters_user_id", + referent_table="users", + remote_cols=['id'], + local_cols=['user_id']) + + with op.batch_alter_table("entries") as batch_op: + batch_op.add_column(sa.Column("user_id", sa.Integer, insert_default=1)) + batch_op.create_foreign_key("fk_entries_user_id", + referent_table="users", + remote_cols=['id'], + local_cols=['user_id']) + + +def downgrade() -> None: + with op.batch_alter_table("counters") as batch_op: + batch_op.drop_constraint("fk_counters_user_id", type_="foreignkey") + batch_op.drop_column('user_id') + + with op.batch_alter_table("entries") as batch_op: + batch_op.drop_constraint('fk_entries_user_id', type_='foreignkey') + batch_op.drop_column('user_id') + + op.drop_table("users") + diff --git a/tests/database/user_db_test.py b/tests/database/user_db_test.py new file mode 100644 index 0000000..a8e9e42 --- /dev/null +++ b/tests/database/user_db_test.py @@ -0,0 +1,55 @@ +import streamlit + +import queries.user as user + +def test_get_default_user(): + users = user.find_default_user() + assert len(users) == 1 + assert users["email"][0] == "default" + +def test_update_default_user_and_find_user(): + user.update_default_user(email="test@testbase.com", name="Test User", oidc_user_id="1111-2222-3333") + + users = user.find_default_user() + assert len(users) == 0 + + users = user.find_user_by_oidc_id("1111-2222-3333") + assert len(users) == 1 + assert users["email"][0] == "test@testbase.com" + assert users["name"][0] == "Test User" + assert users["oidc_user_id"][0] == "1111-2222-3333" + + users = user.find_user_by_email("test@testbase.com") + assert len(users) == 1 + assert users["email"][0] == "test@testbase.com" + assert users["name"][0] == "Test User" + assert users["oidc_user_id"][0] == "1111-2222-3333" + +def test_add_user(): + user.create_user(email="test@testbase.com", name="Test User", oidc_user_id="333-4444-5555") + + users = user.find_user_by_oidc_id("333-4444-5555") + assert len(users) == 1 + assert users["email"][0] == "test@testbase.com" + assert users["name"][0] == "Test User" + assert users["oidc_user_id"][0] == "333-4444-5555" + + users = user.find_user_by_email("test@testbase.com") + assert len(users) == 1 + assert users["email"][0] == "test@testbase.com" + assert users["name"][0] == "Test User" + assert users["oidc_user_id"][0] == "333-4444-5555" + +def test_update_user_in_session(): + userInfo = lambda: None + userInfo.email ="test@testbase.com" + userInfo.name = "Test User" + userInfo.sub = "1111-2222-3333" + + user.set_user_in_session(userInfo) + + state = streamlit.session_state + assert state.user_id == 1 + assert state.user_name == userInfo.name + assert state.user_email == userInfo.email + assert state.user_external_id == userInfo.sub