|
|
@@ -111,9 +111,11 @@ class JwtAuthentication(authentication.BaseAuthentication):
|
|
|
def _handle_public_key(self, access_token):
|
|
|
token_metadata = jwt.get_unverified_header(access_token)
|
|
|
kid = token_metadata.get('kid')
|
|
|
+ key_server_url = self._handle_jku_ha()
|
|
|
|
|
|
- if AUTH.JWT.KEY_SERVER_URL.get():
|
|
|
- response = requests.get(AUTH.JWT.KEY_SERVER_URL.get(), verify=False)
|
|
|
+ if key_server_url:
|
|
|
+ LOG.debug('Fetching JWKS from URL: %s' % key_server_url)
|
|
|
+ response = requests.get(key_server_url, verify=False)
|
|
|
jwk = json.loads(response.content)
|
|
|
|
|
|
if jwk.get('keys'):
|
|
|
@@ -125,6 +127,29 @@ class JwtAuthentication(authentication.BaseAuthentication):
|
|
|
|
|
|
return public_key_pem
|
|
|
|
|
|
+ def _handle_jku_ha(self):
|
|
|
+ res = None
|
|
|
+
|
|
|
+ key_server_urls = AUTH.JWT.KEY_SERVER_URL.get()
|
|
|
+ if not key_server_urls:
|
|
|
+ return None
|
|
|
+
|
|
|
+ if "," in key_server_urls:
|
|
|
+ key_server_urls_list = key_server_urls.split(',')
|
|
|
+
|
|
|
+ for jku in key_server_urls_list:
|
|
|
+ try:
|
|
|
+ res = requests.get(jku.rstrip('/'), verify=False)
|
|
|
+ except Exception as e:
|
|
|
+ if 'Failed to establish a new connection' in str(e):
|
|
|
+ LOG.warning('JKU %s is not available.' % jku)
|
|
|
+
|
|
|
+ # Check response for None and if response code is successful (200) or authentication needed (401), use that host URL.
|
|
|
+ if (res is not None) and (res.status_code in (200, 401)):
|
|
|
+ return jku
|
|
|
+ else:
|
|
|
+ # For non-HA, it's normal url string.
|
|
|
+ return key_server_urls
|
|
|
|
|
|
class DummyCustomAuthentication(authentication.BaseAuthentication):
|
|
|
"""
|