Reuse caller's session in ML2 DB methods
This patch changes the get_port_from_device_mac() and get_sg_ids_grouped_by_port() methods in ML2 db.py module so that they do not create a new database session (via get_session()), but instead reuse the session associated with the caller's context. In order to make the session that is associated with the caller's context available to these ML2 DB methods, the get_ports_from_devices plugin API in securitygroups_rps_base.py needs to be modified so that the context can be passed down to the ML2 plugin. (A similar change is made to the get_port_from_device plugin API for consistency.) Change-Id: I3f990895887e156de929bd7ac3732df114dd4a4b Closes-Bug: 1441205
This commit is contained in:
@@ -76,10 +76,10 @@ class SecurityGroupServerRpcCallback(object):
|
|||||||
def plugin(self):
|
def plugin(self):
|
||||||
return manager.NeutronManager.get_plugin()
|
return manager.NeutronManager.get_plugin()
|
||||||
|
|
||||||
def _get_devices_info(self, devices):
|
def _get_devices_info(self, context, devices):
|
||||||
return dict(
|
return dict(
|
||||||
(port['id'], port)
|
(port['id'], port)
|
||||||
for port in self.plugin.get_ports_from_devices(devices)
|
for port in self.plugin.get_ports_from_devices(context, devices)
|
||||||
if port and not port['device_owner'].startswith('network:')
|
if port and not port['device_owner'].startswith('network:')
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -93,7 +93,7 @@ class SecurityGroupServerRpcCallback(object):
|
|||||||
:returns: port correspond to the devices with security group rules
|
:returns: port correspond to the devices with security group rules
|
||||||
"""
|
"""
|
||||||
devices_info = kwargs.get('devices')
|
devices_info = kwargs.get('devices')
|
||||||
ports = self._get_devices_info(devices_info)
|
ports = self._get_devices_info(context, devices_info)
|
||||||
return self.plugin.security_group_rules_for_ports(context, ports)
|
return self.plugin.security_group_rules_for_ports(context, ports)
|
||||||
|
|
||||||
def security_group_info_for_devices(self, context, **kwargs):
|
def security_group_info_for_devices(self, context, **kwargs):
|
||||||
@@ -110,7 +110,7 @@ class SecurityGroupServerRpcCallback(object):
|
|||||||
Note that sets are serialized into lists by rpc code.
|
Note that sets are serialized into lists by rpc code.
|
||||||
"""
|
"""
|
||||||
devices_info = kwargs.get('devices')
|
devices_info = kwargs.get('devices')
|
||||||
ports = self._get_devices_info(devices_info)
|
ports = self._get_devices_info(context, devices_info)
|
||||||
return self.plugin.security_group_info_for_ports(context, ports)
|
return self.plugin.security_group_info_for_ports(context, ports)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -38,7 +38,7 @@ DHCP_RULE_PORT = {4: (67, 68, q_const.IPv4), 6: (547, 546, q_const.IPv6)}
|
|||||||
class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
|
class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
|
||||||
"""Mixin class to add agent-based security group implementation."""
|
"""Mixin class to add agent-based security group implementation."""
|
||||||
|
|
||||||
def get_port_from_device(self, device):
|
def get_port_from_device(self, context, device):
|
||||||
"""Get port dict from device name on an agent.
|
"""Get port dict from device name on an agent.
|
||||||
|
|
||||||
Subclass must provide this method or get_ports_from_devices.
|
Subclass must provide this method or get_ports_from_devices.
|
||||||
@@ -59,13 +59,14 @@ class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
|
|||||||
"or get_ports_from_devices.")
|
"or get_ports_from_devices.")
|
||||||
% self.__class__.__name__)
|
% self.__class__.__name__)
|
||||||
|
|
||||||
def get_ports_from_devices(self, devices):
|
def get_ports_from_devices(self, context, devices):
|
||||||
"""Bulk method of get_port_from_device.
|
"""Bulk method of get_port_from_device.
|
||||||
|
|
||||||
Subclasses may override this to provide better performance for DB
|
Subclasses may override this to provide better performance for DB
|
||||||
queries, backend calls, etc.
|
queries, backend calls, etc.
|
||||||
"""
|
"""
|
||||||
return [self.get_port_from_device(device) for device in devices]
|
return [self.get_port_from_device(context, device)
|
||||||
|
for device in devices]
|
||||||
|
|
||||||
def create_security_group_rule(self, context, security_group_rule):
|
def create_security_group_rule(self, context, security_group_rule):
|
||||||
bulk_rule = {'security_group_rules': [security_group_rule]}
|
bulk_rule = {'security_group_rules': [security_group_rule]}
|
||||||
|
@@ -19,7 +19,6 @@ from sqlalchemy import or_
|
|||||||
from sqlalchemy.orm import exc
|
from sqlalchemy.orm import exc
|
||||||
|
|
||||||
from neutron.common import constants as n_const
|
from neutron.common import constants as n_const
|
||||||
from neutron.db import api as db_api
|
|
||||||
from neutron.db import models_v2
|
from neutron.db import models_v2
|
||||||
from neutron.db import securitygroups_db as sg_db
|
from neutron.db import securitygroups_db as sg_db
|
||||||
from neutron.extensions import portbindings
|
from neutron.extensions import portbindings
|
||||||
@@ -244,14 +243,14 @@ def get_port(session, port_id):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def get_port_from_device_mac(device_mac):
|
def get_port_from_device_mac(context, device_mac):
|
||||||
LOG.debug("get_port_from_device_mac() called for mac %s", device_mac)
|
LOG.debug("get_port_from_device_mac() called for mac %s", device_mac)
|
||||||
session = db_api.get_session()
|
qry = context.session.query(models_v2.Port).filter_by(
|
||||||
qry = session.query(models_v2.Port).filter_by(mac_address=device_mac)
|
mac_address=device_mac)
|
||||||
return qry.first()
|
return qry.first()
|
||||||
|
|
||||||
|
|
||||||
def get_ports_and_sgs(port_ids):
|
def get_ports_and_sgs(context, port_ids):
|
||||||
"""Get ports from database with security group info."""
|
"""Get ports from database with security group info."""
|
||||||
|
|
||||||
# break large queries into smaller parts
|
# break large queries into smaller parts
|
||||||
@@ -259,25 +258,24 @@ def get_ports_and_sgs(port_ids):
|
|||||||
LOG.debug("Number of ports %(pcount)s exceeds the maximum per "
|
LOG.debug("Number of ports %(pcount)s exceeds the maximum per "
|
||||||
"query %(maxp)s. Partitioning queries.",
|
"query %(maxp)s. Partitioning queries.",
|
||||||
{'pcount': len(port_ids), 'maxp': MAX_PORTS_PER_QUERY})
|
{'pcount': len(port_ids), 'maxp': MAX_PORTS_PER_QUERY})
|
||||||
return (get_ports_and_sgs(port_ids[:MAX_PORTS_PER_QUERY]) +
|
return (get_ports_and_sgs(context, port_ids[:MAX_PORTS_PER_QUERY]) +
|
||||||
get_ports_and_sgs(port_ids[MAX_PORTS_PER_QUERY:]))
|
get_ports_and_sgs(context, port_ids[MAX_PORTS_PER_QUERY:]))
|
||||||
|
|
||||||
LOG.debug("get_ports_and_sgs() called for port_ids %s", port_ids)
|
LOG.debug("get_ports_and_sgs() called for port_ids %s", port_ids)
|
||||||
|
|
||||||
if not port_ids:
|
if not port_ids:
|
||||||
# if port_ids is empty, avoid querying to DB to ask it for nothing
|
# if port_ids is empty, avoid querying to DB to ask it for nothing
|
||||||
return []
|
return []
|
||||||
ports_to_sg_ids = get_sg_ids_grouped_by_port(port_ids)
|
ports_to_sg_ids = get_sg_ids_grouped_by_port(context, port_ids)
|
||||||
return [make_port_dict_with_security_groups(port, sec_groups)
|
return [make_port_dict_with_security_groups(port, sec_groups)
|
||||||
for port, sec_groups in ports_to_sg_ids.iteritems()]
|
for port, sec_groups in ports_to_sg_ids.iteritems()]
|
||||||
|
|
||||||
|
|
||||||
def get_sg_ids_grouped_by_port(port_ids):
|
def get_sg_ids_grouped_by_port(context, port_ids):
|
||||||
sg_ids_grouped_by_port = {}
|
sg_ids_grouped_by_port = {}
|
||||||
session = db_api.get_session()
|
|
||||||
sg_binding_port = sg_db.SecurityGroupPortBinding.port_id
|
sg_binding_port = sg_db.SecurityGroupPortBinding.port_id
|
||||||
|
|
||||||
with session.begin(subtransactions=True):
|
with context.session.begin(subtransactions=True):
|
||||||
# partial UUIDs must be individually matched with startswith.
|
# partial UUIDs must be individually matched with startswith.
|
||||||
# full UUIDs may be matched directly in an IN statement
|
# full UUIDs may be matched directly in an IN statement
|
||||||
partial_uuids = set(port_id for port_id in port_ids
|
partial_uuids = set(port_id for port_id in port_ids
|
||||||
@@ -288,8 +286,8 @@ def get_sg_ids_grouped_by_port(port_ids):
|
|||||||
if full_uuids:
|
if full_uuids:
|
||||||
or_criteria.append(models_v2.Port.id.in_(full_uuids))
|
or_criteria.append(models_v2.Port.id.in_(full_uuids))
|
||||||
|
|
||||||
query = session.query(models_v2.Port,
|
query = context.session.query(
|
||||||
sg_db.SecurityGroupPortBinding.security_group_id)
|
models_v2.Port, sg_db.SecurityGroupPortBinding.security_group_id)
|
||||||
query = query.outerjoin(sg_db.SecurityGroupPortBinding,
|
query = query.outerjoin(sg_db.SecurityGroupPortBinding,
|
||||||
models_v2.Port.id == sg_binding_port)
|
models_v2.Port.id == sg_binding_port)
|
||||||
query = query.filter(or_(*or_criteria))
|
query = query.filter(or_(*or_criteria))
|
||||||
|
@@ -1451,11 +1451,12 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
|
|||||||
port_host = db.get_port_binding_host(context.session, port_id)
|
port_host = db.get_port_binding_host(context.session, port_id)
|
||||||
return (port_host == host)
|
return (port_host == host)
|
||||||
|
|
||||||
def get_ports_from_devices(self, devices):
|
def get_ports_from_devices(self, context, devices):
|
||||||
port_ids_to_devices = dict((self._device_to_port_id(device), device)
|
port_ids_to_devices = dict(
|
||||||
for device in devices)
|
(self._device_to_port_id(context, device), device)
|
||||||
|
for device in devices)
|
||||||
port_ids = port_ids_to_devices.keys()
|
port_ids = port_ids_to_devices.keys()
|
||||||
ports = db.get_ports_and_sgs(port_ids)
|
ports = db.get_ports_and_sgs(context, port_ids)
|
||||||
for port in ports:
|
for port in ports:
|
||||||
# map back to original requested id
|
# map back to original requested id
|
||||||
port_id = next((port_id for port_id in port_ids
|
port_id = next((port_id for port_id in port_ids
|
||||||
@@ -1465,7 +1466,7 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
|
|||||||
return ports
|
return ports
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _device_to_port_id(device):
|
def _device_to_port_id(context, device):
|
||||||
# REVISIT(rkukura): Consider calling into MechanismDrivers to
|
# REVISIT(rkukura): Consider calling into MechanismDrivers to
|
||||||
# process device names, or having MechanismDrivers supply list
|
# process device names, or having MechanismDrivers supply list
|
||||||
# of device prefixes to strip.
|
# of device prefixes to strip.
|
||||||
@@ -1475,7 +1476,7 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
|
|||||||
# REVISIT(irenab): Consider calling into bound MD to
|
# REVISIT(irenab): Consider calling into bound MD to
|
||||||
# handle the get_device_details RPC
|
# handle the get_device_details RPC
|
||||||
if not uuidutils.is_uuid_like(device):
|
if not uuidutils.is_uuid_like(device):
|
||||||
port = db.get_port_from_device_mac(device)
|
port = db.get_port_from_device_mac(context, device)
|
||||||
if port:
|
if port:
|
||||||
return port.id
|
return port.id
|
||||||
return device
|
return device
|
||||||
|
@@ -67,7 +67,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin):
|
|||||||
{'device': device, 'agent_id': agent_id, 'host': host})
|
{'device': device, 'agent_id': agent_id, 'host': host})
|
||||||
|
|
||||||
plugin = manager.NeutronManager.get_plugin()
|
plugin = manager.NeutronManager.get_plugin()
|
||||||
port_id = plugin._device_to_port_id(device)
|
port_id = plugin._device_to_port_id(rpc_context, device)
|
||||||
port_context = plugin.get_bound_port_context(rpc_context,
|
port_context = plugin.get_bound_port_context(rpc_context,
|
||||||
port_id,
|
port_id,
|
||||||
host,
|
host,
|
||||||
@@ -144,7 +144,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin):
|
|||||||
"%(agent_id)s",
|
"%(agent_id)s",
|
||||||
{'device': device, 'agent_id': agent_id})
|
{'device': device, 'agent_id': agent_id})
|
||||||
plugin = manager.NeutronManager.get_plugin()
|
plugin = manager.NeutronManager.get_plugin()
|
||||||
port_id = plugin._device_to_port_id(device)
|
port_id = plugin._device_to_port_id(rpc_context, device)
|
||||||
port_exists = True
|
port_exists = True
|
||||||
if (host and not plugin.port_bound_to_host(rpc_context,
|
if (host and not plugin.port_bound_to_host(rpc_context,
|
||||||
port_id, host)):
|
port_id, host)):
|
||||||
@@ -173,7 +173,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin):
|
|||||||
LOG.debug("Device %(device)s up at agent %(agent_id)s",
|
LOG.debug("Device %(device)s up at agent %(agent_id)s",
|
||||||
{'device': device, 'agent_id': agent_id})
|
{'device': device, 'agent_id': agent_id})
|
||||||
plugin = manager.NeutronManager.get_plugin()
|
plugin = manager.NeutronManager.get_plugin()
|
||||||
port_id = plugin._device_to_port_id(device)
|
port_id = plugin._device_to_port_id(rpc_context, device)
|
||||||
if (host and not plugin.port_bound_to_host(rpc_context,
|
if (host and not plugin.port_bound_to_host(rpc_context,
|
||||||
port_id, host)):
|
port_id, host)):
|
||||||
LOG.debug("Device %(device)s not bound to the"
|
LOG.debug("Device %(device)s not bound to the"
|
||||||
|
@@ -56,7 +56,7 @@ IPv6 = 6
|
|||||||
class SecurityGroupServerRpcMixin(sg_db_rpc.SecurityGroupServerRpcMixin):
|
class SecurityGroupServerRpcMixin(sg_db_rpc.SecurityGroupServerRpcMixin):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_port_from_device(device):
|
def get_port_from_device(context, device):
|
||||||
port = nvsd_db.get_port_from_device(device)
|
port = nvsd_db.get_port_from_device(device)
|
||||||
if port:
|
if port:
|
||||||
port['device'] = device
|
port['device'] = device
|
||||||
|
@@ -94,7 +94,7 @@ class SecurityGroupRpcTestPlugin(test_sg.SecurityGroupTestPlugin,
|
|||||||
self.notify_security_groups_member_updated(context, port)
|
self.notify_security_groups_member_updated(context, port)
|
||||||
del self.devices[id]
|
del self.devices[id]
|
||||||
|
|
||||||
def get_port_from_device(self, device):
|
def get_port_from_device(self, context, device):
|
||||||
device = self.devices.get(device)
|
device = self.devices.get(device)
|
||||||
if device:
|
if device:
|
||||||
device['security_group_rules'] = []
|
device['security_group_rules'] = []
|
||||||
|
@@ -201,7 +201,8 @@ class Ml2DBTestCase(testlib_api.SqlTestCase):
|
|||||||
self._setup_neutron_network(network_id)
|
self._setup_neutron_network(network_id)
|
||||||
port = self._setup_neutron_port(network_id, port_id)
|
port = self._setup_neutron_port(network_id, port_id)
|
||||||
|
|
||||||
observed_port = ml2_db.get_port_from_device_mac(port['mac_address'])
|
observed_port = ml2_db.get_port_from_device_mac(self.ctx,
|
||||||
|
port['mac_address'])
|
||||||
self.assertEqual(port_id, observed_port.id)
|
self.assertEqual(port_id, observed_port.id)
|
||||||
|
|
||||||
def test_get_locked_port_and_binding(self):
|
def test_get_locked_port_and_binding(self):
|
||||||
|
@@ -614,23 +614,26 @@ class TestMl2PluginOnly(Ml2PluginV2TestCase):
|
|||||||
('qvo567890', '567890')]
|
('qvo567890', '567890')]
|
||||||
for device, expected in input_output:
|
for device, expected in input_output:
|
||||||
self.assertEqual(expected,
|
self.assertEqual(expected,
|
||||||
ml2_plugin.Ml2Plugin._device_to_port_id(device))
|
ml2_plugin.Ml2Plugin._device_to_port_id(
|
||||||
|
self.context, device))
|
||||||
|
|
||||||
def test__device_to_port_id_mac_address(self):
|
def test__device_to_port_id_mac_address(self):
|
||||||
with self.port() as p:
|
with self.port() as p:
|
||||||
mac = p['port']['mac_address']
|
mac = p['port']['mac_address']
|
||||||
port_id = p['port']['id']
|
port_id = p['port']['id']
|
||||||
self.assertEqual(port_id,
|
self.assertEqual(port_id,
|
||||||
ml2_plugin.Ml2Plugin._device_to_port_id(mac))
|
ml2_plugin.Ml2Plugin._device_to_port_id(
|
||||||
|
self.context, mac))
|
||||||
|
|
||||||
def test__device_to_port_id_not_uuid_not_mac(self):
|
def test__device_to_port_id_not_uuid_not_mac(self):
|
||||||
dev = '1234567'
|
dev = '1234567'
|
||||||
self.assertEqual(dev, ml2_plugin.Ml2Plugin._device_to_port_id(dev))
|
self.assertEqual(dev, ml2_plugin.Ml2Plugin._device_to_port_id(
|
||||||
|
self.context, dev))
|
||||||
|
|
||||||
def test__device_to_port_id_UUID(self):
|
def test__device_to_port_id_UUID(self):
|
||||||
port_id = uuidutils.generate_uuid()
|
port_id = uuidutils.generate_uuid()
|
||||||
self.assertEqual(port_id,
|
self.assertEqual(port_id, ml2_plugin.Ml2Plugin._device_to_port_id(
|
||||||
ml2_plugin.Ml2Plugin._device_to_port_id(port_id))
|
self.context, port_id))
|
||||||
|
|
||||||
|
|
||||||
class TestMl2DvrPortsV2(TestMl2PortsV2):
|
class TestMl2DvrPortsV2(TestMl2PortsV2):
|
||||||
|
@@ -75,14 +75,14 @@ class RpcCallbacksTestCase(base.BaseTestCase):
|
|||||||
self.plugin.get_bound_port_context.return_value = None
|
self.plugin.get_bound_port_context.return_value = None
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
{'device': 'fake_device'},
|
{'device': 'fake_device'},
|
||||||
self.callbacks.get_device_details('fake_context',
|
self.callbacks.get_device_details(mock.Mock(),
|
||||||
device='fake_device'))
|
device='fake_device'))
|
||||||
|
|
||||||
def test_get_device_details_port_context_without_bounded_segment(self):
|
def test_get_device_details_port_context_without_bounded_segment(self):
|
||||||
self.plugin.get_bound_port_context().bottom_bound_segment = None
|
self.plugin.get_bound_port_context().bottom_bound_segment = None
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
{'device': 'fake_device'},
|
{'device': 'fake_device'},
|
||||||
self.callbacks.get_device_details('fake_context',
|
self.callbacks.get_device_details(mock.Mock(),
|
||||||
device='fake_device'))
|
device='fake_device'))
|
||||||
|
|
||||||
def test_get_device_details_port_status_equal_new_status(self):
|
def test_get_device_details_port_status_equal_new_status(self):
|
||||||
@@ -99,7 +99,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
|
|||||||
port['admin_state_up'] = admin_state_up
|
port['admin_state_up'] = admin_state_up
|
||||||
port['status'] = status
|
port['status'] = status
|
||||||
self.plugin.update_port_status.reset_mock()
|
self.plugin.update_port_status.reset_mock()
|
||||||
self.callbacks.get_device_details('fake_context')
|
self.callbacks.get_device_details(mock.Mock())
|
||||||
self.assertEqual(status == new_status,
|
self.assertEqual(status == new_status,
|
||||||
not self.plugin.update_port_status.called)
|
not self.plugin.update_port_status.called)
|
||||||
|
|
||||||
@@ -109,7 +109,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
|
|||||||
self.plugin.get_bound_port_context().current = port
|
self.plugin.get_bound_port_context().current = port
|
||||||
self.plugin.get_bound_port_context().network.current = (
|
self.plugin.get_bound_port_context().network.current = (
|
||||||
{"id": "fake_network"})
|
{"id": "fake_network"})
|
||||||
self.callbacks.get_device_details('fake_context', host='fake_host',
|
self.callbacks.get_device_details(mock.Mock(), host='fake_host',
|
||||||
cached_networks=cached_networks)
|
cached_networks=cached_networks)
|
||||||
self.assertTrue('fake_port' in cached_networks)
|
self.assertTrue('fake_port' in cached_networks)
|
||||||
|
|
||||||
@@ -119,7 +119,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
|
|||||||
port_context.current = port
|
port_context.current = port
|
||||||
port_context.host = 'fake'
|
port_context.host = 'fake'
|
||||||
self.plugin.update_port_status.reset_mock()
|
self.plugin.update_port_status.reset_mock()
|
||||||
self.callbacks.get_device_details('fake_context',
|
self.callbacks.get_device_details(mock.Mock(),
|
||||||
host='fake_host')
|
host='fake_host')
|
||||||
self.assertFalse(self.plugin.update_port_status.called)
|
self.assertFalse(self.plugin.update_port_status.called)
|
||||||
|
|
||||||
@@ -128,7 +128,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
|
|||||||
port_context = self.plugin.get_bound_port_context()
|
port_context = self.plugin.get_bound_port_context()
|
||||||
port_context.current = port
|
port_context.current = port
|
||||||
self.plugin.update_port_status.reset_mock()
|
self.plugin.update_port_status.reset_mock()
|
||||||
self.callbacks.get_device_details('fake_context')
|
self.callbacks.get_device_details(mock.Mock())
|
||||||
self.assertTrue(self.plugin.update_port_status.called)
|
self.assertTrue(self.plugin.update_port_status.called)
|
||||||
|
|
||||||
def test_get_devices_details_list(self):
|
def test_get_devices_details_list(self):
|
||||||
@@ -155,8 +155,8 @@ class RpcCallbacksTestCase(base.BaseTestCase):
|
|||||||
def _test_update_device_not_bound_to_host(self, func):
|
def _test_update_device_not_bound_to_host(self, func):
|
||||||
self.plugin.port_bound_to_host.return_value = False
|
self.plugin.port_bound_to_host.return_value = False
|
||||||
self.plugin._device_to_port_id.return_value = 'fake_port_id'
|
self.plugin._device_to_port_id.return_value = 'fake_port_id'
|
||||||
res = func('fake_context', device='fake_device', host='fake_host')
|
res = func(mock.Mock(), device='fake_device', host='fake_host')
|
||||||
self.plugin.port_bound_to_host.assert_called_once_with('fake_context',
|
self.plugin.port_bound_to_host.assert_called_once_with(mock.ANY,
|
||||||
'fake_port_id',
|
'fake_port_id',
|
||||||
'fake_host')
|
'fake_host')
|
||||||
return res
|
return res
|
||||||
@@ -176,18 +176,18 @@ class RpcCallbacksTestCase(base.BaseTestCase):
|
|||||||
self.plugin._device_to_port_id.return_value = 'fake_port_id'
|
self.plugin._device_to_port_id.return_value = 'fake_port_id'
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
{'device': 'fake_device', 'exists': False},
|
{'device': 'fake_device', 'exists': False},
|
||||||
self.callbacks.update_device_down('fake_context',
|
self.callbacks.update_device_down(mock.Mock(),
|
||||||
device='fake_device',
|
device='fake_device',
|
||||||
host='fake_host'))
|
host='fake_host'))
|
||||||
self.plugin.update_port_status.assert_called_once_with(
|
self.plugin.update_port_status.assert_called_once_with(
|
||||||
'fake_context', 'fake_port_id', constants.PORT_STATUS_DOWN,
|
mock.ANY, 'fake_port_id', constants.PORT_STATUS_DOWN,
|
||||||
'fake_host')
|
'fake_host')
|
||||||
|
|
||||||
def test_update_device_down_call_update_port_status_failed(self):
|
def test_update_device_down_call_update_port_status_failed(self):
|
||||||
self.plugin.update_port_status.side_effect = exc.StaleDataError
|
self.plugin.update_port_status.side_effect = exc.StaleDataError
|
||||||
self.assertEqual({'device': 'fake_device', 'exists': False},
|
self.assertEqual({'device': 'fake_device', 'exists': False},
|
||||||
self.callbacks.update_device_down(
|
self.callbacks.update_device_down(
|
||||||
'fake_context', device='fake_device'))
|
mock.Mock(), device='fake_device'))
|
||||||
|
|
||||||
|
|
||||||
class RpcApiTestCase(base.BaseTestCase):
|
class RpcApiTestCase(base.BaseTestCase):
|
||||||
|
@@ -19,6 +19,7 @@ import math
|
|||||||
import mock
|
import mock
|
||||||
|
|
||||||
from neutron.common import constants as const
|
from neutron.common import constants as const
|
||||||
|
from neutron import context
|
||||||
from neutron.extensions import securitygroup as ext_sg
|
from neutron.extensions import securitygroup as ext_sg
|
||||||
from neutron import manager
|
from neutron import manager
|
||||||
from neutron.tests import tools
|
from neutron.tests import tools
|
||||||
@@ -51,6 +52,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
|
|||||||
test_sg_rpc.SGNotificationTestMixin):
|
test_sg_rpc.SGNotificationTestMixin):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(TestMl2SecurityGroups, self).setUp()
|
super(TestMl2SecurityGroups, self).setUp()
|
||||||
|
self.ctx = context.get_admin_context()
|
||||||
plugin = manager.NeutronManager.get_plugin()
|
plugin = manager.NeutronManager.get_plugin()
|
||||||
plugin.start_rpc_listeners()
|
plugin.start_rpc_listeners()
|
||||||
|
|
||||||
@@ -75,7 +77,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
|
|||||||
]
|
]
|
||||||
plugin = manager.NeutronManager.get_plugin()
|
plugin = manager.NeutronManager.get_plugin()
|
||||||
# should match full ID and starting chars
|
# should match full ID and starting chars
|
||||||
ports = plugin.get_ports_from_devices(
|
ports = plugin.get_ports_from_devices(self.ctx,
|
||||||
[orig_ports[0]['id'], orig_ports[1]['id'][0:8],
|
[orig_ports[0]['id'], orig_ports[1]['id'][0:8],
|
||||||
orig_ports[2]['id']])
|
orig_ports[2]['id']])
|
||||||
self.assertEqual(len(orig_ports), len(ports))
|
self.assertEqual(len(orig_ports), len(ports))
|
||||||
@@ -92,7 +94,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
|
|||||||
|
|
||||||
def test_security_group_get_ports_from_devices_with_bad_id(self):
|
def test_security_group_get_ports_from_devices_with_bad_id(self):
|
||||||
plugin = manager.NeutronManager.get_plugin()
|
plugin = manager.NeutronManager.get_plugin()
|
||||||
ports = plugin.get_ports_from_devices(['bad_device_id'])
|
ports = plugin.get_ports_from_devices(self.ctx, ['bad_device_id'])
|
||||||
self.assertFalse(ports)
|
self.assertFalse(ports)
|
||||||
|
|
||||||
def test_security_group_no_db_calls_with_no_ports(self):
|
def test_security_group_no_db_calls_with_no_ports(self):
|
||||||
@@ -100,7 +102,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
|
|||||||
with mock.patch(
|
with mock.patch(
|
||||||
'neutron.plugins.ml2.db.get_sg_ids_grouped_by_port'
|
'neutron.plugins.ml2.db.get_sg_ids_grouped_by_port'
|
||||||
) as get_mock:
|
) as get_mock:
|
||||||
self.assertFalse(plugin.get_ports_from_devices([]))
|
self.assertFalse(plugin.get_ports_from_devices(self.ctx, []))
|
||||||
self.assertFalse(get_mock.called)
|
self.assertFalse(get_mock.called)
|
||||||
|
|
||||||
def test_large_port_count_broken_into_parts(self):
|
def test_large_port_count_broken_into_parts(self):
|
||||||
@@ -114,10 +116,10 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
|
|||||||
mock.patch('neutron.plugins.ml2.db.get_sg_ids_grouped_by_port',
|
mock.patch('neutron.plugins.ml2.db.get_sg_ids_grouped_by_port',
|
||||||
return_value={}),
|
return_value={}),
|
||||||
) as (max_mock, get_mock):
|
) as (max_mock, get_mock):
|
||||||
plugin.get_ports_from_devices(
|
plugin.get_ports_from_devices(self.ctx,
|
||||||
['%s%s' % (const.TAP_DEVICE_PREFIX, i)
|
['%s%s' % (const.TAP_DEVICE_PREFIX, i)
|
||||||
for i in range(ports_to_query)])
|
for i in range(ports_to_query)])
|
||||||
all_call_args = map(lambda x: x[1][0], get_mock.mock_calls)
|
all_call_args = map(lambda x: x[1][1], get_mock.mock_calls)
|
||||||
last_call_args = all_call_args.pop()
|
last_call_args = all_call_args.pop()
|
||||||
# all but last should be getting MAX_PORTS_PER_QUERY ports
|
# all but last should be getting MAX_PORTS_PER_QUERY ports
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
@@ -139,14 +141,14 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
|
|||||||
# have one matching 'IN' critiera for all of the IDs
|
# have one matching 'IN' critiera for all of the IDs
|
||||||
with contextlib.nested(
|
with contextlib.nested(
|
||||||
mock.patch('neutron.plugins.ml2.db.or_'),
|
mock.patch('neutron.plugins.ml2.db.or_'),
|
||||||
mock.patch('neutron.plugins.ml2.db.db_api.get_session')
|
mock.patch('sqlalchemy.orm.Session.query')
|
||||||
) as (or_mock, sess_mock):
|
) as (or_mock, qmock):
|
||||||
qmock = sess_mock.return_value.query
|
|
||||||
fmock = qmock.return_value.outerjoin.return_value.filter
|
fmock = qmock.return_value.outerjoin.return_value.filter
|
||||||
# return no ports to exit the method early since we are mocking
|
# return no ports to exit the method early since we are mocking
|
||||||
# the query
|
# the query
|
||||||
fmock.return_value = []
|
fmock.return_value = []
|
||||||
plugin.get_ports_from_devices([test_base._uuid(),
|
plugin.get_ports_from_devices(self.ctx,
|
||||||
|
[test_base._uuid(),
|
||||||
test_base._uuid()])
|
test_base._uuid()])
|
||||||
# the or_ function should only have one argument
|
# the or_ function should only have one argument
|
||||||
or_mock.assert_called_once_with(mock.ANY)
|
or_mock.assert_called_once_with(mock.ANY)
|
||||||
|
@@ -89,7 +89,8 @@ class TestOneConvergenceSecurityGroups(OneConvergenceSecurityGroupsTestCase,
|
|||||||
req.get_response(self.api))
|
req.get_response(self.api))
|
||||||
port_id = res['port']['id']
|
port_id = res['port']['id']
|
||||||
plugin = manager.NeutronManager.get_plugin()
|
plugin = manager.NeutronManager.get_plugin()
|
||||||
port_dict = plugin.get_port_from_device(port_id)
|
port_dict = plugin.get_port_from_device(mock.Mock(),
|
||||||
|
port_id)
|
||||||
self.assertEqual(port_id, port_dict['id'])
|
self.assertEqual(port_id, port_dict['id'])
|
||||||
self.assertEqual([security_group_id],
|
self.assertEqual([security_group_id],
|
||||||
port_dict[ext_sg.SECURITYGROUPS])
|
port_dict[ext_sg.SECURITYGROUPS])
|
||||||
@@ -101,5 +102,5 @@ class TestOneConvergenceSecurityGroups(OneConvergenceSecurityGroupsTestCase,
|
|||||||
def test_security_group_get_port_from_device_with_no_port(self):
|
def test_security_group_get_port_from_device_with_no_port(self):
|
||||||
|
|
||||||
plugin = manager.NeutronManager.get_plugin()
|
plugin = manager.NeutronManager.get_plugin()
|
||||||
port_dict = plugin.get_port_from_device('bad_device_id')
|
port_dict = plugin.get_port_from_device(mock.Mock(), 'bad_device_id')
|
||||||
self.assertIsNone(port_dict)
|
self.assertIsNone(port_dict)
|
||||||
|
Reference in New Issue
Block a user