Coverage for toardb/test_base.py: 100%

62 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-03 20:32 +0000

1# SPDX-FileCopyrightText: 2021 Forschungszentrum Jülich GmbH 

2# SPDX-License-Identifier: MIT 

3 

4# taken from: https://github.com/tiangolo/fastapi/issues/831 

5import pytest 

6from starlette.testclient import TestClient 

7from sqlalchemy import create_engine 

8from sqlalchemy.engine import Engine 

9from sqlalchemy.orm import sessionmaker 

10from sqlalchemy_utils import database_exists, create_database, drop_database 

11from fastapi import Request 

12 

13from toardb.base import Base 

14from toardb.toardb import app 

15from toardb.auth_user.models import AuthUser 

16from toardb.utils.database import DATABASE_URL, get_db, get_engine 

17from toardb.utils.utils import ( 

18 get_admin_access_rights, 

19 get_station_md_change_access_rights, 

20 get_timeseries_md_change_access_rights, 

21 get_data_change_access_rights, 

22 get_data_download_access_rights, 

23 get_map_data_download_access_rights 

24) 

25 

26 

27url = "postgresql://postgres:postgres@postgres:5432/postgres" 

28_db_conn = create_engine(url, pool_pre_ping=True, pool_size=32, max_overflow=128) 

29sess = sessionmaker(bind=_db_conn, autoflush=False, autocommit=False) 

30 

31def get_test_engine() -> Engine: 

32 assert _db_conn is not None 

33 return _db_conn 

34 

35 

36def get_test_db(): 

37 test_db = sess() 

38 try: 

39 yield test_db 

40 finally: 

41 test_db.close() 

42 

43 

44async def override_dependency(request: Request): 

45 db = next(get_test_db()) 

46 email = request.headers.get("email") 

47 db_user = db.query(AuthUser).filter(AuthUser.email == email).first() 

48 # status_code will be taken from the AAI (here: faked) 

49 status_code = 401 

50 if db_user: 

51 # status_code will be taken from the AAI (here: faked) 

52 status_code = 200 

53 access_dict = { 

54 "status_code": status_code, 

55 "user_name": "Sabine Schröder", 

56 "user_email": email, 

57 "auth_user_id": db_user.id, 

58 "role": "registered", 

59 "lfromdashboard": True 

60 } 

61 else: 

62 # the user needs to be added to the database! 

63 # (maybe users already have the credentials (in the AAI), 

64 # but they also need a permanent auth_user_id related to the TOAR database) 

65 access_dict = { 

66 "status_code": status_code, 

67 "user_name": "Something from AAI", 

68 "user_email": email, 

69 "auth_user_id": -1, 

70 "role": "registered", 

71 "lfromdashboard": True 

72 } 

73 return access_dict 

74 

75 

76app.dependency_overrides[get_admin_access_rights] = override_dependency 

77app.dependency_overrides[get_station_md_change_access_rights] = override_dependency 

78app.dependency_overrides[get_timeseries_md_change_access_rights] = override_dependency 

79app.dependency_overrides[get_data_change_access_rights] = override_dependency 

80app.dependency_overrides[get_data_download_access_rights] = override_dependency 

81app.dependency_overrides[get_map_data_download_access_rights] = override_dependency 

82 

83 

84@pytest.fixture(scope="session", autouse=True) 

85def create_test_database(): 

86 """ 

87 Create a clean database on every test case. 

88 We use the `sqlalchemy_utils` package here for a few helpers in consistently 

89 creating and dropping the database. 

90 """ 

91# if database_exists(url): 

92# drop_database(url) 

93# create_database(url) # Create the test database. 

94 # 'create_all' does not work (because of undefined 'Geometry')! 

95 # declare PostGIS extension! (and toar_controlled_vocabulary) 

96# fake_conn = _db_conn.raw_connection() 

97# fake_cur = fake_conn.cursor() 

98# fake_cur.execute("CREATE EXTENSION IF NOT EXISTS postgis") 

99# fake_conn.commit() 

100# fake_cur.execute("CREATE EXTENSION IF NOT EXISTS toar_controlled_vocabulary") 

101# fake_conn.commit() 

102# fake_cur.execute("SET TIMEZONE='UTC'") 

103# fake_conn.commit() 

104 Base.metadata.create_all(_db_conn) # Create the tables. 

105 # try with the basics 

106 app.dependency_overrides[get_db] = get_test_db # Mock the Database Dependency 

107 app.dependency_overrides[get_engine] = get_test_engine # Mock the Database Dependency 

108 yield # Run the tests. 

109# drop_database(url) # Drop the test database. 

110 

111 

112@pytest.fixture 

113def test_db_session(): 

114 """Returns an sqlalchemy session, and after the test tears down everything properly.""" 

115 

116 session = sessionmaker(bind=_db_conn, autoflush=False, autocommit=False)() 

117 yield session 

118 # Drop all data after each test 

119 for tbl in reversed(Base.metadata.sorted_tables): 

120 # otherwiese all tables from "toar_controlled_vocabulary" will get lost! 

121 if not tbl.name.endswith("_vocabulary"): 

122 _db_conn.execute(tbl.delete()) 

123 _db_conn.execute("DELETE FROM staging.data;") 

124 fake_conn = _db_conn.raw_connection() 

125 fake_cur = fake_conn.cursor() 

126 fake_cur.execute("ALTER TABLE timeseries_changelog ALTER COLUMN datetime SET DEFAULT now();") 

127 fake_conn.commit() 

128 # put back the connection to the connection pool 

129 session.close() 

130 

131 

132@pytest.fixture() 

133def client(): 

134 """ 

135 When using the 'client' fixture in test cases, we'll get full database 

136 rollbacks between test cases: 

137 """ 

138 with TestClient(app) as client: 

139 yield client