Coverage for toardb / auth_user / crud.py: 34%
145 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 12:51 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 12:51 +0000
1from sqlalchemy.orm import Session
2from fastapi import Request, Header, HTTPException
3from starlette.datastructures import QueryParams
5from typing import List
6from datetime import datetime
7import requests
8import json
10from . import models
11from toardb.utils.settings import userinfo_endpoint, request_limitations, gridded_product_units
12from toardb.utils.database import get_db, ToarDbSession
14def get_user_role(userinfo, ltoken: bool = False):
15 role = "anonymous"
16 # has the user authorized themselves to the AAI?
17 if ltoken:
18 role = "logged_in"
19 # is the user a power TOAR user?
20 if "eduperson_entitlement" in userinfo and \
21 f"urn:geant:helmholtz.de:res:toar-data:power-toar-user#login.helmholtz.de" \
22 in userinfo["eduperson_entitlement"]:
23 role = "power-toar-user"
24 else:
25 if "eduperson_entitlement" in userinfo and \
26 f"urn:geant:helmholtz.de:res:toar-data#login.helmholtz.de" \
27 in userinfo["eduperson_entitlement"]:
28 role = "registered"
29 return role
31def get_eduperson_and_roles(request: Request, db: Session = None, eduperson_unique_id: str = None, DoIncr: List[int]=[0,0]):
32 if db is None:
33 with ToarDbSession() as db_session:
34 return _get_eduperson_and_roles(request, db_session, eduperson_unique_id, DoIncr)
35 else:
36 return _get_eduperson_and_roles(request, db, eduperson_unique_id, DoIncr)
39def _get_eduperson_and_roles(request: Request, db: Session, eduperson_unique_id: str, DoIncr: List[int]):
40 # Do not use underscores; they are not valid in header attributes!
41 role = 'unknown'
42 status_code = 200
43 userinfo = {}
44 userinfo['name'] = 'fake name'
45 userinfo['email'] = 'fakeEmail@fake.com'
46 person_dict = {}
47 access_token = request.headers.get('AccessToken')
48 ltoken = access_token is not None
49 if ltoken:
50 us_info = requests.get(userinfo_endpoint, headers={'Authorization': f"Bearer {access_token}"})
51 status_code = us_info.status_code
52 if status_code != 401:
53 userinfo = us_info.json()
54 eduperson_unique_id = userinfo['eduperson_unique_id']
55 else:
56 person_dict['userinfo'] = userinfo
57 person_dict['status_code'] = status_code
58 return person_dict
59 else:
60 if eduperson_unique_id is None:
61 role = 'anonymous'
62 person_dict["num_timeseries"] = 0
63 person_dict["num_gridded"] = 0
64 if role != "anonymous":
65 db_person = db.query(models.AuthUser).filter(models.AuthUser.eduperson_unique_id == eduperson_unique_id).first()
66 if db_person:
67 person_dict["num_timeseries"] = db_person.num_timeseries
68 person_dict["num_gridded"] = db_person.num_gridded
69 else:
70 if ltoken:
71 person_dict["eduperson_unique_id"] = userinfo['eduperson_unique_id']
72 person_dict["email"] = userinfo['email']
73 person_dict["username"] = userinfo['name']
74 person_dict["num_timeseries"] = 0
75 person_dict["num_gridded"] = 0
76 db_person = models.AuthUser(**person_dict)
77 db.add(db_person)
78 db.commit()
79 db.refresh(db_person)
80 else:
81 raise HTTPException(406, f"Person {eduperson_unique_id} not known")
82 if eduperson_unique_id is not None:
83 ltoken = True
84 role = get_user_role(userinfo, ltoken)
85 person_dict["role"] = role
86 person_dict.update(request_limitations[role])
87 match DoIncr[0]:
88 case 1:
89 person_dict["num_timeseries"] += DoIncr[1]
90 case 2:
91 person_dict["num_gridded"] += DoIncr[1]
92 if role != "anonymous":
93 match DoIncr[0]:
94 case 1:
95 db_person.num_timeseries += DoIncr[1]
96 case 2:
97 db_person.num_gridded += DoIncr[1]
98 db.add(db_person)
99 db.commit()
100 person_dict['status_code'] = status_code
101 person_dict['userinfo'] = userinfo
102 return person_dict
105def count_year_intervals(daterange):
106 interval_length = gridded_product_units["years"]
107 start_str, end_str = daterange.split(",")
108 start = datetime.fromisoformat(start_str.strip())
109 end = datetime.fromisoformat(end_str.strip())
111 total_years = (end - start).days / 365.25
112 intervals = int(total_years // interval_length)
113 if total_years % interval_length > 0:
114 intervals += 1
115 return intervals
119def modify_query_params(query_params, to_remove=["fields", "daterange", "format", "statistics", "flags",
120 "sampling", "crops", "data_quality_flags", "lat_res", "lon_res",
121 "statistic", "metadata_scheme", "min_data_capture", "merged", "data_aggregation_mode",
122 "quantiles", "num_samples", "seasons", "method", "datetime"],
123 to_add=("limit", "None")):
124 items = list(query_params.multi_items())
125 filtered_items = [(k, v) for k, v in items if k not in to_remove]
126 filtered_items.append(("fields", "id"))
127 if not any(k == to_add[0] for k, _ in filtered_items):
128 filtered_items.append(to_add)
129 return QueryParams(filtered_items)
132def pop_user_role(query_params, to_remove=["user_role"]):
133 items = list(query_params.multi_items())
134 filtered_items = [(k, v) for k, v in items if k not in to_remove]
135 return QueryParams(filtered_items)
138def count_timeseries_intervals(query_params, db):
139 from toardb.timeseries.crud import search_all
141 interval_length = gridded_product_units["timeseries"]
142 lmerge = query_params.get("merged", False)
143 filtered = modify_query_params(query_params)
144 ret = search_all(db, path_params='', query_params=filtered)
145 if isinstance(ret, list):
146 lret = [ x['id'] for x in ret ]
147 if lret:
148 min_id = min(lret)
149 max_id = max(lret)
150 else:
151 min_id = 0
152 max_id = 0
153 total_timeseries = len(ret)
154 intervals = int(total_timeseries// interval_length)
155 if total_timeseries % interval_length > 0:
156 intervals += 1
157 return total_timeseries, intervals, min_id, max_id
158 else:
159 raise HTTPException(406, f"Got unknown filters -- some are not known: {filtered}")
162def determine_increments(request: Request, db, user_role=None):
163 query_params = request.query_params
164 qps = pop_user_role(query_params)
165 daterange = qps.get("daterange")
166 if daterange is None:
167 daterange = "1970-01-01,2025-12-31"
168 datetime = qps.get("datetime")
169 if datetime is None:
170 year_faktor = count_year_intervals(daterange)
171 else:
172 year_faktor = 1
173 num_timeseries, timeseries_faktor, min_id, max_id = count_timeseries_intervals(qps, db)
174 in_allowed_tsrange = True
175 if min_id < request_limitations["anonymous"]["min_tsid"] or max_id > request_limitations["anonymous"]["max_tsid"]:
176 in_allowed_tsrange = False
177 access = 200
178 message = "Authorized access"
179 if user_role:
180 if user_role == "anonymous" and not in_allowed_tsrange:
181 access = 401
182 message = ("Request contains timeseries whose IDs are not in the range of "
183 f"{request_limitations['anonymous']['min_tsid']} to "
184 f"{request_limitations['anonymous']['max_tsid']}.")
186 ret_dict = {"num_timeseries": num_timeseries,
187 "year_faktor": year_faktor,
188 "timeseries_faktor": timeseries_faktor,
189 "num_gridded": year_faktor * timeseries_faktor,
190 "in_allowed_tsrange": in_allowed_tsrange,
191 "access": access,
192 "message": message}
193 return ret_dict