Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 89 additions & 26 deletions cassandra/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,62 @@
from cassandra.marshal import varint_unpack
from cassandra.protocol import QueryMessage
from cassandra.query import dict_factory, bind_params
from cassandra.util import OrderedDict, Version
from cassandra.util import Version
from cassandra.pool import HostDistance
from cassandra.connection import EndPoint
from cassandra.tablets import Tablets
from cassandra.util import maybe_add_timeout_to_query


class _RowView(Mapping):
"""
Lightweight read-only view over a row tuple, supporting dict-like access.
Shares a single index map across all rows from the same result set,
avoiding per-row dict allocation overhead.

Implements the :class:`collections.abc.Mapping` protocol, providing
``__getitem__``, ``__iter__``, ``__len__``, ``get``, ``keys``,
``values``, ``items``, and ``__contains__`` for free.
"""

__slots__ = ("_row", "_index_map")

def __init__(self, row, index_map):
self._row = row
self._index_map = index_map

def __getitem__(self, key):
return self._row[self._index_map[key]]

def __iter__(self):
return iter(self._index_map)

def __len__(self):
return len(self._index_map)

def get(self, key, default=None):
idx = self._index_map.get(key)
if idx is not None:
return self._row[idx]
return default

def __contains__(self, key):
return key in self._index_map

def __repr__(self):
return repr({k: self._row[i] for k, i in self._index_map.items()})


def _row_factory(colnames, rows):
"""
Lightweight replacement for dict_factory used internally by schema parsers.
Returns a list of _RowView objects that support row["key"] and row.get("key")
but store data as tuples with a shared column-name-to-index map.
"""
index_map = {name: i for i, name in enumerate(colnames)}
return [_RowView(row, index_map) for row in rows]


log = logging.getLogger(__name__)

cql_keywords = set((
Expand Down Expand Up @@ -1330,11 +1380,11 @@ def __init__(self, keyspace_name, name, partition_key=None, clustering_key=None,
self.name = name
self.partition_key = [] if partition_key is None else partition_key
self.clustering_key = [] if clustering_key is None else clustering_key
self.columns = OrderedDict() if columns is None else columns
self.columns = {} if columns is None else columns
self.indexes = {}
self.options = {} if options is None else options
self.comparator = None
self.triggers = OrderedDict() if triggers is None else triggers
self.triggers = {} if triggers is None else triggers
self.views = {}
self.virtual = virtual

Expand Down Expand Up @@ -2007,7 +2057,7 @@ def get_next_pages():
yield next_result.parsed_rows

result.parsed_rows += itertools.chain(*get_next_pages())
return dict_factory(result.column_names, result.parsed_rows) if result else []
return _row_factory(result.column_names, result.parsed_rows) if result else []
else:
raise result

Expand Down Expand Up @@ -2569,7 +2619,10 @@ class SchemaParserV3(SchemaParserV22):
"""
_SELECT_KEYSPACES = "SELECT * FROM system_schema.keyspaces"
_SELECT_TABLES = "SELECT * FROM system_schema.tables"
_SELECT_COLUMNS = "SELECT * FROM system_schema.columns"
# Only fetch the columns used by _build_column_metadata / _build_table_columns.
# If _build_column_metadata or _build_table_columns needs more columns, this query
# should be updated accordingly.
_SELECT_COLUMNS = "SELECT keyspace_name, table_name, column_name, clustering_order, kind, position, type FROM system_schema.columns"
_SELECT_INDEXES = "SELECT * FROM system_schema.indexes"
_SELECT_TRIGGERS = "SELECT * FROM system_schema.triggers"
_SELECT_TYPES = "SELECT * FROM system_schema.types"
Expand Down Expand Up @@ -2739,31 +2792,40 @@ def _build_table_options(self, row):
return dict((o, row.get(o)) for o in self.recognized_table_options if o in row)

def _build_table_columns(self, meta, col_rows, compact_static=False, is_dense=False, virtual=False):
# partition key
partition_rows = [r for r in col_rows
if r.get('kind', None) == "partition_key"]
# Single-pass classification of column rows by kind
partition_rows = []
clustering_rows = []
other_rows = []
for r in col_rows:
kind = r.get('kind', None)
if kind == "partition_key":
partition_rows.append(r)
elif kind == "clustering":
if not compact_static:
clustering_rows.append(r)
# else: skip clustering rows entirely for compact_static tables
else:
other_rows.append(r)

# partition key - must be inserted first into meta.columns for CQL export ordering
if len(partition_rows) > 1:
partition_rows = sorted(partition_rows, key=lambda row: row.get('position'))
partition_rows.sort(key=lambda row: row.get('position'))
for r in partition_rows:
# we have to add meta here (and not in the later loop) because TableMetadata.columns is an
# OrderedDict, and it assumes keys are inserted first, in order, when exporting CQL
column_meta = self._build_column_metadata(meta, r)
meta.columns[column_meta.name] = column_meta
meta.partition_key.append(meta.columns[r.get('column_name')])
meta.partition_key.append(column_meta)

# clustering key
if not compact_static:
clustering_rows = [r for r in col_rows
if r.get('kind', None) == "clustering"]
if clustering_rows:
if len(clustering_rows) > 1:
clustering_rows = sorted(clustering_rows, key=lambda row: row.get('position'))
clustering_rows.sort(key=lambda row: row.get('position'))
for r in clustering_rows:
column_meta = self._build_column_metadata(meta, r)
meta.columns[column_meta.name] = column_meta
meta.clustering_key.append(meta.columns[r.get('column_name')])
meta.clustering_key.append(column_meta)

for col_row in (r for r in col_rows
if r.get('kind', None) not in ('partition_key', 'clustering')):
# remaining columns (static, regular, etc.)
for col_row in other_rows:
column_meta = self._build_column_metadata(meta, col_row)
if is_dense and column_meta.cql_type == types.cql_empty_type:
continue
Expand Down Expand Up @@ -3056,11 +3118,12 @@ def get_all_keyspaces(self):

@staticmethod
def _build_keyspace_metadata_internal(row):
# necessary fields that aren't int virtual ks
row["durable_writes"] = row.get("durable_writes", None)
row["replication"] = row.get("replication", {})
row["replication"]["class"] = row["replication"].get("class", None)
return super(SchemaParserV4, SchemaParserV4)._build_keyspace_metadata_internal(row)
# Read without mutating the row, since _RowView is read-only
name = row["keyspace_name"]
durable_writes = row.get("durable_writes", None)
replication = dict(row.get("replication")) if 'replication' in row else {}
replication_class = replication.pop("class") if 'class' in replication else None
return KeyspaceMetadata(name, durable_writes, replication_class, replication)


class SchemaParserDSE67(SchemaParserV4):
Expand Down Expand Up @@ -3328,7 +3391,7 @@ def __init__(self, keyspace_name, view_name, base_table_name, include_all_column
self.base_table_name = base_table_name
self.partition_key = []
self.clustering_key = []
self.columns = OrderedDict()
self.columns = {}
self.include_all_columns = include_all_columns
self.where_clause = where_clause
self.options = options or {}
Expand Down Expand Up @@ -3464,7 +3527,7 @@ def get_column_from_system_local(connection, column_name: str, timeout, metadata
, timeout=timeout, fail_on_error=False)
if not success or not local_result.parsed_rows:
return ""
local_rows = dict_factory(local_result.column_names, local_result.parsed_rows)
local_rows = _row_factory(local_result.column_names, local_result.parsed_rows)
local_row = local_rows[0]
return local_row.get(column_name)

Expand Down
66 changes: 65 additions & 1 deletion tests/unit/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
_UnknownStrategy, ColumnMetadata, TableMetadata,
IndexMetadata, Function, Aggregate,
Metadata, TokenMap, ReplicationFactor,
SchemaParserDSE68)
SchemaParserDSE68,
_RowView, _row_factory)
from cassandra.policies import SimpleConvictionPolicy
from cassandra.pool import Host
from cassandra.protocol import QueryMessage
Expand Down Expand Up @@ -846,3 +847,66 @@ def test_strip_frozen(self):
for argument, expected_result in argument_to_expected_results:
result = strip_frozen(argument)
assert result == expected_result, "strip_frozen() arg: {}".format(argument)

class RowViewTest(unittest.TestCase):
"""Tests for the internal _RowView and _row_factory helpers."""

def test_getitem(self):
rv = _RowView(("a_val", "b_val"), {"a": 0, "b": 1})
self.assertEqual(rv["a"], "a_val")
self.assertEqual(rv["b"], "b_val")

def test_getitem_missing_key(self):
rv = _RowView(("a_val",), {"a": 0})
with self.assertRaises(KeyError):
rv["missing"]

def test_get_present(self):
rv = _RowView(("a_val", "b_val"), {"a": 0, "b": 1})
self.assertEqual(rv.get("a"), "a_val")
self.assertEqual(rv.get("b"), "b_val")

def test_get_missing_returns_default(self):
rv = _RowView(("a_val",), {"a": 0})
self.assertIsNone(rv.get("missing"))
self.assertEqual(rv.get("missing", 42), 42)

def test_contains(self):
rv = _RowView(("a_val",), {"a": 0})
self.assertIn("a", rv)
self.assertNotIn("b", rv)

def test_repr(self):
rv = _RowView(("a_val", "b_val"), {"a": 0, "b": 1})
r = repr(rv)
self.assertIn("'a'", r)
self.assertIn("'a_val'", r)

def test_shared_index_map(self):
"""All _RowView objects from the same _row_factory call share one index map."""
rows = _row_factory(["x", "y"], [("x1", "y1"), ("x2", "y2")])
self.assertIs(rows[0]._index_map, rows[1]._index_map)

def test_read_only(self):
"""_RowView must not allow item assignment or deletion."""
rv = _RowView(("val",), {"col": 0})
with self.assertRaises(TypeError):
rv["col"] = "new"
with self.assertRaises(TypeError):
del rv["col"]

def test_row_factory_empty(self):
result = _row_factory(["a", "b"], [])
self.assertEqual(result, [])

def test_row_factory_single_column(self):
rows = _row_factory(["only"], [("v1",), ("v2",)])
self.assertEqual(rows[0]["only"], "v1")
self.assertEqual(rows[1]["only"], "v2")

def test_row_factory_values(self):
rows = _row_factory(["id", "name"], [(1, "alice"), (2, "bob")])
self.assertEqual(rows[0]["id"], 1)
self.assertEqual(rows[0]["name"], "alice")
self.assertEqual(rows[1]["id"], 2)
self.assertEqual(rows[1]["name"], "bob")
Loading