Another pep8 cleanup branch for nova/tests, should be merged after lp:~eday/nova/pep8-fixes-other. After this, the pep8 violation count is 0!

This commit is contained in:
Eric Day
2010-10-23 00:04:02 +00:00
committed by Tarmac
54 changed files with 968 additions and 806 deletions

View File

@@ -47,19 +47,23 @@ class DbDriver(object):
def get_user(self, uid): def get_user(self, uid):
"""Retrieve user by id""" """Retrieve user by id"""
return self._db_user_to_auth_user(db.user_get(context.get_admin_context(), uid)) user = db.user_get(context.get_admin_context(), uid)
return self._db_user_to_auth_user(user)
def get_user_from_access_key(self, access): def get_user_from_access_key(self, access):
"""Retrieve user by access key""" """Retrieve user by access key"""
return self._db_user_to_auth_user(db.user_get_by_access_key(context.get_admin_context(), access)) user = db.user_get_by_access_key(context.get_admin_context(), access)
return self._db_user_to_auth_user(user)
def get_project(self, pid): def get_project(self, pid):
"""Retrieve project by id""" """Retrieve project by id"""
return self._db_project_to_auth_projectuser(db.project_get(context.get_admin_context(), pid)) project = db.project_get(context.get_admin_context(), pid)
return self._db_project_to_auth_projectuser(project)
def get_users(self): def get_users(self):
"""Retrieve list of users""" """Retrieve list of users"""
return [self._db_user_to_auth_user(user) for user in db.user_get_all(context.get_admin_context())] return [self._db_user_to_auth_user(user)
for user in db.user_get_all(context.get_admin_context())]
def get_projects(self, uid=None): def get_projects(self, uid=None):
"""Retrieve list of projects""" """Retrieve list of projects"""
@@ -71,11 +75,10 @@ class DbDriver(object):
def create_user(self, name, access_key, secret_key, is_admin): def create_user(self, name, access_key, secret_key, is_admin):
"""Create a user""" """Create a user"""
values = { 'id' : name, values = {'id': name,
'access_key' : access_key, 'access_key': access_key,
'secret_key' : secret_key, 'secret_key': secret_key,
'is_admin' : is_admin 'is_admin': is_admin}
}
try: try:
user_ref = db.user_create(context.get_admin_context(), values) user_ref = db.user_create(context.get_admin_context(), values)
return self._db_user_to_auth_user(user_ref) return self._db_user_to_auth_user(user_ref)
@@ -83,18 +86,19 @@ class DbDriver(object):
raise exception.Duplicate('User %s already exists' % name) raise exception.Duplicate('User %s already exists' % name)
def _db_user_to_auth_user(self, user_ref): def _db_user_to_auth_user(self, user_ref):
return { 'id' : user_ref['id'], return {'id': user_ref['id'],
'name' : user_ref['id'], 'name': user_ref['id'],
'access' : user_ref['access_key'], 'access': user_ref['access_key'],
'secret' : user_ref['secret_key'], 'secret': user_ref['secret_key'],
'admin' : user_ref['is_admin'] } 'admin': user_ref['is_admin']}
def _db_project_to_auth_projectuser(self, project_ref): def _db_project_to_auth_projectuser(self, project_ref):
return { 'id' : project_ref['id'], member_ids = [member['id'] for member in project_ref['members']]
'name' : project_ref['name'], return {'id': project_ref['id'],
'project_manager_id' : project_ref['project_manager'], 'name': project_ref['name'],
'description' : project_ref['description'], 'project_manager_id': project_ref['project_manager'],
'member_ids' : [member['id'] for member in project_ref['members']] } 'description': project_ref['description'],
'member_ids': member_ids}
def create_project(self, name, manager_uid, def create_project(self, name, manager_uid,
description=None, member_uids=None): description=None, member_uids=None):
@@ -121,10 +125,10 @@ class DbDriver(object):
% member_uid) % member_uid)
members.add(member) members.add(member)
values = { 'id' : name, values = {'id': name,
'name' : name, 'name': name,
'project_manager' : manager['id'], 'project_manager': manager['id'],
'description': description } 'description': description}
try: try:
project = db.project_create(context.get_admin_context(), values) project = db.project_create(context.get_admin_context(), values)
@@ -244,4 +248,3 @@ class DbDriver(object):
if not project: if not project:
raise exception.NotFound('Project "%s" not found' % project_id) raise exception.NotFound('Project "%s" not found' % project_id)
return user, project return user, project

View File

@@ -35,6 +35,7 @@ flags.DEFINE_integer('redis_port', 6379,
'Port that redis is running on.') 'Port that redis is running on.')
flags.DEFINE_integer('redis_db', 0, 'Multiple DB keeps tests away') flags.DEFINE_integer('redis_db', 0, 'Multiple DB keeps tests away')
class Redis(object): class Redis(object):
def __init__(self): def __init__(self):
if hasattr(self.__class__, '_instance'): if hasattr(self.__class__, '_instance'):
@@ -51,19 +52,19 @@ class Redis(object):
SCOPE_BASE = 0 SCOPE_BASE = 0
SCOPE_ONELEVEL = 1 # not implemented SCOPE_ONELEVEL = 1 # Not implemented
SCOPE_SUBTREE = 2 SCOPE_SUBTREE = 2
MOD_ADD = 0 MOD_ADD = 0
MOD_DELETE = 1 MOD_DELETE = 1
MOD_REPLACE = 2 MOD_REPLACE = 2
class NO_SUCH_OBJECT(Exception): # pylint: disable-msg=C0103 class NO_SUCH_OBJECT(Exception): # pylint: disable-msg=C0103
"""Duplicate exception class from real LDAP module.""" """Duplicate exception class from real LDAP module."""
pass pass
class OBJECT_CLASS_VIOLATION(Exception): # pylint: disable-msg=C0103 class OBJECT_CLASS_VIOLATION(Exception): # pylint: disable-msg=C0103
"""Duplicate exception class from real LDAP module.""" """Duplicate exception class from real LDAP module."""
pass pass
@@ -251,8 +252,6 @@ class FakeLDAP(object):
return objects return objects
@property @property
def __redis_prefix(self): # pylint: disable-msg=R0201 def __redis_prefix(self): # pylint: disable-msg=R0201
"""Get the prefix to use for all redis keys.""" """Get the prefix to use for all redis keys."""
return 'ldap:' return 'ldap:'

View File

@@ -294,24 +294,26 @@ class LdapDriver(object):
def __find_dns(self, dn, query=None, scope=None): def __find_dns(self, dn, query=None, scope=None):
"""Find dns by query""" """Find dns by query"""
if scope is None: # one of the flags is 0!! if scope is None:
# One of the flags is 0!
scope = self.ldap.SCOPE_SUBTREE scope = self.ldap.SCOPE_SUBTREE
try: try:
res = self.conn.search_s(dn, scope, query) res = self.conn.search_s(dn, scope, query)
except self.ldap.NO_SUCH_OBJECT: except self.ldap.NO_SUCH_OBJECT:
return [] return []
# just return the DNs # Just return the DNs
return [dn for dn, _attributes in res] return [dn for dn, _attributes in res]
def __find_objects(self, dn, query=None, scope=None): def __find_objects(self, dn, query=None, scope=None):
"""Find objects by query""" """Find objects by query"""
if scope is None: # one of the flags is 0!! if scope is None:
# One of the flags is 0!
scope = self.ldap.SCOPE_SUBTREE scope = self.ldap.SCOPE_SUBTREE
try: try:
res = self.conn.search_s(dn, scope, query) res = self.conn.search_s(dn, scope, query)
except self.ldap.NO_SUCH_OBJECT: except self.ldap.NO_SUCH_OBJECT:
return [] return []
# just return the attributes # Just return the attributes
return [attributes for dn, attributes in res] return [attributes for dn, attributes in res]
def __find_role_dns(self, tree): def __find_role_dns(self, tree):
@@ -480,6 +482,6 @@ class LdapDriver(object):
class FakeLdapDriver(LdapDriver): class FakeLdapDriver(LdapDriver):
"""Fake Ldap Auth driver""" """Fake Ldap Auth driver"""
def __init__(self): # pylint: disable-msg=W0231 def __init__(self): # pylint: disable-msg=W0231
__import__('nova.auth.fakeldap') __import__('nova.auth.fakeldap')
self.ldap = sys.modules['nova.auth.fakeldap'] self.ldap = sys.modules['nova.auth.fakeldap']

View File

@@ -23,7 +23,7 @@ Nova authentication management
import logging import logging
import os import os
import shutil import shutil
import string # pylint: disable-msg=W0402 import string # pylint: disable-msg=W0402
import tempfile import tempfile
import uuid import uuid
import zipfile import zipfile

View File

@@ -49,7 +49,7 @@ class CloudPipe(object):
self.manager = manager.AuthManager() self.manager = manager.AuthManager()
def launch_vpn_instance(self, project_id): def launch_vpn_instance(self, project_id):
logging.debug( "Launching VPN for %s" % (project_id)) logging.debug("Launching VPN for %s" % (project_id))
project = self.manager.get_project(project_id) project = self.manager.get_project(project_id)
# Make a payload.zip # Make a payload.zip
tmpfolder = tempfile.mkdtemp() tmpfolder = tempfile.mkdtemp()
@@ -57,16 +57,18 @@ class CloudPipe(object):
zippath = os.path.join(tmpfolder, filename) zippath = os.path.join(tmpfolder, filename)
z = zipfile.ZipFile(zippath, "w", zipfile.ZIP_DEFLATED) z = zipfile.ZipFile(zippath, "w", zipfile.ZIP_DEFLATED)
z.write(FLAGS.boot_script_template,'autorun.sh') z.write(FLAGS.boot_script_template, 'autorun.sh')
z.close() z.close()
key_name = self.setup_key_pair(project.project_manager_id, project_id) key_name = self.setup_key_pair(project.project_manager_id, project_id)
zippy = open(zippath, "r") zippy = open(zippath, "r")
context = context.RequestContext(user=project.project_manager, project=project) context = context.RequestContext(user=project.project_manager,
project=project)
reservation = self.controller.run_instances(context, reservation = self.controller.run_instances(context,
# run instances expects encoded userdata, it is decoded in the get_metadata_call # Run instances expects encoded userdata, it is decoded in the
# autorun.sh also decodes the zip file, hence the double encoding # get_metadata_call. autorun.sh also decodes the zip file, hence
# the double encoding.
user_data=zippy.read().encode("base64").encode("base64"), user_data=zippy.read().encode("base64").encode("base64"),
max_count=1, max_count=1,
min_count=1, min_count=1,
@@ -79,12 +81,14 @@ class CloudPipe(object):
def setup_key_pair(self, user_id, project_id): def setup_key_pair(self, user_id, project_id):
key_name = '%s%s' % (project_id, FLAGS.vpn_key_suffix) key_name = '%s%s' % (project_id, FLAGS.vpn_key_suffix)
try: try:
private_key, fingerprint = self.manager.generate_key_pair(user_id, key_name) private_key, fingerprint = self.manager.generate_key_pair(user_id,
key_name)
try: try:
key_dir = os.path.join(FLAGS.keys_path, user_id) key_dir = os.path.join(FLAGS.keys_path, user_id)
if not os.path.exists(key_dir): if not os.path.exists(key_dir):
os.makedirs(key_dir) os.makedirs(key_dir)
with open(os.path.join(key_dir, '%s.pem' % key_name),'w') as f: file_name = os.path.join(key_dir, '%s.pem' % key_name)
with open(file_name, 'w') as f:
f.write(private_key) f.write(private_key)
except: except:
pass pass
@@ -95,9 +99,13 @@ class CloudPipe(object):
# def setup_secgroups(self, username): # def setup_secgroups(self, username):
# conn = self.euca.connection_for(username) # conn = self.euca.connection_for(username)
# try: # try:
# secgroup = conn.create_security_group("vpn-secgroup", "vpn-secgroup") # secgroup = conn.create_security_group("vpn-secgroup",
# secgroup.authorize(ip_protocol = "udp", from_port = "1194", to_port = "1194", cidr_ip = "0.0.0.0/0") # "vpn-secgroup")
# secgroup.authorize(ip_protocol = "tcp", from_port = "80", to_port = "80", cidr_ip = "0.0.0.0/0") # secgroup.authorize(ip_protocol = "udp", from_port = "1194",
# secgroup.authorize(ip_protocol = "tcp", from_port = "22", to_port = "22", cidr_ip = "0.0.0.0/0") # to_port = "1194", cidr_ip = "0.0.0.0/0")
# secgroup.authorize(ip_protocol = "tcp", from_port = "80",
# to_port = "80", cidr_ip = "0.0.0.0/0")
# secgroup.authorize(ip_protocol = "tcp", from_port = "22",
# to_port = "22", cidr_ip = "0.0.0.0/0")
# except: # except:
# pass # pass

View File

@@ -74,12 +74,12 @@ def partition(infile, outfile, local_bytes=0, resize=True,
" by sector size: %d / %d", local_bytes, sector_size) " by sector size: %d / %d", local_bytes, sector_size)
local_sectors = local_bytes / sector_size local_sectors = local_bytes / sector_size
mbr_last = 62 # a mbr_last = 62 # a
primary_first = mbr_last + 1 # b primary_first = mbr_last + 1 # b
primary_last = primary_first + primary_sectors - 1 # c primary_last = primary_first + primary_sectors - 1 # c
local_first = primary_last + 1 # d local_first = primary_last + 1 # d
local_last = local_first + local_sectors - 1 # e local_last = local_first + local_sectors - 1 # e
last_sector = local_last # e last_sector = local_last # e
# create an empty file # create an empty file
yield execute('dd if=/dev/zero of=%s count=1 seek=%d bs=%d' yield execute('dd if=/dev/zero of=%s count=1 seek=%d bs=%d'
@@ -162,7 +162,7 @@ def inject_data(image, key=None, net=None, partition=None, execute=None):
@defer.inlineCallbacks @defer.inlineCallbacks
def _inject_key_into_fs(key, fs, execute=None): def _inject_key_into_fs(key, fs, execute=None):
sshdir = os.path.join(os.path.join(fs, 'root'), '.ssh') sshdir = os.path.join(os.path.join(fs, 'root'), '.ssh')
yield execute('sudo mkdir -p %s' % sshdir) # existing dir doesn't matter yield execute('sudo mkdir -p %s' % sshdir) # existing dir doesn't matter
yield execute('sudo chown root %s' % sshdir) yield execute('sudo chown root %s' % sshdir)
yield execute('sudo chmod 700 %s' % sshdir) yield execute('sudo chmod 700 %s' % sshdir)
keyfile = os.path.join(sshdir, 'authorized_keys') keyfile = os.path.join(sshdir, 'authorized_keys')
@@ -174,4 +174,3 @@ def _inject_net_into_fs(net, fs, execute=None):
netfile = os.path.join(os.path.join(os.path.join( netfile = os.path.join(os.path.join(os.path.join(
fs, 'etc'), 'network'), 'interfaces') fs, 'etc'), 'network'), 'interfaces')
yield execute('sudo tee %s' % netfile, net) yield execute('sudo tee %s' % netfile, net)

View File

@@ -85,8 +85,7 @@ RRD_VALUES = {
'RRA:MAX:0.5:6:800', 'RRA:MAX:0.5:6:800',
'RRA:MAX:0.5:24:800', 'RRA:MAX:0.5:24:800',
'RRA:MAX:0.5:444:800', 'RRA:MAX:0.5:444:800',
] ]}
}
utcnow = datetime.datetime.utcnow utcnow = datetime.datetime.utcnow
@@ -97,15 +96,12 @@ def update_rrd(instance, name, data):
Updates the specified RRD file. Updates the specified RRD file.
""" """
filename = os.path.join(instance.get_rrd_path(), '%s.rrd' % name) filename = os.path.join(instance.get_rrd_path(), '%s.rrd' % name)
if not os.path.exists(filename): if not os.path.exists(filename):
init_rrd(instance, name) init_rrd(instance, name)
timestamp = int(time.mktime(utcnow().timetuple())) timestamp = int(time.mktime(utcnow().timetuple()))
rrdtool.update ( rrdtool.update(filename, '%d:%s' % (timestamp, data))
filename,
'%d:%s' % (timestamp, data)
)
def init_rrd(instance, name): def init_rrd(instance, name):
@@ -113,29 +109,28 @@ def init_rrd(instance, name):
Initializes the specified RRD file. Initializes the specified RRD file.
""" """
path = os.path.join(FLAGS.monitoring_rrd_path, instance.instance_id) path = os.path.join(FLAGS.monitoring_rrd_path, instance.instance_id)
if not os.path.exists(path): if not os.path.exists(path):
os.makedirs(path) os.makedirs(path)
filename = os.path.join(path, '%s.rrd' % name) filename = os.path.join(path, '%s.rrd' % name)
if not os.path.exists(filename): if not os.path.exists(filename):
rrdtool.create ( rrdtool.create(
filename, filename,
'--step', '%d' % FLAGS.monitoring_instances_step, '--step', '%d' % FLAGS.monitoring_instances_step,
'--start', '0', '--start', '0',
*RRD_VALUES[name] *RRD_VALUES[name])
)
def graph_cpu(instance, duration): def graph_cpu(instance, duration):
""" """
Creates a graph of cpu usage for the specified instance and duration. Creates a graph of cpu usage for the specified instance and duration.
""" """
path = instance.get_rrd_path() path = instance.get_rrd_path()
filename = os.path.join(path, 'cpu-%s.png' % duration) filename = os.path.join(path, 'cpu-%s.png' % duration)
rrdtool.graph ( rrdtool.graph(
filename, filename,
'--disable-rrdtool-tag', '--disable-rrdtool-tag',
'--imgformat', 'PNG', '--imgformat', 'PNG',
@@ -146,9 +141,8 @@ def graph_cpu(instance, duration):
'-l', '0', '-l', '0',
'-u', '100', '-u', '100',
'DEF:cpu=%s:cpu:AVERAGE' % os.path.join(path, 'cpu.rrd'), 'DEF:cpu=%s:cpu:AVERAGE' % os.path.join(path, 'cpu.rrd'),
'AREA:cpu#eacc00:% CPU', 'AREA:cpu#eacc00:% CPU',)
)
store_graph(instance.instance_id, filename) store_graph(instance.instance_id, filename)
@@ -158,8 +152,8 @@ def graph_net(instance, duration):
""" """
path = instance.get_rrd_path() path = instance.get_rrd_path()
filename = os.path.join(path, 'net-%s.png' % duration) filename = os.path.join(path, 'net-%s.png' % duration)
rrdtool.graph ( rrdtool.graph(
filename, filename,
'--disable-rrdtool-tag', '--disable-rrdtool-tag',
'--imgformat', 'PNG', '--imgformat', 'PNG',
@@ -174,20 +168,19 @@ def graph_net(instance, duration):
'DEF:rx=%s:rx:AVERAGE' % os.path.join(path, 'net.rrd'), 'DEF:rx=%s:rx:AVERAGE' % os.path.join(path, 'net.rrd'),
'DEF:tx=%s:tx:AVERAGE' % os.path.join(path, 'net.rrd'), 'DEF:tx=%s:tx:AVERAGE' % os.path.join(path, 'net.rrd'),
'AREA:rx#00FF00:In traffic', 'AREA:rx#00FF00:In traffic',
'LINE1:tx#0000FF:Out traffic', 'LINE1:tx#0000FF:Out traffic',)
)
store_graph(instance.instance_id, filename) store_graph(instance.instance_id, filename)
def graph_disk(instance, duration): def graph_disk(instance, duration):
""" """
Creates a graph of disk usage for the specified duration. Creates a graph of disk usage for the specified duration.
""" """
path = instance.get_rrd_path() path = instance.get_rrd_path()
filename = os.path.join(path, 'disk-%s.png' % duration) filename = os.path.join(path, 'disk-%s.png' % duration)
rrdtool.graph ( rrdtool.graph(
filename, filename,
'--disable-rrdtool-tag', '--disable-rrdtool-tag',
'--imgformat', 'PNG', '--imgformat', 'PNG',
@@ -202,9 +195,8 @@ def graph_disk(instance, duration):
'DEF:rd=%s:rd:AVERAGE' % os.path.join(path, 'disk.rrd'), 'DEF:rd=%s:rd:AVERAGE' % os.path.join(path, 'disk.rrd'),
'DEF:wr=%s:wr:AVERAGE' % os.path.join(path, 'disk.rrd'), 'DEF:wr=%s:wr:AVERAGE' % os.path.join(path, 'disk.rrd'),
'AREA:rd#00FF00:Read', 'AREA:rd#00FF00:Read',
'LINE1:wr#0000FF:Write', 'LINE1:wr#0000FF:Write',)
)
store_graph(instance.instance_id, filename) store_graph(instance.instance_id, filename)
@@ -224,17 +216,16 @@ def store_graph(instance_id, filename):
is_secure=False, is_secure=False,
calling_format=boto.s3.connection.OrdinaryCallingFormat(), calling_format=boto.s3.connection.OrdinaryCallingFormat(),
port=FLAGS.s3_port, port=FLAGS.s3_port,
host=FLAGS.s3_host host=FLAGS.s3_host)
)
bucket_name = '_%s.monitor' % instance_id bucket_name = '_%s.monitor' % instance_id
# Object store isn't creating the bucket like it should currently # Object store isn't creating the bucket like it should currently
# when it is first requested, so have to catch and create manually. # when it is first requested, so have to catch and create manually.
try: try:
bucket = s3.get_bucket(bucket_name) bucket = s3.get_bucket(bucket_name)
except Exception: except Exception:
bucket = s3.create_bucket(bucket_name) bucket = s3.create_bucket(bucket_name)
key = boto.s3.Key(bucket) key = boto.s3.Key(bucket)
key.key = os.path.basename(filename) key.key = os.path.basename(filename)
key.set_contents_from_filename(filename) key.set_contents_from_filename(filename)
@@ -247,18 +238,18 @@ class Instance(object):
self.last_updated = datetime.datetime.min self.last_updated = datetime.datetime.min
self.cputime = 0 self.cputime = 0
self.cputime_last_updated = None self.cputime_last_updated = None
init_rrd(self, 'cpu') init_rrd(self, 'cpu')
init_rrd(self, 'net') init_rrd(self, 'net')
init_rrd(self, 'disk') init_rrd(self, 'disk')
def needs_update(self): def needs_update(self):
""" """
Indicates whether this instance is due to have its statistics updated. Indicates whether this instance is due to have its statistics updated.
""" """
delta = utcnow() - self.last_updated delta = utcnow() - self.last_updated
return delta.seconds >= FLAGS.monitoring_instances_step return delta.seconds >= FLAGS.monitoring_instances_step
def update(self): def update(self):
""" """
Updates the instances statistics and stores the resulting graphs Updates the instances statistics and stores the resulting graphs
@@ -271,7 +262,7 @@ class Instance(object):
if data != None: if data != None:
logging.debug('CPU: %s', data) logging.debug('CPU: %s', data)
update_rrd(self, 'cpu', data) update_rrd(self, 'cpu', data)
data = self.fetch_net_stats() data = self.fetch_net_stats()
logging.debug('NET: %s', data) logging.debug('NET: %s', data)
update_rrd(self, 'net', data) update_rrd(self, 'net', data)
@@ -279,7 +270,7 @@ class Instance(object):
data = self.fetch_disk_stats() data = self.fetch_disk_stats()
logging.debug('DISK: %s', data) logging.debug('DISK: %s', data)
update_rrd(self, 'disk', data) update_rrd(self, 'disk', data)
# TODO(devcamcar): Turn these into pool.ProcessPool.execute() calls # TODO(devcamcar): Turn these into pool.ProcessPool.execute() calls
# and make the methods @defer.inlineCallbacks. # and make the methods @defer.inlineCallbacks.
graph_cpu(self, '1d') graph_cpu(self, '1d')
@@ -297,13 +288,13 @@ class Instance(object):
logging.exception('unexpected error during update') logging.exception('unexpected error during update')
self.last_updated = utcnow() self.last_updated = utcnow()
def get_rrd_path(self): def get_rrd_path(self):
""" """
Returns the path to where RRD files are stored. Returns the path to where RRD files are stored.
""" """
return os.path.join(FLAGS.monitoring_rrd_path, self.instance_id) return os.path.join(FLAGS.monitoring_rrd_path, self.instance_id)
def fetch_cpu_stats(self): def fetch_cpu_stats(self):
""" """
Returns cpu usage statistics for this instance. Returns cpu usage statistics for this instance.
@@ -327,17 +318,17 @@ class Instance(object):
# Calculate the number of seconds between samples. # Calculate the number of seconds between samples.
d = self.cputime_last_updated - cputime_last_updated d = self.cputime_last_updated - cputime_last_updated
t = d.days * 86400 + d.seconds t = d.days * 86400 + d.seconds
logging.debug('t = %d', t) logging.debug('t = %d', t)
# Calculate change over time in number of nanoseconds of CPU time used. # Calculate change over time in number of nanoseconds of CPU time used.
cputime_delta = self.cputime - cputime_last cputime_delta = self.cputime - cputime_last
logging.debug('cputime_delta = %s', cputime_delta) logging.debug('cputime_delta = %s', cputime_delta)
# Get the number of virtual cpus in this domain. # Get the number of virtual cpus in this domain.
vcpus = int(info['num_cpu']) vcpus = int(info['num_cpu'])
logging.debug('vcpus = %d', vcpus) logging.debug('vcpus = %d', vcpus)
# Calculate CPU % used and cap at 100. # Calculate CPU % used and cap at 100.
@@ -349,9 +340,9 @@ class Instance(object):
""" """
rd = 0 rd = 0
wr = 0 wr = 0
disks = self.conn.get_disks(self.instance_id) disks = self.conn.get_disks(self.instance_id)
# Aggregate the read and write totals. # Aggregate the read and write totals.
for disk in disks: for disk in disks:
try: try:
@@ -363,7 +354,7 @@ class Instance(object):
logging.error('Cannot get blockstats for "%s" on "%s"', logging.error('Cannot get blockstats for "%s" on "%s"',
disk, self.instance_id) disk, self.instance_id)
raise raise
return '%d:%d' % (rd, wr) return '%d:%d' % (rd, wr)
def fetch_net_stats(self): def fetch_net_stats(self):
@@ -372,9 +363,9 @@ class Instance(object):
""" """
rx = 0 rx = 0
tx = 0 tx = 0
interfaces = self.conn.get_interfaces(self.instance_id) interfaces = self.conn.get_interfaces(self.instance_id)
# Aggregate the in and out totals. # Aggregate the in and out totals.
for interface in interfaces: for interface in interfaces:
try: try:
@@ -385,7 +376,7 @@ class Instance(object):
logging.error('Cannot get ifstats for "%s" on "%s"', logging.error('Cannot get ifstats for "%s" on "%s"',
interface, self.instance_id) interface, self.instance_id)
raise raise
return '%d:%d' % (rx, tx) return '%d:%d' % (rx, tx)
@@ -400,16 +391,16 @@ class InstanceMonitor(object, service.Service):
""" """
self._instances = {} self._instances = {}
self._loop = task.LoopingCall(self.updateInstances) self._loop = task.LoopingCall(self.updateInstances)
def startService(self): def startService(self):
self._instances = {} self._instances = {}
self._loop.start(interval=FLAGS.monitoring_instances_delay) self._loop.start(interval=FLAGS.monitoring_instances_delay)
service.Service.startService(self) service.Service.startService(self)
def stopService(self): def stopService(self):
self._loop.stop() self._loop.stop()
service.Service.stopService(self) service.Service.stopService(self)
def updateInstances(self): def updateInstances(self):
""" """
Update resource usage for all running instances. Update resource usage for all running instances.
@@ -420,20 +411,20 @@ class InstanceMonitor(object, service.Service):
logging.exception('unexpected exception getting connection') logging.exception('unexpected exception getting connection')
time.sleep(FLAGS.monitoring_instances_delay) time.sleep(FLAGS.monitoring_instances_delay)
return return
domain_ids = conn.list_instances() domain_ids = conn.list_instances()
try: try:
self.updateInstances_(conn, domain_ids) self.updateInstances_(conn, domain_ids)
except Exception, exn: except Exception, exn:
logging.exception('updateInstances_') logging.exception('updateInstances_')
def updateInstances_(self, conn, domain_ids): def updateInstances_(self, conn, domain_ids):
for domain_id in domain_ids: for domain_id in domain_ids:
if not domain_id in self._instances: if not domain_id in self._instances:
instance = Instance(conn, domain_id) instance = Instance(conn, domain_id)
self._instances[domain_id] = instance self._instances[domain_id] = instance
logging.debug('Found instance: %s', domain_id) logging.debug('Found instance: %s', domain_id)
for key in self._instances.keys(): for key in self._instances.keys():
instance = self._instances[key] instance = self._instances[key]
if instance.needs_update(): if instance.needs_update():

View File

@@ -30,12 +30,11 @@ CRASHED = 0x06
def name(code): def name(code):
d = { d = {
NOSTATE : 'pending', NOSTATE: 'pending',
RUNNING : 'running', RUNNING: 'running',
BLOCKED : 'blocked', BLOCKED: 'blocked',
PAUSED : 'paused', PAUSED: 'paused',
SHUTDOWN: 'shutdown', SHUTDOWN: 'shutdown',
SHUTOFF : 'shutdown', SHUTOFF: 'shutdown',
CRASHED : 'crashed', CRASHED: 'crashed'}
}
return d[code] return d[code]

View File

@@ -256,10 +256,12 @@ def instance_get_all(context):
"""Get all instances.""" """Get all instances."""
return IMPL.instance_get_all(context) return IMPL.instance_get_all(context)
def instance_get_all_by_user(context, user_id): def instance_get_all_by_user(context, user_id):
"""Get all instances.""" """Get all instances."""
return IMPL.instance_get_all_by_user(context, user_id) return IMPL.instance_get_all_by_user(context, user_id)
def instance_get_all_by_project(context, project_id): def instance_get_all_by_project(context, project_id):
"""Get all instance belonging to a project.""" """Get all instance belonging to a project."""
return IMPL.instance_get_all_by_project(context, project_id) return IMPL.instance_get_all_by_project(context, project_id)
@@ -306,7 +308,8 @@ def instance_update(context, instance_id, values):
def instance_add_security_group(context, instance_id, security_group_id): def instance_add_security_group(context, instance_id, security_group_id):
"""Associate the given security group with the given instance""" """Associate the given security group with the given instance"""
return IMPL.instance_add_security_group(context, instance_id, security_group_id) return IMPL.instance_add_security_group(context, instance_id,
security_group_id)
################### ###################
@@ -482,10 +485,12 @@ def auth_destroy_token(context, token):
"""Destroy an auth token""" """Destroy an auth token"""
return IMPL.auth_destroy_token(context, token) return IMPL.auth_destroy_token(context, token)
def auth_get_token(context, token_hash): def auth_get_token(context, token_hash):
"""Retrieves a token given the hash representing it""" """Retrieves a token given the hash representing it"""
return IMPL.auth_get_token(context, token_hash) return IMPL.auth_get_token(context, token_hash)
def auth_create_token(context, token): def auth_create_token(context, token):
"""Creates a new token""" """Creates a new token"""
return IMPL.auth_create_token(context, token) return IMPL.auth_create_token(context, token)
@@ -644,7 +649,9 @@ def security_group_rule_create(context, values):
def security_group_rule_get_by_security_group(context, security_group_id): def security_group_rule_get_by_security_group(context, security_group_id):
"""Get all rules for a a given security group""" """Get all rules for a a given security group"""
return IMPL.security_group_rule_get_by_security_group(context, security_group_id) return IMPL.security_group_rule_get_by_security_group(context,
security_group_id)
def security_group_rule_destroy(context, security_group_rule_id): def security_group_rule_destroy(context, security_group_rule_id):
"""Deletes a security group rule""" """Deletes a security group rule"""
@@ -767,4 +774,3 @@ def host_get_networks(context, host):
network host network host
""" """
return IMPL.host_get_networks(context, host) return IMPL.host_get_networks(context, host)

File diff suppressed because it is too large Load Diff

View File

@@ -134,8 +134,8 @@ class NovaBase(object):
# """Represents a host where services are running""" # """Represents a host where services are running"""
# __tablename__ = 'hosts' # __tablename__ = 'hosts'
# id = Column(String(255), primary_key=True) # id = Column(String(255), primary_key=True)
#
#
class Service(BASE, NovaBase): class Service(BASE, NovaBase):
"""Represents a running service on a host""" """Represents a running service on a host"""
__tablename__ = 'services' __tablename__ = 'services'
@@ -277,7 +277,8 @@ class Quota(BASE, NovaBase):
class ExportDevice(BASE, NovaBase): class ExportDevice(BASE, NovaBase):
"""Represates a shelf and blade that a volume can be exported on""" """Represates a shelf and blade that a volume can be exported on"""
__tablename__ = 'export_devices' __tablename__ = 'export_devices'
__table_args__ = (schema.UniqueConstraint("shelf_id", "blade_id"), {'mysql_engine': 'InnoDB'}) __table_args__ = (schema.UniqueConstraint("shelf_id", "blade_id"),
{'mysql_engine': 'InnoDB'})
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
shelf_id = Column(Integer) shelf_id = Column(Integer)
blade_id = Column(Integer) blade_id = Column(Integer)
@@ -308,10 +309,13 @@ class SecurityGroup(BASE, NovaBase):
instances = relationship(Instance, instances = relationship(Instance,
secondary="security_group_instance_association", secondary="security_group_instance_association",
primaryjoin="and_(SecurityGroup.id == SecurityGroupInstanceAssociation.security_group_id," primaryjoin='and_('
"SecurityGroup.deleted == False)", 'SecurityGroup.id == '
secondaryjoin="and_(SecurityGroupInstanceAssociation.instance_id == Instance.id," 'SecurityGroupInstanceAssociation.security_group_id,'
"Instance.deleted == False)", 'SecurityGroup.deleted == False)',
secondaryjoin='and_('
'SecurityGroupInstanceAssociation.instance_id == Instance.id,'
'Instance.deleted == False)',
backref='security_groups') backref='security_groups')
@property @property
@@ -330,11 +334,12 @@ class SecurityGroupIngressRule(BASE, NovaBase):
parent_group_id = Column(Integer, ForeignKey('security_groups.id')) parent_group_id = Column(Integer, ForeignKey('security_groups.id'))
parent_group = relationship("SecurityGroup", backref="rules", parent_group = relationship("SecurityGroup", backref="rules",
foreign_keys=parent_group_id, foreign_keys=parent_group_id,
primaryjoin="and_(SecurityGroupIngressRule.parent_group_id == SecurityGroup.id," primaryjoin='and_('
"SecurityGroupIngressRule.deleted == False)") 'SecurityGroupIngressRule.parent_group_id == SecurityGroup.id,'
'SecurityGroupIngressRule.deleted == False)')
protocol = Column(String(5)) # "tcp", "udp", or "icmp" protocol = Column(String(5)) # "tcp", "udp", or "icmp"
from_port = Column(Integer) from_port = Column(Integer)
to_port = Column(Integer) to_port = Column(Integer)
cidr = Column(String(255)) cidr = Column(String(255))
@@ -414,8 +419,9 @@ class FixedIp(BASE, NovaBase):
instance = relationship(Instance, instance = relationship(Instance,
backref=backref('fixed_ip', uselist=False), backref=backref('fixed_ip', uselist=False),
foreign_keys=instance_id, foreign_keys=instance_id,
primaryjoin='and_(FixedIp.instance_id==Instance.id,' primaryjoin='and_('
'FixedIp.deleted==False)') 'FixedIp.instance_id == Instance.id,'
'FixedIp.deleted == False)')
allocated = Column(Boolean, default=False) allocated = Column(Boolean, default=False)
leased = Column(Boolean, default=False) leased = Column(Boolean, default=False)
reserved = Column(Boolean, default=False) reserved = Column(Boolean, default=False)
@@ -455,13 +461,13 @@ class UserProjectRoleAssociation(BASE, NovaBase):
__tablename__ = 'user_project_role_association' __tablename__ = 'user_project_role_association'
user_id = Column(String(255), primary_key=True) user_id = Column(String(255), primary_key=True)
user = relationship(User, user = relationship(User,
primaryjoin=user_id==User.id, primaryjoin=user_id == User.id,
foreign_keys=[User.id], foreign_keys=[User.id],
uselist=False) uselist=False)
project_id = Column(String(255), primary_key=True) project_id = Column(String(255), primary_key=True)
project = relationship(Project, project = relationship(Project,
primaryjoin=project_id==Project.id, primaryjoin=project_id == Project.id,
foreign_keys=[Project.id], foreign_keys=[Project.id],
uselist=False) uselist=False)
@@ -485,7 +491,6 @@ class UserProjectAssociation(BASE, NovaBase):
project_id = Column(String(255), ForeignKey(Project.id), primary_key=True) project_id = Column(String(255), ForeignKey(Project.id), primary_key=True)
class FloatingIp(BASE, NovaBase): class FloatingIp(BASE, NovaBase):
"""Represents a floating ip that dynamically forwards to a fixed ip""" """Represents a floating ip that dynamically forwards to a fixed ip"""
__tablename__ = 'floating_ips' __tablename__ = 'floating_ips'
@@ -495,8 +500,9 @@ class FloatingIp(BASE, NovaBase):
fixed_ip = relationship(FixedIp, fixed_ip = relationship(FixedIp,
backref=backref('floating_ips'), backref=backref('floating_ips'),
foreign_keys=fixed_ip_id, foreign_keys=fixed_ip_id,
primaryjoin='and_(FloatingIp.fixed_ip_id==FixedIp.id,' primaryjoin='and_('
'FloatingIp.deleted==False)') 'FloatingIp.fixed_ip_id == FixedIp.id,'
'FloatingIp.deleted == False)')
project_id = Column(String(255)) project_id = Column(String(255))
host = Column(String(255)) # , ForeignKey('hosts.id')) host = Column(String(255)) # , ForeignKey('hosts.id'))
@@ -507,7 +513,7 @@ def register_models():
models = (Service, Instance, Volume, ExportDevice, FixedIp, models = (Service, Instance, Volume, ExportDevice, FixedIp,
FloatingIp, Network, SecurityGroup, FloatingIp, Network, SecurityGroup,
SecurityGroupIngressRule, SecurityGroupInstanceAssociation, SecurityGroupIngressRule, SecurityGroupInstanceAssociation,
AuthToken, User, Project) # , Image, Host AuthToken, User, Project) # , Image, Host
engine = create_engine(FLAGS.sql_connection, echo=False) engine = create_engine(FLAGS.sql_connection, echo=False)
for model in models: for model in models:
model.metadata.create_all(engine) model.metadata.create_all(engine)

View File

@@ -29,6 +29,7 @@ FLAGS = flags.FLAGS
_ENGINE = None _ENGINE = None
_MAKER = None _MAKER = None
def get_session(autocommit=True, expire_on_commit=False): def get_session(autocommit=True, expire_on_commit=False):
"""Helper method to grab session""" """Helper method to grab session"""
global _ENGINE global _ENGINE
@@ -39,5 +40,5 @@ def get_session(autocommit=True, expire_on_commit=False):
_MAKER = (sessionmaker(bind=_ENGINE, _MAKER = (sessionmaker(bind=_ENGINE,
autocommit=autocommit, autocommit=autocommit,
expire_on_commit=expire_on_commit)) expire_on_commit=expire_on_commit))
session = _MAKER() session = _MAKER()
return session return session

View File

@@ -30,7 +30,8 @@ flags.DEFINE_string('glance_teller_address', 'http://127.0.0.1',
flags.DEFINE_string('glance_teller_port', '9191', flags.DEFINE_string('glance_teller_port', '9191',
'Port for Glance\'s Teller service') 'Port for Glance\'s Teller service')
flags.DEFINE_string('glance_parallax_address', 'http://127.0.0.1', flags.DEFINE_string('glance_parallax_address', 'http://127.0.0.1',
'IP address or URL where Glance\'s Parallax service resides') 'IP address or URL where Glance\'s Parallax service '
'resides')
flags.DEFINE_string('glance_parallax_port', '9292', flags.DEFINE_string('glance_parallax_port', '9292',
'Port for Glance\'s Parallax service') 'Port for Glance\'s Parallax service')
@@ -120,10 +121,10 @@ class BaseImageService(object):
def delete(self, image_id): def delete(self, image_id):
""" """
Delete the given image. Delete the given image.
:raises NotFound if the image does not exist. :raises NotFound if the image does not exist.
""" """
raise NotImplementedError raise NotImplementedError
@@ -131,14 +132,14 @@ class BaseImageService(object):
class LocalImageService(BaseImageService): class LocalImageService(BaseImageService):
"""Image service storing images to local disk. """Image service storing images to local disk.
It assumes that image_ids are integers.""" It assumes that image_ids are integers."""
def __init__(self): def __init__(self):
self._path = "/tmp/nova/images" self._path = "/tmp/nova/images"
try: try:
os.makedirs(self._path) os.makedirs(self._path)
except OSError: # exists except OSError: # Exists
pass pass
def _path_to(self, image_id): def _path_to(self, image_id):
@@ -156,7 +157,7 @@ class LocalImageService(BaseImageService):
def show(self, id): def show(self, id):
try: try:
return pickle.load(open(self._path_to(id))) return pickle.load(open(self._path_to(id)))
except IOError: except IOError:
raise exception.NotFound raise exception.NotFound
@@ -164,7 +165,7 @@ class LocalImageService(BaseImageService):
""" """
Store the image data and return the new image id. Store the image data and return the new image id.
""" """
id = random.randint(0, 2**32-1) id = random.randint(0, 2 ** 32 - 1)
data['id'] = id data['id'] = id
self.update(id, data) self.update(id, data)
return id return id

View File

@@ -30,6 +30,7 @@ import nova.image.service
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
class TellerClient(object): class TellerClient(object):
def __init__(self): def __init__(self):
@@ -153,7 +154,6 @@ class ParallaxClient(object):
class GlanceImageService(nova.image.service.BaseImageService): class GlanceImageService(nova.image.service.BaseImageService):
"""Provides storage and retrieval of disk image objects within Glance.""" """Provides storage and retrieval of disk image objects within Glance."""
def __init__(self): def __init__(self):
@@ -202,10 +202,10 @@ class GlanceImageService(nova.image.service.BaseImageService):
def delete(self, image_id): def delete(self, image_id):
""" """
Delete the given image. Delete the given image.
:raises NotFound if the image does not exist. :raises NotFound if the image does not exist.
""" """
self.parallax.delete_image_metadata(image_id) self.parallax.delete_image_metadata(image_id)

View File

@@ -53,6 +53,7 @@ flags.DEFINE_bool('use_nova_chains', False,
DEFAULT_PORTS = [("tcp", 80), ("tcp", 22), ("udp", 1194), ("tcp", 443)] DEFAULT_PORTS = [("tcp", 80), ("tcp", 22), ("udp", 1194), ("tcp", 443)]
def init_host(): def init_host():
"""Basic networking setup goes here""" """Basic networking setup goes here"""
# NOTE(devcamcar): Cloud public DNAT entries, CloudPipe port # NOTE(devcamcar): Cloud public DNAT entries, CloudPipe port
@@ -72,6 +73,7 @@ def init_host():
_confirm_rule("POSTROUTING", "-t nat -s %(range)s -d %(range)s -j ACCEPT" % _confirm_rule("POSTROUTING", "-t nat -s %(range)s -d %(range)s -j ACCEPT" %
{'range': FLAGS.fixed_range}) {'range': FLAGS.fixed_range})
def bind_floating_ip(floating_ip): def bind_floating_ip(floating_ip):
"""Bind ip to public interface""" """Bind ip to public interface"""
_execute("sudo ip addr add %s dev %s" % (floating_ip, _execute("sudo ip addr add %s dev %s" % (floating_ip,
@@ -103,7 +105,7 @@ def ensure_floating_forward(floating_ip, fixed_ip):
_confirm_rule("FORWARD", "-d %s -p icmp -j ACCEPT" _confirm_rule("FORWARD", "-d %s -p icmp -j ACCEPT"
% (fixed_ip)) % (fixed_ip))
for (protocol, port) in DEFAULT_PORTS: for (protocol, port) in DEFAULT_PORTS:
_confirm_rule("FORWARD","-d %s -p %s --dport %s -j ACCEPT" _confirm_rule("FORWARD", "-d %s -p %s --dport %s -j ACCEPT"
% (fixed_ip, protocol, port)) % (fixed_ip, protocol, port))
@@ -189,7 +191,8 @@ def update_dhcp(context, network_id):
# if dnsmasq is already running, then tell it to reload # if dnsmasq is already running, then tell it to reload
if pid: if pid:
out, _err = _execute('cat /proc/%d/cmdline' % pid, check_exit_code=False) out, _err = _execute('cat /proc/%d/cmdline' % pid,
check_exit_code=False)
if conffile in out: if conffile in out:
try: try:
_execute('sudo kill -HUP %d' % pid) _execute('sudo kill -HUP %d' % pid)
@@ -233,7 +236,8 @@ def _confirm_rule(chain, cmd):
"""Delete and re-add iptables rule""" """Delete and re-add iptables rule"""
if FLAGS.use_nova_chains: if FLAGS.use_nova_chains:
chain = "nova_%s" % chain.lower() chain = "nova_%s" % chain.lower()
_execute("sudo iptables --delete %s %s" % (chain, cmd), check_exit_code=False) _execute("sudo iptables --delete %s %s" % (chain, cmd),
check_exit_code=False)
_execute("sudo iptables -I %s %s" % (chain, cmd)) _execute("sudo iptables -I %s %s" % (chain, cmd))

View File

@@ -49,7 +49,8 @@ flags.DEFINE_string('vpn_ip', utils.get_my_ip(),
flags.DEFINE_integer('vpn_start', 1000, 'First Vpn port for private networks') flags.DEFINE_integer('vpn_start', 1000, 'First Vpn port for private networks')
flags.DEFINE_integer('network_size', 256, flags.DEFINE_integer('network_size', 256,
'Number of addresses in each private subnet') 'Number of addresses in each private subnet')
flags.DEFINE_string('floating_range', '4.4.4.0/24', 'Floating IP address block') flags.DEFINE_string('floating_range', '4.4.4.0/24',
'Floating IP address block')
flags.DEFINE_string('fixed_range', '10.0.0.0/8', 'Fixed IP address block') flags.DEFINE_string('fixed_range', '10.0.0.0/8', 'Fixed IP address block')
flags.DEFINE_integer('cnt_vpn_clients', 5, flags.DEFINE_integer('cnt_vpn_clients', 5,
'Number of addresses reserved for vpn clients') 'Number of addresses reserved for vpn clients')
@@ -287,7 +288,6 @@ class FlatManager(NetworkManager):
self.db.network_update(context, network_id, net) self.db.network_update(context, network_id, net)
class FlatDHCPManager(NetworkManager): class FlatDHCPManager(NetworkManager):
"""Flat networking with dhcp""" """Flat networking with dhcp"""
@@ -432,4 +432,3 @@ class VlanManager(NetworkManager):
"""Number of reserved ips at the top of the range""" """Number of reserved ips at the top of the range"""
parent_reserved = super(VlanManager, self)._top_reserved_ips parent_reserved = super(VlanManager, self)._top_reserved_ips
return parent_reserved + FLAGS.cnt_vpn_clients return parent_reserved + FLAGS.cnt_vpn_clients

View File

@@ -69,7 +69,8 @@ class Bucket(object):
"""Create a new bucket owned by a project. """Create a new bucket owned by a project.
@bucket_name: a string representing the name of the bucket to create @bucket_name: a string representing the name of the bucket to create
@context: a nova.auth.api.ApiContext object representing who owns the bucket. @context: a nova.auth.api.ApiContext object representing who owns the
bucket.
Raises: Raises:
NotAuthorized: if the bucket is already exists or has invalid name NotAuthorized: if the bucket is already exists or has invalid name
@@ -77,12 +78,12 @@ class Bucket(object):
path = os.path.abspath(os.path.join( path = os.path.abspath(os.path.join(
FLAGS.buckets_path, bucket_name)) FLAGS.buckets_path, bucket_name))
if not path.startswith(os.path.abspath(FLAGS.buckets_path)) or \ if not path.startswith(os.path.abspath(FLAGS.buckets_path)) or \
os.path.exists(path): os.path.exists(path):
raise exception.NotAuthorized() raise exception.NotAuthorized()
os.makedirs(path) os.makedirs(path)
with open(path+'.json', 'w') as f: with open(path + '.json', 'w') as f:
json.dump({'ownerId': context.project_id}, f) json.dump({'ownerId': context.project_id}, f)
@property @property
@@ -99,22 +100,25 @@ class Bucket(object):
@property @property
def owner_id(self): def owner_id(self):
try: try:
with open(self.path+'.json') as f: with open(self.path + '.json') as f:
return json.load(f)['ownerId'] return json.load(f)['ownerId']
except: except:
return None return None
def is_authorized(self, context): def is_authorized(self, context):
try: try:
return context.user.is_admin() or self.owner_id == context.project_id return context.user.is_admin() or \
self.owner_id == context.project_id
except Exception, e: except Exception, e:
return False return False
def list_keys(self, prefix='', marker=None, max_keys=1000, terse=False): def list_keys(self, prefix='', marker=None, max_keys=1000, terse=False):
object_names = [] object_names = []
path_length = len(self.path)
for root, dirs, files in os.walk(self.path): for root, dirs, files in os.walk(self.path):
for file_name in files: for file_name in files:
object_names.append(os.path.join(root, file_name)[len(self.path)+1:]) object_name = os.path.join(root, file_name)[path_length + 1:]
object_names.append(object_name)
object_names.sort() object_names.sort()
contents = [] contents = []
@@ -164,7 +168,7 @@ class Bucket(object):
if len(os.listdir(self.path)) > 0: if len(os.listdir(self.path)) > 0:
raise exception.NotEmpty() raise exception.NotEmpty()
os.rmdir(self.path) os.rmdir(self.path)
os.remove(self.path+'.json') os.remove(self.path + '.json')
def __getitem__(self, key): def __getitem__(self, key):
return stored.Object(self, key) return stored.Object(self, key)

View File

@@ -136,6 +136,7 @@ def get_context(request):
logging.debug("Authentication Failure: %s", ex) logging.debug("Authentication Failure: %s", ex)
raise exception.NotAuthorized() raise exception.NotAuthorized()
class ErrorHandlingResource(resource.Resource): class ErrorHandlingResource(resource.Resource):
"""Maps exceptions to 404 / 401 codes. Won't work for """Maps exceptions to 404 / 401 codes. Won't work for
exceptions thrown after NOT_DONE_YET is returned. exceptions thrown after NOT_DONE_YET is returned.
@@ -162,7 +163,7 @@ class S3(ErrorHandlingResource):
def __init__(self): def __init__(self):
ErrorHandlingResource.__init__(self) ErrorHandlingResource.__init__(self)
def getChild(self, name, request): # pylint: disable-msg=C0103 def getChild(self, name, request): # pylint: disable-msg=C0103
"""Returns either the image or bucket resource""" """Returns either the image or bucket resource"""
request.context = get_context(request) request.context = get_context(request)
if name == '': if name == '':
@@ -172,7 +173,7 @@ class S3(ErrorHandlingResource):
else: else:
return BucketResource(name) return BucketResource(name)
def render_GET(self, request): # pylint: disable-msg=R0201 def render_GET(self, request): # pylint: disable-msg=R0201
"""Renders the GET request for a list of buckets as XML""" """Renders the GET request for a list of buckets as XML"""
logging.debug('List of buckets requested') logging.debug('List of buckets requested')
buckets = [b for b in bucket.Bucket.all() \ buckets = [b for b in bucket.Bucket.all() \
@@ -321,11 +322,13 @@ class ImageResource(ErrorHandlingResource):
if not self.img.is_authorized(request.context, True): if not self.img.is_authorized(request.context, True):
raise exception.NotAuthorized() raise exception.NotAuthorized()
return static.File(self.img.image_path, return static.File(self.img.image_path,
defaultType='application/octet-stream' defaultType='application/octet-stream').\
).render_GET(request) render_GET(request)
class ImagesResource(resource.Resource): class ImagesResource(resource.Resource):
"""A web resource representing a list of images""" """A web resource representing a list of images"""
def getChild(self, name, _request): def getChild(self, name, _request):
"""Returns itself or an ImageResource if no name given""" """Returns itself or an ImageResource if no name given"""
if name == '': if name == '':
@@ -333,7 +336,7 @@ class ImagesResource(resource.Resource):
else: else:
return ImageResource(name) return ImageResource(name)
def render_GET(self, request): # pylint: disable-msg=R0201 def render_GET(self, request): # pylint: disable-msg=R0201
""" returns a json listing of all images """ returns a json listing of all images
that a user has permissions to see """ that a user has permissions to see """
@@ -362,7 +365,7 @@ class ImagesResource(resource.Resource):
request.finish() request.finish()
return server.NOT_DONE_YET return server.NOT_DONE_YET
def render_PUT(self, request): # pylint: disable-msg=R0201 def render_PUT(self, request): # pylint: disable-msg=R0201
""" create a new registered image """ """ create a new registered image """
image_id = get_argument(request, 'image_id', u'') image_id = get_argument(request, 'image_id', u'')
@@ -383,7 +386,7 @@ class ImagesResource(resource.Resource):
p.start() p.start()
return '' return ''
def render_POST(self, request): # pylint: disable-msg=R0201 def render_POST(self, request): # pylint: disable-msg=R0201
"""Update image attributes: public/private""" """Update image attributes: public/private"""
# image_id required for all requests # image_id required for all requests
@@ -397,7 +400,7 @@ class ImagesResource(resource.Resource):
if operation: if operation:
# operation implies publicity toggle # operation implies publicity toggle
logging.debug("handling publicity toggle") logging.debug("handling publicity toggle")
image_object.set_public(operation=='add') image_object.set_public(operation == 'add')
else: else:
# other attributes imply update # other attributes imply update
logging.debug("update user fields") logging.debug("update user fields")
@@ -407,7 +410,7 @@ class ImagesResource(resource.Resource):
image_object.update_user_editable_fields(clean_args) image_object.update_user_editable_fields(clean_args)
return '' return ''
def render_DELETE(self, request): # pylint: disable-msg=R0201 def render_DELETE(self, request): # pylint: disable-msg=R0201
"""Delete a registered image""" """Delete a registered image"""
image_id = get_argument(request, "image_id", u"") image_id = get_argument(request, "image_id", u"")
image_object = image.Image(image_id) image_object = image.Image(image_id)

View File

@@ -48,8 +48,8 @@ class Image(object):
self.image_id = image_id self.image_id = image_id
self.path = os.path.abspath(os.path.join(FLAGS.images_path, image_id)) self.path = os.path.abspath(os.path.join(FLAGS.images_path, image_id))
if not self.path.startswith(os.path.abspath(FLAGS.images_path)) or \ if not self.path.startswith(os.path.abspath(FLAGS.images_path)) or \
not os.path.isdir(self.path): not os.path.isdir(self.path):
raise exception.NotFound raise exception.NotFound
@property @property
def image_path(self): def image_path(self):
@@ -127,8 +127,8 @@ class Image(object):
a string of the image id for the kernel a string of the image id for the kernel
@type ramdisk: bool or str @type ramdisk: bool or str
@param ramdisk: either TRUE meaning this partition is a ramdisk image or @param ramdisk: either TRUE meaning this partition is a ramdisk image
a string of the image id for the ramdisk or a string of the image id for the ramdisk
@type public: bool @type public: bool
@@ -160,8 +160,7 @@ class Image(object):
'isPublic': public, 'isPublic': public,
'architecture': 'x86_64', 'architecture': 'x86_64',
'imageType': image_type, 'imageType': image_type,
'state': 'available' 'state': 'available'}
}
if type(kernel) is str and len(kernel) > 0: if type(kernel) is str and len(kernel) > 0:
info['kernelId'] = kernel info['kernelId'] = kernel
@@ -180,7 +179,7 @@ class Image(object):
os.makedirs(image_path) os.makedirs(image_path)
bucket_name = image_location.split("/")[0] bucket_name = image_location.split("/")[0]
manifest_path = image_location[len(bucket_name)+1:] manifest_path = image_location[len(bucket_name) + 1:]
bucket_object = bucket.Bucket(bucket_name) bucket_object = bucket.Bucket(bucket_name)
manifest = ElementTree.fromstring(bucket_object[manifest_path].read()) manifest = ElementTree.fromstring(bucket_object[manifest_path].read())
@@ -204,10 +203,9 @@ class Image(object):
'imageId': image_id, 'imageId': image_id,
'imageLocation': image_location, 'imageLocation': image_location,
'imageOwnerId': context.project_id, 'imageOwnerId': context.project_id,
'isPublic': False, # FIXME: grab public from manifest 'isPublic': False, # FIXME: grab public from manifest
'architecture': 'x86_64', # FIXME: grab architecture from manifest 'architecture': 'x86_64', # FIXME: grab architecture from manifest
'imageType' : image_type 'imageType': image_type}
}
if kernel_id: if kernel_id:
info['kernelId'] = kernel_id info['kernelId'] = kernel_id
@@ -230,24 +228,29 @@ class Image(object):
write_state('decrypting') write_state('decrypting')
# FIXME: grab kernelId and ramdiskId from bundle manifest # FIXME: grab kernelId and ramdiskId from bundle manifest
encrypted_key = binascii.a2b_hex(manifest.find("image/ec2_encrypted_key").text) hex_key = manifest.find("image/ec2_encrypted_key").text
encrypted_iv = binascii.a2b_hex(manifest.find("image/ec2_encrypted_iv").text) encrypted_key = binascii.a2b_hex(hex_key)
hex_iv = manifest.find("image/ec2_encrypted_iv").text
encrypted_iv = binascii.a2b_hex(hex_iv)
cloud_private_key = os.path.join(FLAGS.ca_path, "private/cakey.pem") cloud_private_key = os.path.join(FLAGS.ca_path, "private/cakey.pem")
decrypted_filename = os.path.join(image_path, 'image.tar.gz') decrypted_filename = os.path.join(image_path, 'image.tar.gz')
Image.decrypt_image(encrypted_filename, encrypted_key, encrypted_iv, cloud_private_key, decrypted_filename) Image.decrypt_image(encrypted_filename, encrypted_key, encrypted_iv,
cloud_private_key, decrypted_filename)
write_state('untarring') write_state('untarring')
image_file = Image.untarzip_image(image_path, decrypted_filename) image_file = Image.untarzip_image(image_path, decrypted_filename)
shutil.move(os.path.join(image_path, image_file), os.path.join(image_path, 'image')) shutil.move(os.path.join(image_path, image_file),
os.path.join(image_path, 'image'))
write_state('available') write_state('available')
os.unlink(decrypted_filename) os.unlink(decrypted_filename)
os.unlink(encrypted_filename) os.unlink(encrypted_filename)
@staticmethod @staticmethod
def decrypt_image(encrypted_filename, encrypted_key, encrypted_iv, cloud_private_key, decrypted_filename): def decrypt_image(encrypted_filename, encrypted_key, encrypted_iv,
cloud_private_key, decrypted_filename):
key, err = utils.execute( key, err = utils.execute(
'openssl rsautl -decrypt -inkey %s' % cloud_private_key, 'openssl rsautl -decrypt -inkey %s' % cloud_private_key,
process_input=encrypted_key, process_input=encrypted_key,
@@ -259,13 +262,15 @@ class Image(object):
process_input=encrypted_iv, process_input=encrypted_iv,
check_exit_code=False) check_exit_code=False)
if err: if err:
raise exception.Error("Failed to decrypt initialization vector: %s" % err) raise exception.Error("Failed to decrypt initialization "
"vector: %s" % err)
_out, err = utils.execute( _out, err = utils.execute(
'openssl enc -d -aes-128-cbc -in %s -K %s -iv %s -out %s' 'openssl enc -d -aes-128-cbc -in %s -K %s -iv %s -out %s'
% (encrypted_filename, key, iv, decrypted_filename), % (encrypted_filename, key, iv, decrypted_filename),
check_exit_code=False) check_exit_code=False)
if err: if err:
raise exception.Error("Failed to decrypt image file %s : %s" % (encrypted_filename, err)) raise exception.Error("Failed to decrypt image file %s : %s" %
(encrypted_filename, err))
@staticmethod @staticmethod
def untarzip_image(path, filename): def untarzip_image(path, filename):

View File

@@ -50,8 +50,8 @@ class Object(object):
return os.path.getmtime(self.path) return os.path.getmtime(self.path)
def read(self): def read(self):
""" read all contents of key into memory and return """ """ read all contents of key into memory and return """
return self.file.read() return self.file.read()
@property @property
def file(self): def file(self):

View File

@@ -31,10 +31,12 @@ FLAGS = flags.FLAGS
flags.DEFINE_integer('service_down_time', 60, flags.DEFINE_integer('service_down_time', 60,
'maximum time since last checkin for up service') 'maximum time since last checkin for up service')
class NoValidHost(exception.Error): class NoValidHost(exception.Error):
"""There is no valid host for the command.""" """There is no valid host for the command."""
pass pass
class Scheduler(object): class Scheduler(object):
"""The base class that all Scheduler clases should inherit from.""" """The base class that all Scheduler clases should inherit from."""

View File

@@ -56,7 +56,8 @@ class SchedulerManager(manager.Manager):
driver_method = 'schedule_%s' % method driver_method = 'schedule_%s' % method
elevated = context.elevated() elevated = context.elevated()
try: try:
host = getattr(self.driver, driver_method)(elevated, *args, **kwargs) host = getattr(self.driver, driver_method)(elevated, *args,
**kwargs)
except AttributeError: except AttributeError:
host = self.driver.schedule(elevated, topic, *args, **kwargs) host = self.driver.schedule(elevated, topic, *args, **kwargs)

View File

@@ -36,6 +36,7 @@ flags.DEFINE_integer("max_gigabytes", 10000,
flags.DEFINE_integer("max_networks", 1000, flags.DEFINE_integer("max_networks", 1000,
"maximum number of networks to allow per host") "maximum number of networks to allow per host")
class SimpleScheduler(chance.ChanceScheduler): class SimpleScheduler(chance.ChanceScheduler):
"""Implements Naive Scheduler that tries to find least loaded host.""" """Implements Naive Scheduler that tries to find least loaded host."""

View File

@@ -29,9 +29,12 @@ from nova.auth import manager
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
class Context(object): class Context(object):
pass pass
class AccessTestCase(test.TrialTestCase): class AccessTestCase(test.TrialTestCase):
def setUp(self): def setUp(self):
super(AccessTestCase, self).setUp() super(AccessTestCase, self).setUp()
@@ -56,9 +59,11 @@ class AccessTestCase(test.TrialTestCase):
self.project.add_role(self.testnet, 'netadmin') self.project.add_role(self.testnet, 'netadmin')
self.project.add_role(self.testsys, 'sysadmin') self.project.add_role(self.testsys, 'sysadmin')
#user is set in each test #user is set in each test
def noopWSGIApp(environ, start_response): def noopWSGIApp(environ, start_response):
start_response('200 OK', []) start_response('200 OK', [])
return [''] return ['']
self.mw = ec2.Authorizer(noopWSGIApp) self.mw = ec2.Authorizer(noopWSGIApp)
self.mw.action_roles = {'str': { self.mw.action_roles = {'str': {
'_allow_all': ['all'], '_allow_all': ['all'],
@@ -80,7 +85,7 @@ class AccessTestCase(test.TrialTestCase):
def response_status(self, user, methodName): def response_status(self, user, methodName):
ctxt = context.RequestContext(user, self.project) ctxt = context.RequestContext(user, self.project)
environ = {'ec2.context' : ctxt, environ = {'ec2.context': ctxt,
'ec2.controller': 'some string', 'ec2.controller': 'some string',
'ec2.action': methodName} 'ec2.action': methodName}
req = webob.Request.blank('/', environ) req = webob.Request.blank('/', environ)

View File

@@ -66,8 +66,7 @@ class Test(unittest.TestCase):
def test_metadata(self): def test_metadata(self):
def go(url): def go(url):
result = self._request(url, 'ec2', result = self._request(url, 'ec2', REMOTE_ADDR='128.192.151.2')
REMOTE_ADDR='128.192.151.2')
# Each should get to the ORM layer and fail to find the IP # Each should get to the ORM layer and fail to find the IP
self.assertRaises(nova.exception.NotFound, go, '/latest/') self.assertRaises(nova.exception.NotFound, go, '/latest/')
self.assertRaises(nova.exception.NotFound, go, '/2009-04-04/') self.assertRaises(nova.exception.NotFound, go, '/2009-04-04/')
@@ -78,6 +77,5 @@ class Test(unittest.TestCase):
self.assertTrue('2007-12-15\n' in result.body) self.assertTrue('2007-12-15\n' in result.body)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -1,6 +1,24 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 OpenStack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import webob.dec import webob.dec
from nova import wsgi from nova import wsgi
class APIStub(object): class APIStub(object):
"""Class to verify request and mark it was called.""" """Class to verify request and mark it was called."""
@webob.dec.wsgify @webob.dec.wsgify

View File

@@ -27,11 +27,13 @@ class RateLimitingMiddlewareTest(unittest.TestCase):
def test_get_action_name(self): def test_get_action_name(self):
middleware = RateLimitingMiddleware(APIStub()) middleware = RateLimitingMiddleware(APIStub())
def verify(method, url, action_name): def verify(method, url, action_name):
req = Request.blank(url) req = Request.blank(url)
req.method = method req.method = method
action = middleware.get_action_name(req) action = middleware.get_action_name(req)
self.assertEqual(action, action_name) self.assertEqual(action, action_name)
verify('PUT', '/servers/4', 'PUT') verify('PUT', '/servers/4', 'PUT')
verify('DELETE', '/servers/4', 'DELETE') verify('DELETE', '/servers/4', 'DELETE')
verify('POST', '/images/4', 'POST') verify('POST', '/images/4', 'POST')
@@ -60,7 +62,7 @@ class RateLimitingMiddlewareTest(unittest.TestCase):
middleware = RateLimitingMiddleware(APIStub()) middleware = RateLimitingMiddleware(APIStub())
self.exhaust(middleware, 'POST', '/servers/4', 'usr1', 10) self.exhaust(middleware, 'POST', '/servers/4', 'usr1', 10)
self.exhaust(middleware, 'POST', '/images/4', 'usr2', 10) self.exhaust(middleware, 'POST', '/images/4', 'usr2', 10)
self.assertTrue(set(middleware.limiter._levels) == self.assertTrue(set(middleware.limiter._levels) ==
set(['usr1:POST', 'usr1:POST servers', 'usr2:POST'])) set(['usr1:POST', 'usr1:POST servers', 'usr2:POST']))
def test_POST_servers_action_correctly_ratelimited(self): def test_POST_servers_action_correctly_ratelimited(self):
@@ -85,19 +87,19 @@ class LimiterTest(unittest.TestCase):
def test_limiter(self): def test_limiter(self):
items = range(2000) items = range(2000)
req = Request.blank('/') req = Request.blank('/')
self.assertEqual(limited(items, req), items[ :1000]) self.assertEqual(limited(items, req), items[:1000])
req = Request.blank('/?offset=0') req = Request.blank('/?offset=0')
self.assertEqual(limited(items, req), items[ :1000]) self.assertEqual(limited(items, req), items[:1000])
req = Request.blank('/?offset=3') req = Request.blank('/?offset=3')
self.assertEqual(limited(items, req), items[3:1003]) self.assertEqual(limited(items, req), items[3:1003])
req = Request.blank('/?offset=2005') req = Request.blank('/?offset=2005')
self.assertEqual(limited(items, req), []) self.assertEqual(limited(items, req), [])
req = Request.blank('/?limit=10') req = Request.blank('/?limit=10')
self.assertEqual(limited(items, req), items[ :10]) self.assertEqual(limited(items, req), items[:10])
req = Request.blank('/?limit=0') req = Request.blank('/?limit=0')
self.assertEqual(limited(items, req), items[ :1000]) self.assertEqual(limited(items, req), items[:1000])
req = Request.blank('/?limit=3000') req = Request.blank('/?limit=3000')
self.assertEqual(limited(items, req), items[ :1000]) self.assertEqual(limited(items, req), items[:1000])
req = Request.blank('/?offset=1&limit=3') req = Request.blank('/?offset=1&limit=3')
self.assertEqual(limited(items, req), items[1:4]) self.assertEqual(limited(items, req), items[1:4])
req = Request.blank('/?offset=3&limit=0') req = Request.blank('/?offset=3&limit=0')

View File

@@ -36,7 +36,7 @@ from nova.wsgi import Router
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
class Context(object): class Context(object):
pass pass
@@ -84,11 +84,11 @@ def stub_out_image_service(stubs):
def stub_out_auth(stubs): def stub_out_auth(stubs):
def fake_auth_init(self, app): def fake_auth_init(self, app):
self.application = app self.application = app
stubs.Set(nova.api.openstack.AuthMiddleware, stubs.Set(nova.api.openstack.AuthMiddleware,
'__init__', fake_auth_init) '__init__', fake_auth_init)
stubs.Set(nova.api.openstack.AuthMiddleware, stubs.Set(nova.api.openstack.AuthMiddleware,
'__call__', fake_wsgi) '__call__', fake_wsgi)
def stub_out_rate_limiting(stubs): def stub_out_rate_limiting(stubs):
@@ -105,7 +105,7 @@ def stub_out_rate_limiting(stubs):
def stub_out_networking(stubs): def stub_out_networking(stubs):
def get_my_ip(): def get_my_ip():
return '127.0.0.1' return '127.0.0.1'
stubs.Set(nova.utils, 'get_my_ip', get_my_ip) stubs.Set(nova.utils, 'get_my_ip', get_my_ip)
FLAGS.FAKE_subdomain = 'api' FLAGS.FAKE_subdomain = 'api'
@@ -137,7 +137,6 @@ def stub_out_glance(stubs, initial_fixtures=[]):
return id return id
def fake_update_image_metadata(self, image_id, image_data): def fake_update_image_metadata(self, image_id, image_data):
f = self.fake_get_image_metadata(image_id) f = self.fake_get_image_metadata(image_id)
if not f: if not f:
raise exc.NotFound raise exc.NotFound
@@ -145,7 +144,6 @@ def stub_out_glance(stubs, initial_fixtures=[]):
f.update(image_data) f.update(image_data)
def fake_delete_image_metadata(self, image_id): def fake_delete_image_metadata(self, image_id):
f = self.fake_get_image_metadata(image_id) f = self.fake_get_image_metadata(image_id)
if not f: if not f:
raise exc.NotFound raise exc.NotFound
@@ -164,9 +162,11 @@ def stub_out_glance(stubs, initial_fixtures=[]):
fake_parallax_client.fake_get_image_metadata) fake_parallax_client.fake_get_image_metadata)
stubs.Set(nova.image.services.glance.ParallaxClient, 'add_image_metadata', stubs.Set(nova.image.services.glance.ParallaxClient, 'add_image_metadata',
fake_parallax_client.fake_add_image_metadata) fake_parallax_client.fake_add_image_metadata)
stubs.Set(nova.image.services.glance.ParallaxClient, 'update_image_metadata', stubs.Set(nova.image.services.glance.ParallaxClient,
'update_image_metadata',
fake_parallax_client.fake_update_image_metadata) fake_parallax_client.fake_update_image_metadata)
stubs.Set(nova.image.services.glance.ParallaxClient, 'delete_image_metadata', stubs.Set(nova.image.services.glance.ParallaxClient,
'delete_image_metadata',
fake_parallax_client.fake_delete_image_metadata) fake_parallax_client.fake_delete_image_metadata)
stubs.Set(nova.image.services.glance.GlanceImageService, 'delete_all', stubs.Set(nova.image.services.glance.GlanceImageService, 'delete_all',
fake_parallax_client.fake_delete_all) fake_parallax_client.fake_delete_all)
@@ -174,7 +174,7 @@ def stub_out_glance(stubs, initial_fixtures=[]):
class FakeToken(object): class FakeToken(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
for k,v in kwargs.iteritems(): for k, v in kwargs.iteritems():
setattr(self, k, v) setattr(self, k, v)
@@ -200,7 +200,7 @@ class FakeAuthDatabase(object):
class FakeAuthManager(object): class FakeAuthManager(object):
auth_data = {} auth_data = {}
def add_user(self, key, user): def add_user(self, key, user):
FakeAuthManager.auth_data[key] = user FakeAuthManager.auth_data[key] = user
def get_user(self, uid): def get_user(self, uid):

View File

@@ -1,3 +1,20 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 OpenStack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime import datetime
import unittest import unittest
@@ -11,7 +28,9 @@ import nova.auth.manager
from nova import auth from nova import auth
from nova.tests.api.openstack import fakes from nova.tests.api.openstack import fakes
class Test(unittest.TestCase): class Test(unittest.TestCase):
def setUp(self): def setUp(self):
self.stubs = stubout.StubOutForTesting() self.stubs = stubout.StubOutForTesting()
self.stubs.Set(nova.api.openstack.auth.BasicApiAuthManager, self.stubs.Set(nova.api.openstack.auth.BasicApiAuthManager,
@@ -42,7 +61,7 @@ class Test(unittest.TestCase):
def test_authorize_token(self): def test_authorize_token(self):
f = fakes.FakeAuthManager() f = fakes.FakeAuthManager()
f.add_user('derp', nova.auth.manager.User(1, 'herp', None, None, None)) f.add_user('derp', nova.auth.manager.User(1, 'herp', None, None, None))
req = webob.Request.blank('/v1.0/') req = webob.Request.blank('/v1.0/')
req.headers['X-Auth-User'] = 'herp' req.headers['X-Auth-User'] = 'herp'
req.headers['X-Auth-Key'] = 'derp' req.headers['X-Auth-Key'] = 'derp'
@@ -63,14 +82,14 @@ class Test(unittest.TestCase):
result = req.get_response(nova.api.API()) result = req.get_response(nova.api.API())
self.assertEqual(result.status, '200 OK') self.assertEqual(result.status, '200 OK')
self.assertEqual(result.headers['X-Test-Success'], 'True') self.assertEqual(result.headers['X-Test-Success'], 'True')
def test_token_expiry(self): def test_token_expiry(self):
self.destroy_called = False self.destroy_called = False
token_hash = 'bacon' token_hash = 'bacon'
def destroy_token_mock(meh, context, token): def destroy_token_mock(meh, context, token):
self.destroy_called = True self.destroy_called = True
def bad_token(meh, context, token_hash): def bad_token(meh, context, token_hash):
return fakes.FakeToken( return fakes.FakeToken(
token_hash=token_hash, token_hash=token_hash,

View File

@@ -1,3 +1,20 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 OpenStack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import unittest import unittest
import webob import webob
import webob.dec import webob.dec
@@ -5,6 +22,7 @@ import webob.exc
from nova.api.openstack import faults from nova.api.openstack import faults
class TestFaults(unittest.TestCase): class TestFaults(unittest.TestCase):
def test_fault_parts(self): def test_fault_parts(self):
@@ -19,7 +37,7 @@ class TestFaults(unittest.TestCase):
def test_retry_header(self): def test_retry_header(self):
req = webob.Request.blank('/.xml') req = webob.Request.blank('/.xml')
exc = webob.exc.HTTPRequestEntityTooLarge(explanation='sorry', exc = webob.exc.HTTPRequestEntityTooLarge(explanation='sorry',
headers={'Retry-After': 4}) headers={'Retry-After': 4})
f = faults.Fault(exc) f = faults.Fault(exc)
resp = req.get_response(f) resp = req.get_response(f)

View File

@@ -90,7 +90,7 @@ class BaseImageServiceTests(object):
id = self.service.create(fixture) id = self.service.create(fixture)
fixture['status'] = 'in progress' fixture['status'] = 'in progress'
self.service.update(id, fixture) self.service.update(id, fixture)
new_image_data = self.service.show(id) new_image_data = self.service.show(id)
self.assertEquals('in progress', new_image_data['status']) self.assertEquals('in progress', new_image_data['status'])
@@ -121,7 +121,7 @@ class BaseImageServiceTests(object):
num_images = len(self.service.index()) num_images = len(self.service.index())
self.assertEquals(2, num_images, str(self.service.index())) self.assertEquals(2, num_images, str(self.service.index()))
self.service.delete(ids[0]) self.service.delete(ids[0])
num_images = len(self.service.index()) num_images = len(self.service.index())
@@ -135,7 +135,8 @@ class LocalImageServiceTest(unittest.TestCase,
def setUp(self): def setUp(self):
self.stubs = stubout.StubOutForTesting() self.stubs = stubout.StubOutForTesting()
self.service = utils.import_object('nova.image.service.LocalImageService') service_class = 'nova.image.service.LocalImageService'
self.service = utils.import_object(service_class)
def tearDown(self): def tearDown(self):
self.service.delete_all() self.service.delete_all()
@@ -150,7 +151,8 @@ class GlanceImageServiceTest(unittest.TestCase,
def setUp(self): def setUp(self):
self.stubs = stubout.StubOutForTesting() self.stubs = stubout.StubOutForTesting()
fakes.stub_out_glance(self.stubs) fakes.stub_out_glance(self.stubs)
self.service = utils.import_object('nova.image.services.glance.GlanceImageService') service_class = 'nova.image.services.glance.GlanceImageService'
self.service = utils.import_object(service_class)
self.service.delete_all() self.service.delete_all()
def tearDown(self): def tearDown(self):
@@ -172,8 +174,7 @@ class ImageControllerWithGlanceServiceTest(unittest.TestCase):
'deleted': False, 'deleted': False,
'is_public': True, 'is_public': True,
'status': 'available', 'status': 'available',
'image_type': 'kernel' 'image_type': 'kernel'},
},
{'id': 'slkduhfas73kkaskgdas', {'id': 'slkduhfas73kkaskgdas',
'name': 'public image #2', 'name': 'public image #2',
'created_at': str(datetime.datetime.utcnow()), 'created_at': str(datetime.datetime.utcnow()),
@@ -182,9 +183,7 @@ class ImageControllerWithGlanceServiceTest(unittest.TestCase):
'deleted': False, 'deleted': False,
'is_public': True, 'is_public': True,
'status': 'available', 'status': 'available',
'image_type': 'ramdisk' 'image_type': 'ramdisk'}]
},
]
def setUp(self): def setUp(self):
self.orig_image_service = FLAGS.image_service self.orig_image_service = FLAGS.image_service
@@ -211,7 +210,8 @@ class ImageControllerWithGlanceServiceTest(unittest.TestCase):
in self.IMAGE_FIXTURES] in self.IMAGE_FIXTURES]
for image in res_dict['images']: for image in res_dict['images']:
self.assertEquals(1, fixture_index.count(image), "image %s not in fixture index!" % str(image)) self.assertEquals(1, fixture_index.count(image),
"image %s not in fixture index!" % str(image))
def test_get_image_details(self): def test_get_image_details(self):
req = webob.Request.blank('/v1.0/images/detail') req = webob.Request.blank('/v1.0/images/detail')
@@ -219,4 +219,5 @@ class ImageControllerWithGlanceServiceTest(unittest.TestCase):
res_dict = json.loads(res.body) res_dict = json.loads(res.body)
for image in res_dict['images']: for image in res_dict['images']:
self.assertEquals(1, self.IMAGE_FIXTURES.count(image), "image %s not in fixtures!" % str(image)) self.assertEquals(1, self.IMAGE_FIXTURES.count(image),
"image %s not in fixtures!" % str(image))

View File

@@ -6,6 +6,7 @@ import webob
import nova.api.openstack.ratelimiting as ratelimiting import nova.api.openstack.ratelimiting as ratelimiting
class LimiterTest(unittest.TestCase): class LimiterTest(unittest.TestCase):
def setUp(self): def setUp(self):
@@ -66,13 +67,16 @@ class LimiterTest(unittest.TestCase):
class FakeLimiter(object): class FakeLimiter(object):
"""Fake Limiter class that you can tell how to behave.""" """Fake Limiter class that you can tell how to behave."""
def __init__(self, test): def __init__(self, test):
self._action = self._username = self._delay = None self._action = self._username = self._delay = None
self.test = test self.test = test
def mock(self, action, username, delay): def mock(self, action, username, delay):
self._action = action self._action = action
self._username = username self._username = username
self._delay = delay self._delay = delay
def perform(self, action, username): def perform(self, action, username):
self.test.assertEqual(action, self._action) self.test.assertEqual(action, self._action)
self.test.assertEqual(username, self._username) self.test.assertEqual(username, self._username)
@@ -88,7 +92,7 @@ class WSGIAppTest(unittest.TestCase):
def test_invalid_methods(self): def test_invalid_methods(self):
requests = [] requests = []
for method in ['GET', 'PUT', 'DELETE']: for method in ['GET', 'PUT', 'DELETE']:
req = webob.Request.blank('/limits/michael/breakdance', req = webob.Request.blank('/limits/michael/breakdance',
dict(REQUEST_METHOD=method)) dict(REQUEST_METHOD=method))
requests.append(req) requests.append(req)
for req in requests: for req in requests:
@@ -180,7 +184,7 @@ def wire_HTTPConnection_to_WSGI(host, app):
the connection object will be a fake. Its requests will be sent directly the connection object will be a fake. Its requests will be sent directly
to the given WSGI app rather than through a socket. to the given WSGI app rather than through a socket.
Code connecting to hosts other than host will not be affected. Code connecting to hosts other than host will not be affected.
This method may be called multiple times to map different hosts to This method may be called multiple times to map different hosts to
@@ -189,13 +193,16 @@ def wire_HTTPConnection_to_WSGI(host, app):
class HTTPConnectionDecorator(object): class HTTPConnectionDecorator(object):
"""Wraps the real HTTPConnection class so that when you instantiate """Wraps the real HTTPConnection class so that when you instantiate
the class you might instead get a fake instance.""" the class you might instead get a fake instance."""
def __init__(self, wrapped): def __init__(self, wrapped):
self.wrapped = wrapped self.wrapped = wrapped
def __call__(self, connection_host, *args, **kwargs): def __call__(self, connection_host, *args, **kwargs):
if connection_host == host: if connection_host == host:
return FakeHttplibConnection(app, host) return FakeHttplibConnection(app, host)
else: else:
return self.wrapped(connection_host, *args, **kwargs) return self.wrapped(connection_host, *args, **kwargs)
httplib.HTTPConnection = HTTPConnectionDecorator(httplib.HTTPConnection) httplib.HTTPConnection = HTTPConnectionDecorator(httplib.HTTPConnection)

View File

@@ -32,9 +32,9 @@ from nova.tests.api.openstack import fakes
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
FLAGS.verbose = True FLAGS.verbose = True
def return_server(context, id): def return_server(context, id):
return stub_instance(id) return stub_instance(id)
@@ -44,10 +44,8 @@ def return_servers(context, user_id=1):
def stub_instance(id, user_id=1): def stub_instance(id, user_id=1):
return Instance( return Instance(id=id, state=0, image_id=10, server_name='server%s' % id,
id=id, state=0, image_id=10, server_name='server%s'%id, user_id=user_id)
user_id=user_id
)
class ServersTest(unittest.TestCase): class ServersTest(unittest.TestCase):
@@ -61,9 +59,10 @@ class ServersTest(unittest.TestCase):
fakes.stub_out_key_pair_funcs(self.stubs) fakes.stub_out_key_pair_funcs(self.stubs)
fakes.stub_out_image_service(self.stubs) fakes.stub_out_image_service(self.stubs)
self.stubs.Set(nova.db.api, 'instance_get_all', return_servers) self.stubs.Set(nova.db.api, 'instance_get_all', return_servers)
self.stubs.Set(nova.db.api, 'instance_get_by_internal_id', return_server) self.stubs.Set(nova.db.api, 'instance_get_by_internal_id',
self.stubs.Set(nova.db.api, 'instance_get_all_by_user', return_server)
return_servers) self.stubs.Set(nova.db.api, 'instance_get_all_by_user',
return_servers)
def tearDown(self): def tearDown(self):
self.stubs.UnsetAll() self.stubs.UnsetAll()
@@ -79,17 +78,17 @@ class ServersTest(unittest.TestCase):
req = webob.Request.blank('/v1.0/servers') req = webob.Request.blank('/v1.0/servers')
res = req.get_response(nova.api.API()) res = req.get_response(nova.api.API())
res_dict = json.loads(res.body) res_dict = json.loads(res.body)
i = 0 i = 0
for s in res_dict['servers']: for s in res_dict['servers']:
self.assertEqual(s['id'], i) self.assertEqual(s['id'], i)
self.assertEqual(s['name'], 'server%d'%i) self.assertEqual(s['name'], 'server%d' % i)
self.assertEqual(s.get('imageId', None), None) self.assertEqual(s.get('imageId', None), None)
i += 1 i += 1
def test_create_instance(self): def test_create_instance(self):
def server_update(context, id, params): def server_update(context, id, params):
pass pass
def instance_create(context, inst): def instance_create(context, inst):
class Foo(object): class Foo(object):
@@ -98,9 +97,9 @@ class ServersTest(unittest.TestCase):
def fake_method(*args, **kwargs): def fake_method(*args, **kwargs):
pass pass
def project_get_network(context, user_id): def project_get_network(context, user_id):
return dict(id='1', host='localhost') return dict(id='1', host='localhost')
def queue_get_for(context, *args): def queue_get_for(context, *args):
return 'network_topic' return 'network_topic'
@@ -114,11 +113,10 @@ class ServersTest(unittest.TestCase):
self.stubs.Set(nova.db.api, 'queue_get_for', queue_get_for) self.stubs.Set(nova.db.api, 'queue_get_for', queue_get_for)
self.stubs.Set(nova.network.manager.VlanManager, 'allocate_fixed_ip', self.stubs.Set(nova.network.manager.VlanManager, 'allocate_fixed_ip',
fake_method) fake_method)
body = dict(server=dict( body = dict(server=dict(
name='server_test', imageId=2, flavorId=2, metadata={}, name='server_test', imageId=2, flavorId=2, metadata={},
personality = {} personality={}))
))
req = webob.Request.blank('/v1.0/servers') req = webob.Request.blank('/v1.0/servers')
req.method = 'POST' req.method = 'POST'
req.body = json.dumps(body) req.body = json.dumps(body)
@@ -188,44 +186,41 @@ class ServersTest(unittest.TestCase):
req = webob.Request.blank('/v1.0/servers/detail') req = webob.Request.blank('/v1.0/servers/detail')
res = req.get_response(nova.api.API()) res = req.get_response(nova.api.API())
res_dict = json.loads(res.body) res_dict = json.loads(res.body)
i = 0 i = 0
for s in res_dict['servers']: for s in res_dict['servers']:
self.assertEqual(s['id'], i) self.assertEqual(s['id'], i)
self.assertEqual(s['name'], 'server%d'%i) self.assertEqual(s['name'], 'server%d' % i)
self.assertEqual(s['imageId'], 10) self.assertEqual(s['imageId'], 10)
i += 1 i += 1
def test_server_reboot(self): def test_server_reboot(self):
body = dict(server=dict( body = dict(server=dict(
name='server_test', imageId=2, flavorId=2, metadata={}, name='server_test', imageId=2, flavorId=2, metadata={},
personality = {} personality={}))
))
req = webob.Request.blank('/v1.0/servers/1/action') req = webob.Request.blank('/v1.0/servers/1/action')
req.method = 'POST' req.method = 'POST'
req.content_type= 'application/json' req.content_type = 'application/json'
req.body = json.dumps(body) req.body = json.dumps(body)
res = req.get_response(nova.api.API()) res = req.get_response(nova.api.API())
def test_server_rebuild(self): def test_server_rebuild(self):
body = dict(server=dict( body = dict(server=dict(
name='server_test', imageId=2, flavorId=2, metadata={}, name='server_test', imageId=2, flavorId=2, metadata={},
personality = {} personality={}))
))
req = webob.Request.blank('/v1.0/servers/1/action') req = webob.Request.blank('/v1.0/servers/1/action')
req.method = 'POST' req.method = 'POST'
req.content_type= 'application/json' req.content_type = 'application/json'
req.body = json.dumps(body) req.body = json.dumps(body)
res = req.get_response(nova.api.API()) res = req.get_response(nova.api.API())
def test_server_resize(self): def test_server_resize(self):
body = dict(server=dict( body = dict(server=dict(
name='server_test', imageId=2, flavorId=2, metadata={}, name='server_test', imageId=2, flavorId=2, metadata={},
personality = {} personality={}))
))
req = webob.Request.blank('/v1.0/servers/1/action') req = webob.Request.blank('/v1.0/servers/1/action')
req.method = 'POST' req.method = 'POST'
req.content_type= 'application/json' req.content_type = 'application/json'
req.body = json.dumps(body) req.body = json.dumps(body)
res = req.get_response(nova.api.API()) res = req.get_response(nova.api.API())
@@ -234,8 +229,9 @@ class ServersTest(unittest.TestCase):
req.method = 'DELETE' req.method = 'DELETE'
self.server_delete_called = False self.server_delete_called = False
def instance_destroy_mock(context, id): def instance_destroy_mock(context, id):
self.server_delete_called = True self.server_delete_called = True
self.stubs.Set(nova.db.api, 'instance_destroy', self.stubs.Set(nova.db.api, 'instance_destroy',
instance_destroy_mock) instance_destroy_mock)

View File

@@ -72,7 +72,7 @@ class Test(unittest.TestCase):
"""Test controller to call from router.""" """Test controller to call from router."""
test = self test = self
def show(self, req, id): # pylint: disable-msg=W0622,C0103 def show(self, req, id): # pylint: disable-msg=W0622,C0103
"""Default action called for requests with an ID.""" """Default action called for requests with an ID."""
self.test.assertEqual(req.path_info, '/tests/123') self.test.assertEqual(req.path_info, '/tests/123')
self.test.assertEqual(id, '123') self.test.assertEqual(id, '123')
@@ -95,7 +95,7 @@ class Test(unittest.TestCase):
class SerializerTest(unittest.TestCase): class SerializerTest(unittest.TestCase):
def match(self, url, accept, expect): def match(self, url, accept, expect):
input_dict = dict(servers=dict(a=(2,3))) input_dict = dict(servers=dict(a=(2, 3)))
expected_xml = '<servers><a>(2,3)</a></servers>' expected_xml = '<servers><a>(2,3)</a></servers>'
expected_json = '{"servers":{"a":[2,3]}}' expected_json = '{"servers":{"a":[2,3]}}'
req = webob.Request.blank(url, headers=dict(Accept=accept)) req = webob.Request.blank(url, headers=dict(Accept=accept))

View File

@@ -28,16 +28,17 @@ CLC_IP = '127.0.0.1'
CLC_PORT = 8773 CLC_PORT = 8773
REGION = 'test' REGION = 'test'
def get_connection(): def get_connection():
return boto.connect_ec2 ( return boto.connect_ec2(
aws_access_key_id=ACCESS_KEY, aws_access_key_id=ACCESS_KEY,
aws_secret_access_key=SECRET_KEY, aws_secret_access_key=SECRET_KEY,
is_secure=False, is_secure=False,
region=RegionInfo(None, REGION, CLC_IP), region=RegionInfo(None, REGION, CLC_IP),
port=CLC_PORT, port=CLC_PORT,
path='/services/Cloud', path='/services/Cloud',
debug=99 debug=99)
)
class APIIntegrationTests(unittest.TestCase): class APIIntegrationTests(unittest.TestCase):
def test_001_get_all_images(self): def test_001_get_all_images(self):
@@ -51,4 +52,3 @@ if __name__ == '__main__':
#print conn.get_all_key_pairs() #print conn.get_all_key_pairs()
#print conn.create_key_pair #print conn.create_key_pair
#print conn.create_security_group('name', 'description') #print conn.create_security_group('name', 'description')

View File

@@ -99,6 +99,7 @@ class XmlConversionTestCase(test.BaseTestCase):
self.assertEqual(conv('-'), '-') self.assertEqual(conv('-'), '-')
self.assertEqual(conv('-0'), 0) self.assertEqual(conv('-0'), 0)
class ApiEc2TestCase(test.BaseTestCase): class ApiEc2TestCase(test.BaseTestCase):
"""Unit test for the cloud controller on an EC2 API""" """Unit test for the cloud controller on an EC2 API"""
def setUp(self): def setUp(self):
@@ -138,7 +139,6 @@ class ApiEc2TestCase(test.BaseTestCase):
self.manager.delete_project(project) self.manager.delete_project(project)
self.manager.delete_user(user) self.manager.delete_user(user)
def test_get_all_key_pairs(self): def test_get_all_key_pairs(self):
"""Test that, after creating a user and project and generating """Test that, after creating a user and project and generating
a key pair, that the API call to list key pairs works properly""" a key pair, that the API call to list key pairs works properly"""
@@ -183,7 +183,7 @@ class ApiEc2TestCase(test.BaseTestCase):
self.manager.add_role('fake', 'netadmin') self.manager.add_role('fake', 'netadmin')
project.add_role('fake', 'netadmin') project.add_role('fake', 'netadmin')
security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd")
for x in range(random.randint(4, 8))) for x in range(random.randint(4, 8)))
self.ec2.create_security_group(security_group_name, 'test group') self.ec2.create_security_group(security_group_name, 'test group')
@@ -217,10 +217,11 @@ class ApiEc2TestCase(test.BaseTestCase):
self.manager.add_role('fake', 'netadmin') self.manager.add_role('fake', 'netadmin')
project.add_role('fake', 'netadmin') project.add_role('fake', 'netadmin')
security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd")
for x in range(random.randint(4, 8))) for x in range(random.randint(4, 8)))
group = self.ec2.create_security_group(security_group_name, 'test group') group = self.ec2.create_security_group(security_group_name,
'test group')
self.expect_http() self.expect_http()
self.mox.ReplayAll() self.mox.ReplayAll()
@@ -282,12 +283,14 @@ class ApiEc2TestCase(test.BaseTestCase):
self.manager.add_role('fake', 'netadmin') self.manager.add_role('fake', 'netadmin')
project.add_role('fake', 'netadmin') project.add_role('fake', 'netadmin')
security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ rand_string = 'sdiuisudfsdcnpaqwertasd'
security_group_name = "".join(random.choice(rand_string)
for x in range(random.randint(4, 8)))
other_security_group_name = "".join(random.choice(rand_string)
for x in range(random.randint(4, 8))) for x in range(random.randint(4, 8)))
other_security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
for x in range(random.randint(4, 8)))
group = self.ec2.create_security_group(security_group_name, 'test group') group = self.ec2.create_security_group(security_group_name,
'test group')
self.expect_http() self.expect_http()
self.mox.ReplayAll() self.mox.ReplayAll()
@@ -313,9 +316,8 @@ class ApiEc2TestCase(test.BaseTestCase):
if group.name == security_group_name: if group.name == security_group_name:
self.assertEquals(len(group.rules), 1) self.assertEquals(len(group.rules), 1)
self.assertEquals(len(group.rules[0].grants), 1) self.assertEquals(len(group.rules[0].grants), 1)
self.assertEquals(str(group.rules[0].grants[0]), self.assertEquals(str(group.rules[0].grants[0]), '%s-%s' %
'%s-%s' % (other_security_group_name, 'fake')) (other_security_group_name, 'fake'))
self.expect_http() self.expect_http()
self.mox.ReplayAll() self.mox.ReplayAll()

View File

@@ -28,6 +28,7 @@ from nova.api.ec2 import cloud
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
class user_generator(object): class user_generator(object):
def __init__(self, manager, **user_state): def __init__(self, manager, **user_state):
if 'name' not in user_state: if 'name' not in user_state:
@@ -41,6 +42,7 @@ class user_generator(object):
def __exit__(self, value, type, trace): def __exit__(self, value, type, trace):
self.manager.delete_user(self.user) self.manager.delete_user(self.user)
class project_generator(object): class project_generator(object):
def __init__(self, manager, **project_state): def __init__(self, manager, **project_state):
if 'name' not in project_state: if 'name' not in project_state:
@@ -56,6 +58,7 @@ class project_generator(object):
def __exit__(self, value, type, trace): def __exit__(self, value, type, trace):
self.manager.delete_project(self.project) self.manager.delete_project(self.project)
class user_and_project_generator(object): class user_and_project_generator(object):
def __init__(self, manager, user_state={}, project_state={}): def __init__(self, manager, user_state={}, project_state={}):
self.manager = manager self.manager = manager
@@ -75,6 +78,7 @@ class user_and_project_generator(object):
self.manager.delete_user(self.user) self.manager.delete_user(self.user)
self.manager.delete_project(self.project) self.manager.delete_project(self.project)
class AuthManagerTestCase(object): class AuthManagerTestCase(object):
def setUp(self): def setUp(self):
FLAGS.auth_driver = self.auth_driver FLAGS.auth_driver = self.auth_driver
@@ -96,7 +100,7 @@ class AuthManagerTestCase(object):
self.assertEqual('private-party', u.access) self.assertEqual('private-party', u.access)
def test_004_signature_is_valid(self): def test_004_signature_is_valid(self):
#self.assertTrue(self.manager.authenticate( **boto.generate_url ... ? ? ? )) #self.assertTrue(self.manager.authenticate(**boto.generate_url ...? ))
pass pass
#raise NotImplementedError #raise NotImplementedError
@@ -127,7 +131,7 @@ class AuthManagerTestCase(object):
self.assertFalse(self.manager.has_role('test1', 'itsec')) self.assertFalse(self.manager.has_role('test1', 'itsec'))
def test_can_create_and_get_project(self): def test_can_create_and_get_project(self):
with user_and_project_generator(self.manager) as (u,p): with user_and_project_generator(self.manager) as (u, p):
self.assert_(self.manager.get_user('test1')) self.assert_(self.manager.get_user('test1'))
self.assert_(self.manager.get_user('test1')) self.assert_(self.manager.get_user('test1'))
self.assert_(self.manager.get_project('testproj')) self.assert_(self.manager.get_project('testproj'))
@@ -321,6 +325,7 @@ class AuthManagerTestCase(object):
self.assertEqual('secret', user.secret) self.assertEqual('secret', user.secret)
self.assertTrue(user.is_admin()) self.assertTrue(user.is_admin())
class AuthManagerLdapTestCase(AuthManagerTestCase, test.TrialTestCase): class AuthManagerLdapTestCase(AuthManagerTestCase, test.TrialTestCase):
auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver' auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
@@ -337,6 +342,7 @@ class AuthManagerLdapTestCase(AuthManagerTestCase, test.TrialTestCase):
except: except:
self.skip = True self.skip = True
class AuthManagerDbTestCase(AuthManagerTestCase, test.TrialTestCase): class AuthManagerDbTestCase(AuthManagerTestCase, test.TrialTestCase):
auth_driver = 'nova.auth.dbdriver.DbDriver' auth_driver = 'nova.auth.dbdriver.DbDriver'

View File

@@ -46,13 +46,13 @@ from nova.objectstore import image
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
# Temp dirs for working with image attributes through the cloud controller # Temp dirs for working with image attributes through the cloud controller
# (stole this from objectstore_unittest.py) # (stole this from objectstore_unittest.py)
OSS_TEMPDIR = tempfile.mkdtemp(prefix='test_oss-') OSS_TEMPDIR = tempfile.mkdtemp(prefix='test_oss-')
IMAGES_PATH = os.path.join(OSS_TEMPDIR, 'images') IMAGES_PATH = os.path.join(OSS_TEMPDIR, 'images')
os.makedirs(IMAGES_PATH) os.makedirs(IMAGES_PATH)
class CloudTestCase(test.TrialTestCase): class CloudTestCase(test.TrialTestCase):
def setUp(self): def setUp(self):
super(CloudTestCase, self).setUp() super(CloudTestCase, self).setUp()
@@ -97,17 +97,17 @@ class CloudTestCase(test.TrialTestCase):
max_count = 1 max_count = 1
kwargs = {'image_id': image_id, kwargs = {'image_id': image_id,
'instance_type': instance_type, 'instance_type': instance_type,
'max_count': max_count } 'max_count': max_count}
rv = yield self.cloud.run_instances(self.context, **kwargs) rv = yield self.cloud.run_instances(self.context, **kwargs)
instance_id = rv['instancesSet'][0]['instanceId'] instance_id = rv['instancesSet'][0]['instanceId']
output = yield self.cloud.get_console_output(context=self.context, instance_id=[instance_id]) output = yield self.cloud.get_console_output(context=self.context,
instance_id=[instance_id])
self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE OUTPUT') self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE OUTPUT')
# TODO(soren): We need this until we can stop polling in the rpc code # TODO(soren): We need this until we can stop polling in the rpc code
# for unit tests. # for unit tests.
greenthread.sleep(0.3) greenthread.sleep(0.3)
rv = yield self.cloud.terminate_instances(self.context, [instance_id]) rv = yield self.cloud.terminate_instances(self.context, [instance_id])
def test_key_generation(self): def test_key_generation(self):
result = self._create_key('test') result = self._create_key('test')
private_key = result['private_key'] private_key = result['private_key']
@@ -146,8 +146,10 @@ class CloudTestCase(test.TrialTestCase):
'max_count': max_count} 'max_count': max_count}
rv = yield self.cloud.run_instances(self.context, **kwargs) rv = yield self.cloud.run_instances(self.context, **kwargs)
# TODO: check for proper response # TODO: check for proper response
instance = rv['reservationSet'][0][rv['reservationSet'][0].keys()[0]][0] instance_id = rv['reservationSet'][0].keys()[0]
logging.debug("Need to watch instance %s until it's running..." % instance['instance_id']) instance = rv['reservationSet'][0][instance_id][0]
logging.debug("Need to watch instance %s until it's running..." %
instance['instance_id'])
while True: while True:
rv = yield defer.succeed(time.sleep(1)) rv = yield defer.succeed(time.sleep(1))
info = self.cloud._get_instance(instance['instance_id']) info = self.cloud._get_instance(instance['instance_id'])
@@ -157,14 +159,15 @@ class CloudTestCase(test.TrialTestCase):
self.assert_(rv) self.assert_(rv)
if connection_type != 'fake': if connection_type != 'fake':
time.sleep(45) # Should use boto for polling here time.sleep(45) # Should use boto for polling here
for reservations in rv['reservationSet']: for reservations in rv['reservationSet']:
# for res_id in reservations.keys(): # for res_id in reservations.keys():
# logging.debug(reservations[res_id]) # logging.debug(reservations[res_id])
# for instance in reservations[res_id]: # for instance in reservations[res_id]:
for instance in reservations[reservations.keys()[0]]: for instance in reservations[reservations.keys()[0]]:
logging.debug("Terminating instance %s" % instance['instance_id']) instance_id = instance['instance_id']
rv = yield self.compute.terminate_instance(instance['instance_id']) logging.debug("Terminating instance %s" % instance_id)
rv = yield self.compute.terminate_instance(instance_id)
def test_instance_update_state(self): def test_instance_update_state(self):
def instance(num): def instance(num):
@@ -183,8 +186,7 @@ class CloudTestCase(test.TrialTestCase):
'groups': ['default'], 'groups': ['default'],
'product_codes': None, 'product_codes': None,
'state': 0x01, 'state': 0x01,
'user_data': '' 'user_data': ''}
}
rv = self.cloud._format_describe_instances(self.context) rv = self.cloud._format_describe_instances(self.context)
self.assert_(len(rv['reservationSet']) == 0) self.assert_(len(rv['reservationSet']) == 0)
@@ -199,7 +201,9 @@ class CloudTestCase(test.TrialTestCase):
#self.assert_(len(rv['reservationSet'][0]['instances_set']) == 5) #self.assert_(len(rv['reservationSet'][0]['instances_set']) == 5)
# report 4 nodes each having 1 of the instances # report 4 nodes each having 1 of the instances
#for i in xrange(4): #for i in xrange(4):
# self.cloud.update_state('instances', {('node-%s' % i): {('i-%s' % i): instance(i)}}) # self.cloud.update_state('instances',
# {('node-%s' % i): {('i-%s' % i):
# instance(i)}})
# one instance should be pending still # one instance should be pending still
#self.assert_(len(self.cloud.instances['pending'].keys()) == 1) #self.assert_(len(self.cloud.instances['pending'].keys()) == 1)
@@ -217,8 +221,10 @@ class CloudTestCase(test.TrialTestCase):
@staticmethod @staticmethod
def _fake_set_image_description(ctxt, image_id, description): def _fake_set_image_description(ctxt, image_id, description):
from nova.objectstore import handler from nova.objectstore import handler
class req: class req:
pass pass
request = req() request = req()
request.context = ctxt request.context = ctxt
request.args = {'image_id': [image_id], request.args = {'image_id': [image_id],

View File

@@ -23,7 +23,9 @@ from nova import test
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('flags_unittest', 'foo', 'for testing purposes only') flags.DEFINE_string('flags_unittest', 'foo', 'for testing purposes only')
class FlagsTestCase(test.TrialTestCase): class FlagsTestCase(test.TrialTestCase):
def setUp(self): def setUp(self):
super(FlagsTestCase, self).setUp() super(FlagsTestCase, self).setUp()
self.FLAGS = flags.FlagValues() self.FLAGS = flags.FlagValues()
@@ -35,7 +37,8 @@ class FlagsTestCase(test.TrialTestCase):
self.assert_('false' not in self.FLAGS) self.assert_('false' not in self.FLAGS)
self.assert_('true' not in self.FLAGS) self.assert_('true' not in self.FLAGS)
flags.DEFINE_string('string', 'default', 'desc', flag_values=self.FLAGS) flags.DEFINE_string('string', 'default', 'desc',
flag_values=self.FLAGS)
flags.DEFINE_integer('int', 1, 'desc', flag_values=self.FLAGS) flags.DEFINE_integer('int', 1, 'desc', flag_values=self.FLAGS)
flags.DEFINE_bool('false', False, 'desc', flag_values=self.FLAGS) flags.DEFINE_bool('false', False, 'desc', flag_values=self.FLAGS)
flags.DEFINE_bool('true', True, 'desc', flag_values=self.FLAGS) flags.DEFINE_bool('true', True, 'desc', flag_values=self.FLAGS)

View File

@@ -98,7 +98,6 @@ class NetworkTestCase(test.TrialTestCase):
self.context.project_id = self.projects[project_num].id self.context.project_id = self.projects[project_num].id
self.network.deallocate_fixed_ip(self.context, address) self.network.deallocate_fixed_ip(self.context, address)
def test_public_network_association(self): def test_public_network_association(self):
"""Makes sure that we can allocaate a public ip""" """Makes sure that we can allocaate a public ip"""
# TODO(vish): better way of adding floating ips # TODO(vish): better way of adding floating ips
@@ -118,10 +117,12 @@ class NetworkTestCase(test.TrialTestCase):
lease_ip(fix_addr) lease_ip(fix_addr)
self.assertEqual(float_addr, str(pubnet[0])) self.assertEqual(float_addr, str(pubnet[0]))
self.network.associate_floating_ip(self.context, float_addr, fix_addr) self.network.associate_floating_ip(self.context, float_addr, fix_addr)
address = db.instance_get_floating_address(context.get_admin_context(), self.instance_id) address = db.instance_get_floating_address(context.get_admin_context(),
self.instance_id)
self.assertEqual(address, float_addr) self.assertEqual(address, float_addr)
self.network.disassociate_floating_ip(self.context, float_addr) self.network.disassociate_floating_ip(self.context, float_addr)
address = db.instance_get_floating_address(context.get_admin_context(), self.instance_id) address = db.instance_get_floating_address(context.get_admin_context(),
self.instance_id)
self.assertEqual(address, None) self.assertEqual(address, None)
self.network.deallocate_floating_ip(self.context, float_addr) self.network.deallocate_floating_ip(self.context, float_addr)
self.network.deallocate_fixed_ip(self.context, fix_addr) self.network.deallocate_fixed_ip(self.context, fix_addr)
@@ -254,18 +255,24 @@ class NetworkTestCase(test.TrialTestCase):
There are ips reserved at the bottom and top of the range. There are ips reserved at the bottom and top of the range.
services (network, gateway, CloudPipe, broadcast) services (network, gateway, CloudPipe, broadcast)
""" """
network = db.project_get_network(context.get_admin_context(), self.projects[0].id) network = db.project_get_network(context.get_admin_context(),
self.projects[0].id)
net_size = flags.FLAGS.network_size net_size = flags.FLAGS.network_size
total_ips = (db.network_count_available_ips(context.get_admin_context(), network['id']) + admin_context = context.get_admin_context()
db.network_count_reserved_ips(context.get_admin_context(), network['id']) + total_ips = (db.network_count_available_ips(admin_context,
db.network_count_allocated_ips(context.get_admin_context(), network['id'])) network['id']) +
db.network_count_reserved_ips(admin_context,
network['id']) +
db.network_count_allocated_ips(admin_context,
network['id']))
self.assertEqual(total_ips, net_size) self.assertEqual(total_ips, net_size)
def test_too_many_addresses(self): def test_too_many_addresses(self):
"""Test for a NoMoreAddresses exception when all fixed ips are used. """Test for a NoMoreAddresses exception when all fixed ips are used.
""" """
network = db.project_get_network(context.get_admin_context(), self.projects[0].id) admin_context = context.get_admin_context()
num_available_ips = db.network_count_available_ips(context.get_admin_context(), network = db.project_get_network(admin_context, self.projects[0].id)
num_available_ips = db.network_count_available_ips(admin_context,
network['id']) network['id'])
addresses = [] addresses = []
instance_ids = [] instance_ids = []
@@ -276,8 +283,9 @@ class NetworkTestCase(test.TrialTestCase):
addresses.append(address) addresses.append(address)
lease_ip(address) lease_ip(address)
self.assertEqual(db.network_count_available_ips(context.get_admin_context(), ip_count = db.network_count_available_ips(context.get_admin_context(),
network['id']), 0) network['id'])
self.assertEqual(ip_count, 0)
self.assertRaises(db.NoMoreAddresses, self.assertRaises(db.NoMoreAddresses,
self.network.allocate_fixed_ip, self.network.allocate_fixed_ip,
self.context, self.context,
@@ -287,14 +295,15 @@ class NetworkTestCase(test.TrialTestCase):
self.network.deallocate_fixed_ip(self.context, addresses[i]) self.network.deallocate_fixed_ip(self.context, addresses[i])
release_ip(addresses[i]) release_ip(addresses[i])
db.instance_destroy(context.get_admin_context(), instance_ids[i]) db.instance_destroy(context.get_admin_context(), instance_ids[i])
self.assertEqual(db.network_count_available_ips(context.get_admin_context(), ip_count = db.network_count_available_ips(context.get_admin_context(),
network['id']), network['id'])
num_available_ips) self.assertEqual(ip_count, num_available_ips)
def is_allocated_in_project(address, project_id): def is_allocated_in_project(address, project_id):
"""Returns true if address is in specified project""" """Returns true if address is in specified project"""
project_net = db.project_get_network(context.get_admin_context(), project_id) project_net = db.project_get_network(context.get_admin_context(),
project_id)
network = db.fixed_ip_get_network(context.get_admin_context(), address) network = db.fixed_ip_get_network(context.get_admin_context(), address)
instance = db.fixed_ip_get_instance(context.get_admin_context(), address) instance = db.fixed_ip_get_instance(context.get_admin_context(), address)
# instance exists until release # instance exists until release
@@ -308,8 +317,10 @@ def binpath(script):
def lease_ip(private_ip): def lease_ip(private_ip):
"""Run add command on dhcpbridge""" """Run add command on dhcpbridge"""
network_ref = db.fixed_ip_get_network(context.get_admin_context(), private_ip) network_ref = db.fixed_ip_get_network(context.get_admin_context(),
instance_ref = db.fixed_ip_get_instance(context.get_admin_context(), private_ip) private_ip)
instance_ref = db.fixed_ip_get_instance(context.get_admin_context(),
private_ip)
cmd = "%s add %s %s fake" % (binpath('nova-dhcpbridge'), cmd = "%s add %s %s fake" % (binpath('nova-dhcpbridge'),
instance_ref['mac_address'], instance_ref['mac_address'],
private_ip) private_ip)
@@ -322,8 +333,10 @@ def lease_ip(private_ip):
def release_ip(private_ip): def release_ip(private_ip):
"""Run del command on dhcpbridge""" """Run del command on dhcpbridge"""
network_ref = db.fixed_ip_get_network(context.get_admin_context(), private_ip) network_ref = db.fixed_ip_get_network(context.get_admin_context(),
instance_ref = db.fixed_ip_get_instance(context.get_admin_context(), private_ip) private_ip)
instance_ref = db.fixed_ip_get_instance(context.get_admin_context(),
private_ip)
cmd = "%s del %s %s fake" % (binpath('nova-dhcpbridge'), cmd = "%s del %s %s fake" % (binpath('nova-dhcpbridge'),
instance_ref['mac_address'], instance_ref['mac_address'],
private_ip) private_ip)

View File

@@ -181,7 +181,7 @@ class ObjectStoreTestCase(test.TrialTestCase):
class TestHTTPChannel(http.HTTPChannel): class TestHTTPChannel(http.HTTPChannel):
"""Dummy site required for twisted.web""" """Dummy site required for twisted.web"""
def checkPersistence(self, _, __): # pylint: disable-msg=C0103 def checkPersistence(self, _, __): # pylint: disable-msg=C0103
"""Otherwise we end up with an unclean reactor.""" """Otherwise we end up with an unclean reactor."""
return False return False
@@ -217,7 +217,6 @@ class S3APITestCase(test.TrialTestCase):
# pylint: enable-msg=E1101 # pylint: enable-msg=E1101
self.tcp_port = self.listening_port.getHost().port self.tcp_port = self.listening_port.getHost().port
if not boto.config.has_section('Boto'): if not boto.config.has_section('Boto'):
boto.config.add_section('Boto') boto.config.add_section('Boto')
boto.config.set('Boto', 'num_retries', '0') boto.config.set('Boto', 'num_retries', '0')
@@ -234,11 +233,11 @@ class S3APITestCase(test.TrialTestCase):
self.conn.get_http_connection = get_http_connection self.conn.get_http_connection = get_http_connection
def _ensure_no_buckets(self, buckets): # pylint: disable-msg=C0111 def _ensure_no_buckets(self, buckets): # pylint: disable-msg=C0111
self.assertEquals(len(buckets), 0, "Bucket list was not empty") self.assertEquals(len(buckets), 0, "Bucket list was not empty")
return True return True
def _ensure_one_bucket(self, buckets, name): # pylint: disable-msg=C0111 def _ensure_one_bucket(self, buckets, name): # pylint: disable-msg=C0111
self.assertEquals(len(buckets), 1, self.assertEquals(len(buckets), 1,
"Bucket list didn't have exactly one element in it") "Bucket list didn't have exactly one element in it")
self.assertEquals(buckets[0].name, name, "Wrong name") self.assertEquals(buckets[0].name, name, "Wrong name")

View File

@@ -38,6 +38,7 @@ class ProcessTestCase(test.TrialTestCase):
def test_execute_stdout(self): def test_execute_stdout(self):
pool = process.ProcessPool(2) pool = process.ProcessPool(2)
d = pool.simple_execute('echo test') d = pool.simple_execute('echo test')
def _check(rv): def _check(rv):
self.assertEqual(rv[0], 'test\n') self.assertEqual(rv[0], 'test\n')
self.assertEqual(rv[1], '') self.assertEqual(rv[1], '')
@@ -49,6 +50,7 @@ class ProcessTestCase(test.TrialTestCase):
def test_execute_stderr(self): def test_execute_stderr(self):
pool = process.ProcessPool(2) pool = process.ProcessPool(2)
d = pool.simple_execute('cat BAD_FILE', check_exit_code=False) d = pool.simple_execute('cat BAD_FILE', check_exit_code=False)
def _check(rv): def _check(rv):
self.assertEqual(rv[0], '') self.assertEqual(rv[0], '')
self.assert_('No such file' in rv[1]) self.assert_('No such file' in rv[1])
@@ -72,6 +74,7 @@ class ProcessTestCase(test.TrialTestCase):
d4 = pool.simple_execute('sleep 0.005') d4 = pool.simple_execute('sleep 0.005')
called = [] called = []
def _called(rv, name): def _called(rv, name):
called.append(name) called.append(name)

View File

@@ -141,12 +141,13 @@ class QuotaTestCase(test.TrialTestCase):
try: try:
db.floating_ip_get_by_address(context.get_admin_context(), address) db.floating_ip_get_by_address(context.get_admin_context(), address)
except exception.NotFound: except exception.NotFound:
db.floating_ip_create(context.get_admin_context(), {'address': address, db.floating_ip_create(context.get_admin_context(),
'host': FLAGS.host}) {'address': address, 'host': FLAGS.host})
float_addr = self.network.allocate_floating_ip(self.context, float_addr = self.network.allocate_floating_ip(self.context,
self.project.id) self.project.id)
# NOTE(vish): This assert never fails. When cloud attempts to # NOTE(vish): This assert never fails. When cloud attempts to
# make an rpc.call, the test just finishes with OK. It # make an rpc.call, the test just finishes with OK. It
# appears to be something in the magic inline callbacks # appears to be something in the magic inline callbacks
# that is breaking. # that is breaking.
self.assertRaises(cloud.QuotaError, self.cloud.allocate_address, self.context) self.assertRaises(cloud.QuotaError, self.cloud.allocate_address,
self.context)

View File

@@ -41,7 +41,7 @@ class RpcTestCase(test.TrialTestCase):
topic='test', topic='test',
proxy=self.receiver) proxy=self.receiver)
self.consumer.attach_to_twisted() self.consumer.attach_to_twisted()
self.context= context.get_admin_context() self.context = context.get_admin_context()
def test_call_succeed(self): def test_call_succeed(self):
"""Get a value through rpc call""" """Get a value through rpc call"""
@@ -67,9 +67,9 @@ class RpcTestCase(test.TrialTestCase):
to an int in the test. to an int in the test.
""" """
value = 42 value = 42
self.assertFailure(rpc.call_twisted(self.context, self.assertFailure(rpc.call_twisted(self.context, 'test',
'test', {"method": "fail", {"method": "fail",
"args": {"value": value}}), "args": {"value": value}}),
rpc.RemoteError) rpc.RemoteError)
try: try:
yield rpc.call_twisted(self.context, yield rpc.call_twisted(self.context,
@@ -101,4 +101,3 @@ class TestReceiver(object):
def fail(context, value): def fail(context, value):
"""Raises an exception with the value sent in""" """Raises an exception with the value sent in"""
raise Exception(value) raise Exception(value)

View File

@@ -34,6 +34,7 @@ from nova.scheduler import driver
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DECLARE('max_cores', 'nova.scheduler.simple') flags.DECLARE('max_cores', 'nova.scheduler.simple')
class TestDriver(driver.Scheduler): class TestDriver(driver.Scheduler):
"""Scheduler Driver for Tests""" """Scheduler Driver for Tests"""
def schedule(context, topic, *args, **kwargs): def schedule(context, topic, *args, **kwargs):
@@ -42,6 +43,7 @@ class TestDriver(driver.Scheduler):
def schedule_named_method(context, topic, num): def schedule_named_method(context, topic, num):
return 'named_host' return 'named_host'
class SchedulerTestCase(test.TrialTestCase): class SchedulerTestCase(test.TrialTestCase):
"""Test case for scheduler""" """Test case for scheduler"""
def setUp(self): def setUp(self):

View File

@@ -179,7 +179,8 @@ class ServiceTestCase(test.BaseTestCase):
binary).AndRaise(exception.NotFound()) binary).AndRaise(exception.NotFound())
service.db.service_create(self.context, service.db.service_create(self.context,
service_create).AndReturn(service_ref) service_create).AndReturn(service_ref)
service.db.service_get(self.context, service_ref['id']).AndReturn(service_ref) service.db.service_get(self.context,
service_ref['id']).AndReturn(service_ref)
service.db.service_update(self.context, service_ref['id'], service.db.service_update(self.context, service_ref['id'],
mox.ContainsKeyValue('report_count', 1)) mox.ContainsKeyValue('report_count', 1))
@@ -227,4 +228,3 @@ class ServiceTestCase(test.BaseTestCase):
rv = yield s.report_state(host, binary) rv = yield s.report_state(host, binary)
self.assert_(not s.model_disconnected) self.assert_(not s.model_disconnected)

View File

@@ -35,7 +35,8 @@ class ValidationTestCase(test.TrialTestCase):
self.assertTrue(type_case("foo", 5, 1)) self.assertTrue(type_case("foo", 5, 1))
self.assertRaises(TypeError, type_case, "bar", "5", 1) self.assertRaises(TypeError, type_case, "bar", "5", 1)
self.assertRaises(TypeError, type_case, None, 5, 1) self.assertRaises(TypeError, type_case, None, 5, 1)
@validate.typetest(instanceid=str, size=int, number_of_instances=int) @validate.typetest(instanceid=str, size=int, number_of_instances=int)
def type_case(instanceid, size, number_of_instances): def type_case(instanceid, size, number_of_instances):
return True return True

View File

@@ -29,11 +29,13 @@ from nova.virt import libvirt_conn
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DECLARE('instances_path', 'nova.compute.manager') flags.DECLARE('instances_path', 'nova.compute.manager')
class LibvirtConnTestCase(test.TrialTestCase): class LibvirtConnTestCase(test.TrialTestCase):
def setUp(self): def setUp(self):
super(LibvirtConnTestCase, self).setUp() super(LibvirtConnTestCase, self).setUp()
self.manager = manager.AuthManager() self.manager = manager.AuthManager()
self.user = self.manager.create_user('fake', 'fake', 'fake', admin=True) self.user = self.manager.create_user('fake', 'fake', 'fake',
admin=True)
self.project = self.manager.create_project('fake', 'fake', 'fake') self.project = self.manager.create_project('fake', 'fake', 'fake')
self.network = utils.import_object(FLAGS.network_manager) self.network = utils.import_object(FLAGS.network_manager)
FLAGS.instances_path = '' FLAGS.instances_path = ''
@@ -41,15 +43,15 @@ class LibvirtConnTestCase(test.TrialTestCase):
def test_get_uri_and_template(self): def test_get_uri_and_template(self):
ip = '10.11.12.13' ip = '10.11.12.13'
instance = { 'internal_id' : 1, instance = {'internal_id': 1,
'memory_kb' : '1024000', 'memory_kb': '1024000',
'basepath' : '/some/path', 'basepath': '/some/path',
'bridge_name' : 'br100', 'bridge_name': 'br100',
'mac_address' : '02:12:34:46:56:67', 'mac_address': '02:12:34:46:56:67',
'vcpus' : 2, 'vcpus': 2,
'project_id' : 'fake', 'project_id': 'fake',
'bridge' : 'br101', 'bridge': 'br101',
'instance_type' : 'm1.small'} 'instance_type': 'm1.small'}
user_context = context.RequestContext(project=self.project, user_context = context.RequestContext(project=self.project,
user=self.user) user=self.user)
@@ -58,36 +60,34 @@ class LibvirtConnTestCase(test.TrialTestCase):
self.network.set_network_host(context.get_admin_context(), self.network.set_network_host(context.get_admin_context(),
network_ref['id']) network_ref['id'])
fixed_ip = { 'address' : ip, fixed_ip = {'address': ip,
'network_id' : network_ref['id'] } 'network_id': network_ref['id']}
ctxt = context.get_admin_context() ctxt = context.get_admin_context()
fixed_ip_ref = db.fixed_ip_create(ctxt, fixed_ip) fixed_ip_ref = db.fixed_ip_create(ctxt, fixed_ip)
db.fixed_ip_update(ctxt, ip, {'allocated': True, db.fixed_ip_update(ctxt, ip, {'allocated': True,
'instance_id': instance_ref['id'] }) 'instance_id': instance_ref['id']})
type_uri_map = { 'qemu' : ('qemu:///system', type_uri_map = {'qemu': ('qemu:///system',
[(lambda t: t.find('.').get('type'), 'qemu'), [(lambda t: t.find('.').get('type'), 'qemu'),
(lambda t: t.find('./os/type').text, 'hvm'), (lambda t: t.find('./os/type').text, 'hvm'),
(lambda t: t.find('./devices/emulator'), None)]), (lambda t: t.find('./devices/emulator'), None)]),
'kvm' : ('qemu:///system', 'kvm': ('qemu:///system',
[(lambda t: t.find('.').get('type'), 'kvm'), [(lambda t: t.find('.').get('type'), 'kvm'),
(lambda t: t.find('./os/type').text, 'hvm'), (lambda t: t.find('./os/type').text, 'hvm'),
(lambda t: t.find('./devices/emulator'), None)]), (lambda t: t.find('./devices/emulator'), None)]),
'uml' : ('uml:///system', 'uml': ('uml:///system',
[(lambda t: t.find('.').get('type'), 'uml'), [(lambda t: t.find('.').get('type'), 'uml'),
(lambda t: t.find('./os/type').text, 'uml')]), (lambda t: t.find('./os/type').text, 'uml')])}
}
common_checks = [(lambda t: t.find('.').tag, 'domain'), common_checks = [
(lambda t: \ (lambda t: t.find('.').tag, 'domain'),
t.find('./devices/interface/filterref/parameter') \ (lambda t: t.find('./devices/interface/filterref/parameter').\
.get('name'), 'IP'), get('name'), 'IP'),
(lambda t: \ (lambda t: t.find('./devices/interface/filterref/parameter').\
t.find('./devices/interface/filterref/parameter') \ get('value'), '10.11.12.13')]
.get('value'), '10.11.12.13')]
for (libvirt_type,(expected_uri, checks)) in type_uri_map.iteritems(): for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems():
FLAGS.libvirt_type = libvirt_type FLAGS.libvirt_type = libvirt_type
conn = libvirt_conn.LibvirtConnection(True) conn = libvirt_conn.LibvirtConnection(True)
@@ -111,19 +111,20 @@ class LibvirtConnTestCase(test.TrialTestCase):
# implementation doesn't fiddle around with the FLAGS. # implementation doesn't fiddle around with the FLAGS.
testuri = 'something completely different' testuri = 'something completely different'
FLAGS.libvirt_uri = testuri FLAGS.libvirt_uri = testuri
for (libvirt_type,(expected_uri, checks)) in type_uri_map.iteritems(): for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems():
FLAGS.libvirt_type = libvirt_type FLAGS.libvirt_type = libvirt_type
conn = libvirt_conn.LibvirtConnection(True) conn = libvirt_conn.LibvirtConnection(True)
uri, template = conn.get_uri_and_template() uri, template = conn.get_uri_and_template()
self.assertEquals(uri, testuri) self.assertEquals(uri, testuri)
def tearDown(self): def tearDown(self):
super(LibvirtConnTestCase, self).tearDown() super(LibvirtConnTestCase, self).tearDown()
self.manager.delete_project(self.project) self.manager.delete_project(self.project)
self.manager.delete_user(self.user) self.manager.delete_user(self.user)
class NWFilterTestCase(test.TrialTestCase): class NWFilterTestCase(test.TrialTestCase):
def setUp(self): def setUp(self):
super(NWFilterTestCase, self).setUp() super(NWFilterTestCase, self).setUp()
@@ -131,7 +132,8 @@ class NWFilterTestCase(test.TrialTestCase):
pass pass
self.manager = manager.AuthManager() self.manager = manager.AuthManager()
self.user = self.manager.create_user('fake', 'fake', 'fake', admin=True) self.user = self.manager.create_user('fake', 'fake', 'fake',
admin=True)
self.project = self.manager.create_project('fake', 'fake', 'fake') self.project = self.manager.create_project('fake', 'fake', 'fake')
self.context = context.RequestContext(self.user, self.project) self.context = context.RequestContext(self.user, self.project)
@@ -143,7 +145,6 @@ class NWFilterTestCase(test.TrialTestCase):
self.manager.delete_project(self.project) self.manager.delete_project(self.project)
self.manager.delete_user(self.user) self.manager.delete_user(self.user)
def test_cidr_rule_nwfilter_xml(self): def test_cidr_rule_nwfilter_xml(self):
cloud_controller = cloud.CloudController() cloud_controller = cloud.CloudController()
cloud_controller.create_security_group(self.context, cloud_controller.create_security_group(self.context,
@@ -156,7 +157,6 @@ class NWFilterTestCase(test.TrialTestCase):
ip_protocol='tcp', ip_protocol='tcp',
cidr_ip='0.0.0.0/0') cidr_ip='0.0.0.0/0')
security_group = db.security_group_get_by_name(self.context, security_group = db.security_group_get_by_name(self.context,
'fake', 'fake',
'testgroup') 'testgroup')
@@ -182,15 +182,12 @@ class NWFilterTestCase(test.TrialTestCase):
self.assertEqual(ip_conditions[0].getAttribute('srcipmask'), '0.0.0.0') self.assertEqual(ip_conditions[0].getAttribute('srcipmask'), '0.0.0.0')
self.assertEqual(ip_conditions[0].getAttribute('dstportstart'), '80') self.assertEqual(ip_conditions[0].getAttribute('dstportstart'), '80')
self.assertEqual(ip_conditions[0].getAttribute('dstportend'), '81') self.assertEqual(ip_conditions[0].getAttribute('dstportend'), '81')
self.teardown_security_group() self.teardown_security_group()
def teardown_security_group(self): def teardown_security_group(self):
cloud_controller = cloud.CloudController() cloud_controller = cloud.CloudController()
cloud_controller.delete_security_group(self.context, 'testgroup') cloud_controller.delete_security_group(self.context, 'testgroup')
def setup_and_return_security_group(self): def setup_and_return_security_group(self):
cloud_controller = cloud.CloudController() cloud_controller = cloud.CloudController()
cloud_controller.create_security_group(self.context, cloud_controller.create_security_group(self.context,
@@ -244,16 +241,19 @@ class NWFilterTestCase(test.TrialTestCase):
for required in [secgroup_filter, 'allow-dhcp-server', for required in [secgroup_filter, 'allow-dhcp-server',
'no-arp-spoofing', 'no-ip-spoofing', 'no-arp-spoofing', 'no-ip-spoofing',
'no-mac-spoofing']: 'no-mac-spoofing']:
self.assertTrue(required in self.recursive_depends[instance_filter], self.assertTrue(required in
"Instance's filter does not include %s" % required) self.recursive_depends[instance_filter],
"Instance's filter does not include %s" %
required)
self.security_group = self.setup_and_return_security_group() self.security_group = self.setup_and_return_security_group()
db.instance_add_security_group(self.context, inst_id, self.security_group.id) db.instance_add_security_group(self.context, inst_id,
self.security_group.id)
instance = db.instance_get(self.context, inst_id) instance = db.instance_get(self.context, inst_id)
d = self.fw.setup_nwfilters_for_instance(instance) d = self.fw.setup_nwfilters_for_instance(instance)
d.addCallback(_ensure_all_called) d.addCallback(_ensure_all_called)
d.addCallback(lambda _:self.teardown_security_group()) d.addCallback(lambda _: self.teardown_security_group())
return d return d

View File

@@ -59,7 +59,8 @@ class VolumeTestCase(test.TrialTestCase):
"""Test volume can be created and deleted""" """Test volume can be created and deleted"""
volume_id = self._create_volume() volume_id = self._create_volume()
yield self.volume.create_volume(self.context, volume_id) yield self.volume.create_volume(self.context, volume_id)
self.assertEqual(volume_id, db.volume_get(context.get_admin_context(), volume_id).id) self.assertEqual(volume_id, db.volume_get(context.get_admin_context(),
volume_id).id)
yield self.volume.delete_volume(self.context, volume_id) yield self.volume.delete_volume(self.context, volume_id)
self.assertRaises(exception.NotFound, self.assertRaises(exception.NotFound,
@@ -114,7 +115,8 @@ class VolumeTestCase(test.TrialTestCase):
volume_id = self._create_volume() volume_id = self._create_volume()
yield self.volume.create_volume(self.context, volume_id) yield self.volume.create_volume(self.context, volume_id)
if FLAGS.fake_tests: if FLAGS.fake_tests:
db.volume_attached(self.context, volume_id, instance_id, mountpoint) db.volume_attached(self.context, volume_id, instance_id,
mountpoint)
else: else:
yield self.compute.attach_volume(self.context, yield self.compute.attach_volume(self.context,
instance_id, instance_id,
@@ -154,7 +156,8 @@ class VolumeTestCase(test.TrialTestCase):
def _check(volume_id): def _check(volume_id):
"""Make sure blades aren't duplicated""" """Make sure blades aren't duplicated"""
volume_ids.append(volume_id) volume_ids.append(volume_id)
(shelf_id, blade_id) = db.volume_get_shelf_and_blade(context.get_admin_context(), admin_context = context.get_admin_context()
(shelf_id, blade_id) = db.volume_get_shelf_and_blade(admin_context,
volume_id) volume_id)
shelf_blade = '%s.%s' % (shelf_id, blade_id) shelf_blade = '%s.%s' % (shelf_id, blade_id)
self.assert_(shelf_blade not in shelf_blades) self.assert_(shelf_blade not in shelf_blades)

View File

@@ -226,6 +226,7 @@ class FakeConnection(object):
def get_console_output(self, instance): def get_console_output(self, instance):
return 'FAKE CONSOLE OUTPUT' return 'FAKE CONSOLE OUTPUT'
class FakeInstance(object): class FakeInstance(object):
def __init__(self): def __init__(self):
self._state = power_state.NOSTATE self._state = power_state.NOSTATE

View File

@@ -62,8 +62,8 @@ def _fetch_s3_image(image, path, user, project):
headers['Authorization'] = 'AWS %s:%s' % (access, signature) headers['Authorization'] = 'AWS %s:%s' % (access, signature)
cmd = ['/usr/bin/curl', '--fail', '--silent', url] cmd = ['/usr/bin/curl', '--fail', '--silent', url]
for (k,v) in headers.iteritems(): for (k, v) in headers.iteritems():
cmd += ['-H', '%s: %s' % (k,v)] cmd += ['-H', '%s: %s' % (k, v)]
cmd += ['-o', path] cmd += ['-o', path]
return process.SharedPool().execute(executable=cmd[0], args=cmd[1:]) return process.SharedPool().execute(executable=cmd[0], args=cmd[1:])

View File

@@ -62,7 +62,8 @@ flags.DEFINE_string('injected_network_template',
'Template file for injected network') 'Template file for injected network')
flags.DEFINE_string('libvirt_type', flags.DEFINE_string('libvirt_type',
'kvm', 'kvm',
'Libvirt domain type (valid options are: kvm, qemu, uml, xen)') 'Libvirt domain type (valid options are: '
'kvm, qemu, uml, xen)')
flags.DEFINE_string('libvirt_uri', flags.DEFINE_string('libvirt_uri',
'', '',
'Override the default libvirt URI (which is dependent' 'Override the default libvirt URI (which is dependent'
@@ -96,7 +97,8 @@ class LibvirtConnection(object):
def _conn(self): def _conn(self):
if not self._wrapped_conn or not self._test_connection(): if not self._wrapped_conn or not self._test_connection():
logging.debug('Connecting to libvirt: %s' % self.libvirt_uri) logging.debug('Connecting to libvirt: %s' % self.libvirt_uri)
self._wrapped_conn = self._connect(self.libvirt_uri, self.read_only) self._wrapped_conn = self._connect(self.libvirt_uri,
self.read_only)
return self._wrapped_conn return self._wrapped_conn
def _test_connection(self): def _test_connection(self):
@@ -150,6 +152,7 @@ class LibvirtConnection(object):
# WE'LL save this for when we do shutdown, # WE'LL save this for when we do shutdown,
# instead of destroy - but destroy returns immediately # instead of destroy - but destroy returns immediately
timer = task.LoopingCall(f=None) timer = task.LoopingCall(f=None)
def _wait_for_shutdown(): def _wait_for_shutdown():
try: try:
state = self.get_info(instance['name'])['state'] state = self.get_info(instance['name'])['state']
@@ -164,6 +167,7 @@ class LibvirtConnection(object):
power_state.SHUTDOWN) power_state.SHUTDOWN)
timer.stop() timer.stop()
d.callback(None) d.callback(None)
timer.f = _wait_for_shutdown timer.f = _wait_for_shutdown
timer.start(interval=0.5, now=True) timer.start(interval=0.5, now=True)
return d return d
@@ -201,6 +205,7 @@ class LibvirtConnection(object):
d = defer.Deferred() d = defer.Deferred()
timer = task.LoopingCall(f=None) timer = task.LoopingCall(f=None)
def _wait_for_reboot(): def _wait_for_reboot():
try: try:
state = self.get_info(instance['name'])['state'] state = self.get_info(instance['name'])['state']
@@ -217,6 +222,7 @@ class LibvirtConnection(object):
power_state.SHUTDOWN) power_state.SHUTDOWN)
timer.stop() timer.stop()
d.callback(None) d.callback(None)
timer.f = _wait_for_reboot timer.f = _wait_for_reboot
timer.start(interval=0.5, now=True) timer.start(interval=0.5, now=True)
yield d yield d
@@ -229,7 +235,8 @@ class LibvirtConnection(object):
instance['id'], instance['id'],
power_state.NOSTATE, power_state.NOSTATE,
'launching') 'launching')
yield NWFilterFirewall(self._conn).setup_nwfilters_for_instance(instance) yield NWFilterFirewall(self._conn).\
setup_nwfilters_for_instance(instance)
yield self._create_image(instance, xml) yield self._create_image(instance, xml)
yield self._conn.createXML(xml, 0) yield self._conn.createXML(xml, 0)
# TODO(termie): this should actually register # TODO(termie): this should actually register
@@ -238,6 +245,7 @@ class LibvirtConnection(object):
local_d = defer.Deferred() local_d = defer.Deferred()
timer = task.LoopingCall(f=None) timer = task.LoopingCall(f=None)
def _wait_for_boot(): def _wait_for_boot():
try: try:
state = self.get_info(instance['name'])['state'] state = self.get_info(instance['name'])['state']
@@ -265,8 +273,9 @@ class LibvirtConnection(object):
if virsh_output.startswith('/dev/'): if virsh_output.startswith('/dev/'):
logging.info('cool, it\'s a device') logging.info('cool, it\'s a device')
d = process.simple_execute("sudo dd if=%s iflag=nonblock" % virsh_output, check_exit_code=False) d = process.simple_execute("sudo dd if=%s iflag=nonblock" %
d.addCallback(lambda r:r[0]) virsh_output, check_exit_code=False)
d.addCallback(lambda r: r[0])
return d return d
else: else:
return '' return ''
@@ -285,11 +294,15 @@ class LibvirtConnection(object):
@exception.wrap_exception @exception.wrap_exception
def get_console_output(self, instance): def get_console_output(self, instance):
console_log = os.path.join(FLAGS.instances_path, instance['name'], 'console.log') console_log = os.path.join(FLAGS.instances_path, instance['name'],
d = process.simple_execute('sudo chown %d %s' % (os.getuid(), console_log)) 'console.log')
d = process.simple_execute('sudo chown %d %s' % (os.getuid(),
console_log))
if FLAGS.libvirt_type == 'xen': if FLAGS.libvirt_type == 'xen':
# Xen is spethial # Xen is spethial
d.addCallback(lambda _: process.simple_execute("virsh ttyconsole %s" % instance['name'])) d.addCallback(lambda _:
process.simple_execute("virsh ttyconsole %s" %
instance['name']))
d.addCallback(self._flush_xen_console) d.addCallback(self._flush_xen_console)
d.addCallback(self._append_to_file, console_log) d.addCallback(self._append_to_file, console_log)
else: else:
@@ -297,7 +310,6 @@ class LibvirtConnection(object):
d.addCallback(self._dump_file) d.addCallback(self._dump_file)
return d return d
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_image(self, inst, libvirt_xml): def _create_image(self, inst, libvirt_xml):
# syntactic nicety # syntactic nicety
@@ -309,7 +321,6 @@ class LibvirtConnection(object):
yield process.simple_execute('mkdir -p %s' % basepath()) yield process.simple_execute('mkdir -p %s' % basepath())
yield process.simple_execute('chmod 0777 %s' % basepath()) yield process.simple_execute('chmod 0777 %s' % basepath())
# TODO(termie): these are blocking calls, it would be great # TODO(termie): these are blocking calls, it would be great
# if they weren't. # if they weren't.
logging.info('instance %s: Creating image', inst['name']) logging.info('instance %s: Creating image', inst['name'])
@@ -317,17 +328,21 @@ class LibvirtConnection(object):
f.write(libvirt_xml) f.write(libvirt_xml)
f.close() f.close()
os.close(os.open(basepath('console.log'), os.O_CREAT | os.O_WRONLY, 0660)) os.close(os.open(basepath('console.log'), os.O_CREAT | os.O_WRONLY,
0660))
user = manager.AuthManager().get_user(inst['user_id']) user = manager.AuthManager().get_user(inst['user_id'])
project = manager.AuthManager().get_project(inst['project_id']) project = manager.AuthManager().get_project(inst['project_id'])
if not os.path.exists(basepath('disk')): if not os.path.exists(basepath('disk')):
yield images.fetch(inst.image_id, basepath('disk-raw'), user, project) yield images.fetch(inst.image_id, basepath('disk-raw'), user,
project)
if not os.path.exists(basepath('kernel')): if not os.path.exists(basepath('kernel')):
yield images.fetch(inst.kernel_id, basepath('kernel'), user, project) yield images.fetch(inst.kernel_id, basepath('kernel'), user,
project)
if not os.path.exists(basepath('ramdisk')): if not os.path.exists(basepath('ramdisk')):
yield images.fetch(inst.ramdisk_id, basepath('ramdisk'), user, project) yield images.fetch(inst.ramdisk_id, basepath('ramdisk'), user,
project)
execute = lambda cmd, process_input=None, check_exit_code=True: \ execute = lambda cmd, process_input=None, check_exit_code=True: \
process.simple_execute(cmd=cmd, process.simple_execute(cmd=cmd,
@@ -339,8 +354,8 @@ class LibvirtConnection(object):
network_ref = db.network_get_by_instance(context.get_admin_context(), network_ref = db.network_get_by_instance(context.get_admin_context(),
inst['id']) inst['id'])
if network_ref['injected']: if network_ref['injected']:
address = db.instance_get_fixed_address(context.get_admin_context(), admin_context = context.get_admin_context()
inst['id']) address = db.instance_get_fixed_address(admin_context, inst['id'])
with open(FLAGS.injected_network_template) as f: with open(FLAGS.injected_network_template) as f:
net = f.read() % {'address': address, net = f.read() % {'address': address,
'netmask': network_ref['netmask'], 'netmask': network_ref['netmask'],
@@ -354,7 +369,8 @@ class LibvirtConnection(object):
if net: if net:
logging.info('instance %s: injecting net into image %s', logging.info('instance %s: injecting net into image %s',
inst['name'], inst.image_id) inst['name'], inst.image_id)
yield disk.inject_data(basepath('disk-raw'), key, net, execute=execute) yield disk.inject_data(basepath('disk-raw'), key, net,
execute=execute)
if os.path.exists(basepath('disk')): if os.path.exists(basepath('disk')):
yield process.simple_execute('rm -f %s' % basepath('disk')) yield process.simple_execute('rm -f %s' % basepath('disk'))
@@ -377,7 +393,8 @@ class LibvirtConnection(object):
network = db.project_get_network(context.get_admin_context(), network = db.project_get_network(context.get_admin_context(),
instance['project_id']) instance['project_id'])
# FIXME(vish): stick this in db # FIXME(vish): stick this in db
instance_type = instance_types.INSTANCE_TYPES[instance['instance_type']] instance_type = instance['instance_type']
instance_type = instance_types.INSTANCE_TYPES[instance_type]
ip_address = db.instance_get_fixed_address(context.get_admin_context(), ip_address = db.instance_get_fixed_address(context.get_admin_context(),
instance['id']) instance['id'])
# Assume that the gateway also acts as the dhcp server. # Assume that the gateway also acts as the dhcp server.
@@ -391,7 +408,7 @@ class LibvirtConnection(object):
'bridge_name': network['bridge'], 'bridge_name': network['bridge'],
'mac_address': instance['mac_address'], 'mac_address': instance['mac_address'],
'ip_address': ip_address, 'ip_address': ip_address,
'dhcp_server': dhcp_server } 'dhcp_server': dhcp_server}
libvirt_xml = self.libvirt_xml % xml_info libvirt_xml = self.libvirt_xml % xml_info
logging.debug('instance %s: finished toXML method', instance['name']) logging.debug('instance %s: finished toXML method', instance['name'])
@@ -506,7 +523,6 @@ class LibvirtConnection(object):
domain = self._conn.lookupByName(instance_name) domain = self._conn.lookupByName(instance_name)
return domain.interfaceStats(interface) return domain.interfaceStats(interface)
def refresh_security_group(self, security_group_id): def refresh_security_group(self, security_group_id):
fw = NWFilterFirewall(self._conn) fw = NWFilterFirewall(self._conn)
fw.ensure_security_group_filter(security_group_id) fw.ensure_security_group_filter(security_group_id)
@@ -557,7 +573,6 @@ class NWFilterFirewall(object):
def __init__(self, get_connection): def __init__(self, get_connection):
self._conn = get_connection self._conn = get_connection
nova_base_filter = '''<filter name='nova-base' chain='root'> nova_base_filter = '''<filter name='nova-base' chain='root'>
<uuid>26717364-50cf-42d1-8185-29bf893ab110</uuid> <uuid>26717364-50cf-42d1-8185-29bf893ab110</uuid>
<filterref filter='no-mac-spoofing'/> <filterref filter='no-mac-spoofing'/>
@@ -578,7 +593,8 @@ class NWFilterFirewall(object):
srcportstart='68' srcportstart='68'
dstportstart='67'/> dstportstart='67'/>
</rule> </rule>
<rule action='accept' direction='in' priority='100'> <rule action='accept' direction='in'
priority='100'>
<udp srcipaddr='$DHCPSERVER' <udp srcipaddr='$DHCPSERVER'
srcportstart='67' srcportstart='67'
dstportstart='68'/> dstportstart='68'/>
@@ -588,8 +604,8 @@ class NWFilterFirewall(object):
def nova_base_ipv4_filter(self): def nova_base_ipv4_filter(self):
retval = "<filter name='nova-base-ipv4' chain='ipv4'>" retval = "<filter name='nova-base-ipv4' chain='ipv4'>"
for protocol in ['tcp', 'udp', 'icmp']: for protocol in ['tcp', 'udp', 'icmp']:
for direction,action,priority in [('out','accept', 399), for direction, action, priority in [('out', 'accept', 399),
('inout','drop', 400)]: ('inout', 'drop', 400)]:
retval += """<rule action='%s' direction='%s' priority='%d'> retval += """<rule action='%s' direction='%s' priority='%d'>
<%s /> <%s />
</rule>""" % (action, direction, </rule>""" % (action, direction,
@@ -597,12 +613,11 @@ class NWFilterFirewall(object):
retval += '</filter>' retval += '</filter>'
return retval return retval
def nova_base_ipv6_filter(self): def nova_base_ipv6_filter(self):
retval = "<filter name='nova-base-ipv6' chain='ipv6'>" retval = "<filter name='nova-base-ipv6' chain='ipv6'>"
for protocol in ['tcp', 'udp', 'icmp']: for protocol in ['tcp', 'udp', 'icmp']:
for direction,action,priority in [('out','accept',399), for direction, action, priority in [('out', 'accept', 399),
('inout','drop',400)]: ('inout', 'drop', 400)]:
retval += """<rule action='%s' direction='%s' priority='%d'> retval += """<rule action='%s' direction='%s' priority='%d'>
<%s-ipv6 /> <%s-ipv6 />
</rule>""" % (action, direction, </rule>""" % (action, direction,
@@ -610,7 +625,6 @@ class NWFilterFirewall(object):
retval += '</filter>' retval += '</filter>'
return retval return retval
def nova_project_filter(self, project, net, mask): def nova_project_filter(self, project, net, mask):
retval = "<filter name='nova-project-%s' chain='ipv4'>" % project retval = "<filter name='nova-project-%s' chain='ipv4'>" % project
for protocol in ['tcp', 'udp', 'icmp']: for protocol in ['tcp', 'udp', 'icmp']:
@@ -620,14 +634,12 @@ class NWFilterFirewall(object):
retval += '</filter>' retval += '</filter>'
return retval return retval
def _define_filter(self, xml): def _define_filter(self, xml):
if callable(xml): if callable(xml):
xml = xml() xml = xml()
d = threads.deferToThread(self._conn.nwfilterDefineXML, xml) d = threads.deferToThread(self._conn.nwfilterDefineXML, xml)
return d return d
@staticmethod @staticmethod
def _get_net_and_mask(cidr): def _get_net_and_mask(cidr):
net = IPy.IP(cidr) net = IPy.IP(cidr)
@@ -646,9 +658,9 @@ class NWFilterFirewall(object):
yield self._define_filter(self.nova_dhcp_filter) yield self._define_filter(self.nova_dhcp_filter)
yield self._define_filter(self.nova_base_filter) yield self._define_filter(self.nova_base_filter)
nwfilter_xml = ("<filter name='nova-instance-%s' chain='root'>\n" + nwfilter_xml = "<filter name='nova-instance-%s' chain='root'>\n" \
" <filterref filter='nova-base' />\n" " <filterref filter='nova-base' />\n" % \
) % instance['name'] instance['name']
if FLAGS.allow_project_net_traffic: if FLAGS.allow_project_net_traffic:
network_ref = db.project_get_network(context.get_admin_context(), network_ref = db.project_get_network(context.get_admin_context(),
@@ -658,14 +670,14 @@ class NWFilterFirewall(object):
net, mask) net, mask)
yield self._define_filter(project_filter) yield self._define_filter(project_filter)
nwfilter_xml += (" <filterref filter='nova-project-%s' />\n" nwfilter_xml += " <filterref filter='nova-project-%s' />\n" % \
) % instance['project_id'] instance['project_id']
for security_group in instance.security_groups: for security_group in instance.security_groups:
yield self.ensure_security_group_filter(security_group['id']) yield self.ensure_security_group_filter(security_group['id'])
nwfilter_xml += (" <filterref filter='nova-secgroup-%d' />\n" nwfilter_xml += " <filterref filter='nova-secgroup-%d' />\n" % \
) % security_group['id'] security_group['id']
nwfilter_xml += "</filter>" nwfilter_xml += "</filter>"
yield self._define_filter(nwfilter_xml) yield self._define_filter(nwfilter_xml)
@@ -675,7 +687,6 @@ class NWFilterFirewall(object):
return self._define_filter( return self._define_filter(
self.security_group_to_nwfilter_xml(security_group_id)) self.security_group_to_nwfilter_xml(security_group_id))
def security_group_to_nwfilter_xml(self, security_group_id): def security_group_to_nwfilter_xml(self, security_group_id):
security_group = db.security_group_get(context.get_admin_context(), security_group = db.security_group_get(context.get_admin_context(),
security_group_id) security_group_id)
@@ -684,12 +695,15 @@ class NWFilterFirewall(object):
rule_xml += "<rule action='accept' direction='in' priority='300'>" rule_xml += "<rule action='accept' direction='in' priority='300'>"
if rule.cidr: if rule.cidr:
net, mask = self._get_net_and_mask(rule.cidr) net, mask = self._get_net_and_mask(rule.cidr)
rule_xml += "<%s srcipaddr='%s' srcipmask='%s' " % (rule.protocol, net, mask) rule_xml += "<%s srcipaddr='%s' srcipmask='%s' " % \
(rule.protocol, net, mask)
if rule.protocol in ['tcp', 'udp']: if rule.protocol in ['tcp', 'udp']:
rule_xml += "dstportstart='%s' dstportend='%s' " % \ rule_xml += "dstportstart='%s' dstportend='%s' " % \
(rule.from_port, rule.to_port) (rule.from_port, rule.to_port)
elif rule.protocol == 'icmp': elif rule.protocol == 'icmp':
logging.info('rule.protocol: %r, rule.from_port: %r, rule.to_port: %r' % (rule.protocol, rule.from_port, rule.to_port)) logging.info('rule.protocol: %r, rule.from_port: %r, '
'rule.to_port: %r' %
(rule.protocol, rule.from_port, rule.to_port))
if rule.from_port != -1: if rule.from_port != -1:
rule_xml += "type='%s' " % rule.from_port rule_xml += "type='%s' " % rule.from_port
if rule.to_port != -1: if rule.to_port != -1:
@@ -697,5 +711,6 @@ class NWFilterFirewall(object):
rule_xml += '/>\n' rule_xml += '/>\n'
rule_xml += "</rule>\n" rule_xml += "</rule>\n"
xml = '''<filter name='nova-secgroup-%s' chain='ipv4'>%s</filter>''' % (security_group_id, rule_xml,) xml = "<filter name='nova-secgroup-%s' chain='ipv4'>%s</filter>" % \
(security_group_id, rule_xml,)
return xml return xml

View File

@@ -75,12 +75,11 @@ flags.DEFINE_float('xenapi_task_poll_interval',
XENAPI_POWER_STATE = { XENAPI_POWER_STATE = {
'Halted' : power_state.SHUTDOWN, 'Halted': power_state.SHUTDOWN,
'Running' : power_state.RUNNING, 'Running': power_state.RUNNING,
'Paused' : power_state.PAUSED, 'Paused': power_state.PAUSED,
'Suspended': power_state.SHUTDOWN, # FIXME 'Suspended': power_state.SHUTDOWN, # FIXME
'Crashed' : power_state.CRASHED 'Crashed': power_state.CRASHED}
}
def get_connection(_): def get_connection(_):
@@ -90,12 +89,15 @@ def get_connection(_):
# library when not using XenAPI. # library when not using XenAPI.
global XenAPI global XenAPI
if XenAPI is None: if XenAPI is None:
XenAPI = __import__('XenAPI') XenAPI = __import__('XenAPI')
url = FLAGS.xenapi_connection_url url = FLAGS.xenapi_connection_url
username = FLAGS.xenapi_connection_username username = FLAGS.xenapi_connection_username
password = FLAGS.xenapi_connection_password password = FLAGS.xenapi_connection_password
if not url or password is None: if not url or password is None:
raise Exception('Must specify xenapi_connection_url, xenapi_connection_username (optionally), and xenapi_connection_password to use connection_type=xenapi') raise Exception('Must specify xenapi_connection_url, '
'xenapi_connection_username (optionally), and '
'xenapi_connection_password to use '
'connection_type=xenapi')
return XenAPIConnection(url, username, password) return XenAPIConnection(url, username, password)
@@ -141,7 +143,7 @@ class XenAPIConnection(object):
def _create_vm(self, instance, kernel, ramdisk): def _create_vm(self, instance, kernel, ramdisk):
"""Create a VM record. Returns a Deferred that gives the new """Create a VM record. Returns a Deferred that gives the new
VM reference.""" VM reference."""
instance_type = instance_types.INSTANCE_TYPES[instance.instance_type] instance_type = instance_types.INSTANCE_TYPES[instance.instance_type]
mem = str(long(instance_type['memory_mb']) * 1024 * 1024) mem = str(long(instance_type['memory_mb']) * 1024 * 1024)
vcpus = str(instance_type['vcpus']) vcpus = str(instance_type['vcpus'])
@@ -183,7 +185,7 @@ class XenAPIConnection(object):
def _create_vbd(self, vm_ref, vdi_ref, userdevice, bootable): def _create_vbd(self, vm_ref, vdi_ref, userdevice, bootable):
"""Create a VBD record. Returns a Deferred that gives the new """Create a VBD record. Returns a Deferred that gives the new
VBD reference.""" VBD reference."""
vbd_rec = {} vbd_rec = {}
vbd_rec['VM'] = vm_ref vbd_rec['VM'] = vm_ref
vbd_rec['VDI'] = vdi_ref vbd_rec['VDI'] = vdi_ref
@@ -207,10 +209,10 @@ class XenAPIConnection(object):
def _create_vif(self, vm_ref, network_ref, mac_address): def _create_vif(self, vm_ref, network_ref, mac_address):
"""Create a VIF record. Returns a Deferred that gives the new """Create a VIF record. Returns a Deferred that gives the new
VIF reference.""" VIF reference."""
vif_rec = {} vif_rec = {}
vif_rec['device'] = '0' vif_rec['device'] = '0'
vif_rec['network']= network_ref vif_rec['network'] = network_ref
vif_rec['VM'] = vm_ref vif_rec['VM'] = vm_ref
vif_rec['MAC'] = mac_address vif_rec['MAC'] = mac_address
vif_rec['MTU'] = '1500' vif_rec['MTU'] = '1500'
@@ -303,7 +305,7 @@ class XenAPIConnection(object):
def _lookup_blocking(self, i): def _lookup_blocking(self, i):
vms = self._conn.xenapi.VM.get_by_name_label(i) vms = self._conn.xenapi.VM.get_by_name_label(i)
n = len(vms) n = len(vms)
if n == 0: if n == 0:
return None return None
elif n > 1: elif n > 1:

View File

@@ -61,7 +61,6 @@ class AOEDriver(object):
"Try number %s", tries) "Try number %s", tries)
yield self._execute("sleep %s" % tries ** 2) yield self._execute("sleep %s" % tries ** 2)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_volume(self, volume_name, size): def create_volume(self, volume_name, size):
"""Creates a logical volume""" """Creates a logical volume"""