Create a model for storing session tokens.
This commit is contained in:
@@ -40,9 +40,11 @@ True
|
||||
True
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
import redis
|
||||
import uuid
|
||||
|
||||
from nova import datastore
|
||||
from nova import exception
|
||||
@@ -228,6 +230,78 @@ class Daemon(datastore.BasicModel):
|
||||
for x in cls.associated_to("host", hostname):
|
||||
yield x
|
||||
|
||||
class SessionToken(datastore.BasicModel):
|
||||
"""This is a short-lived auth token that is passed through web requests"""
|
||||
|
||||
def __init__(self, session_token):
|
||||
self.token = session_token
|
||||
self.default_ttl = FLAGS.auth_token_ttl
|
||||
super(SessionToken, self).__init__()
|
||||
|
||||
@property
|
||||
def identifier(self):
|
||||
return self.token
|
||||
|
||||
def default_state(self):
|
||||
now = datetime.datetime.utcnow()
|
||||
diff = datetime.timedelta(seconds=self.default_ttl)
|
||||
expires = now + diff
|
||||
return {'user': None, 'session_type': None, 'token': self.token,
|
||||
'expiry': expires.strftime(utils.TIME_FORMAT)}
|
||||
|
||||
def save(self):
|
||||
"""Call into superclass to save object, then save associations"""
|
||||
if not self['user']:
|
||||
raise exception.Invalid("SessionToken requires a User association")
|
||||
success = super(SessionToken, self).save()
|
||||
if success:
|
||||
self.associate_with("user", self['user'])
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def lookup(cls, key):
|
||||
token = super(SessionToken, cls).lookup(key)
|
||||
if token:
|
||||
expires_at = utils.parse_isotime(token['expiry'])
|
||||
if datetime.datetime.utcnow() >= expires_at:
|
||||
token.destroy()
|
||||
return None
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def generate(cls, userid, session_type=None):
|
||||
"""make a new token for the given user"""
|
||||
token = str(uuid.uuid4())
|
||||
while cls.lookup(token):
|
||||
token = str(uuid.uuid4())
|
||||
instance = cls(token)
|
||||
instance['user'] = userid
|
||||
instance['session_type'] = session_type
|
||||
instance.save()
|
||||
return instance
|
||||
|
||||
def update_expiry(self, **kwargs):
|
||||
"""updates the expirty attribute, but doesn't save"""
|
||||
if not kwargs:
|
||||
kwargs['seconds'] = self.default_ttl
|
||||
time = datetime.datetime.utcnow()
|
||||
diff = datetime.timedelta(**kwargs)
|
||||
expires = time + diff
|
||||
self['expiry'] = expires.strftime(utils.TIME_FORMAT)
|
||||
|
||||
def is_expired(self):
|
||||
now = datetime.datetime.utcnow()
|
||||
expires = utils.parse_isotime(self['expiry'])
|
||||
return expires <= now
|
||||
|
||||
def ttl(self):
|
||||
"""number of seconds remaining before expiration"""
|
||||
now = datetime.datetime.utcnow()
|
||||
expires = utils.parse_isotime(self['expiry'])
|
||||
delta = expires - now
|
||||
return (delta.seconds + (delta.days * 24 * 3600))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
|
@@ -47,6 +47,9 @@ class NotAuthorized(Error):
|
||||
class NotEmpty(Error):
|
||||
pass
|
||||
|
||||
class Invalid(Error):
|
||||
pass
|
||||
|
||||
def wrap_exception(f):
|
||||
def _wrap(*args, **kw):
|
||||
try:
|
||||
|
@@ -76,6 +76,8 @@ DEFINE_string('vpn_key_suffix',
|
||||
'-key',
|
||||
'Suffix to add to project name for vpn key')
|
||||
|
||||
DEFINE_integer('auth_token_ttl', 3600, 'Seconds for auth tokens to linger')
|
||||
|
||||
# UNUSED
|
||||
DEFINE_string('node_availability_zone',
|
||||
'nova',
|
||||
|
@@ -16,6 +16,7 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
import time
|
||||
from twisted.internet import defer
|
||||
@@ -64,6 +65,12 @@ class ModelTestCase(test.TrialTestCase):
|
||||
daemon.save()
|
||||
return daemon
|
||||
|
||||
def create_session_token(self):
|
||||
session_token = model.SessionToken('tk12341234')
|
||||
session_token['user'] = 'testuser'
|
||||
session_token.save()
|
||||
return session_token
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_create_instance(self):
|
||||
"""store with create_instace, then test that a load finds it"""
|
||||
@@ -202,3 +209,91 @@ class ModelTestCase(test.TrialTestCase):
|
||||
if x.identifier == 'testhost:nova-testdaemon':
|
||||
found = True
|
||||
self.assertTrue(found)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_create_session_token(self):
|
||||
"""create"""
|
||||
d = yield self.create_session_token()
|
||||
d = model.SessionToken(d.token)
|
||||
self.assertFalse(d.is_new_record())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_delete_session_token(self):
|
||||
"""create, then destroy, then make sure loads a new record"""
|
||||
instance = yield self.create_session_token()
|
||||
yield instance.destroy()
|
||||
newinst = yield model.SessionToken(instance.token)
|
||||
self.assertTrue(newinst.is_new_record())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_session_token_added_to_set(self):
|
||||
"""create, then check that it is included in list"""
|
||||
instance = yield self.create_session_token()
|
||||
found = False
|
||||
for x in model.SessionToken.all():
|
||||
if x.identifier == instance.token:
|
||||
found = True
|
||||
self.assert_(found)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_session_token_associates_user(self):
|
||||
"""create, then check that it is listed for the user"""
|
||||
instance = yield self.create_session_token()
|
||||
found = False
|
||||
for x in model.SessionToken.associated_to('user', 'testuser'):
|
||||
if x.identifier == instance.identifier:
|
||||
found = True
|
||||
self.assertTrue(found)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_session_token_generation(self):
|
||||
instance = yield model.SessionToken.generate('username', 'TokenType')
|
||||
self.assertFalse(instance.is_new_record())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_find_generated_session_token(self):
|
||||
instance = yield model.SessionToken.generate('username', 'TokenType')
|
||||
found = yield model.SessionToken.lookup(instance.identifier)
|
||||
self.assert_(found)
|
||||
|
||||
def test_update_session_token_expiry(self):
|
||||
instance = model.SessionToken('tk12341234')
|
||||
oldtime = datetime.utcnow()
|
||||
instance['expiry'] = oldtime.strftime(utils.TIME_FORMAT)
|
||||
instance.update_expiry()
|
||||
expiry = utils.parse_isotime(instance['expiry'])
|
||||
self.assert_(expiry > datetime.utcnow())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_session_token_lookup_when_expired(self):
|
||||
instance = yield model.SessionToken.generate("testuser")
|
||||
instance['expiry'] = datetime.utcnow().strftime(utils.TIME_FORMAT)
|
||||
instance.save()
|
||||
inst = model.SessionToken.lookup(instance.identifier)
|
||||
self.assertFalse(inst)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_session_token_lookup_when_not_expired(self):
|
||||
instance = yield model.SessionToken.generate("testuser")
|
||||
inst = model.SessionToken.lookup(instance.identifier)
|
||||
self.assert_(inst)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_session_token_is_expired_when_expired(self):
|
||||
instance = yield model.SessionToken.generate("testuser")
|
||||
instance['expiry'] = datetime.utcnow().strftime(utils.TIME_FORMAT)
|
||||
self.assert_(instance.is_expired())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_session_token_is_expired_when_not_expired(self):
|
||||
instance = yield model.SessionToken.generate("testuser")
|
||||
self.assertFalse(instance.is_expired())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_session_token_ttl(self):
|
||||
instance = yield model.SessionToken.generate("testuser")
|
||||
now = datetime.utcnow()
|
||||
delta = timedelta(hours=1)
|
||||
instance['expiry'] = (now + delta).strftime(utils.TIME_FORMAT)
|
||||
# give 5 seconds of fuzziness
|
||||
self.assert_(abs(instance.ttl() - FLAGS.auth_token_ttl) < 5)
|
||||
|
@@ -20,7 +20,7 @@
|
||||
System-level utilities and helper functions.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
@@ -32,7 +32,7 @@ import sys
|
||||
from nova import flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"
|
||||
|
||||
def fetchfile(url, target):
|
||||
logging.debug("Fetching %s" % url)
|
||||
@@ -118,4 +118,7 @@ def get_my_ip():
|
||||
def isotime(at=None):
|
||||
if not at:
|
||||
at = datetime.utcnow()
|
||||
return at.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
return at.strftime(TIME_FORMAT)
|
||||
|
||||
def parse_isotime(timestr):
|
||||
return datetime.strptime(timestr, TIME_FORMAT)
|
||||
|
Reference in New Issue
Block a user