diff --git a/msal/application.py b/msal/application.py index 260d80e..cb78cf2 100644 --- a/msal/application.py +++ b/msal/application.py @@ -1,3 +1,4 @@ +from __future__ import annotations import functools import json import time @@ -238,6 +239,10 @@ class ClientApplication(object): "You can enable broker by following these instructions. " "https://msal-python.readthedocs.io/en/latest/#publicclientapplication") + _TOKEN_CACHE_DATA: dict[str, str] = { # field_in_data: field_in_cache + "key_id": "key_id", # Some token types (SSH-certs, POP) are bound to a key + } + def __init__( self, client_id, client_credential=None, authority=None, validate_authority=True, @@ -651,6 +656,7 @@ def __init__( self._decide_broker(allow_broker, enable_pii_log) self.token_cache = token_cache or TokenCache() + self.token_cache._set(data_to_at=self._TOKEN_CACHE_DATA) self._region_configured = azure_region self._region_detected = None self.client, self._regional_client = self._build_client( @@ -1528,9 +1534,10 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it( "realm": authority.tenant, "home_account_id": (account or {}).get("home_account_id"), } - key_id = kwargs.get("data", {}).get("key_id") - if key_id: # Some token types (SSH-certs, POP) are bound to a key - query["key_id"] = key_id + for field_in_data, field_in_cache in self._TOKEN_CACHE_DATA.items(): + value = kwargs.get("data", {}).get(field_in_data) + if value: + query[field_in_cache] = value now = time.time() refresh_reason = msal.telemetry.AT_ABSENT for entry in self.token_cache.search( # A generator allows us to diff --git a/msal/token_cache.py b/msal/token_cache.py index 66be5c9..c16a7a5 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -1,6 +1,8 @@ -import json +from __future__ import annotations +import json import threading import time +from typing import Optional # Needed in Python 3.7 & 3.8 import logging import warnings @@ -39,6 +41,25 @@ class AuthorityType: ADFS = "ADFS" MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA + _data_to_at: dict[str, str] = { # field_in_data: field_in_cache + # Store extra data which we explicitly allow, + # so that we won't accidentally store a user's password etc. + # It can be used to store for example key_id used in SSH-cert or POP + } + _response_to_at: dict[str, str] = { # field_in_response: field_in_cache + } + + def _set( + self, + *, + data_to_at: Optional[dict[str, str]] = None, + response_to_at: Optional[dict[str, str]] = None, + ) -> None: + # This helper should probably be better in __init__(), + # but there is no easy way for MSAL EX to pick up a kwargs + self._data_to_at = data_to_at or {} + self._response_to_at = response_to_at or {} + def __init__(self): self._lock = threading.RLock() self._cache = {} @@ -267,11 +288,14 @@ def __add(self, event, now=None): "expires_on": str(now + expires_in), # Same here "extended_expires_on": str(now + ext_expires_in) # Same here } - at.update({k: data[k] for k in data if k in { - # Also store extra data which we explicitly allow - # So that we won't accidentally store a user's password etc. - "key_id", # It happens in SSH-cert or POP scenario - }}) + for field_in_resp, field_in_cache in self._response_to_at.items(): + value = response.get(field_in_resp) + if value: + at[field_in_cache] = value + for field_in_data, field_in_cache in self._data_to_at.items(): + value = data.get(field_in_data) + if value: + at[field_in_cache] = value if "refresh_in" in response: refresh_in = response["refresh_in"] # It is an integer at["refresh_on"] = str(now + refresh_in) # Schema wants a string diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index 494d6da..499ae8b 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -218,6 +218,7 @@ def assertFoundAccessToken(self, *, scopes, query, data=None, now=None): def _test_data_should_be_saved_and_searchable_in_access_token(self, data): scopes = ["s2", "s1", "s3"] # Not in particular order now = 1000 + self.cache._set(data_to_at={"key_id": "key_id"}) self.cache.add({ "data": data, "client_id": "my_client_id",