Add user specific profiles
Some checks failed
Run Tests / run-tests (push) Failing after 58s

This commit is contained in:
2026-04-28 21:04:52 +02:00
parent f750cfa8e1
commit bd9ff7191a
12 changed files with 294 additions and 36 deletions

View File

@@ -29,7 +29,10 @@ if "counter_id" in st.query_params.keys():
with st.container(horizontal_alignment="right", vertical_alignment="bottom", horizontal=True): with st.container(horizontal_alignment="right", vertical_alignment="bottom", horizontal=True):
st.header('Counter: ' + df['name']) st.header('Counter: ' + df['name'])
selection = st.segmented_control("Time Range", options, selection_mode="single", required=True, default=counter_type.name, label_visibility="hidden") selected = counter_type.name
if selected == CounterType.SIMPLE.name:
selected = CounterType.DAILY.name
selection = st.segmented_control("Time Range", options, selection_mode="single", required=True, default=selected, label_visibility="hidden")
match getattr(CounterType, selection): match getattr(CounterType, selection):
case CounterType.DAILY: case CounterType.DAILY:

View File

@@ -5,7 +5,7 @@ from sqlalchemy.sql import text
from streamlit.connections import BaseConnection from streamlit.connections import BaseConnection
def connection() -> BaseConnection: def connection() -> BaseConnection:
_connection = st.connection("sql", url=getenv('DATABASE_URL')) _connection = st.connection("sql", url=getenv('DATABASE_URL'), ttl=0, autocommit=True)
with _connection.session as configured_session: with _connection.session as configured_session:
configured_session.execute(text('PRAGMA foreign_keys=ON')) configured_session.execute(text('PRAGMA foreign_keys=ON'))
return _connection return _connection

View File

@@ -7,48 +7,73 @@ from enums import CounterType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def create_counter(title:str, counter_type:CounterType, counter_color) -> None: def create_counter(title:str, counter_type:CounterType, counter_color) -> None:
logger.info("Adding counter %s", counter_type) user_id = int(st.session_state.user_id)
logger.info("Adding counter %s for user %d", counter_type, user_id)
with connection().session as session: with connection().session as session:
try: try:
query = text('INSERT INTO counters (name, type, color) VALUES (:title, :type, :color)') query = text('INSERT INTO counters (user_id, name, type, color) VALUES (:user, :title, :type, :color)')
session.execute(query, {'title': title, 'type': counter_type, 'color': counter_color}) session.execute(query, {
session.commit() 'user': user_id,
'title': title,
'type': counter_type,
'color': counter_color
})
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
session.rollback() session.rollback()
def get_counters(): def get_counters():
user_id = int(st.session_state.user_id)
try: try:
return connection().query('SELECT id, name, type, color FROM counters', ttl=0) return connection().query("""
SELECT id, name, type, color
FROM counters
WHERE user_id = :user
""", params={'user': user_id })
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
return st.dataframe() return st.dataframe()
def increment_counter(counter_id:int) -> None: def increment_counter(counter_id:int) -> None:
logger.info("Incrementing counter %s", counter_id) user_id = int(st.session_state.user_id)
logger.info("Incrementing counter %d for user %d", counter_id, user_id)
with connection().session as session: with connection().session as session:
try: try:
query = text('INSERT INTO entries (counter_id) VALUES (:id)') query = text('INSERT INTO entries (counter_id, user_id) VALUES (:id, :user)')
session.execute(query, {'id': counter_id}) session.execute(query, {
session.commit() 'id': counter_id,
'user': user_id
})
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
session.rollback() session.rollback()
def remove_counter(counter_id:int) -> None: def remove_counter(counter_id:int) -> None:
logger.info("Removing counter %s", counter_id) user_id = int(st.session_state.user_id)
logger.info("Removing counter %d from user %d", counter_id, user_id)
with connection().session as session: with connection().session as session:
try: try:
query = text('DELETE FROM counters WHERE id = :id') query = text('DELETE FROM counters WHERE id = :id AND user_id = :user')
session.execute(query, {'id': counter_id}) session.execute(query, {
session.commit() 'id': counter_id,
'user': user_id
})
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
session.rollback() session.rollback()
def get_counter(counter_id:int): def get_counter(counter_id:int):
user_id = int(st.session_state.user_id)
try: try:
return connection().query('SELECT * FROM counters WHERE id = :id', params={'id': counter_id}, ttl=0).iloc[0] counters = connection().query("""
SELECT * FROM counters
WHERE id = :id AND user_id = :user
""", params={ 'id': counter_id, 'user': user_id}
)
if counters.empty:
return None
return counters.iloc[0]
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
return None return None

View File

@@ -1,9 +1,11 @@
import logging import logging
from queries.connection import connection from queries.connection import connection
import streamlit as st
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_all_daily_analytics(end_date:str = 'now'): def get_all_daily_analytics(end_date:str = 'now'):
user_id = int(st.session_state.user_id)
try: try:
return connection().query(''' return connection().query('''
WITH RECURSIVE timeseries(d) AS ( WITH RECURSIVE timeseries(d) AS (
@@ -19,6 +21,7 @@ def get_all_daily_analytics(end_date:str = 'now'):
counter_id, counter_id,
sum(increment) as count sum(increment) as count
FROM entries FROM entries
WHERE user_id = :user_id
group by counter_id, date(timestamp) group by counter_id, date(timestamp)
) )
select select
@@ -31,13 +34,14 @@ def get_all_daily_analytics(end_date:str = 'now'):
left outer join stats t on s.d = t.d left outer join stats t on s.d = t.d
left join counters c on t.counter_id = c.id left join counters c on t.counter_id = c.id
GROUP by s.d GROUP by s.d
''', params={"end_date": end_date}, ttl=0) ''', params={"end_date": end_date, "user_id": user_id })
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
return None return None
def get_daily_analytics(counter_id:int, end_date:str = 'now'): def get_daily_analytics(counter_id:int, end_date:str = 'now'):
user_id = int(st.session_state.user_id)
try: try:
return connection().query(''' return connection().query('''
WITH RECURSIVE timeseries(d) AS ( WITH RECURSIVE timeseries(d) AS (
@@ -53,6 +57,7 @@ def get_daily_analytics(counter_id:int, end_date:str = 'now'):
sum(increment) as count sum(increment) as count
FROM entries FROM entries
where counter_id = :id where counter_id = :id
and user_id = :user_id
group by date(timestamp) group by date(timestamp)
) )
SELECT SELECT
@@ -60,7 +65,7 @@ def get_daily_analytics(counter_id:int, end_date:str = 'now'):
coalesce(s.count, 0) as count coalesce(s.count, 0) as count
FROM timeseries as t FROM timeseries as t
LEFT JOIN stats as s on s.d = t.d LEFT JOIN stats as s on s.d = t.d
''', params={'id': counter_id, "end_date": end_date}, ttl=0) ''', params={'id': counter_id, "end_date": end_date, "user_id": user_id})
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
return None return None

View File

@@ -1,9 +1,11 @@
import logging import logging
from queries.connection import connection from queries.connection import connection
import streamlit as st
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_all_monthly_analytics(end_date:str = 'now'): def get_all_monthly_analytics(end_date:str = 'now'):
user_id = int(st.session_state.user_id)
try: try:
return connection().query(''' return connection().query('''
WITH RECURSIVE timeseries(d) AS ( WITH RECURSIVE timeseries(d) AS (
@@ -26,6 +28,7 @@ def get_all_monthly_analytics(end_date:str = 'now'):
counter_id, counter_id,
sum(increment) as count sum(increment) as count
FROM entries FROM entries
WHERE user_id = :user_id
group by counter_id, strftime('%m', timestamp), strftime('%Y', timestamp) group by counter_id, strftime('%m', timestamp), strftime('%Y', timestamp)
) )
select select
@@ -38,12 +41,13 @@ def get_all_monthly_analytics(end_date:str = 'now'):
left outer join stats t on m.m = t.m and m.y = t.y left outer join stats t on m.m = t.m and m.y = t.y
left join counters c on t.counter_id = c.id left join counters c on t.counter_id = c.id
GROUP by m.m, m.y GROUP by m.m, m.y
''', params={"end_date": end_date}, ttl=0) ''', params={"end_date": end_date, "user_id": user_id})
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
return None return None
def get_monthly_analytics(counter_id:int, end_date:str = 'now'): def get_monthly_analytics(counter_id:int, end_date:str = 'now'):
user_id = int(st.session_state.user_id)
try: try:
return connection().query(''' return connection().query('''
WITH RECURSIVE timeseries(d) AS ( WITH RECURSIVE timeseries(d) AS (
@@ -66,6 +70,7 @@ def get_monthly_analytics(counter_id:int, end_date:str = 'now'):
sum(increment) as count sum(increment) as count
FROM entries FROM entries
where counter_id = :id where counter_id = :id
and user_id = :user_id
group by strftime('%m', timestamp), strftime('%Y', timestamp) group by strftime('%m', timestamp), strftime('%Y', timestamp)
) )
SELECT SELECT
@@ -73,7 +78,7 @@ def get_monthly_analytics(counter_id:int, end_date:str = 'now'):
coalesce(s.count, 0) as count coalesce(s.count, 0) as count
FROM months as m FROM months as m
LEFT JOIN stats as s on s.m = m.m and s.y = m.y LEFT JOIN stats as s on s.m = m.m and s.y = m.y
''', params={'id': counter_id, "end_date": end_date}, ttl=0) ''', params={'id': counter_id, "end_date": end_date, "user_id": user_id})
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
return None return None

67
app/queries/user.py Normal file
View File

@@ -0,0 +1,67 @@
import logging
import streamlit as st
from sqlalchemy.sql import text
from streamlit.user_info import UserInfoProxy
from queries.connection import connection
logger = logging.getLogger(__name__)
def find_user_by_oidc_id(oidc_user_id):
return connection().query('SELECT * FROM users WHERE oidc_user_id = :id', params={'id': oidc_user_id})
def find_user_by_email(email):
return connection().query('SELECT * FROM users WHERE email = :email', params={'email': email})
def find_default_user():
return find_user_by_email('default')
def update_default_user(email, name, oidc_user_id):
with connection().session as session:
try:
query = text("UPDATE users SET email = :email, name = :name, oidc_user_id = :user_id WHERE email = 'default'")
session.execute(query, {'email': email, 'name': name, 'user_id': oidc_user_id})
except Exception as e:
session.rollback()
raise e
def create_user(email, name, oidc_user_id):
with connection().session as session:
try:
logger.info("Creating new user %s", email)
query = text('INSERT INTO users (email, name, oidc_user_id) VALUES (:email, :name, :user_id)')
session.execute(query, {'email': email, 'name': name, 'user_id': oidc_user_id})
return connection().query('SELECT * FROM users WHERE oidc_user_id = :id', params={'id': oidc_user_id})
except Exception as e:
session.rollback()
raise e
def set_user_in_session(user: UserInfoProxy):
email = user.email
user_id = user.sub
if hasattr(user, 'name'):
name = user.name
else:
name = None
user_entity = find_user_by_oidc_id(user_id)
if user_entity.empty:
user_entity = find_user_by_email(email)
if user_entity.empty:
user_entity = find_default_user()
if user_entity.empty:
user_entity = create_user(email, name, user_id)
else:
update_default_user(email, name, user_id)
user_entity = find_user_by_oidc_id(user_id)
st.session_state.user_id = user_entity["id"][0]
st.session_state.user_name = user_entity["name"][0]
st.session_state.user_email = user_entity["email"][0]
st.session_state.user_external_id = user_entity["oidc_user_id"][0]

View File

@@ -1,9 +1,11 @@
import logging import logging
import streamlit as st
from queries.connection import connection from queries.connection import connection
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_all_weekly_analytics(end_date:str = 'now'): def get_all_weekly_analytics(end_date:str = 'now'):
user_id = int(st.session_state.user_id)
try: try:
return connection().query(''' return connection().query('''
WITH RECURSIVE timeseries(d) AS ( WITH RECURSIVE timeseries(d) AS (
@@ -23,6 +25,7 @@ def get_all_weekly_analytics(end_date:str = 'now'):
counter_id, counter_id,
sum(increment) as count sum(increment) as count
FROM entries FROM entries
WHERE user_id = :user_id
group by counter_id, strftime('%W', timestamp) group by counter_id, strftime('%W', timestamp)
) )
select select
@@ -35,12 +38,13 @@ def get_all_weekly_analytics(end_date:str = 'now'):
left outer join stats t on s.w = t.w left outer join stats t on s.w = t.w
left join counters c on t.counter_id = c.id left join counters c on t.counter_id = c.id
GROUP by s.w GROUP by s.w
''', params={"end_date": end_date}, ttl=0) ''', params={"end_date": end_date, "user_id": user_id})
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
return None return None
def get_weekly_analytics(counter_id:int, end_date:str = 'now'): def get_weekly_analytics(counter_id:int, end_date:str = 'now'):
user_id = int(st.session_state.user_id)
try: try:
return connection().query(''' return connection().query('''
WITH RECURSIVE timeseries(d) AS ( WITH RECURSIVE timeseries(d) AS (
@@ -60,6 +64,7 @@ def get_weekly_analytics(counter_id:int, end_date:str = 'now'):
sum(increment) as count sum(increment) as count
FROM entries FROM entries
where counter_id = :id where counter_id = :id
and user_id = :user_id
group by strftime('%W', timestamp) group by strftime('%W', timestamp)
) )
SELECT SELECT
@@ -67,7 +72,7 @@ def get_weekly_analytics(counter_id:int, end_date:str = 'now'):
coalesce(s.count, 0) as count coalesce(s.count, 0) as count
FROM weeks as w FROM weeks as w
LEFT JOIN stats as s on s.w = w.w LEFT JOIN stats as s on s.w = w.w
''', params={'id': counter_id, "end_date": end_date}, ttl=0) ''', params={'id': counter_id, "end_date": end_date, "user_id": user_id})
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
return None return None

View File

@@ -1,9 +1,11 @@
import logging import logging
import streamlit as st
from queries.connection import connection from queries.connection import connection
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_all_yearly_analytics(end_date:str = 'now'): def get_all_yearly_analytics(end_date:str = 'now'):
user_id = int(st.session_state.user_id)
try: try:
return connection().query(''' return connection().query('''
WITH RECURSIVE timeseries(d) AS ( WITH RECURSIVE timeseries(d) AS (
@@ -23,6 +25,7 @@ def get_all_yearly_analytics(end_date:str = 'now'):
counter_id, counter_id,
sum(increment) as count sum(increment) as count
FROM entries FROM entries
WHERE user_id = :user_id
group by counter_id, strftime('%Y', timestamp) group by counter_id, strftime('%Y', timestamp)
) )
select select
@@ -35,12 +38,13 @@ def get_all_yearly_analytics(end_date:str = 'now'):
left outer join stats t on y.y = t.y left outer join stats t on y.y = t.y
left join counters c on t.counter_id = c.id left join counters c on t.counter_id = c.id
GROUP by y.y GROUP by y.y
''', params={"end_date": end_date}, ttl=0) ''', params={"end_date": end_date, "user_id": user_id})
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
return None return None
def get_yearly_analytics(counter_id:int, end_date:str = 'now'): def get_yearly_analytics(counter_id:int, end_date:str = 'now'):
user_id = int(st.session_state.user_id)
try: try:
return connection().query(''' return connection().query('''
WITH RECURSIVE timeseries(d) AS ( WITH RECURSIVE timeseries(d) AS (
@@ -60,6 +64,7 @@ def get_yearly_analytics(counter_id:int, end_date:str = 'now'):
sum(increment) as count sum(increment) as count
FROM entries FROM entries
where counter_id = :id where counter_id = :id
and user_id = :user_id
group by strftime('%Y', timestamp) group by strftime('%Y', timestamp)
) )
SELECT SELECT
@@ -67,7 +72,7 @@ def get_yearly_analytics(counter_id:int, end_date:str = 'now'):
coalesce(s.count, 0) as count coalesce(s.count, 0) as count
FROM years as m FROM years as m
LEFT JOIN stats as s on s.y = m.y LEFT JOIN stats as s on s.y = m.y
''', params={'id': counter_id, "end_date": end_date}, ttl=0) ''', params={'id': counter_id, "end_date": end_date, "user_id":user_id})
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
return None return None

View File

@@ -1,21 +1,37 @@
import streamlit as st import streamlit as st
import logging from streamlit import dialog
import queries.user as user_queries
from logger import init_logger from logger import init_logger
from styles import init_styles from styles import init_styles
init_logger() init_logger()
init_styles() init_styles()
if hasattr(st, 'user') and hasattr(st.user, 'is_logged_in'): is_login_enabled = hasattr(st, 'user')
if not st.user.is_logged_in: is_logged_in = is_login_enabled and hasattr(st.user, 'is_logged_in') and st.user.is_logged_in
with st.container(width="stretch", height="stretch", horizontal_alignment="center"):
st.title("Daily Counter", width="stretch", text_alignment="center") if is_logged_in:
st.text("Please log in to use this app", width="stretch", text_alignment="center") user_queries.set_user_in_session(st.user)
st.space() else:
if st.button("Log in"): st.session_state.user_id = 1 # default user
st.login()
if is_login_enabled and not is_logged_in:
with st.container(width="stretch", height="stretch", horizontal_alignment="center"):
st.title("Daily Counter", width="stretch", text_alignment="center")
st.text("Please log in to use this app", width="stretch", text_alignment="center")
st.space()
if st.button("Log in"):
st.login()
else: else:
counters = st.Page("pages/counters.py", title="Counters", icon=":material/update:") counters = st.Page("pages/counters.py", title="Counters", icon=":material/update:")
stats = st.Page("pages/stats.py", title="Statistics", icon=":material/chart_data:") stats = st.Page("pages/stats.py", title="Statistics", icon=":material/chart_data:")
pg = st.navigation(position="top", pages=[counters, stats]) logoutPage = st.Page(st.logout, title="Logout", icon=":material/logout:")
pg.run()
pages = [counters, stats]
if is_login_enabled:
pages = pages + [logoutPage]
pg = st.navigation(position="top", pages=pages)
pg.run()

View File

@@ -1,5 +1,5 @@
#MainMenu { #MainMenu {
#display: none; display: none;
} }
.stApp { .stApp {
min-width: 360px; min-width: 360px;
@@ -54,3 +54,19 @@
visibility: hidden; visibility: hidden;
} }
div:has(> .stToolbarActions) {
display: none;
}
.rc-overflow > .rc-overflow-item {
flex: 1;
}
.rc-overflow > .rc-overflow-item:nth-child(3) > div {
margin-left: auto;
width: fit-content;
}
.rc-overflow > .rc-overflow-item:nth-child(3) > div > a {
padding: 0;
}
.rc-overflow > .rc-overflow-item-rest {
display: none;
}

View File

@@ -0,0 +1,56 @@
"""add user id
Revision ID: d9faf8fb8642
Revises: 4ee21f978e6c
Create Date: 2026-04-27 17:24:17.892586
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'd9faf8fb8642'
down_revision: Union[str, Sequence[str], None] = '4ee21f978e6c'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
users = op.create_table(
"users",
sa.Column("id", sa.Integer, primary_key=True, autoincrement=True),
sa.Column("email", sa.String, nullable=False),
sa.Column("name", sa.String),
sa.Column("oidc_user_id", sa.Integer)
)
op.bulk_insert(users, [ { "email": "default" } ])
with op.batch_alter_table("counters") as batch_op:
batch_op.add_column(sa.Column("user_id", sa.Integer, insert_default=1))
batch_op.create_foreign_key("fk_counters_user_id",
referent_table="users",
remote_cols=['id'],
local_cols=['user_id'])
with op.batch_alter_table("entries") as batch_op:
batch_op.add_column(sa.Column("user_id", sa.Integer, insert_default=1))
batch_op.create_foreign_key("fk_entries_user_id",
referent_table="users",
remote_cols=['id'],
local_cols=['user_id'])
def downgrade() -> None:
with op.batch_alter_table("counters") as batch_op:
batch_op.drop_constraint("fk_counters_user_id", type_="foreignkey")
batch_op.drop_column('user_id')
with op.batch_alter_table("entries") as batch_op:
batch_op.drop_constraint('fk_entries_user_id', type_='foreignkey')
batch_op.drop_column('user_id')
op.drop_table("users")

View File

@@ -0,0 +1,55 @@
import streamlit
import queries.user as user
def test_get_default_user():
users = user.find_default_user()
assert len(users) == 1
assert users["email"][0] == "default"
def test_update_default_user_and_find_user():
user.update_default_user(email="test@testbase.com", name="Test User", oidc_user_id="1111-2222-3333")
users = user.find_default_user()
assert len(users) == 0
users = user.find_user_by_oidc_id("1111-2222-3333")
assert len(users) == 1
assert users["email"][0] == "test@testbase.com"
assert users["name"][0] == "Test User"
assert users["oidc_user_id"][0] == "1111-2222-3333"
users = user.find_user_by_email("test@testbase.com")
assert len(users) == 1
assert users["email"][0] == "test@testbase.com"
assert users["name"][0] == "Test User"
assert users["oidc_user_id"][0] == "1111-2222-3333"
def test_add_user():
user.create_user(email="test@testbase.com", name="Test User", oidc_user_id="333-4444-5555")
users = user.find_user_by_oidc_id("333-4444-5555")
assert len(users) == 1
assert users["email"][0] == "test@testbase.com"
assert users["name"][0] == "Test User"
assert users["oidc_user_id"][0] == "333-4444-5555"
users = user.find_user_by_email("test@testbase.com")
assert len(users) == 1
assert users["email"][0] == "test@testbase.com"
assert users["name"][0] == "Test User"
assert users["oidc_user_id"][0] == "333-4444-5555"
def test_update_user_in_session():
userInfo = lambda: None
userInfo.email ="test@testbase.com"
userInfo.name = "Test User"
userInfo.sub = "1111-2222-3333"
user.set_user_in_session(userInfo)
state = streamlit.session_state
assert state.user_id == 1
assert state.user_name == userInfo.name
assert state.user_email == userInfo.email
assert state.user_external_id == userInfo.sub