Reuse caller's session in ML2 DB methods

This patch changes the get_port_from_device_mac() and
get_sg_ids_grouped_by_port() methods in ML2 db.py module so that
they do not create a new database session (via get_session()), but
instead reuse the session associated with the caller's context.

In order to make the session that is associated with the caller's
context available to these ML2 DB methods, the
get_ports_from_devices plugin API in securitygroups_rps_base.py
needs to be modified so that the context can be passed down to the
ML2 plugin. (A similar change is made to the get_port_from_device
plugin API for consistency.)

Change-Id: I3f990895887e156de929bd7ac3732df114dd4a4b
Closes-Bug: 1441205
This commit is contained in:
Dane LeBlanc
2015-04-14 09:18:18 -04:00
parent 94f3a14f54
commit 47dd65cf98
12 changed files with 66 additions and 59 deletions

View File

@@ -76,10 +76,10 @@ class SecurityGroupServerRpcCallback(object):
def plugin(self): def plugin(self):
return manager.NeutronManager.get_plugin() return manager.NeutronManager.get_plugin()
def _get_devices_info(self, devices): def _get_devices_info(self, context, devices):
return dict( return dict(
(port['id'], port) (port['id'], port)
for port in self.plugin.get_ports_from_devices(devices) for port in self.plugin.get_ports_from_devices(context, devices)
if port and not port['device_owner'].startswith('network:') if port and not port['device_owner'].startswith('network:')
) )
@@ -93,7 +93,7 @@ class SecurityGroupServerRpcCallback(object):
:returns: port correspond to the devices with security group rules :returns: port correspond to the devices with security group rules
""" """
devices_info = kwargs.get('devices') devices_info = kwargs.get('devices')
ports = self._get_devices_info(devices_info) ports = self._get_devices_info(context, devices_info)
return self.plugin.security_group_rules_for_ports(context, ports) return self.plugin.security_group_rules_for_ports(context, ports)
def security_group_info_for_devices(self, context, **kwargs): def security_group_info_for_devices(self, context, **kwargs):
@@ -110,7 +110,7 @@ class SecurityGroupServerRpcCallback(object):
Note that sets are serialized into lists by rpc code. Note that sets are serialized into lists by rpc code.
""" """
devices_info = kwargs.get('devices') devices_info = kwargs.get('devices')
ports = self._get_devices_info(devices_info) ports = self._get_devices_info(context, devices_info)
return self.plugin.security_group_info_for_ports(context, ports) return self.plugin.security_group_info_for_ports(context, ports)

View File

@@ -38,7 +38,7 @@ DHCP_RULE_PORT = {4: (67, 68, q_const.IPv4), 6: (547, 546, q_const.IPv6)}
class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin): class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
"""Mixin class to add agent-based security group implementation.""" """Mixin class to add agent-based security group implementation."""
def get_port_from_device(self, device): def get_port_from_device(self, context, device):
"""Get port dict from device name on an agent. """Get port dict from device name on an agent.
Subclass must provide this method or get_ports_from_devices. Subclass must provide this method or get_ports_from_devices.
@@ -59,13 +59,14 @@ class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
"or get_ports_from_devices.") "or get_ports_from_devices.")
% self.__class__.__name__) % self.__class__.__name__)
def get_ports_from_devices(self, devices): def get_ports_from_devices(self, context, devices):
"""Bulk method of get_port_from_device. """Bulk method of get_port_from_device.
Subclasses may override this to provide better performance for DB Subclasses may override this to provide better performance for DB
queries, backend calls, etc. queries, backend calls, etc.
""" """
return [self.get_port_from_device(device) for device in devices] return [self.get_port_from_device(context, device)
for device in devices]
def create_security_group_rule(self, context, security_group_rule): def create_security_group_rule(self, context, security_group_rule):
bulk_rule = {'security_group_rules': [security_group_rule]} bulk_rule = {'security_group_rules': [security_group_rule]}

View File

@@ -19,7 +19,6 @@ from sqlalchemy import or_
from sqlalchemy.orm import exc from sqlalchemy.orm import exc
from neutron.common import constants as n_const from neutron.common import constants as n_const
from neutron.db import api as db_api
from neutron.db import models_v2 from neutron.db import models_v2
from neutron.db import securitygroups_db as sg_db from neutron.db import securitygroups_db as sg_db
from neutron.extensions import portbindings from neutron.extensions import portbindings
@@ -244,14 +243,14 @@ def get_port(session, port_id):
return return
def get_port_from_device_mac(device_mac): def get_port_from_device_mac(context, device_mac):
LOG.debug("get_port_from_device_mac() called for mac %s", device_mac) LOG.debug("get_port_from_device_mac() called for mac %s", device_mac)
session = db_api.get_session() qry = context.session.query(models_v2.Port).filter_by(
qry = session.query(models_v2.Port).filter_by(mac_address=device_mac) mac_address=device_mac)
return qry.first() return qry.first()
def get_ports_and_sgs(port_ids): def get_ports_and_sgs(context, port_ids):
"""Get ports from database with security group info.""" """Get ports from database with security group info."""
# break large queries into smaller parts # break large queries into smaller parts
@@ -259,25 +258,24 @@ def get_ports_and_sgs(port_ids):
LOG.debug("Number of ports %(pcount)s exceeds the maximum per " LOG.debug("Number of ports %(pcount)s exceeds the maximum per "
"query %(maxp)s. Partitioning queries.", "query %(maxp)s. Partitioning queries.",
{'pcount': len(port_ids), 'maxp': MAX_PORTS_PER_QUERY}) {'pcount': len(port_ids), 'maxp': MAX_PORTS_PER_QUERY})
return (get_ports_and_sgs(port_ids[:MAX_PORTS_PER_QUERY]) + return (get_ports_and_sgs(context, port_ids[:MAX_PORTS_PER_QUERY]) +
get_ports_and_sgs(port_ids[MAX_PORTS_PER_QUERY:])) get_ports_and_sgs(context, port_ids[MAX_PORTS_PER_QUERY:]))
LOG.debug("get_ports_and_sgs() called for port_ids %s", port_ids) LOG.debug("get_ports_and_sgs() called for port_ids %s", port_ids)
if not port_ids: if not port_ids:
# if port_ids is empty, avoid querying to DB to ask it for nothing # if port_ids is empty, avoid querying to DB to ask it for nothing
return [] return []
ports_to_sg_ids = get_sg_ids_grouped_by_port(port_ids) ports_to_sg_ids = get_sg_ids_grouped_by_port(context, port_ids)
return [make_port_dict_with_security_groups(port, sec_groups) return [make_port_dict_with_security_groups(port, sec_groups)
for port, sec_groups in ports_to_sg_ids.iteritems()] for port, sec_groups in ports_to_sg_ids.iteritems()]
def get_sg_ids_grouped_by_port(port_ids): def get_sg_ids_grouped_by_port(context, port_ids):
sg_ids_grouped_by_port = {} sg_ids_grouped_by_port = {}
session = db_api.get_session()
sg_binding_port = sg_db.SecurityGroupPortBinding.port_id sg_binding_port = sg_db.SecurityGroupPortBinding.port_id
with session.begin(subtransactions=True): with context.session.begin(subtransactions=True):
# partial UUIDs must be individually matched with startswith. # partial UUIDs must be individually matched with startswith.
# full UUIDs may be matched directly in an IN statement # full UUIDs may be matched directly in an IN statement
partial_uuids = set(port_id for port_id in port_ids partial_uuids = set(port_id for port_id in port_ids
@@ -288,8 +286,8 @@ def get_sg_ids_grouped_by_port(port_ids):
if full_uuids: if full_uuids:
or_criteria.append(models_v2.Port.id.in_(full_uuids)) or_criteria.append(models_v2.Port.id.in_(full_uuids))
query = session.query(models_v2.Port, query = context.session.query(
sg_db.SecurityGroupPortBinding.security_group_id) models_v2.Port, sg_db.SecurityGroupPortBinding.security_group_id)
query = query.outerjoin(sg_db.SecurityGroupPortBinding, query = query.outerjoin(sg_db.SecurityGroupPortBinding,
models_v2.Port.id == sg_binding_port) models_v2.Port.id == sg_binding_port)
query = query.filter(or_(*or_criteria)) query = query.filter(or_(*or_criteria))

View File

@@ -1451,11 +1451,12 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
port_host = db.get_port_binding_host(context.session, port_id) port_host = db.get_port_binding_host(context.session, port_id)
return (port_host == host) return (port_host == host)
def get_ports_from_devices(self, devices): def get_ports_from_devices(self, context, devices):
port_ids_to_devices = dict((self._device_to_port_id(device), device) port_ids_to_devices = dict(
for device in devices) (self._device_to_port_id(context, device), device)
for device in devices)
port_ids = port_ids_to_devices.keys() port_ids = port_ids_to_devices.keys()
ports = db.get_ports_and_sgs(port_ids) ports = db.get_ports_and_sgs(context, port_ids)
for port in ports: for port in ports:
# map back to original requested id # map back to original requested id
port_id = next((port_id for port_id in port_ids port_id = next((port_id for port_id in port_ids
@@ -1465,7 +1466,7 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
return ports return ports
@staticmethod @staticmethod
def _device_to_port_id(device): def _device_to_port_id(context, device):
# REVISIT(rkukura): Consider calling into MechanismDrivers to # REVISIT(rkukura): Consider calling into MechanismDrivers to
# process device names, or having MechanismDrivers supply list # process device names, or having MechanismDrivers supply list
# of device prefixes to strip. # of device prefixes to strip.
@@ -1475,7 +1476,7 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
# REVISIT(irenab): Consider calling into bound MD to # REVISIT(irenab): Consider calling into bound MD to
# handle the get_device_details RPC # handle the get_device_details RPC
if not uuidutils.is_uuid_like(device): if not uuidutils.is_uuid_like(device):
port = db.get_port_from_device_mac(device) port = db.get_port_from_device_mac(context, device)
if port: if port:
return port.id return port.id
return device return device

View File

@@ -67,7 +67,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin):
{'device': device, 'agent_id': agent_id, 'host': host}) {'device': device, 'agent_id': agent_id, 'host': host})
plugin = manager.NeutronManager.get_plugin() plugin = manager.NeutronManager.get_plugin()
port_id = plugin._device_to_port_id(device) port_id = plugin._device_to_port_id(rpc_context, device)
port_context = plugin.get_bound_port_context(rpc_context, port_context = plugin.get_bound_port_context(rpc_context,
port_id, port_id,
host, host,
@@ -144,7 +144,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin):
"%(agent_id)s", "%(agent_id)s",
{'device': device, 'agent_id': agent_id}) {'device': device, 'agent_id': agent_id})
plugin = manager.NeutronManager.get_plugin() plugin = manager.NeutronManager.get_plugin()
port_id = plugin._device_to_port_id(device) port_id = plugin._device_to_port_id(rpc_context, device)
port_exists = True port_exists = True
if (host and not plugin.port_bound_to_host(rpc_context, if (host and not plugin.port_bound_to_host(rpc_context,
port_id, host)): port_id, host)):
@@ -173,7 +173,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin):
LOG.debug("Device %(device)s up at agent %(agent_id)s", LOG.debug("Device %(device)s up at agent %(agent_id)s",
{'device': device, 'agent_id': agent_id}) {'device': device, 'agent_id': agent_id})
plugin = manager.NeutronManager.get_plugin() plugin = manager.NeutronManager.get_plugin()
port_id = plugin._device_to_port_id(device) port_id = plugin._device_to_port_id(rpc_context, device)
if (host and not plugin.port_bound_to_host(rpc_context, if (host and not plugin.port_bound_to_host(rpc_context,
port_id, host)): port_id, host)):
LOG.debug("Device %(device)s not bound to the" LOG.debug("Device %(device)s not bound to the"

View File

@@ -56,7 +56,7 @@ IPv6 = 6
class SecurityGroupServerRpcMixin(sg_db_rpc.SecurityGroupServerRpcMixin): class SecurityGroupServerRpcMixin(sg_db_rpc.SecurityGroupServerRpcMixin):
@staticmethod @staticmethod
def get_port_from_device(device): def get_port_from_device(context, device):
port = nvsd_db.get_port_from_device(device) port = nvsd_db.get_port_from_device(device)
if port: if port:
port['device'] = device port['device'] = device

View File

@@ -94,7 +94,7 @@ class SecurityGroupRpcTestPlugin(test_sg.SecurityGroupTestPlugin,
self.notify_security_groups_member_updated(context, port) self.notify_security_groups_member_updated(context, port)
del self.devices[id] del self.devices[id]
def get_port_from_device(self, device): def get_port_from_device(self, context, device):
device = self.devices.get(device) device = self.devices.get(device)
if device: if device:
device['security_group_rules'] = [] device['security_group_rules'] = []

View File

@@ -201,7 +201,8 @@ class Ml2DBTestCase(testlib_api.SqlTestCase):
self._setup_neutron_network(network_id) self._setup_neutron_network(network_id)
port = self._setup_neutron_port(network_id, port_id) port = self._setup_neutron_port(network_id, port_id)
observed_port = ml2_db.get_port_from_device_mac(port['mac_address']) observed_port = ml2_db.get_port_from_device_mac(self.ctx,
port['mac_address'])
self.assertEqual(port_id, observed_port.id) self.assertEqual(port_id, observed_port.id)
def test_get_locked_port_and_binding(self): def test_get_locked_port_and_binding(self):

View File

@@ -614,23 +614,26 @@ class TestMl2PluginOnly(Ml2PluginV2TestCase):
('qvo567890', '567890')] ('qvo567890', '567890')]
for device, expected in input_output: for device, expected in input_output:
self.assertEqual(expected, self.assertEqual(expected,
ml2_plugin.Ml2Plugin._device_to_port_id(device)) ml2_plugin.Ml2Plugin._device_to_port_id(
self.context, device))
def test__device_to_port_id_mac_address(self): def test__device_to_port_id_mac_address(self):
with self.port() as p: with self.port() as p:
mac = p['port']['mac_address'] mac = p['port']['mac_address']
port_id = p['port']['id'] port_id = p['port']['id']
self.assertEqual(port_id, self.assertEqual(port_id,
ml2_plugin.Ml2Plugin._device_to_port_id(mac)) ml2_plugin.Ml2Plugin._device_to_port_id(
self.context, mac))
def test__device_to_port_id_not_uuid_not_mac(self): def test__device_to_port_id_not_uuid_not_mac(self):
dev = '1234567' dev = '1234567'
self.assertEqual(dev, ml2_plugin.Ml2Plugin._device_to_port_id(dev)) self.assertEqual(dev, ml2_plugin.Ml2Plugin._device_to_port_id(
self.context, dev))
def test__device_to_port_id_UUID(self): def test__device_to_port_id_UUID(self):
port_id = uuidutils.generate_uuid() port_id = uuidutils.generate_uuid()
self.assertEqual(port_id, self.assertEqual(port_id, ml2_plugin.Ml2Plugin._device_to_port_id(
ml2_plugin.Ml2Plugin._device_to_port_id(port_id)) self.context, port_id))
class TestMl2DvrPortsV2(TestMl2PortsV2): class TestMl2DvrPortsV2(TestMl2PortsV2):

View File

@@ -75,14 +75,14 @@ class RpcCallbacksTestCase(base.BaseTestCase):
self.plugin.get_bound_port_context.return_value = None self.plugin.get_bound_port_context.return_value = None
self.assertEqual( self.assertEqual(
{'device': 'fake_device'}, {'device': 'fake_device'},
self.callbacks.get_device_details('fake_context', self.callbacks.get_device_details(mock.Mock(),
device='fake_device')) device='fake_device'))
def test_get_device_details_port_context_without_bounded_segment(self): def test_get_device_details_port_context_without_bounded_segment(self):
self.plugin.get_bound_port_context().bottom_bound_segment = None self.plugin.get_bound_port_context().bottom_bound_segment = None
self.assertEqual( self.assertEqual(
{'device': 'fake_device'}, {'device': 'fake_device'},
self.callbacks.get_device_details('fake_context', self.callbacks.get_device_details(mock.Mock(),
device='fake_device')) device='fake_device'))
def test_get_device_details_port_status_equal_new_status(self): def test_get_device_details_port_status_equal_new_status(self):
@@ -99,7 +99,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
port['admin_state_up'] = admin_state_up port['admin_state_up'] = admin_state_up
port['status'] = status port['status'] = status
self.plugin.update_port_status.reset_mock() self.plugin.update_port_status.reset_mock()
self.callbacks.get_device_details('fake_context') self.callbacks.get_device_details(mock.Mock())
self.assertEqual(status == new_status, self.assertEqual(status == new_status,
not self.plugin.update_port_status.called) not self.plugin.update_port_status.called)
@@ -109,7 +109,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
self.plugin.get_bound_port_context().current = port self.plugin.get_bound_port_context().current = port
self.plugin.get_bound_port_context().network.current = ( self.plugin.get_bound_port_context().network.current = (
{"id": "fake_network"}) {"id": "fake_network"})
self.callbacks.get_device_details('fake_context', host='fake_host', self.callbacks.get_device_details(mock.Mock(), host='fake_host',
cached_networks=cached_networks) cached_networks=cached_networks)
self.assertTrue('fake_port' in cached_networks) self.assertTrue('fake_port' in cached_networks)
@@ -119,7 +119,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
port_context.current = port port_context.current = port
port_context.host = 'fake' port_context.host = 'fake'
self.plugin.update_port_status.reset_mock() self.plugin.update_port_status.reset_mock()
self.callbacks.get_device_details('fake_context', self.callbacks.get_device_details(mock.Mock(),
host='fake_host') host='fake_host')
self.assertFalse(self.plugin.update_port_status.called) self.assertFalse(self.plugin.update_port_status.called)
@@ -128,7 +128,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
port_context = self.plugin.get_bound_port_context() port_context = self.plugin.get_bound_port_context()
port_context.current = port port_context.current = port
self.plugin.update_port_status.reset_mock() self.plugin.update_port_status.reset_mock()
self.callbacks.get_device_details('fake_context') self.callbacks.get_device_details(mock.Mock())
self.assertTrue(self.plugin.update_port_status.called) self.assertTrue(self.plugin.update_port_status.called)
def test_get_devices_details_list(self): def test_get_devices_details_list(self):
@@ -155,8 +155,8 @@ class RpcCallbacksTestCase(base.BaseTestCase):
def _test_update_device_not_bound_to_host(self, func): def _test_update_device_not_bound_to_host(self, func):
self.plugin.port_bound_to_host.return_value = False self.plugin.port_bound_to_host.return_value = False
self.plugin._device_to_port_id.return_value = 'fake_port_id' self.plugin._device_to_port_id.return_value = 'fake_port_id'
res = func('fake_context', device='fake_device', host='fake_host') res = func(mock.Mock(), device='fake_device', host='fake_host')
self.plugin.port_bound_to_host.assert_called_once_with('fake_context', self.plugin.port_bound_to_host.assert_called_once_with(mock.ANY,
'fake_port_id', 'fake_port_id',
'fake_host') 'fake_host')
return res return res
@@ -176,18 +176,18 @@ class RpcCallbacksTestCase(base.BaseTestCase):
self.plugin._device_to_port_id.return_value = 'fake_port_id' self.plugin._device_to_port_id.return_value = 'fake_port_id'
self.assertEqual( self.assertEqual(
{'device': 'fake_device', 'exists': False}, {'device': 'fake_device', 'exists': False},
self.callbacks.update_device_down('fake_context', self.callbacks.update_device_down(mock.Mock(),
device='fake_device', device='fake_device',
host='fake_host')) host='fake_host'))
self.plugin.update_port_status.assert_called_once_with( self.plugin.update_port_status.assert_called_once_with(
'fake_context', 'fake_port_id', constants.PORT_STATUS_DOWN, mock.ANY, 'fake_port_id', constants.PORT_STATUS_DOWN,
'fake_host') 'fake_host')
def test_update_device_down_call_update_port_status_failed(self): def test_update_device_down_call_update_port_status_failed(self):
self.plugin.update_port_status.side_effect = exc.StaleDataError self.plugin.update_port_status.side_effect = exc.StaleDataError
self.assertEqual({'device': 'fake_device', 'exists': False}, self.assertEqual({'device': 'fake_device', 'exists': False},
self.callbacks.update_device_down( self.callbacks.update_device_down(
'fake_context', device='fake_device')) mock.Mock(), device='fake_device'))
class RpcApiTestCase(base.BaseTestCase): class RpcApiTestCase(base.BaseTestCase):

View File

@@ -19,6 +19,7 @@ import math
import mock import mock
from neutron.common import constants as const from neutron.common import constants as const
from neutron import context
from neutron.extensions import securitygroup as ext_sg from neutron.extensions import securitygroup as ext_sg
from neutron import manager from neutron import manager
from neutron.tests import tools from neutron.tests import tools
@@ -51,6 +52,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
test_sg_rpc.SGNotificationTestMixin): test_sg_rpc.SGNotificationTestMixin):
def setUp(self): def setUp(self):
super(TestMl2SecurityGroups, self).setUp() super(TestMl2SecurityGroups, self).setUp()
self.ctx = context.get_admin_context()
plugin = manager.NeutronManager.get_plugin() plugin = manager.NeutronManager.get_plugin()
plugin.start_rpc_listeners() plugin.start_rpc_listeners()
@@ -75,7 +77,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
] ]
plugin = manager.NeutronManager.get_plugin() plugin = manager.NeutronManager.get_plugin()
# should match full ID and starting chars # should match full ID and starting chars
ports = plugin.get_ports_from_devices( ports = plugin.get_ports_from_devices(self.ctx,
[orig_ports[0]['id'], orig_ports[1]['id'][0:8], [orig_ports[0]['id'], orig_ports[1]['id'][0:8],
orig_ports[2]['id']]) orig_ports[2]['id']])
self.assertEqual(len(orig_ports), len(ports)) self.assertEqual(len(orig_ports), len(ports))
@@ -92,7 +94,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
def test_security_group_get_ports_from_devices_with_bad_id(self): def test_security_group_get_ports_from_devices_with_bad_id(self):
plugin = manager.NeutronManager.get_plugin() plugin = manager.NeutronManager.get_plugin()
ports = plugin.get_ports_from_devices(['bad_device_id']) ports = plugin.get_ports_from_devices(self.ctx, ['bad_device_id'])
self.assertFalse(ports) self.assertFalse(ports)
def test_security_group_no_db_calls_with_no_ports(self): def test_security_group_no_db_calls_with_no_ports(self):
@@ -100,7 +102,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
with mock.patch( with mock.patch(
'neutron.plugins.ml2.db.get_sg_ids_grouped_by_port' 'neutron.plugins.ml2.db.get_sg_ids_grouped_by_port'
) as get_mock: ) as get_mock:
self.assertFalse(plugin.get_ports_from_devices([])) self.assertFalse(plugin.get_ports_from_devices(self.ctx, []))
self.assertFalse(get_mock.called) self.assertFalse(get_mock.called)
def test_large_port_count_broken_into_parts(self): def test_large_port_count_broken_into_parts(self):
@@ -114,10 +116,10 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
mock.patch('neutron.plugins.ml2.db.get_sg_ids_grouped_by_port', mock.patch('neutron.plugins.ml2.db.get_sg_ids_grouped_by_port',
return_value={}), return_value={}),
) as (max_mock, get_mock): ) as (max_mock, get_mock):
plugin.get_ports_from_devices( plugin.get_ports_from_devices(self.ctx,
['%s%s' % (const.TAP_DEVICE_PREFIX, i) ['%s%s' % (const.TAP_DEVICE_PREFIX, i)
for i in range(ports_to_query)]) for i in range(ports_to_query)])
all_call_args = map(lambda x: x[1][0], get_mock.mock_calls) all_call_args = map(lambda x: x[1][1], get_mock.mock_calls)
last_call_args = all_call_args.pop() last_call_args = all_call_args.pop()
# all but last should be getting MAX_PORTS_PER_QUERY ports # all but last should be getting MAX_PORTS_PER_QUERY ports
self.assertTrue( self.assertTrue(
@@ -139,14 +141,14 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
# have one matching 'IN' critiera for all of the IDs # have one matching 'IN' critiera for all of the IDs
with contextlib.nested( with contextlib.nested(
mock.patch('neutron.plugins.ml2.db.or_'), mock.patch('neutron.plugins.ml2.db.or_'),
mock.patch('neutron.plugins.ml2.db.db_api.get_session') mock.patch('sqlalchemy.orm.Session.query')
) as (or_mock, sess_mock): ) as (or_mock, qmock):
qmock = sess_mock.return_value.query
fmock = qmock.return_value.outerjoin.return_value.filter fmock = qmock.return_value.outerjoin.return_value.filter
# return no ports to exit the method early since we are mocking # return no ports to exit the method early since we are mocking
# the query # the query
fmock.return_value = [] fmock.return_value = []
plugin.get_ports_from_devices([test_base._uuid(), plugin.get_ports_from_devices(self.ctx,
[test_base._uuid(),
test_base._uuid()]) test_base._uuid()])
# the or_ function should only have one argument # the or_ function should only have one argument
or_mock.assert_called_once_with(mock.ANY) or_mock.assert_called_once_with(mock.ANY)

View File

@@ -89,7 +89,8 @@ class TestOneConvergenceSecurityGroups(OneConvergenceSecurityGroupsTestCase,
req.get_response(self.api)) req.get_response(self.api))
port_id = res['port']['id'] port_id = res['port']['id']
plugin = manager.NeutronManager.get_plugin() plugin = manager.NeutronManager.get_plugin()
port_dict = plugin.get_port_from_device(port_id) port_dict = plugin.get_port_from_device(mock.Mock(),
port_id)
self.assertEqual(port_id, port_dict['id']) self.assertEqual(port_id, port_dict['id'])
self.assertEqual([security_group_id], self.assertEqual([security_group_id],
port_dict[ext_sg.SECURITYGROUPS]) port_dict[ext_sg.SECURITYGROUPS])
@@ -101,5 +102,5 @@ class TestOneConvergenceSecurityGroups(OneConvergenceSecurityGroupsTestCase,
def test_security_group_get_port_from_device_with_no_port(self): def test_security_group_get_port_from_device_with_no_port(self):
plugin = manager.NeutronManager.get_plugin() plugin = manager.NeutronManager.get_plugin()
port_dict = plugin.get_port_from_device('bad_device_id') port_dict = plugin.get_port_from_device(mock.Mock(), 'bad_device_id')
self.assertIsNone(port_dict) self.assertIsNone(port_dict)