This commit is contained in:
@@ -29,7 +29,10 @@ if "counter_id" in st.query_params.keys():
|
|||||||
|
|
||||||
with st.container(horizontal_alignment="right", vertical_alignment="bottom", horizontal=True):
|
with st.container(horizontal_alignment="right", vertical_alignment="bottom", horizontal=True):
|
||||||
st.header('Counter: ' + df['name'])
|
st.header('Counter: ' + df['name'])
|
||||||
selection = st.segmented_control("Time Range", options, selection_mode="single", required=True, default=counter_type.name, label_visibility="hidden")
|
selected = counter_type.name
|
||||||
|
if selected == CounterType.SIMPLE.name:
|
||||||
|
selected = CounterType.DAILY.name
|
||||||
|
selection = st.segmented_control("Time Range", options, selection_mode="single", required=True, default=selected, label_visibility="hidden")
|
||||||
|
|
||||||
match getattr(CounterType, selection):
|
match getattr(CounterType, selection):
|
||||||
case CounterType.DAILY:
|
case CounterType.DAILY:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from sqlalchemy.sql import text
|
|||||||
from streamlit.connections import BaseConnection
|
from streamlit.connections import BaseConnection
|
||||||
|
|
||||||
def connection() -> BaseConnection:
|
def connection() -> BaseConnection:
|
||||||
_connection = st.connection("sql", url=getenv('DATABASE_URL'))
|
_connection = st.connection("sql", url=getenv('DATABASE_URL'), ttl=0, autocommit=True)
|
||||||
with _connection.session as configured_session:
|
with _connection.session as configured_session:
|
||||||
configured_session.execute(text('PRAGMA foreign_keys=ON'))
|
configured_session.execute(text('PRAGMA foreign_keys=ON'))
|
||||||
return _connection
|
return _connection
|
||||||
|
|||||||
@@ -7,48 +7,73 @@ from enums import CounterType
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def create_counter(title:str, counter_type:CounterType, counter_color) -> None:
|
def create_counter(title:str, counter_type:CounterType, counter_color) -> None:
|
||||||
logger.info("Adding counter %s", counter_type)
|
user_id = int(st.session_state.user_id)
|
||||||
|
logger.info("Adding counter %s for user %d", counter_type, user_id)
|
||||||
with connection().session as session:
|
with connection().session as session:
|
||||||
try:
|
try:
|
||||||
query = text('INSERT INTO counters (name, type, color) VALUES (:title, :type, :color)')
|
query = text('INSERT INTO counters (user_id, name, type, color) VALUES (:user, :title, :type, :color)')
|
||||||
session.execute(query, {'title': title, 'type': counter_type, 'color': counter_color})
|
session.execute(query, {
|
||||||
session.commit()
|
'user': user_id,
|
||||||
|
'title': title,
|
||||||
|
'type': counter_type,
|
||||||
|
'color': counter_color
|
||||||
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
session.rollback()
|
session.rollback()
|
||||||
|
|
||||||
def get_counters():
|
def get_counters():
|
||||||
|
user_id = int(st.session_state.user_id)
|
||||||
try:
|
try:
|
||||||
return connection().query('SELECT id, name, type, color FROM counters', ttl=0)
|
return connection().query("""
|
||||||
|
SELECT id, name, type, color
|
||||||
|
FROM counters
|
||||||
|
WHERE user_id = :user
|
||||||
|
""", params={'user': user_id })
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
return st.dataframe()
|
return st.dataframe()
|
||||||
|
|
||||||
def increment_counter(counter_id:int) -> None:
|
def increment_counter(counter_id:int) -> None:
|
||||||
logger.info("Incrementing counter %s", counter_id)
|
user_id = int(st.session_state.user_id)
|
||||||
|
logger.info("Incrementing counter %d for user %d", counter_id, user_id)
|
||||||
with connection().session as session:
|
with connection().session as session:
|
||||||
try:
|
try:
|
||||||
query = text('INSERT INTO entries (counter_id) VALUES (:id)')
|
query = text('INSERT INTO entries (counter_id, user_id) VALUES (:id, :user)')
|
||||||
session.execute(query, {'id': counter_id})
|
session.execute(query, {
|
||||||
session.commit()
|
'id': counter_id,
|
||||||
|
'user': user_id
|
||||||
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
session.rollback()
|
session.rollback()
|
||||||
|
|
||||||
def remove_counter(counter_id:int) -> None:
|
def remove_counter(counter_id:int) -> None:
|
||||||
logger.info("Removing counter %s", counter_id)
|
user_id = int(st.session_state.user_id)
|
||||||
|
logger.info("Removing counter %d from user %d", counter_id, user_id)
|
||||||
with connection().session as session:
|
with connection().session as session:
|
||||||
try:
|
try:
|
||||||
query = text('DELETE FROM counters WHERE id = :id')
|
query = text('DELETE FROM counters WHERE id = :id AND user_id = :user')
|
||||||
session.execute(query, {'id': counter_id})
|
session.execute(query, {
|
||||||
session.commit()
|
'id': counter_id,
|
||||||
|
'user': user_id
|
||||||
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
session.rollback()
|
session.rollback()
|
||||||
|
|
||||||
def get_counter(counter_id:int):
|
def get_counter(counter_id:int):
|
||||||
|
user_id = int(st.session_state.user_id)
|
||||||
try:
|
try:
|
||||||
return connection().query('SELECT * FROM counters WHERE id = :id', params={'id': counter_id}, ttl=0).iloc[0]
|
counters = connection().query("""
|
||||||
|
SELECT * FROM counters
|
||||||
|
WHERE id = :id AND user_id = :user
|
||||||
|
""", params={ 'id': counter_id, 'user': user_id}
|
||||||
|
)
|
||||||
|
if counters.empty:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return counters.iloc[0]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
from queries.connection import connection
|
from queries.connection import connection
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def get_all_daily_analytics(end_date:str = 'now'):
|
def get_all_daily_analytics(end_date:str = 'now'):
|
||||||
|
user_id = int(st.session_state.user_id)
|
||||||
try:
|
try:
|
||||||
return connection().query('''
|
return connection().query('''
|
||||||
WITH RECURSIVE timeseries(d) AS (
|
WITH RECURSIVE timeseries(d) AS (
|
||||||
@@ -19,6 +21,7 @@ def get_all_daily_analytics(end_date:str = 'now'):
|
|||||||
counter_id,
|
counter_id,
|
||||||
sum(increment) as count
|
sum(increment) as count
|
||||||
FROM entries
|
FROM entries
|
||||||
|
WHERE user_id = :user_id
|
||||||
group by counter_id, date(timestamp)
|
group by counter_id, date(timestamp)
|
||||||
)
|
)
|
||||||
select
|
select
|
||||||
@@ -31,13 +34,14 @@ def get_all_daily_analytics(end_date:str = 'now'):
|
|||||||
left outer join stats t on s.d = t.d
|
left outer join stats t on s.d = t.d
|
||||||
left join counters c on t.counter_id = c.id
|
left join counters c on t.counter_id = c.id
|
||||||
GROUP by s.d
|
GROUP by s.d
|
||||||
''', params={"end_date": end_date}, ttl=0)
|
''', params={"end_date": end_date, "user_id": user_id })
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_daily_analytics(counter_id:int, end_date:str = 'now'):
|
def get_daily_analytics(counter_id:int, end_date:str = 'now'):
|
||||||
|
user_id = int(st.session_state.user_id)
|
||||||
try:
|
try:
|
||||||
return connection().query('''
|
return connection().query('''
|
||||||
WITH RECURSIVE timeseries(d) AS (
|
WITH RECURSIVE timeseries(d) AS (
|
||||||
@@ -53,6 +57,7 @@ def get_daily_analytics(counter_id:int, end_date:str = 'now'):
|
|||||||
sum(increment) as count
|
sum(increment) as count
|
||||||
FROM entries
|
FROM entries
|
||||||
where counter_id = :id
|
where counter_id = :id
|
||||||
|
and user_id = :user_id
|
||||||
group by date(timestamp)
|
group by date(timestamp)
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
@@ -60,7 +65,7 @@ def get_daily_analytics(counter_id:int, end_date:str = 'now'):
|
|||||||
coalesce(s.count, 0) as count
|
coalesce(s.count, 0) as count
|
||||||
FROM timeseries as t
|
FROM timeseries as t
|
||||||
LEFT JOIN stats as s on s.d = t.d
|
LEFT JOIN stats as s on s.d = t.d
|
||||||
''', params={'id': counter_id, "end_date": end_date}, ttl=0)
|
''', params={'id': counter_id, "end_date": end_date, "user_id": user_id})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
return None
|
return None
|
||||||
@@ -1,9 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
from queries.connection import connection
|
from queries.connection import connection
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def get_all_monthly_analytics(end_date:str = 'now'):
|
def get_all_monthly_analytics(end_date:str = 'now'):
|
||||||
|
user_id = int(st.session_state.user_id)
|
||||||
try:
|
try:
|
||||||
return connection().query('''
|
return connection().query('''
|
||||||
WITH RECURSIVE timeseries(d) AS (
|
WITH RECURSIVE timeseries(d) AS (
|
||||||
@@ -26,6 +28,7 @@ def get_all_monthly_analytics(end_date:str = 'now'):
|
|||||||
counter_id,
|
counter_id,
|
||||||
sum(increment) as count
|
sum(increment) as count
|
||||||
FROM entries
|
FROM entries
|
||||||
|
WHERE user_id = :user_id
|
||||||
group by counter_id, strftime('%m', timestamp), strftime('%Y', timestamp)
|
group by counter_id, strftime('%m', timestamp), strftime('%Y', timestamp)
|
||||||
)
|
)
|
||||||
select
|
select
|
||||||
@@ -38,12 +41,13 @@ def get_all_monthly_analytics(end_date:str = 'now'):
|
|||||||
left outer join stats t on m.m = t.m and m.y = t.y
|
left outer join stats t on m.m = t.m and m.y = t.y
|
||||||
left join counters c on t.counter_id = c.id
|
left join counters c on t.counter_id = c.id
|
||||||
GROUP by m.m, m.y
|
GROUP by m.m, m.y
|
||||||
''', params={"end_date": end_date}, ttl=0)
|
''', params={"end_date": end_date, "user_id": user_id})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_monthly_analytics(counter_id:int, end_date:str = 'now'):
|
def get_monthly_analytics(counter_id:int, end_date:str = 'now'):
|
||||||
|
user_id = int(st.session_state.user_id)
|
||||||
try:
|
try:
|
||||||
return connection().query('''
|
return connection().query('''
|
||||||
WITH RECURSIVE timeseries(d) AS (
|
WITH RECURSIVE timeseries(d) AS (
|
||||||
@@ -66,6 +70,7 @@ def get_monthly_analytics(counter_id:int, end_date:str = 'now'):
|
|||||||
sum(increment) as count
|
sum(increment) as count
|
||||||
FROM entries
|
FROM entries
|
||||||
where counter_id = :id
|
where counter_id = :id
|
||||||
|
and user_id = :user_id
|
||||||
group by strftime('%m', timestamp), strftime('%Y', timestamp)
|
group by strftime('%m', timestamp), strftime('%Y', timestamp)
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
@@ -73,7 +78,7 @@ def get_monthly_analytics(counter_id:int, end_date:str = 'now'):
|
|||||||
coalesce(s.count, 0) as count
|
coalesce(s.count, 0) as count
|
||||||
FROM months as m
|
FROM months as m
|
||||||
LEFT JOIN stats as s on s.m = m.m and s.y = m.y
|
LEFT JOIN stats as s on s.m = m.m and s.y = m.y
|
||||||
''', params={'id': counter_id, "end_date": end_date}, ttl=0)
|
''', params={'id': counter_id, "end_date": end_date, "user_id": user_id})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
return None
|
return None
|
||||||
67
app/queries/user.py
Normal file
67
app/queries/user.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
from sqlalchemy.sql import text
|
||||||
|
from streamlit.user_info import UserInfoProxy
|
||||||
|
|
||||||
|
from queries.connection import connection
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def find_user_by_oidc_id(oidc_user_id):
|
||||||
|
return connection().query('SELECT * FROM users WHERE oidc_user_id = :id', params={'id': oidc_user_id})
|
||||||
|
|
||||||
|
|
||||||
|
def find_user_by_email(email):
|
||||||
|
return connection().query('SELECT * FROM users WHERE email = :email', params={'email': email})
|
||||||
|
|
||||||
|
|
||||||
|
def find_default_user():
|
||||||
|
return find_user_by_email('default')
|
||||||
|
|
||||||
|
|
||||||
|
def update_default_user(email, name, oidc_user_id):
|
||||||
|
with connection().session as session:
|
||||||
|
try:
|
||||||
|
query = text("UPDATE users SET email = :email, name = :name, oidc_user_id = :user_id WHERE email = 'default'")
|
||||||
|
session.execute(query, {'email': email, 'name': name, 'user_id': oidc_user_id})
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def create_user(email, name, oidc_user_id):
|
||||||
|
with connection().session as session:
|
||||||
|
try:
|
||||||
|
logger.info("Creating new user %s", email)
|
||||||
|
query = text('INSERT INTO users (email, name, oidc_user_id) VALUES (:email, :name, :user_id)')
|
||||||
|
session.execute(query, {'email': email, 'name': name, 'user_id': oidc_user_id})
|
||||||
|
return connection().query('SELECT * FROM users WHERE oidc_user_id = :id', params={'id': oidc_user_id})
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def set_user_in_session(user: UserInfoProxy):
|
||||||
|
email = user.email
|
||||||
|
user_id = user.sub
|
||||||
|
if hasattr(user, 'name'):
|
||||||
|
name = user.name
|
||||||
|
else:
|
||||||
|
name = None
|
||||||
|
|
||||||
|
user_entity = find_user_by_oidc_id(user_id)
|
||||||
|
if user_entity.empty:
|
||||||
|
user_entity = find_user_by_email(email)
|
||||||
|
if user_entity.empty:
|
||||||
|
user_entity = find_default_user()
|
||||||
|
if user_entity.empty:
|
||||||
|
user_entity = create_user(email, name, user_id)
|
||||||
|
else:
|
||||||
|
update_default_user(email, name, user_id)
|
||||||
|
user_entity = find_user_by_oidc_id(user_id)
|
||||||
|
|
||||||
|
st.session_state.user_id = user_entity["id"][0]
|
||||||
|
st.session_state.user_name = user_entity["name"][0]
|
||||||
|
st.session_state.user_email = user_entity["email"][0]
|
||||||
|
st.session_state.user_external_id = user_entity["oidc_user_id"][0]
|
||||||
@@ -1,9 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import streamlit as st
|
||||||
from queries.connection import connection
|
from queries.connection import connection
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def get_all_weekly_analytics(end_date:str = 'now'):
|
def get_all_weekly_analytics(end_date:str = 'now'):
|
||||||
|
user_id = int(st.session_state.user_id)
|
||||||
try:
|
try:
|
||||||
return connection().query('''
|
return connection().query('''
|
||||||
WITH RECURSIVE timeseries(d) AS (
|
WITH RECURSIVE timeseries(d) AS (
|
||||||
@@ -23,6 +25,7 @@ def get_all_weekly_analytics(end_date:str = 'now'):
|
|||||||
counter_id,
|
counter_id,
|
||||||
sum(increment) as count
|
sum(increment) as count
|
||||||
FROM entries
|
FROM entries
|
||||||
|
WHERE user_id = :user_id
|
||||||
group by counter_id, strftime('%W', timestamp)
|
group by counter_id, strftime('%W', timestamp)
|
||||||
)
|
)
|
||||||
select
|
select
|
||||||
@@ -35,12 +38,13 @@ def get_all_weekly_analytics(end_date:str = 'now'):
|
|||||||
left outer join stats t on s.w = t.w
|
left outer join stats t on s.w = t.w
|
||||||
left join counters c on t.counter_id = c.id
|
left join counters c on t.counter_id = c.id
|
||||||
GROUP by s.w
|
GROUP by s.w
|
||||||
''', params={"end_date": end_date}, ttl=0)
|
''', params={"end_date": end_date, "user_id": user_id})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_weekly_analytics(counter_id:int, end_date:str = 'now'):
|
def get_weekly_analytics(counter_id:int, end_date:str = 'now'):
|
||||||
|
user_id = int(st.session_state.user_id)
|
||||||
try:
|
try:
|
||||||
return connection().query('''
|
return connection().query('''
|
||||||
WITH RECURSIVE timeseries(d) AS (
|
WITH RECURSIVE timeseries(d) AS (
|
||||||
@@ -60,6 +64,7 @@ def get_weekly_analytics(counter_id:int, end_date:str = 'now'):
|
|||||||
sum(increment) as count
|
sum(increment) as count
|
||||||
FROM entries
|
FROM entries
|
||||||
where counter_id = :id
|
where counter_id = :id
|
||||||
|
and user_id = :user_id
|
||||||
group by strftime('%W', timestamp)
|
group by strftime('%W', timestamp)
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
@@ -67,7 +72,7 @@ def get_weekly_analytics(counter_id:int, end_date:str = 'now'):
|
|||||||
coalesce(s.count, 0) as count
|
coalesce(s.count, 0) as count
|
||||||
FROM weeks as w
|
FROM weeks as w
|
||||||
LEFT JOIN stats as s on s.w = w.w
|
LEFT JOIN stats as s on s.w = w.w
|
||||||
''', params={'id': counter_id, "end_date": end_date}, ttl=0)
|
''', params={'id': counter_id, "end_date": end_date, "user_id": user_id})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
return None
|
return None
|
||||||
@@ -1,9 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import streamlit as st
|
||||||
from queries.connection import connection
|
from queries.connection import connection
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def get_all_yearly_analytics(end_date:str = 'now'):
|
def get_all_yearly_analytics(end_date:str = 'now'):
|
||||||
|
user_id = int(st.session_state.user_id)
|
||||||
try:
|
try:
|
||||||
return connection().query('''
|
return connection().query('''
|
||||||
WITH RECURSIVE timeseries(d) AS (
|
WITH RECURSIVE timeseries(d) AS (
|
||||||
@@ -23,6 +25,7 @@ def get_all_yearly_analytics(end_date:str = 'now'):
|
|||||||
counter_id,
|
counter_id,
|
||||||
sum(increment) as count
|
sum(increment) as count
|
||||||
FROM entries
|
FROM entries
|
||||||
|
WHERE user_id = :user_id
|
||||||
group by counter_id, strftime('%Y', timestamp)
|
group by counter_id, strftime('%Y', timestamp)
|
||||||
)
|
)
|
||||||
select
|
select
|
||||||
@@ -35,12 +38,13 @@ def get_all_yearly_analytics(end_date:str = 'now'):
|
|||||||
left outer join stats t on y.y = t.y
|
left outer join stats t on y.y = t.y
|
||||||
left join counters c on t.counter_id = c.id
|
left join counters c on t.counter_id = c.id
|
||||||
GROUP by y.y
|
GROUP by y.y
|
||||||
''', params={"end_date": end_date}, ttl=0)
|
''', params={"end_date": end_date, "user_id": user_id})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_yearly_analytics(counter_id:int, end_date:str = 'now'):
|
def get_yearly_analytics(counter_id:int, end_date:str = 'now'):
|
||||||
|
user_id = int(st.session_state.user_id)
|
||||||
try:
|
try:
|
||||||
return connection().query('''
|
return connection().query('''
|
||||||
WITH RECURSIVE timeseries(d) AS (
|
WITH RECURSIVE timeseries(d) AS (
|
||||||
@@ -60,6 +64,7 @@ def get_yearly_analytics(counter_id:int, end_date:str = 'now'):
|
|||||||
sum(increment) as count
|
sum(increment) as count
|
||||||
FROM entries
|
FROM entries
|
||||||
where counter_id = :id
|
where counter_id = :id
|
||||||
|
and user_id = :user_id
|
||||||
group by strftime('%Y', timestamp)
|
group by strftime('%Y', timestamp)
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
@@ -67,7 +72,7 @@ def get_yearly_analytics(counter_id:int, end_date:str = 'now'):
|
|||||||
coalesce(s.count, 0) as count
|
coalesce(s.count, 0) as count
|
||||||
FROM years as m
|
FROM years as m
|
||||||
LEFT JOIN stats as s on s.y = m.y
|
LEFT JOIN stats as s on s.y = m.y
|
||||||
''', params={'id': counter_id, "end_date": end_date}, ttl=0)
|
''', params={'id': counter_id, "end_date": end_date, "user_id":user_id})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
return None
|
return None
|
||||||
@@ -1,21 +1,37 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
import logging
|
from streamlit import dialog
|
||||||
|
import queries.user as user_queries
|
||||||
|
|
||||||
from logger import init_logger
|
from logger import init_logger
|
||||||
from styles import init_styles
|
from styles import init_styles
|
||||||
|
|
||||||
init_logger()
|
init_logger()
|
||||||
init_styles()
|
init_styles()
|
||||||
|
|
||||||
if hasattr(st, 'user') and hasattr(st.user, 'is_logged_in'):
|
is_login_enabled = hasattr(st, 'user')
|
||||||
if not st.user.is_logged_in:
|
is_logged_in = is_login_enabled and hasattr(st.user, 'is_logged_in') and st.user.is_logged_in
|
||||||
|
|
||||||
|
if is_logged_in:
|
||||||
|
user_queries.set_user_in_session(st.user)
|
||||||
|
else:
|
||||||
|
st.session_state.user_id = 1 # default user
|
||||||
|
|
||||||
|
if is_login_enabled and not is_logged_in:
|
||||||
with st.container(width="stretch", height="stretch", horizontal_alignment="center"):
|
with st.container(width="stretch", height="stretch", horizontal_alignment="center"):
|
||||||
st.title("Daily Counter", width="stretch", text_alignment="center")
|
st.title("Daily Counter", width="stretch", text_alignment="center")
|
||||||
st.text("Please log in to use this app", width="stretch", text_alignment="center")
|
st.text("Please log in to use this app", width="stretch", text_alignment="center")
|
||||||
st.space()
|
st.space()
|
||||||
if st.button("Log in"):
|
if st.button("Log in"):
|
||||||
st.login()
|
st.login()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
counters = st.Page("pages/counters.py", title="Counters", icon=":material/update:")
|
counters = st.Page("pages/counters.py", title="Counters", icon=":material/update:")
|
||||||
stats = st.Page("pages/stats.py", title="Statistics", icon=":material/chart_data:")
|
stats = st.Page("pages/stats.py", title="Statistics", icon=":material/chart_data:")
|
||||||
pg = st.navigation(position="top", pages=[counters, stats])
|
logoutPage = st.Page(st.logout, title="Logout", icon=":material/logout:")
|
||||||
|
|
||||||
|
pages = [counters, stats]
|
||||||
|
if is_login_enabled:
|
||||||
|
pages = pages + [logoutPage]
|
||||||
|
|
||||||
|
pg = st.navigation(position="top", pages=pages)
|
||||||
pg.run()
|
pg.run()
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
#MainMenu {
|
#MainMenu {
|
||||||
#display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
.stApp {
|
.stApp {
|
||||||
min-width: 360px;
|
min-width: 360px;
|
||||||
@@ -54,3 +54,19 @@
|
|||||||
visibility: hidden;
|
visibility: hidden;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
div:has(> .stToolbarActions) {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
.rc-overflow > .rc-overflow-item {
|
||||||
|
flex: 1;
|
||||||
|
}
|
||||||
|
.rc-overflow > .rc-overflow-item:nth-child(3) > div {
|
||||||
|
margin-left: auto;
|
||||||
|
width: fit-content;
|
||||||
|
}
|
||||||
|
.rc-overflow > .rc-overflow-item:nth-child(3) > div > a {
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
.rc-overflow > .rc-overflow-item-rest {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|||||||
56
migrations/versions/20260427172417_user_id.py
Normal file
56
migrations/versions/20260427172417_user_id.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""add user id
|
||||||
|
|
||||||
|
Revision ID: d9faf8fb8642
|
||||||
|
Revises: 4ee21f978e6c
|
||||||
|
Create Date: 2026-04-27 17:24:17.892586
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'd9faf8fb8642'
|
||||||
|
down_revision: Union[str, Sequence[str], None] = '4ee21f978e6c'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
users = op.create_table(
|
||||||
|
"users",
|
||||||
|
sa.Column("id", sa.Integer, primary_key=True, autoincrement=True),
|
||||||
|
sa.Column("email", sa.String, nullable=False),
|
||||||
|
sa.Column("name", sa.String),
|
||||||
|
sa.Column("oidc_user_id", sa.Integer)
|
||||||
|
)
|
||||||
|
|
||||||
|
op.bulk_insert(users, [ { "email": "default" } ])
|
||||||
|
|
||||||
|
with op.batch_alter_table("counters") as batch_op:
|
||||||
|
batch_op.add_column(sa.Column("user_id", sa.Integer, insert_default=1))
|
||||||
|
batch_op.create_foreign_key("fk_counters_user_id",
|
||||||
|
referent_table="users",
|
||||||
|
remote_cols=['id'],
|
||||||
|
local_cols=['user_id'])
|
||||||
|
|
||||||
|
with op.batch_alter_table("entries") as batch_op:
|
||||||
|
batch_op.add_column(sa.Column("user_id", sa.Integer, insert_default=1))
|
||||||
|
batch_op.create_foreign_key("fk_entries_user_id",
|
||||||
|
referent_table="users",
|
||||||
|
remote_cols=['id'],
|
||||||
|
local_cols=['user_id'])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
with op.batch_alter_table("counters") as batch_op:
|
||||||
|
batch_op.drop_constraint("fk_counters_user_id", type_="foreignkey")
|
||||||
|
batch_op.drop_column('user_id')
|
||||||
|
|
||||||
|
with op.batch_alter_table("entries") as batch_op:
|
||||||
|
batch_op.drop_constraint('fk_entries_user_id', type_='foreignkey')
|
||||||
|
batch_op.drop_column('user_id')
|
||||||
|
|
||||||
|
op.drop_table("users")
|
||||||
|
|
||||||
55
tests/database/user_db_test.py
Normal file
55
tests/database/user_db_test.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import streamlit
|
||||||
|
|
||||||
|
import queries.user as user
|
||||||
|
|
||||||
|
def test_get_default_user():
|
||||||
|
users = user.find_default_user()
|
||||||
|
assert len(users) == 1
|
||||||
|
assert users["email"][0] == "default"
|
||||||
|
|
||||||
|
def test_update_default_user_and_find_user():
|
||||||
|
user.update_default_user(email="test@testbase.com", name="Test User", oidc_user_id="1111-2222-3333")
|
||||||
|
|
||||||
|
users = user.find_default_user()
|
||||||
|
assert len(users) == 0
|
||||||
|
|
||||||
|
users = user.find_user_by_oidc_id("1111-2222-3333")
|
||||||
|
assert len(users) == 1
|
||||||
|
assert users["email"][0] == "test@testbase.com"
|
||||||
|
assert users["name"][0] == "Test User"
|
||||||
|
assert users["oidc_user_id"][0] == "1111-2222-3333"
|
||||||
|
|
||||||
|
users = user.find_user_by_email("test@testbase.com")
|
||||||
|
assert len(users) == 1
|
||||||
|
assert users["email"][0] == "test@testbase.com"
|
||||||
|
assert users["name"][0] == "Test User"
|
||||||
|
assert users["oidc_user_id"][0] == "1111-2222-3333"
|
||||||
|
|
||||||
|
def test_add_user():
|
||||||
|
user.create_user(email="test@testbase.com", name="Test User", oidc_user_id="333-4444-5555")
|
||||||
|
|
||||||
|
users = user.find_user_by_oidc_id("333-4444-5555")
|
||||||
|
assert len(users) == 1
|
||||||
|
assert users["email"][0] == "test@testbase.com"
|
||||||
|
assert users["name"][0] == "Test User"
|
||||||
|
assert users["oidc_user_id"][0] == "333-4444-5555"
|
||||||
|
|
||||||
|
users = user.find_user_by_email("test@testbase.com")
|
||||||
|
assert len(users) == 1
|
||||||
|
assert users["email"][0] == "test@testbase.com"
|
||||||
|
assert users["name"][0] == "Test User"
|
||||||
|
assert users["oidc_user_id"][0] == "333-4444-5555"
|
||||||
|
|
||||||
|
def test_update_user_in_session():
|
||||||
|
userInfo = lambda: None
|
||||||
|
userInfo.email ="test@testbase.com"
|
||||||
|
userInfo.name = "Test User"
|
||||||
|
userInfo.sub = "1111-2222-3333"
|
||||||
|
|
||||||
|
user.set_user_in_session(userInfo)
|
||||||
|
|
||||||
|
state = streamlit.session_state
|
||||||
|
assert state.user_id == 1
|
||||||
|
assert state.user_name == userInfo.name
|
||||||
|
assert state.user_email == userInfo.email
|
||||||
|
assert state.user_external_id == userInfo.sub
|
||||||
Reference in New Issue
Block a user