Newer
Older
import json
import uuid
from cryptography.fernet import Fernet
from starlette.datastructures import MutableHeaders
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from starlette.requests import HTTPConnection
from starlette.types import ASGIApp, Message, Receive, Scope, Send
class SessionMiddleware:
"""Based on Starlette SessionMiddleware.
https://github.com/encode/starlette/blob/0.13.2/starlette/middleware/sessions.py
Updated to store session id in cookie, and keep session data elsewhere.
Usage:
app.add_middleware(SessionMiddleware, **params)
Parameters
----------
app: the ASGI application
delete_session_callback(session_id): callback to delete stored session data.
get_session_callback(session_id): callback to get stored session data.
save_session_callback(session_id): callback to update stored session data.
encryption: encrypt session data before storage if provided
session_cookie: name of session cookie
path: path for session cookie
max_age: how long session cookies last
same_site: cookie same site policy
https_only: whether to require https for cookies
"""
def __init__(
self,
app: ASGIApp,
delete_session_callback: Callable[[str], None],
get_session_callback: Callable[[str], str],
save_session_callback: Callable[[str, str], None],
encryption: Fernet = None,
session_cookie: str = "session",
path: str = "/",
max_age: int = 14 * 24 * 60 * 60, # 14 days, in seconds
same_site: str = "lax",
https_only: bool = False,
) -> None:
self.app = app
self.encryption = encryption
self.delete_session_callback = delete_session_callback
self.get_session_callback = get_session_callback
self.save_session_callback = save_session_callback
self.session_cookie = session_cookie
self.path = path
self.max_age = max_age
self.security_flags = "httponly; samesite=" + same_site
if https_only: # Secure flag can be used with HTTPS only
self.security_flags += "; secure"
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"): # pragma: no cover
await self.app(scope, receive, send)
return
connection = HTTPConnection(scope)
initial_session_was_empty = True
session_id = None
if self.session_cookie in connection.cookies:
session_id = connection.cookies[self.session_cookie]
try:
scope["session"] = await self.get_session(session_id)
initial_session_was_empty = False
except Exception:
logging.exception(f"Error loading session {session_id}")
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
scope["session"] = {}
else:
scope["session"] = {}
async def send_wrapper(message: Message) -> None:
nonlocal session_id
if message["type"] == "http.response.start":
if scope["session"]:
session_id = session_id or uuid.uuid4().hex
# Persist session
await self.save_session(session_id, scope["session"])
self.set_cookie(message=message, value=session_id)
elif not initial_session_was_empty:
# Clear session
await self.delete_session(session_id)
self.set_cookie(message=message, value="null", max_age=-1)
await send(message)
await self.app(scope, receive, send_wrapper)
async def delete_session(self, session_id: str):
await self.delete_session_callback(session_id)
async def get_session(self, session_id: str) -> Dict:
data = await self.get_session_callback(session_id)
if self.encryption:
data = self.encryption.decrypt(data.encode("utf8"))
return json.loads(data)
async def save_session(self, session_id: str, data: Mapping):
data = json.dumps(data)
if self.encryption:
data = self.encryption.encrypt(data.encode("utf8")).decode("utf8")
await self.save_session_callback(session_id, data)
def set_cookie(
self,
message: Message,
value: str,
max_age: int = None,
):
headers = MutableHeaders(scope=message)
headers.append("Cache-Control", "no-cache")
headers.append(
"Set-Cookie",
f"{self.session_cookie}={value};"
f" path={self.path};"
f" Max-Age={max_age or self.max_age};"
f" {self.security_flags}",
)