diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 43399b7152..f6ad7f55b5 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -40,12 +40,62 @@ from cassandra.marshal import varint_unpack from cassandra.protocol import QueryMessage from cassandra.query import dict_factory, bind_params -from cassandra.util import OrderedDict, Version +from cassandra.util import Version from cassandra.pool import HostDistance from cassandra.connection import EndPoint from cassandra.tablets import Tablets from cassandra.util import maybe_add_timeout_to_query + +class _RowView(Mapping): + """ + Lightweight read-only view over a row tuple, supporting dict-like access. + Shares a single index map across all rows from the same result set, + avoiding per-row dict allocation overhead. + + Implements the :class:`collections.abc.Mapping` protocol, providing + ``__getitem__``, ``__iter__``, ``__len__``, ``get``, ``keys``, + ``values``, ``items``, and ``__contains__`` for free. + """ + + __slots__ = ("_row", "_index_map") + + def __init__(self, row, index_map): + self._row = row + self._index_map = index_map + + def __getitem__(self, key): + return self._row[self._index_map[key]] + + def __iter__(self): + return iter(self._index_map) + + def __len__(self): + return len(self._index_map) + + def get(self, key, default=None): + idx = self._index_map.get(key) + if idx is not None: + return self._row[idx] + return default + + def __contains__(self, key): + return key in self._index_map + + def __repr__(self): + return repr({k: self._row[i] for k, i in self._index_map.items()}) + + +def _row_factory(colnames, rows): + """ + Lightweight replacement for dict_factory used internally by schema parsers. + Returns a list of _RowView objects that support row["key"] and row.get("key") + but store data as tuples with a shared column-name-to-index map. + """ + index_map = {name: i for i, name in enumerate(colnames)} + return [_RowView(row, index_map) for row in rows] + + log = logging.getLogger(__name__) cql_keywords = set(( @@ -1330,11 +1380,11 @@ def __init__(self, keyspace_name, name, partition_key=None, clustering_key=None, self.name = name self.partition_key = [] if partition_key is None else partition_key self.clustering_key = [] if clustering_key is None else clustering_key - self.columns = OrderedDict() if columns is None else columns + self.columns = {} if columns is None else columns self.indexes = {} self.options = {} if options is None else options self.comparator = None - self.triggers = OrderedDict() if triggers is None else triggers + self.triggers = {} if triggers is None else triggers self.views = {} self.virtual = virtual @@ -2007,7 +2057,7 @@ def get_next_pages(): yield next_result.parsed_rows result.parsed_rows += itertools.chain(*get_next_pages()) - return dict_factory(result.column_names, result.parsed_rows) if result else [] + return _row_factory(result.column_names, result.parsed_rows) if result else [] else: raise result @@ -2569,7 +2619,10 @@ class SchemaParserV3(SchemaParserV22): """ _SELECT_KEYSPACES = "SELECT * FROM system_schema.keyspaces" _SELECT_TABLES = "SELECT * FROM system_schema.tables" - _SELECT_COLUMNS = "SELECT * FROM system_schema.columns" + # Only fetch the columns used by _build_column_metadata / _build_table_columns. + # If _build_column_metadata or _build_table_columns needs more columns, this query + # should be updated accordingly. + _SELECT_COLUMNS = "SELECT keyspace_name, table_name, column_name, clustering_order, kind, position, type FROM system_schema.columns" _SELECT_INDEXES = "SELECT * FROM system_schema.indexes" _SELECT_TRIGGERS = "SELECT * FROM system_schema.triggers" _SELECT_TYPES = "SELECT * FROM system_schema.types" @@ -2739,31 +2792,40 @@ def _build_table_options(self, row): return dict((o, row.get(o)) for o in self.recognized_table_options if o in row) def _build_table_columns(self, meta, col_rows, compact_static=False, is_dense=False, virtual=False): - # partition key - partition_rows = [r for r in col_rows - if r.get('kind', None) == "partition_key"] + # Single-pass classification of column rows by kind + partition_rows = [] + clustering_rows = [] + other_rows = [] + for r in col_rows: + kind = r.get('kind', None) + if kind == "partition_key": + partition_rows.append(r) + elif kind == "clustering": + if not compact_static: + clustering_rows.append(r) + # else: skip clustering rows entirely for compact_static tables + else: + other_rows.append(r) + + # partition key - must be inserted first into meta.columns for CQL export ordering if len(partition_rows) > 1: - partition_rows = sorted(partition_rows, key=lambda row: row.get('position')) + partition_rows.sort(key=lambda row: row.get('position')) for r in partition_rows: - # we have to add meta here (and not in the later loop) because TableMetadata.columns is an - # OrderedDict, and it assumes keys are inserted first, in order, when exporting CQL column_meta = self._build_column_metadata(meta, r) meta.columns[column_meta.name] = column_meta - meta.partition_key.append(meta.columns[r.get('column_name')]) + meta.partition_key.append(column_meta) # clustering key - if not compact_static: - clustering_rows = [r for r in col_rows - if r.get('kind', None) == "clustering"] + if clustering_rows: if len(clustering_rows) > 1: - clustering_rows = sorted(clustering_rows, key=lambda row: row.get('position')) + clustering_rows.sort(key=lambda row: row.get('position')) for r in clustering_rows: column_meta = self._build_column_metadata(meta, r) meta.columns[column_meta.name] = column_meta - meta.clustering_key.append(meta.columns[r.get('column_name')]) + meta.clustering_key.append(column_meta) - for col_row in (r for r in col_rows - if r.get('kind', None) not in ('partition_key', 'clustering')): + # remaining columns (static, regular, etc.) + for col_row in other_rows: column_meta = self._build_column_metadata(meta, col_row) if is_dense and column_meta.cql_type == types.cql_empty_type: continue @@ -3056,11 +3118,12 @@ def get_all_keyspaces(self): @staticmethod def _build_keyspace_metadata_internal(row): - # necessary fields that aren't int virtual ks - row["durable_writes"] = row.get("durable_writes", None) - row["replication"] = row.get("replication", {}) - row["replication"]["class"] = row["replication"].get("class", None) - return super(SchemaParserV4, SchemaParserV4)._build_keyspace_metadata_internal(row) + # Read without mutating the row, since _RowView is read-only + name = row["keyspace_name"] + durable_writes = row.get("durable_writes", None) + replication = dict(row.get("replication")) if 'replication' in row else {} + replication_class = replication.pop("class") if 'class' in replication else None + return KeyspaceMetadata(name, durable_writes, replication_class, replication) class SchemaParserDSE67(SchemaParserV4): @@ -3328,7 +3391,7 @@ def __init__(self, keyspace_name, view_name, base_table_name, include_all_column self.base_table_name = base_table_name self.partition_key = [] self.clustering_key = [] - self.columns = OrderedDict() + self.columns = {} self.include_all_columns = include_all_columns self.where_clause = where_clause self.options = options or {} @@ -3464,7 +3527,7 @@ def get_column_from_system_local(connection, column_name: str, timeout, metadata , timeout=timeout, fail_on_error=False) if not success or not local_result.parsed_rows: return "" - local_rows = dict_factory(local_result.column_names, local_result.parsed_rows) + local_rows = _row_factory(local_result.column_names, local_result.parsed_rows) local_row = local_rows[0] return local_row.get(column_name) diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py index dcbb840447..6d33ccec93 100644 --- a/tests/unit/test_metadata.py +++ b/tests/unit/test_metadata.py @@ -32,7 +32,8 @@ _UnknownStrategy, ColumnMetadata, TableMetadata, IndexMetadata, Function, Aggregate, Metadata, TokenMap, ReplicationFactor, - SchemaParserDSE68) + SchemaParserDSE68, + _RowView, _row_factory) from cassandra.policies import SimpleConvictionPolicy from cassandra.pool import Host from cassandra.protocol import QueryMessage @@ -846,3 +847,66 @@ def test_strip_frozen(self): for argument, expected_result in argument_to_expected_results: result = strip_frozen(argument) assert result == expected_result, "strip_frozen() arg: {}".format(argument) + +class RowViewTest(unittest.TestCase): + """Tests for the internal _RowView and _row_factory helpers.""" + + def test_getitem(self): + rv = _RowView(("a_val", "b_val"), {"a": 0, "b": 1}) + self.assertEqual(rv["a"], "a_val") + self.assertEqual(rv["b"], "b_val") + + def test_getitem_missing_key(self): + rv = _RowView(("a_val",), {"a": 0}) + with self.assertRaises(KeyError): + rv["missing"] + + def test_get_present(self): + rv = _RowView(("a_val", "b_val"), {"a": 0, "b": 1}) + self.assertEqual(rv.get("a"), "a_val") + self.assertEqual(rv.get("b"), "b_val") + + def test_get_missing_returns_default(self): + rv = _RowView(("a_val",), {"a": 0}) + self.assertIsNone(rv.get("missing")) + self.assertEqual(rv.get("missing", 42), 42) + + def test_contains(self): + rv = _RowView(("a_val",), {"a": 0}) + self.assertIn("a", rv) + self.assertNotIn("b", rv) + + def test_repr(self): + rv = _RowView(("a_val", "b_val"), {"a": 0, "b": 1}) + r = repr(rv) + self.assertIn("'a'", r) + self.assertIn("'a_val'", r) + + def test_shared_index_map(self): + """All _RowView objects from the same _row_factory call share one index map.""" + rows = _row_factory(["x", "y"], [("x1", "y1"), ("x2", "y2")]) + self.assertIs(rows[0]._index_map, rows[1]._index_map) + + def test_read_only(self): + """_RowView must not allow item assignment or deletion.""" + rv = _RowView(("val",), {"col": 0}) + with self.assertRaises(TypeError): + rv["col"] = "new" + with self.assertRaises(TypeError): + del rv["col"] + + def test_row_factory_empty(self): + result = _row_factory(["a", "b"], []) + self.assertEqual(result, []) + + def test_row_factory_single_column(self): + rows = _row_factory(["only"], [("v1",), ("v2",)]) + self.assertEqual(rows[0]["only"], "v1") + self.assertEqual(rows[1]["only"], "v2") + + def test_row_factory_values(self): + rows = _row_factory(["id", "name"], [(1, "alice"), (2, "bob")]) + self.assertEqual(rows[0]["id"], 1) + self.assertEqual(rows[0]["name"], "alice") + self.assertEqual(rows[1]["id"], 2) + self.assertEqual(rows[1]["name"], "bob")