Simplify
This commit is contained in:
parent
52e67df244
commit
3afe124758
|
@ -11,241 +11,16 @@ hub:
|
|||
storageClassName: csi-sc-cinderplugin
|
||||
extraConfig:
|
||||
oauthCode: |
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
from collections.abc import Mapping
|
||||
from functools import lru_cache
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import requests
|
||||
import tornado.options
|
||||
import yaml
|
||||
from jupyterhub.services.auth import HubAuthenticated
|
||||
from jupyterhub.utils import url_path_join
|
||||
from oauthenticator.generic import GenericOAuthenticator
|
||||
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.log import app_log
|
||||
from tornado.web import Application, HTTPError, RequestHandler, authenticated
|
||||
|
||||
|
||||
def post_auth_hook(authenticator, handler, authentication):
|
||||
user = authentication['auth_state']['oauth_user']['ocs']['data']['id']
|
||||
auth_state = authentication['auth_state']
|
||||
authenticator.user_dict[user] = auth_state
|
||||
return authentication
|
||||
|
||||
|
||||
class NextcloudOAuthenticator(GenericOAuthenticator):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.user_dict = {}
|
||||
|
||||
def pre_spawn_start(self, user, spawner):
|
||||
super().pre_spawn_start(user, spawner)
|
||||
access_token = self.user_dict[user.name]['access_token']
|
||||
# refresh_token = self.user_dict[user.name]['refresh_token']
|
||||
spawner.environment['NEXTCLOUD_ACCESS_TOKEN'] = access_token
|
||||
|
||||
|
||||
c.JupyterHub.authenticator_class = NextcloudOAuthenticator
|
||||
c.NextcloudOAuthenticator.client_id = os.environ['NEXTCLOUD_CLIENT_ID']
|
||||
c.NextcloudOAuthenticator.client_secret = os.environ['NEXTCLOUD_CLIENT_SECRET']
|
||||
c.NextcloudOAuthenticator.login_service = 'Sunet Drive'
|
||||
c.NextcloudOAuthenticator.username_key = lambda r: r.get('ocs', {}).get(
|
||||
'data', {}).get('id')
|
||||
c.NextcloudOAuthenticator.userdata_url = 'https://' + os.environ[
|
||||
'NEXTCLOUD_HOST'] + '/ocs/v2.php/cloud/user?format=json'
|
||||
c.NextcloudOAuthenticator.authorize_url = 'https://' + os.environ[
|
||||
'NEXTCLOUD_HOST'] + '/index.php/apps/oauth2/authorize'
|
||||
c.NextcloudOAuthenticator.token_url = 'https://' + os.environ[
|
||||
'NEXTCLOUD_HOST'] + '/index.php/apps/oauth2/api/v1/token'
|
||||
c.NextcloudOAuthenticator.oauth_callback_url = 'https://' + os.environ[
|
||||
'JUPYTER_HOST'] + '/hub/oauth_callback'
|
||||
c.NextcloudOAuthenticator.refresh_pre_spawn = True
|
||||
c.NextcloudOAuthenticator.enable_auth_state = True
|
||||
c.NextcloudOAuthenticator.post_auth_hook = post_auth_hook
|
||||
|
||||
|
||||
# memoize so we only load config once
|
||||
@lru_cache()
|
||||
def _load_config():
|
||||
"""Load configuration from disk
|
||||
Memoized to only load once
|
||||
"""
|
||||
cfg = {}
|
||||
for source in ('config', 'secret'):
|
||||
path = f"/etc/jupyterhub/{source}/values.yaml"
|
||||
if os.path.exists(path):
|
||||
print(f"Loading {path}")
|
||||
with open(path) as f:
|
||||
values = yaml.safe_load(f)
|
||||
cfg = _merge_dictionaries(cfg, values)
|
||||
else:
|
||||
print(f"No config at {path}")
|
||||
return cfg
|
||||
|
||||
|
||||
def _merge_dictionaries(a, b):
|
||||
"""Merge two dictionaries recursively.
|
||||
Simplified From https://stackoverflow.com/a/7205107
|
||||
"""
|
||||
merged = a.copy()
|
||||
for key in b:
|
||||
if key in a:
|
||||
if isinstance(a[key], Mapping) and isinstance(b[key], Mapping):
|
||||
merged[key] = _merge_dictionaries(a[key], b[key])
|
||||
else:
|
||||
merged[key] = b[key]
|
||||
else:
|
||||
merged[key] = b[key]
|
||||
return merged
|
||||
|
||||
|
||||
def get_config(key, default=None):
|
||||
"""
|
||||
Find a config item of a given name & return it
|
||||
Parses everything as YAML, so lists and dicts are available too
|
||||
get_config("a.b.c") returns config['a']['b']['c']
|
||||
"""
|
||||
value = _load_config()
|
||||
# resolve path in yaml
|
||||
for level in key.split('.'):
|
||||
if not isinstance(value, dict):
|
||||
# a parent is a scalar or null,
|
||||
# can't resolve full path
|
||||
return default
|
||||
if level not in value:
|
||||
return default
|
||||
else:
|
||||
value = value[level]
|
||||
return value
|
||||
|
||||
|
||||
async def fetch_new_token(token_url, client_id, client_secret, refresh_token):
|
||||
params = {
|
||||
"grant_type": "refresh_token",
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
body = urlencode(params)
|
||||
req = HTTPRequest(token_url, 'POST', body=body)
|
||||
app_log.error("URL: %s body: %s", token_url, body)
|
||||
|
||||
client = AsyncHTTPClient()
|
||||
resp = await client.fetch(req)
|
||||
|
||||
resp_json = json.loads(resp.body.decode('utf8', 'replace'))
|
||||
return resp_json
|
||||
|
||||
|
||||
class TokenHandler(HubAuthenticated, RequestHandler):
|
||||
|
||||
def api_request(self, method, url, **kwargs):
|
||||
"""Make an API request"""
|
||||
url = url_path_join(self.hub_auth.api_url, url)
|
||||
allow_404 = kwargs.pop('allow_404', False)
|
||||
headers = kwargs.setdefault('headers', {})
|
||||
headers.setdefault('Authorization',
|
||||
'token %s' % self.hub_auth.api_token)
|
||||
try:
|
||||
r = requests.request(method, url, **kwargs)
|
||||
except requests.ConnectionError as e:
|
||||
app_log.error("Error connecting to %s: %s", url, e)
|
||||
msg = "Failed to connect to Hub API at %r." % url
|
||||
msg += " Is the Hub accessible at this URL (from host: %s)?" % socket.gethostname(
|
||||
)
|
||||
if '127.0.0.1' in url:
|
||||
msg += " Make sure to set c.JupyterHub.hub_ip to an IP accessible to" + \
|
||||
" single-user servers if the servers are not on the same host as the Hub."
|
||||
raise HTTPError(500, msg)
|
||||
|
||||
data = None
|
||||
if r.status_code == 404 and allow_404:
|
||||
pass
|
||||
elif r.status_code == 403:
|
||||
app_log.error(
|
||||
"I don't have permission to check authorization with JupyterHub, my auth token may have expired: [%i] %s",
|
||||
r.status_code, r.reason)
|
||||
app_log.error(r.text)
|
||||
raise HTTPError(
|
||||
500,
|
||||
"Permission failure checking authorization, I may need a new token"
|
||||
)
|
||||
elif r.status_code >= 500:
|
||||
app_log.error("Upstream failure verifying auth token: [%i] %s",
|
||||
r.status_code, r.reason)
|
||||
app_log.error(r.text)
|
||||
raise HTTPError(
|
||||
502, "Failed to check authorization (upstream problem)")
|
||||
elif r.status_code >= 400:
|
||||
app_log.warning("Failed to check authorization: [%i] %s",
|
||||
r.status_code, r.reason)
|
||||
app_log.warning(r.text)
|
||||
raise HTTPError(500, "Failed to check authorization")
|
||||
else:
|
||||
data = r.json()
|
||||
|
||||
return data
|
||||
|
||||
@authenticated
|
||||
async def get(self):
|
||||
oauth_config = get_config('auth.custom.config')
|
||||
|
||||
client_id = oauth_config['client_id']
|
||||
client_secret = oauth_config['client_secret']
|
||||
token_url = oauth_config['token_url']
|
||||
user_model = self.get_current_user()
|
||||
|
||||
# Fetch current auth state
|
||||
u = self.api_request('GET', url_path_join('users', user_model['name']))
|
||||
app_log.error("User: %s", u)
|
||||
auth_state = u['auth_state']
|
||||
|
||||
new_tokens = await fetch_new_token(token_url, client_id, client_secret,
|
||||
auth_state.get('refresh_token'))
|
||||
|
||||
# update auth state in the hub
|
||||
auth_state['access_token'] = new_tokens['access_token']
|
||||
auth_state['refresh_token'] = new_tokens['refresh_token']
|
||||
self.api_request('PATCH',
|
||||
url_path_join('users', user_model['name']),
|
||||
data=json.dumps({'auth_state': auth_state}))
|
||||
|
||||
# send new token to the user
|
||||
tokens = {'access_token': auth_state.get('access_token')}
|
||||
self.set_header('content-type', 'application/json')
|
||||
self.write(json.dumps(tokens, indent=1, sort_keys=True))
|
||||
|
||||
|
||||
class PingHandler(RequestHandler):
|
||||
|
||||
def get(self):
|
||||
self.set_header('content-type', 'application/json')
|
||||
self.write(json.dumps({'ping': 1}))
|
||||
|
||||
|
||||
def main():
|
||||
tornado.options.parse_command_line()
|
||||
app = Application([
|
||||
(os.environ['JUPYTERHUB_SERVICE_PREFIX'] + 'tokens', TokenHandler),
|
||||
(os.environ['JUPYTERHUB_SERVICE_PREFIX'] + '/?', PingHandler),
|
||||
])
|
||||
|
||||
http_server = HTTPServer(app)
|
||||
url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL'])
|
||||
|
||||
http_server.listen(url.port)
|
||||
|
||||
IOLoop.current().start()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
c.JupyterHub.authenticator_class = GenericOAuthenticator
|
||||
c.GenericOAuthenticator.client_id = os.environ['NEXTCLOUD_CLIENT_ID']
|
||||
c.GenericOAuthenticator.client_secret = os.environ['NEXTCLOUD_CLIENT_ID']
|
||||
c.GenericOAuthenticator.login_service = 'Sunet Drive'
|
||||
c.GenericOAuthenticator.username_key = lambda r: r.get('ocs', {}).get('data', {}).get('id')
|
||||
c.GenericOAuthenticator.userdata_url = 'https://' + os.environ['NEXTCLOUD_HOST'] + '/ocs/v2.php/cloud/user?format=json'
|
||||
c.GenericOAuthenticator.authorize_url = 'https://' + os.environ['NEXTCLOUD_HOST'] + '/index.php/apps/oauth2/authorize'
|
||||
c.GenericOAuthenticator.token_url = 'https://' + os.environ['NEXTCLOUD_HOST'] + '/index.php/apps/oauth2/api/v1/token'
|
||||
c.GenericOAuthenticator.oauth_callback_url = 'https://' + os.environ['JUPYTER_HOST'] + '/hub/oauth_callback'
|
||||
extraEnv:
|
||||
NEXTCLOUD_HOST: sunet.drive.test.sunet.se
|
||||
JUPYTER_HOST: jupyter.drive.test.sunet.se
|
||||
|
@ -259,6 +34,7 @@ hub:
|
|||
secretKeyRef:
|
||||
name: nextcloud-oauth-secrets
|
||||
key: client-secret
|
||||
|
||||
singleuser:
|
||||
image:
|
||||
name: docker.sunet.se/drive/jupyter-custom
|
||||
|
|
Loading…
Reference in a new issue