Coverage for toardb / auth_user / crud.py: 34%

145 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-12 12:51 +0000

1from sqlalchemy.orm import Session 

2from fastapi import Request, Header, HTTPException 

3from starlette.datastructures import QueryParams 

4 

5from typing import List 

6from datetime import datetime 

7import requests 

8import json 

9 

10from . import models 

11from toardb.utils.settings import userinfo_endpoint, request_limitations, gridded_product_units 

12from toardb.utils.database import get_db, ToarDbSession 

13 

14def get_user_role(userinfo, ltoken: bool = False): 

15 role = "anonymous" 

16 # has the user authorized themselves to the AAI? 

17 if ltoken: 

18 role = "logged_in" 

19 # is the user a power TOAR user? 

20 if "eduperson_entitlement" in userinfo and \ 

21 f"urn:geant:helmholtz.de:res:toar-data:power-toar-user#login.helmholtz.de" \ 

22 in userinfo["eduperson_entitlement"]: 

23 role = "power-toar-user" 

24 else: 

25 if "eduperson_entitlement" in userinfo and \ 

26 f"urn:geant:helmholtz.de:res:toar-data#login.helmholtz.de" \ 

27 in userinfo["eduperson_entitlement"]: 

28 role = "registered" 

29 return role 

30 

31def get_eduperson_and_roles(request: Request, db: Session = None, eduperson_unique_id: str = None, DoIncr: List[int]=[0,0]): 

32 if db is None: 

33 with ToarDbSession() as db_session: 

34 return _get_eduperson_and_roles(request, db_session, eduperson_unique_id, DoIncr) 

35 else: 

36 return _get_eduperson_and_roles(request, db, eduperson_unique_id, DoIncr) 

37 

38 

39def _get_eduperson_and_roles(request: Request, db: Session, eduperson_unique_id: str, DoIncr: List[int]): 

40 # Do not use underscores; they are not valid in header attributes! 

41 role = 'unknown' 

42 status_code = 200 

43 userinfo = {} 

44 userinfo['name'] = 'fake name' 

45 userinfo['email'] = 'fakeEmail@fake.com' 

46 person_dict = {} 

47 access_token = request.headers.get('AccessToken') 

48 ltoken = access_token is not None 

49 if ltoken: 

50 us_info = requests.get(userinfo_endpoint, headers={'Authorization': f"Bearer {access_token}"}) 

51 status_code = us_info.status_code 

52 if status_code != 401: 

53 userinfo = us_info.json() 

54 eduperson_unique_id = userinfo['eduperson_unique_id'] 

55 else: 

56 person_dict['userinfo'] = userinfo 

57 person_dict['status_code'] = status_code 

58 return person_dict 

59 else: 

60 if eduperson_unique_id is None: 

61 role = 'anonymous' 

62 person_dict["num_timeseries"] = 0 

63 person_dict["num_gridded"] = 0 

64 if role != "anonymous": 

65 db_person = db.query(models.AuthUser).filter(models.AuthUser.eduperson_unique_id == eduperson_unique_id).first() 

66 if db_person: 

67 person_dict["num_timeseries"] = db_person.num_timeseries 

68 person_dict["num_gridded"] = db_person.num_gridded 

69 else: 

70 if ltoken: 

71 person_dict["eduperson_unique_id"] = userinfo['eduperson_unique_id'] 

72 person_dict["email"] = userinfo['email'] 

73 person_dict["username"] = userinfo['name'] 

74 person_dict["num_timeseries"] = 0 

75 person_dict["num_gridded"] = 0 

76 db_person = models.AuthUser(**person_dict) 

77 db.add(db_person) 

78 db.commit() 

79 db.refresh(db_person) 

80 else: 

81 raise HTTPException(406, f"Person {eduperson_unique_id} not known") 

82 if eduperson_unique_id is not None: 

83 ltoken = True 

84 role = get_user_role(userinfo, ltoken) 

85 person_dict["role"] = role 

86 person_dict.update(request_limitations[role]) 

87 match DoIncr[0]: 

88 case 1: 

89 person_dict["num_timeseries"] += DoIncr[1] 

90 case 2: 

91 person_dict["num_gridded"] += DoIncr[1] 

92 if role != "anonymous": 

93 match DoIncr[0]: 

94 case 1: 

95 db_person.num_timeseries += DoIncr[1] 

96 case 2: 

97 db_person.num_gridded += DoIncr[1] 

98 db.add(db_person) 

99 db.commit() 

100 person_dict['status_code'] = status_code 

101 person_dict['userinfo'] = userinfo 

102 return person_dict 

103 

104 

105def count_year_intervals(daterange): 

106 interval_length = gridded_product_units["years"] 

107 start_str, end_str = daterange.split(",") 

108 start = datetime.fromisoformat(start_str.strip()) 

109 end = datetime.fromisoformat(end_str.strip()) 

110 

111 total_years = (end - start).days / 365.25 

112 intervals = int(total_years // interval_length) 

113 if total_years % interval_length > 0: 

114 intervals += 1 

115 return intervals 

116 

117 

118 

119def modify_query_params(query_params, to_remove=["fields", "daterange", "format", "statistics", "flags", 

120 "sampling", "crops", "data_quality_flags", "lat_res", "lon_res", 

121 "statistic", "metadata_scheme", "min_data_capture", "merged", "data_aggregation_mode", 

122 "quantiles", "num_samples", "seasons", "method", "datetime"], 

123 to_add=("limit", "None")): 

124 items = list(query_params.multi_items()) 

125 filtered_items = [(k, v) for k, v in items if k not in to_remove] 

126 filtered_items.append(("fields", "id")) 

127 if not any(k == to_add[0] for k, _ in filtered_items): 

128 filtered_items.append(to_add) 

129 return QueryParams(filtered_items) 

130 

131 

132def pop_user_role(query_params, to_remove=["user_role"]): 

133 items = list(query_params.multi_items()) 

134 filtered_items = [(k, v) for k, v in items if k not in to_remove] 

135 return QueryParams(filtered_items) 

136 

137 

138def count_timeseries_intervals(query_params, db): 

139 from toardb.timeseries.crud import search_all 

140 

141 interval_length = gridded_product_units["timeseries"] 

142 lmerge = query_params.get("merged", False) 

143 filtered = modify_query_params(query_params) 

144 ret = search_all(db, path_params='', query_params=filtered) 

145 if isinstance(ret, list): 

146 lret = [ x['id'] for x in ret ] 

147 if lret: 

148 min_id = min(lret) 

149 max_id = max(lret) 

150 else: 

151 min_id = 0 

152 max_id = 0 

153 total_timeseries = len(ret) 

154 intervals = int(total_timeseries// interval_length) 

155 if total_timeseries % interval_length > 0: 

156 intervals += 1 

157 return total_timeseries, intervals, min_id, max_id 

158 else: 

159 raise HTTPException(406, f"Got unknown filters -- some are not known: {filtered}") 

160 

161 

162def determine_increments(request: Request, db, user_role=None): 

163 query_params = request.query_params 

164 qps = pop_user_role(query_params) 

165 daterange = qps.get("daterange") 

166 if daterange is None: 

167 daterange = "1970-01-01,2025-12-31" 

168 datetime = qps.get("datetime") 

169 if datetime is None: 

170 year_faktor = count_year_intervals(daterange) 

171 else: 

172 year_faktor = 1 

173 num_timeseries, timeseries_faktor, min_id, max_id = count_timeseries_intervals(qps, db) 

174 in_allowed_tsrange = True 

175 if min_id < request_limitations["anonymous"]["min_tsid"] or max_id > request_limitations["anonymous"]["max_tsid"]: 

176 in_allowed_tsrange = False 

177 access = 200 

178 message = "Authorized access" 

179 if user_role: 

180 if user_role == "anonymous" and not in_allowed_tsrange: 

181 access = 401 

182 message = ("Request contains timeseries whose IDs are not in the range of " 

183 f"{request_limitations['anonymous']['min_tsid']} to " 

184 f"{request_limitations['anonymous']['max_tsid']}.") 

185 

186 ret_dict = {"num_timeseries": num_timeseries, 

187 "year_faktor": year_faktor, 

188 "timeseries_faktor": timeseries_faktor, 

189 "num_gridded": year_faktor * timeseries_faktor, 

190 "in_allowed_tsrange": in_allowed_tsrange, 

191 "access": access, 

192 "message": message} 

193 return ret_dict