From 54c4f6eaeae9bfc2468cb38a4d5587135a66653e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 23 Oct 2025 14:14:20 -0700 Subject: [PATCH 01/14] fixed tests --- tests/unit/v1/test_pipeline_expressions.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 9f06c47b8..0e96d06e6 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -370,7 +370,7 @@ def test__from_query_filter_pb_composite_filter_or(self, mock_client): field1 = Field.of("field1") field2 = Field.of("field2") expected_cond1 = expr.And(field1.exists(), field1.equal(Constant("val1"))) - expected_cond2 = expr.And(field2.exists(), field2.equal(Constant(None))) + expected_cond2 = expr.And(field2.exists(), field2.is_null()) expected = expr.Or(expected_cond1, expected_cond2) assert repr(result) == repr(expected) @@ -458,7 +458,7 @@ def test__from_query_filter_pb_composite_filter_nested(self, mock_client): expected_cond1 = expr.And(field1.exists(), field1.equal(Constant("val1"))) expected_cond2 = expr.And(field2.exists(), field2.greater_than(Constant(10))) expected_cond3 = expr.And( - field3.exists(), expr.Not(field3.equal(Constant(None))) + field3.exists(), expr.Not(field3.is_null()) ) expected_inner_and = expr.And(expected_cond2, expected_cond3) expected_outer_or = expr.Or(expected_cond1, expected_inner_and) @@ -495,11 +495,11 @@ def test__from_query_filter_pb_composite_filter_unknown_op(self, mock_client): ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, - lambda f: f.equal(None), + lambda f: f.is_null(), ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, - lambda f: expr.Not(f.equal(None)), + lambda f: expr.Not(f.is_null()), ), ], ) @@ -836,7 +836,7 @@ def test_is_nan(self): def test_is_null(self): arg1 = self._make_arg("Value") - instance = Expr.is_ull(arg1) + instance = Expr.is_null(arg1) assert instance.name == "is_null" assert instance.params == [arg1] assert repr(instance) == "Value.is_null()" From 107e412cdeddeb3b0b53507fac7e6d8c89d917d6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 23 Oct 2025 16:23:50 -0700 Subject: [PATCH 02/14] added vector expressions --- google/cloud/firestore_v1/_pipeline_stages.py | 6 +- .../firestore_v1/pipeline_expressions.py | 62 ++++++++ tests/system/pipeline_e2e.yaml | 148 ++++++++++++++++++ tests/system/test_pipeline_acceptance.py | 3 +- tests/unit/v1/test_pipeline_expressions.py | 47 ++++++ 5 files changed, 263 insertions(+), 3 deletions(-) diff --git a/google/cloud/firestore_v1/_pipeline_stages.py b/google/cloud/firestore_v1/_pipeline_stages.py index 7233a8eec..2ada1c90f 100644 --- a/google/cloud/firestore_v1/_pipeline_stages.py +++ b/google/cloud/firestore_v1/_pipeline_stages.py @@ -274,13 +274,15 @@ def __init__( self, field: str | Expr, vector: Sequence[float] | Vector, - distance_measure: "DistanceMeasure", + distance_measure: "DistanceMeasure" | str, options: Optional["FindNearestOptions"] = None, ): super().__init__("find_nearest") self.field: Expr = Field(field) if isinstance(field, str) else field self.vector: Vector = vector if isinstance(vector, Vector) else Vector(vector) - self.distance_measure = distance_measure + self.distance_measure = distance_measure if isinstance( + distance_measure, DistanceMeasure + ) else DistanceMeasure[distance_measure.upper()] self.options = options or FindNearestOptions() def _pb_args(self): diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 4639e0f7d..9b510a7e4 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -116,6 +116,14 @@ def _to_pb(self) -> Value: def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": return o if isinstance(o, Expr) else Constant(o) + @staticmethod + def _convert_to_vector_expr(o: list[float] | Vector | Expr) -> Expr: + if isinstance(o, list): + o = Vector(o) + elif isinstance(o, Constant) and isinstance(o.value, list): + o = Vector(o.value) + return Expr._cast_to_expr_or_convert_to_constant(o) + class expose_as_static: """ Decorator to mark instance methods to be exposed as static methods as well as instance @@ -864,6 +872,60 @@ def map_get(self, key: str | Constant[str]) -> "Expr": "map_get", [self, Constant.of(key) if isinstance(key, str) else key] ) + @expose_as_static + def cosine_distance(self, other: Expr | list[float] | Vector) -> "Expr": + """Calculates the cosine distance between two vectors. + + Example: + >>> # Calculate the cosine distance between the 'userVector' field and the 'itemVector' field + >>> Field.of("userVector").cosine_distance(Field.of("itemVector")) + >>> # Calculate the Cosine distance between the 'location' field and a target location + >>> Field.of("location").cosine_distance([37.7749, -122.4194]) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to compare against. + + Returns: + A new `Expr` representing the cosine distance between the two vectors. + """ + return Function("cosine_distance", [self, self._convert_to_vector_expr(other)]) + + @expose_as_static + def euclidean_distance(self, other: Expr | list[float] | Vector) -> "Expr": + """Calculates the Euclidean distance between two vectors. + + Example: + >>> # Calculate the Euclidean distance between the 'location' field and a target location + >>> Field.of("location").euclidean_distance([37.7749, -122.4194]) + >>> # Calculate the Euclidean distance between two vector fields: 'pointA' and 'pointB' + >>> Field.of("pointA").euclidean_distance(Field.of("pointB")) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to compare against. + + Returns: + A new `Expr` representing the Euclidean distance between the two vectors. + """ + return Function("euclidean_distance", [self, self._convert_to_vector_expr(other)]) + + @expose_as_static + def dot_product(self, other: Expr | list[float] | Vector) -> "Expr": + """Calculates the dot product between two vectors. + + Example: + >>> # Calculate the dot product between a feature vector and a target vector + >>> Field.of("features").dot_product([0.5, 0.8, 0.2]) + >>> # Calculate the dot product between two document vectors: 'docVector1' and 'docVector2' + >>> Field.of("docVector1").dot_product(Field.of("docVector2")) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to calculate dot product with. + + Returns: + A new `Expr` representing the dot product between the two vectors. + """ + return Function("dot_product", [self, self._convert_to_vector_expr(other)]) + @expose_as_static def vector_length(self) -> "Expr": """Creates an expression that calculates the length (dimension) of a Firestore Vector. diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 50cc7c29d..50fc8e6ab 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -136,6 +136,10 @@ data: embedding: [1.0, 2.0, 3.0] vec2: embedding: [4.0, 5.0, 6.0, 7.0] + vec3: + embedding: [5.0, 6.0, 7.0] + vec4: + embedding: [1.0, 2.0, 4.0] tests: - description: "testAggregates - count" pipeline: @@ -1782,6 +1786,8 @@ tests: - Field: embedding_length - ASCENDING assert_results: + - embedding_length: 3 + - embedding_length: 3 - embedding_length: 3 - embedding_length: 4 - description: testTimestampFunctions @@ -1956,3 +1962,145 @@ tests: conditional_field: "Dystopian" - title: "Dune" conditional_field: "Frank Herbert" + - description: testFindNearestEuclidean + pipeline: + - Collection: vectors + - FindNearest: + field: embedding + vector: [1.0, 2.0, 3.0] + distance_measure: EUCLIDEAN + options: + FindNearestOptions: + limit: 2 + distance_field: + Field: distance + - Select: + - distance + assert_results: + - distance: 0.0 + - distance: 1.0 + assert_proto: + pipeline: + stages: + - name: collection + args: + - referenceValue: /vectors + - name: find_nearest + args: + - fieldReferenceValue: embedding + - mapValue: + fields: + __type__: + stringValue: __vector__ + value: + arrayValue: + values: + - doubleValue: 1.0 + - doubleValue: 2.0 + - doubleValue: 3.0 + - stringValue: euclidean + options: + limit: + integerValue: '2' + distance_field: + fieldReferenceValue: distance + - name: select + args: + - mapValue: + fields: + distance: + fieldReferenceValue: distance + - description: testFindNearestDotProduct + pipeline: + - Collection: vectors + - FindNearest: + field: embedding + vector: [1.0, 2.0, 3.0] + distance_measure: DOT_PRODUCT + options: + FindNearestOptions: + limit: 3 + distance_field: + Field: distance + - Select: + - distance + assert_results: + - distance: 38.0 + - distance: 17.0 + - distance: 14.0 + assert_proto: + pipeline: + stages: + - name: collection + args: + - referenceValue: /vectors + - name: find_nearest + args: + - fieldReferenceValue: embedding + - mapValue: + fields: + __type__: + stringValue: __vector__ + value: + arrayValue: + values: + - doubleValue: 1.0 + - doubleValue: 2.0 + - doubleValue: 3.0 + - stringValue: dot_product + options: + limit: + integerValue: '3' + distance_field: + fieldReferenceValue: distance + - name: select + args: + - mapValue: + fields: + distance: + fieldReferenceValue: distance + - description: testDotProductWithConstant + pipeline: + - Collection: vectors + - Where: + - Expr.equal: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - Select: + - AliasedExpr: + - Expr.dot_product: + - Field: embedding + - Vector: [1.0, 1.0, 1.0] + - "dot_product_result" + assert_results: + - dot_product_result: 6.0 + - description: testEuclideanDistanceWithConstant + pipeline: + - Collection: vectors + - Where: + - Expr.equal: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - Select: + - AliasedExpr: + - Expr.euclidean_distance: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - "euclidean_distance_result" + assert_results: + - euclidean_distance_result: 0.0 + - description: testCosineDistanceWithConstant + pipeline: + - Collection: vectors + - Where: + - Expr.equal: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - Select: + - AliasedExpr: + - Expr.cosine_distance: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - "cosine_distance_result" + assert_results: + - cosine_distance_result: 0.0 \ No newline at end of file diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index d4c654e63..fe1d7659b 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -28,6 +28,7 @@ from google.cloud.firestore_v1 import _pipeline_stages as stages from google.cloud.firestore_v1 import pipeline_expressions from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1 import pipeline_expressions as expr from google.api_core.exceptions import GoogleAPIError from google.cloud.firestore import Client, AsyncClient @@ -218,7 +219,7 @@ def _apply_yaml_args_to_callable(callable_obj, client, yaml_args): """ if isinstance(yaml_args, dict): return callable_obj(**_parse_expressions(client, yaml_args)) - elif isinstance(yaml_args, list): + elif isinstance(yaml_args, list) and not (callable_obj == expr.Constant or callable_obj == Vector): # yaml has an array of arguments. Treat as args return callable_obj(*_parse_expressions(client, yaml_args)) else: diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 0e96d06e6..4fa3dd8e8 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -1097,6 +1097,53 @@ def test_unix_seconds_to_timestamp(self): infix_instance = arg1.unix_seconds_to_timestamp() assert infix_instance == instance + def test_euclidean_distance(self): + arg1 = self._make_arg("Vector1") + arg2 = self._make_arg("Vector2") + instance = Expr.euclidean_distance(arg1, arg2) + assert instance.name == "euclidean_distance" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Vector1.euclidean_distance(Vector2)" + infix_instance = arg1.euclidean_distance(arg2) + assert infix_instance == instance + + def test_cosine_distance(self): + arg1 = self._make_arg("Vector1") + arg2 = self._make_arg("Vector2") + instance = Expr.cosine_distance(arg1, arg2) + assert instance.name == "cosine_distance" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Vector1.cosine_distance(Vector2)" + infix_instance = arg1.cosine_distance(arg2) + assert infix_instance == instance + + def test_dot_product(self): + arg1 = self._make_arg("Vector1") + arg2 = self._make_arg("Vector2") + instance = Expr.dot_product(arg1, arg2) + assert instance.name == "dot_product" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Vector1.dot_product(Vector2)" + infix_instance = arg1.dot_product(arg2) + assert infix_instance == instance + + @pytest.mark.parametrize("method", ["euclidean_distance", "cosine_distance", "dot_product"]) + @pytest.mark.parametrize( + "input", [Vector([1.0, 2.0]), [1, 2], Constant.of(Vector([1.0, 2.0])), Constant.of([1, 2]), []] + ) + def test_vector_ctor(self, method, input): + """ + test constructing various vector expressions with + different inputs + """ + arg1 = self._make_arg("Vector") + instance = getattr(arg1, method)(input) + assert instance.name == method + got_second_param = instance.params[1] + assert isinstance(got_second_param, Constant) + assert isinstance(got_second_param.value, Vector) + + def test_vector_length(self): arg1 = self._make_arg("Array") instance = Expr.vector_length(arg1) From a0c36ca4f5b744e4f8cf9cea9dbf9a54d097e7b4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 24 Oct 2025 10:44:58 -0700 Subject: [PATCH 03/14] added new math expressions --- .../firestore_v1/pipeline_expressions.py | 129 +++++++++++++ tests/system/pipeline_e2e.yaml | 172 ++++++++++++++++++ tests/system/test_pipeline_acceptance.py | 6 + tests/unit/v1/test_pipeline_expressions.py | 83 +++++++++ 4 files changed, 390 insertions(+) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 9b510a7e4..ecb51f2d0 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -247,6 +247,135 @@ def mod(self, other: Expr | float) -> "Expr": """ return Function("mod", [self, self._cast_to_expr_or_convert_to_constant(other)]) + @expose_as_static + def abs(self) -> "Expr": + """Creates an expression that calculates the absolute value of this expression. + + Example: + >>> # Get the absolute value of the 'change' field. + >>> Field.of("change").abs() + + Returns: + A new `Expr` representing the absolute value. + """ + return Function("abs", [self]) + + @expose_as_static + def ceil(self) -> "Expr": + """Creates an expression that calculates the ceiling of this expression. + + Example: + >>> # Get the ceiling of the 'value' field. + >>> Field.of("value").ceil() + + Returns: + A new `Expr` representing the ceiling value. + """ + return Function("ceil", [self]) + + @expose_as_static + def exp(self) -> "Expr": + """Creates an expression that computes e to the power of this expression. + + Example: + >>> # Compute e to the power of the 'value' field + >>> Field.of("value").exp() + + Returns: + A new `Expr` representing the exponential value. + """ + return Function("exp", [self]) + + @expose_as_static + def floor(self) -> "Expr": + """Creates an expression that calculates the floor of this expression. + + Example: + >>> # Get the floor of the 'value' field. + >>> Field.of("value").floor() + + Returns: + A new `Expr` representing the floor value. + """ + return Function("floor", [self]) + + @expose_as_static + def ln(self) -> "Expr": + """Creates an expression that calculates the natural logarithm of this expression. + + Example: + >>> # Get the natural logarithm of the 'value' field. + >>> Field.of("value").ln() + + Returns: + A new `Expr` representing the natural logarithm. + """ + return Function("ln", [self]) + + @expose_as_static + def log(self, base: Expr | float) -> "Expr": + """Creates an expression that calculates the logarithm of this expression with a given base. + + Example: + >>> # Get the logarithm of 'value' with base 2. + >>> Field.of("value").log(2) + >>> # Get the logarithm of 'value' with base from 'base_field'. + >>> Field.of("value").log(Field.of("base_field")) + + Args: + base: The base of the logarithm. + + Returns: + A new `Expr` representing the logarithm. + """ + return Function( + "log", [self, self._cast_to_expr_or_convert_to_constant(base)] + ) + + @expose_as_static + def pow(self, exponent: Expr | float) -> "Expr": + """Creates an expression that calculates this expression raised to the power of the exponent. + + Example: + >>> # Raise 'base_val' to the power of 2. + >>> Field.of("base_val").pow(2) + >>> # Raise 'base_val' to the power of 'exponent_val'. + >>> Field.of("base_val").pow(Field.of("exponent_val")) + + Args: + exponent: The exponent. + + Returns: + A new `Expr` representing the power operation. + """ + return Function("pow", [self, self._cast_to_expr_or_convert_to_constant(exponent)]) + + @expose_as_static + def round(self) -> "Expr": + """Creates an expression that rounds this expression to the nearest integer. + + Example: + >>> # Round the 'value' field. + >>> Field.of("value").round() + + Returns: + A new `Expr` representing the rounded value. + """ + return Function("round", [self]) + + @expose_as_static + def sqrt(self) -> "Expr": + """Creates an expression that calculates the square root of this expression. + + Example: + >>> # Get the square root of the 'area' field. + >>> Field.of("area").sqrt() + + Returns: + A new `Expr` representing the square root. + """ + return Function("sqrt", [self]) + @expose_as_static def logical_maximum(self, other: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the larger value between this expression diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 50fc8e6ab..f50efbc23 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -2010,6 +2010,178 @@ tests: fields: distance: fieldReferenceValue: distance + - description: testMathExpressions + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: To Kill a Mockingbird + - Select: + - AliasedExpr: + - Expr.abs: + - Field: rating + - "abs_rating" + - AliasedExpr: + - Expr.ceil: + - Field: rating + - "ceil_rating" + - AliasedExpr: + - Expr.exp: + - Field: rating + - "exp_rating" + - AliasedExpr: + - Expr.floor: + - Field: rating + - "floor_rating" + - AliasedExpr: + - Expr.ln: + - Field: rating + - "ln_rating" + - AliasedExpr: + - Expr.log: + - Field: rating + - Constant: 10 + - "log_rating_base10" + - AliasedExpr: + - Expr.pow: + - Field: rating + - Constant: 2 + - "pow_rating" + - AliasedExpr: + - Expr.sqrt: + - Field: rating + - "sqrt_rating" + assert_results_approximate: + - abs_rating: 4.2 + ceil_rating: 5.0 + exp_rating: 66.686331 + floor_rating: 4.0 + ln_rating: 1.4350845 + log_rating_base10: 0.623249 + pow_rating: 17.64 + sqrt_rating: 2.049390 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: To Kill a Mockingbird + name: equal + name: where + - args: + - mapValue: + fields: + abs_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: abs + ceil_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: ceil + exp_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: exp + floor_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: floor + ln_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: ln + log_rating_base10: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '10' + name: log + pow_rating: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: pow + sqrt_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: sqrt + name: select + - description: testRoundExpressions + pipeline: + - Collection: books + - Where: + - Expr.equal_any: + - Field: title + - - Constant: "To Kill a Mockingbird" # rating 4.2 + - Constant: "Pride and Prejudice" # rating 4.5 + - Constant: "The Lord of the Rings" # rating 4.7 + - Select: + - title + - AliasedExpr: + - Expr.round: + - Field: rating + - "round_rating" + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "Pride and Prejudice" + round_rating: 5.0 + - title: "The Lord of the Rings" + round_rating: 5.0 + - title: "To Kill a Mockingbird" + round_rating: 4.0 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - arrayValue: + values: + - stringValue: "To Kill a Mockingbird" + - stringValue: "Pride and Prejudice" + - stringValue: "The Lord of the Rings" + name: equal_any + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + round_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: round + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort - description: testFindNearestDotProduct pipeline: - Collection: vectors diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index fe1d7659b..ab63afd37 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -95,12 +95,15 @@ def test_pipeline_results(test_dict, client): Ensure pipeline returns expected results """ expected_results = _parse_yaml_types(test_dict.get("assert_results", None)) + expected_approximate_results = _parse_yaml_types(test_dict.get("assert_results_approximate", None)) expected_count = test_dict.get("assert_count", None) pipeline = parse_pipeline(client, test_dict["pipeline"]) # check if server responds as expected got_results = [snapshot.data() for snapshot in pipeline.stream()] if expected_results: assert got_results == expected_results + if expected_approximate_results: + assert got_results == pytest.approximate(expected_approximate_results) if expected_count is not None: assert len(got_results) == expected_count @@ -136,12 +139,15 @@ async def test_pipeline_results_async(test_dict, async_client): Ensure pipeline returns expected results """ expected_results = _parse_yaml_types(test_dict.get("assert_results", None)) + expected_approximate_results = _parse_yaml_types(test_dict.get("assert_results_approximate", None)) expected_count = test_dict.get("assert_count", None) pipeline = parse_pipeline(async_client, test_dict["pipeline"]) # check if server responds as expected got_results = [snapshot.data() async for snapshot in pipeline.stream()] if expected_results: assert got_results == expected_results + if expected_approximate_results: + assert got_results == pytest.approximate(expected_approximate_results) if expected_count is not None: assert len(got_results) == expected_count diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 4fa3dd8e8..d49dcea13 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -1163,6 +1163,89 @@ def test_add(self): infix_instance = arg1.add(arg2) assert infix_instance == instance + def test_abs(self): + arg1 = self._make_arg("Value") + instance = Expr.abs(arg1) + assert instance.name == "abs" + assert instance.params == [arg1] + assert repr(instance) == "Value.abs()" + infix_instance = arg1.abs() + assert infix_instance == instance + + def test_ceil(self): + arg1 = self._make_arg("Value") + instance = Expr.ceil(arg1) + assert instance.name == "ceil" + assert instance.params == [arg1] + assert repr(instance) == "Value.ceil()" + infix_instance = arg1.ceil() + assert infix_instance == instance + + def test_exp(self): + arg1 = self._make_arg("Value") + instance = Expr.exp(arg1) + assert instance.name == "exp" + assert instance.params == [arg1] + assert repr(instance) == "Value.exp()" + infix_instance = arg1.exp() + assert infix_instance == instance + + def test_floor(self): + arg1 = self._make_arg("Value") + instance = Expr.floor(arg1) + assert instance.name == "floor" + assert instance.params == [arg1] + assert repr(instance) == "Value.floor()" + infix_instance = arg1.floor() + assert infix_instance == instance + + def test_ln(self): + arg1 = self._make_arg("Value") + instance = Expr.ln(arg1) + assert instance.name == "ln" + assert instance.params == [arg1] + assert repr(instance) == "Value.ln()" + infix_instance = arg1.ln() + assert infix_instance == instance + + def test_log(self): + arg1 = self._make_arg("Value") + arg2 = self._make_arg("Base") + instance = Expr.log(arg1, arg2) + assert instance.name == "log" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Value.log(Base)" + infix_instance = arg1.log(arg2) + assert infix_instance == instance + + def test_pow(self): + arg1 = self._make_arg("Value") + arg2 = self._make_arg("Exponent") + instance = Expr.pow(arg1, arg2) + assert instance.name == "pow" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Value.pow(Exponent)" + infix_instance = arg1.pow(arg2) + assert infix_instance == instance + + def test_round(self): + arg1 = self._make_arg("Value") + instance = Expr.round(arg1) + assert instance.name == "round" + assert instance.params == [arg1] + assert repr(instance) == "Value.round()" + infix_instance = arg1.round() + assert infix_instance == instance + + def test_sqrt(self): + arg1 = self._make_arg("Value") + instance = Expr.sqrt(arg1) + assert instance.name == "sqrt" + assert instance.params == [arg1] + assert repr(instance) == "Value.sqrt()" + infix_instance = arg1.sqrt() + assert infix_instance == instance + def test_array_length(self): arg1 = self._make_arg("Array") instance = Expr.array_length(arg1) From dcd6af1c69b54884de23e901c874d5a820692e5b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 24 Oct 2025 13:24:40 -0700 Subject: [PATCH 04/14] added string manipulation expressions --- .../firestore_v1/pipeline_expressions.py | 93 ++++++ tests/system/pipeline_e2e.yaml | 270 +++++++++++++++++- tests/unit/v1/test_pipeline_expressions.py | 67 +++++ 3 files changed, 429 insertions(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index ecb51f2d0..d83d27e88 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -983,6 +983,99 @@ def string_concat(self, *elements: Expr | CONSTANT_TYPE) -> "Expr": [self] + [self._cast_to_expr_or_convert_to_constant(el) for el in elements], ) + @expose_as_static + def to_lower(self) -> "Expr": + """Creates an expression that converts a string to lowercase. + + Example: + >>> # Convert the 'name' field to lowercase + >>> Field.of("name").to_lower() + + Returns: + A new `Expr` representing the lowercase string. + """ + return Function("to_lower", [self]) + + @expose_as_static + def to_upper(self) -> "Expr": + """Creates an expression that converts a string to uppercase. + + Example: + >>> # Convert the 'title' field to uppercase + >>> Field.of("title").to_upper() + + Returns: + A new `Expr` representing the uppercase string. + """ + return Function("to_upper", [self]) + + @expose_as_static + def trim(self) -> "Expr": + """Creates an expression that removes leading and trailing whitespace from a string. + + Example: + >>> # Trim whitespace from the 'userInput' field + >>> Field.of("userInput").trim() + + Returns: + A new `Expr` representing the trimmed string. + """ + return Function("trim", [self]) + + @expose_as_static + def string_reverse(self) -> "Expr": + """Creates an expression that reverses a string. + + Example: + >>> # Reverse the 'userInput' field + >>> Field.of("userInput").reverse() + + Returns: + A new `Expr` representing the reversed string. + """ + return Function("string_reverse", [self]) + + @expose_as_static + def substring(self, position: Expr | int, length: Expr | int | None=None) -> "Expr": + """Creates an expression that returns a substring of the results of this expression. + + + Example: + >>> Field.of("description").substring(5, 10) + >>> Field.of("description").substring(5) + + Args: + position: the index of the first character of the substring. + length: the length of the substring. If not provided the substring + will end at the end of the input. + + Returns: + A new `Expr` representing the extracted substring. + """ + args = [self, self._cast_to_expr_or_convert_to_constant(position)] + if length is not None: + args.append(self._cast_to_expr_or_convert_to_constant(length)) + return Function("substring", args) + + @expose_as_static + def join(self, delimeter: Expr | str) -> "Expr": + """Creates an expression that joins the elements of an array into a string + + + Example: + >>> Field.of("tags").join(", ") + + Args: + delimiter: The delimiter to add between the elements of the array. + + Returns: + A new `Expr` representing the joined string. + """ + return Function( + "join", [self, self._cast_to_expr_or_convert_to_constant(delimeter)] + ) + + @expose_as_static def map_get(self, key: str | Constant[str]) -> "Expr": """Accesses a value from the map produced by evaluating this expression. diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index f50efbc23..10a72476c 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -2275,4 +2275,272 @@ tests: - Vector: [1.0, 2.0, 3.0] - "cosine_distance_result" assert_results: - - cosine_distance_result: 0.0 \ No newline at end of file + - cosine_distance_result: 0.0 + - description: testStringFunctions - ToLower + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.to_lower: + - Field: title + - "lower_title" + assert_results: + - lower_title: "the hitchhiker's guide to the galaxy" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + lower_title: + functionValue: + args: + - fieldReferenceValue: title + name: to_lower + name: select + - description: testStringFunctions - ToUpper + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.to_upper: + - Field: title + - "upper_title" + assert_results: + - upper_title: "THE HITCHHIKER'S GUIDE TO THE GALAXY" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + upper_title: + functionValue: + args: + - fieldReferenceValue: title + name: to_upper + name: select + - description: testStringFunctions - Trim + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.trim: + - Expr.string_concat: + - Constant: " " + - Field: title + - Constant: " " + - "trimmed_title" + assert_results: + - trimmed_title: "The Hitchhiker's Guide to the Galaxy" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + trimmed_title: + functionValue: + args: + - functionValue: + args: + - stringValue: " " + - fieldReferenceValue: title + - stringValue: " " + name: string_concat + name: trim + name: select + - description: testStringFunctions - StringReverse + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Jane Austen" + - Select: + - AliasedExpr: + - Expr.string_reverse: + - Field: title + - "reversed_title" + assert_results: + - reversed_title: "ecidujerP dna edirP" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Jane Austen" + name: equal + name: where + - args: + - mapValue: + fields: + reversed_title: + functionValue: + args: + - fieldReferenceValue: title + name: string_reverse + name: select + - description: testStringFunctions - Substring + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.substring: + - Field: title + - Constant: 4 + - Constant: 11 + - "substring_title" + assert_results: + - substring_title: "Hitchhiker'" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Douglas Adams" + name: equal + name: where + - args: + - mapValue: + fields: + substring_title: + functionValue: + args: + - fieldReferenceValue: title + - integerValue: '4' + - integerValue: '11' + name: substring + name: select + - description: testStringFunctions - Substring without length + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Fyodor Dostoevsky" + - Select: + - AliasedExpr: + - Expr.substring: + - Field: title + - Constant: 10 + - "substring_title" + assert_results: + - substring_title: "Punishment" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Fyodor Dostoevsky" + name: equal + name: where + - args: + - mapValue: + fields: + substring_title: + functionValue: + args: + - fieldReferenceValue: title + - integerValue: '10' + name: substring + name: select + - description: testStringFunctions - Join + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.join: + - Field: tags + - Constant: ", " + - "joined_tags" + assert_results: + - joined_tags: "comedy, space, adventure" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Douglas Adams" + name: equal + name: where + - args: + - mapValue: + fields: + joined_tags: + functionValue: + args: + - fieldReferenceValue: tags + - stringValue: ", " + name: join + name: select \ No newline at end of file diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index d49dcea13..15cb86eb0 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -970,6 +970,73 @@ def test_logical_minimum(self): infix_instance = arg1.logical_minimum(arg2) assert infix_instance == instance + def test_to_lower(self): + arg1 = self._make_arg("Input") + instance = Expr.to_lower(arg1) + assert instance.name == "to_lower" + assert instance.params == [arg1] + assert repr(instance) == "Input.to_lower()" + infix_instance = arg1.to_lower() + assert infix_instance == instance + + def test_to_upper(self): + arg1 = self._make_arg("Input") + instance = Expr.to_upper(arg1) + assert instance.name == "to_upper" + assert instance.params == [arg1] + assert repr(instance) == "Input.to_upper()" + infix_instance = arg1.to_upper() + assert infix_instance == instance + + def test_trim(self): + arg1 = self._make_arg("Input") + instance = Expr.trim(arg1) + assert instance.name == "trim" + assert instance.params == [arg1] + assert repr(instance) == "Input.trim()" + infix_instance = arg1.trim() + assert infix_instance == instance + + def test_string_reverse(self): + arg1 = self._make_arg("Input") + instance = Expr.string_reverse(arg1) + assert instance.name == "string_reverse" + assert instance.params == [arg1] + assert repr(instance) == "Input.string_reverse()" + infix_instance = arg1.string_reverse() + assert infix_instance == instance + + def test_substring(self): + arg1 = self._make_arg("Input") + arg2 = self._make_arg("Position") + instance = Expr.substring(arg1, arg2) + assert instance.name == "substring" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Input.substring(Position)" + infix_instance = arg1.substring(arg2) + assert infix_instance == instance + + def test_substring_w_length(self): + arg1 = self._make_arg("Input") + arg2 = self._make_arg("Position") + arg3 = self._make_arg("Length") + instance = Expr.substring(arg1, arg2, arg3) + assert instance.name == "substring" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "Input.substring(Position, Length)" + infix_instance = arg1.substring(arg2, arg3) + assert infix_instance == instance + + def test_join(self): + arg1 = self._make_arg("Array") + arg2 = self._make_arg("Separator") + instance = Expr.join(arg1, arg2) + assert instance.name == "join" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Array.join(Separator)" + infix_instance = arg1.join(arg2) + assert infix_instance == instance + def test_map_get(self): arg1 = self._make_arg("Map") arg2 = "key" From 71308dc030483eb1a5061626682cb22dac11242f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 24 Oct 2025 13:41:55 -0700 Subject: [PATCH 05/14] added not_nan, not_null, and is_absent --- .../firestore_v1/pipeline_expressions.py | 45 ++++++++++++++- tests/system/pipeline_e2e.yaml | 57 +++++++++++++++++++ tests/unit/v1/test_pipeline_expressions.py | 35 ++++++++++-- 3 files changed, 130 insertions(+), 7 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index d83d27e88..bd0e17547 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -711,6 +711,20 @@ def array_reverse(self) -> "Expr": """ return Function("array_reverse", [self]) + @expose_as_static + def is_absent(self) -> "BooleanExpr": + """Creates an expression that returns true if a value is absent. Otherwise, returns false even if + the value is null. + + Example: + >>> # Check if the 'email' field is absent. + >>> Field.of("email").is_absent() + + Returns: + A new `BooleanExpression` representing the isAbsent operation. + """ + return BooleanExpr("is_absent", [self]) + @expose_as_static def is_nan(self) -> "BooleanExpr": """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). @@ -724,9 +738,22 @@ def is_nan(self) -> "BooleanExpr": """ return BooleanExpr("is_nan", [self]) + @expose_as_static + def is_not_nan(self) -> "BooleanExpr": + """Creates an expression that checks if this expression evaluates to a non-'NaN' (Not a Number) value. + + Example: + >>> # Check if the result of a calculation is not NaN + >>> Field.of("value").divide(1).is_not_nan() + + Returns: + A new `Expr` representing the 'is not NaN' check. + """ + return BooleanExpr("is_not_nan", [self]) + @expose_as_static def is_null(self) -> "BooleanExpr": - """Creates an expression that checks if this expression evaluates to 'Null'. + """Creates an expression that checks if the value of a field is 'Null'. Example: >>> Field.of("value").is_null() @@ -736,6 +763,18 @@ def is_null(self) -> "BooleanExpr": """ return BooleanExpr("is_null", [self]) + @expose_as_static + def is_not_null(self) -> "BooleanExpr": + """Creates an expression that checks if the value of a field is not 'Null'. + + Example: + >>> Field.of("value").is_not_null() + + Returns: + A new `Expr` representing the 'isNotNull' check. + """ + return BooleanExpr("is_not_null", [self]) + @expose_as_static def exists(self) -> "BooleanExpr": """Creates an expression that checks if a field exists in the document. @@ -1607,11 +1646,11 @@ def _from_query_filter_pb(filter_pb, client): if filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NAN: return And(field.exists(), field.is_nan()) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NAN: - return And(field.exists(), Not(field.is_nan())) + return And(field.exists(), field.is_not_nan()) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NULL: return And(field.exists(), field.is_null()) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NULL: - return And(field.exists(), Not(field.is_null())) + return And(field.exists(), field.is_not_null()) else: raise TypeError(f"Unexpected UnaryFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.FieldFilter): diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 10a72476c..83429fb70 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -1418,6 +1418,63 @@ tests: - args: - integerValue: '1' name: limit + - description: testIsNotNull + pipeline: + - Collection: books + - Where: + - Expr.is_not_null: + - Field: rating + assert_count: 10 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: rating + name: is_not_null + name: where + - description: testIsNotNaN + pipeline: + - Collection: books + - Where: + - Expr.is_not_nan: + - Field: rating + assert_count: 10 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: rating + name: is_not_nan + name: where + - description: testIsAbsent + pipeline: + - Collection: books + - Where: + - Expr.is_absent: + - Field: awards.pulitzer + assert_count: 9 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: awards.pulitzer + name: is_absent + name: where - description: testLogicalMinMax pipeline: - Collection: books diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 15cb86eb0..567960808 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -458,7 +458,7 @@ def test__from_query_filter_pb_composite_filter_nested(self, mock_client): expected_cond1 = expr.And(field1.exists(), field1.equal(Constant("val1"))) expected_cond2 = expr.And(field2.exists(), field2.greater_than(Constant(10))) expected_cond3 = expr.And( - field3.exists(), expr.Not(field3.is_null()) + field3.exists(), field3.is_not_null() ) expected_inner_and = expr.And(expected_cond2, expected_cond3) expected_outer_or = expr.Or(expected_cond1, expected_inner_and) @@ -491,15 +491,15 @@ def test__from_query_filter_pb_composite_filter_unknown_op(self, mock_client): (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NAN, Expr.is_nan), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NAN, - lambda f: expr.Not(f.is_nan()), + Expr.is_not_nan, ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, - lambda f: f.is_null(), + Expr.is_null, ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, - lambda f: expr.Not(f.is_null()), + Expr.is_not_null, ), ], ) @@ -825,6 +825,15 @@ def test_not_equal_any(self): infix_instance = arg1.not_equal_any([arg2, arg3]) assert infix_instance == instance + def test_is_absent(self): + arg1 = self._make_arg("Field") + instance = Expr.is_absent(arg1) + assert instance.name == "is_absent" + assert instance.params == [arg1] + assert repr(instance) == "Field.is_absent()" + infix_instance = arg1.is_absent() + assert infix_instance == instance + def test_is_nan(self): arg1 = self._make_arg("Value") instance = Expr.is_nan(arg1) @@ -834,6 +843,15 @@ def test_is_nan(self): infix_instance = arg1.is_nan() assert infix_instance == instance + def test_is_not_nan(self): + arg1 = self._make_arg("Value") + instance = Expr.is_not_nan(arg1) + assert instance.name == "is_not_nan" + assert instance.params == [arg1] + assert repr(instance) == "Value.is_not_nan()" + infix_instance = arg1.is_not_nan() + assert infix_instance == instance + def test_is_null(self): arg1 = self._make_arg("Value") instance = Expr.is_null(arg1) @@ -843,6 +861,15 @@ def test_is_null(self): infix_instance = arg1.is_null() assert infix_instance == instance + def test_is_not_null(self): + arg1 = self._make_arg("Value") + instance = Expr.is_not_null(arg1) + assert instance.name == "is_not_null" + assert instance.params == [arg1] + assert repr(instance) == "Value.is_not_null()" + infix_instance = arg1.is_not_null() + assert infix_instance == instance + def test_not(self): arg1 = self._make_arg("Condition") instance = expr.Not(arg1) From 026001447e6a5d86a4089d77cb4cf4a37b2e482e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 24 Oct 2025 15:53:41 -0700 Subject: [PATCH 06/14] added new Array type --- .../firestore_v1/pipeline_expressions.py | 112 +++++---- tests/system/pipeline_e2e.yaml | 227 +++++++++++++++++- tests/system/test_pipeline_acceptance.py | 2 +- tests/unit/v1/test_pipeline_expressions.py | 125 +++++----- 4 files changed, 346 insertions(+), 120 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index bd0e17547..25544dfb5 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -113,16 +113,18 @@ def _to_pb(self) -> Value: raise NotImplementedError @staticmethod - def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": - return o if isinstance(o, Expr) else Constant(o) - - @staticmethod - def _convert_to_vector_expr(o: list[float] | Vector | Expr) -> Expr: + def _cast_to_expr_or_convert_to_constant(o: Any, include_vector=False, include_array=False) -> "Expr": + """Convert arbitrary object to an Expr.""" + if isinstance(o, Constant) and isinstance(o.value, list): + o = o.value + if isinstance(o, Expr): + return o if isinstance(o, list): - o = Vector(o) - elif isinstance(o, Constant) and isinstance(o.value, list): - o = Vector(o.value) - return Expr._cast_to_expr_or_convert_to_constant(o) + if include_vector and all([isinstance(i, (float, int)) for i in o]): + return Constant(Vector(o)) + elif include_array: + return Array(o) + return Constant(o) class expose_as_static: """ @@ -140,6 +142,10 @@ def __init__(self, instance_func): self.instance_func = instance_func def static_func(self, first_arg, *other_args, **kwargs): + if not isinstance(first_arg, (Expr, str)): + raise TypeError( + f"`expressions must be called on an Expr or a string representing a field name. got {type(first_arg)}." + ) first_expr = ( Field.of(first_arg) if not isinstance(first_arg, Expr) else first_arg ) @@ -557,7 +563,7 @@ def less_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": ) @expose_as_static - def equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": + def equal_any(self, array: Array | Sequence[Expr | CONSTANT_TYPE] | Expr) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to any of the provided values or expressions. @@ -575,14 +581,12 @@ def equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": "equal_any", [ self, - _ListOfExprs( - [self._cast_to_expr_or_convert_to_constant(v) for v in array] - ), + self._cast_to_expr_or_convert_to_constant(array, include_array=True), ], ) @expose_as_static - def not_equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": + def not_equal_any(self, array: Array | list[Expr | CONSTANT_TYPE] | Expr) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to any of the provided values or expressions. @@ -600,9 +604,7 @@ def not_equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": "not_equal_any", [ self, - _ListOfExprs( - [self._cast_to_expr_or_convert_to_constant(v) for v in array] - ), + self._cast_to_expr_or_convert_to_constant(array, include_array=True), ], ) @@ -629,7 +631,7 @@ def array_contains(self, element: Expr | CONSTANT_TYPE) -> "BooleanExpr": @expose_as_static def array_contains_all( self, - elements: Sequence[Expr | CONSTANT_TYPE], + elements: Array | list[Expr | CONSTANT_TYPE] | Expr, ) -> "BooleanExpr": """Creates an expression that checks if an array contains all the specified elements. @@ -649,16 +651,14 @@ def array_contains_all( "array_contains_all", [ self, - _ListOfExprs( - [self._cast_to_expr_or_convert_to_constant(e) for e in elements] - ), + self._cast_to_expr_or_convert_to_constant(elements, include_array=True), ], ) @expose_as_static def array_contains_any( self, - elements: Sequence[Expr | CONSTANT_TYPE], + elements: Array | list[Expr | CONSTANT_TYPE] | Expr, ) -> "BooleanExpr": """Creates an expression that checks if an array contains any of the specified elements. @@ -679,9 +679,7 @@ def array_contains_any( "array_contains_any", [ self, - _ListOfExprs( - [self._cast_to_expr_or_convert_to_constant(e) for e in elements] - ), + self._cast_to_expr_or_convert_to_constant(elements, include_array=True), ], ) @@ -711,6 +709,26 @@ def array_reverse(self) -> "Expr": """ return Function("array_reverse", [self]) + @expose_as_static + def array_concat(self, *other_arrays: Array | list[Expr | CONSTANT_TYPE] | Expr) -> "Expr": + """Creates an expression that concatenates an array expression with another array. + + Example: + >>> # Combine the 'tags' array with a new array and an array field + >>> Field.of("tags").array_concat(["newTag1", "newTag2", Field.of("otherTag")]) + + Args: + array: The list of constants or expressions to concat with. + + Returns: + A new `Expr` representing the concatenated array. + """ + return Function( + "array_concat", [self] + [ + self._cast_to_expr_or_convert_to_constant(arr, include_array=True) for arr in other_arrays + ], + ) + @expose_as_static def is_absent(self) -> "BooleanExpr": """Creates an expression that returns true if a value is absent. Otherwise, returns false even if @@ -1149,7 +1167,7 @@ def cosine_distance(self, other: Expr | list[float] | Vector) -> "Expr": Returns: A new `Expr` representing the cosine distance between the two vectors. """ - return Function("cosine_distance", [self, self._convert_to_vector_expr(other)]) + return Function("cosine_distance", [self, self._cast_to_expr_or_convert_to_constant(other, include_vector=True)]) @expose_as_static def euclidean_distance(self, other: Expr | list[float] | Vector) -> "Expr": @@ -1167,7 +1185,7 @@ def euclidean_distance(self, other: Expr | list[float] | Vector) -> "Expr": Returns: A new `Expr` representing the Euclidean distance between the two vectors. """ - return Function("euclidean_distance", [self, self._convert_to_vector_expr(other)]) + return Function("euclidean_distance", [self, self._cast_to_expr_or_convert_to_constant(other, include_vector=True)]) @expose_as_static def dot_product(self, other: Expr | list[float] | Vector) -> "Expr": @@ -1185,7 +1203,7 @@ def dot_product(self, other: Expr | list[float] | Vector) -> "Expr": Returns: A new `Expr` representing the dot product between the two vectors. """ - return Function("dot_product", [self, self._convert_to_vector_expr(other)]) + return Function("dot_product", [self, self._cast_to_expr_or_convert_to_constant(other, include_vector=True)]) @expose_as_static def vector_length(self) -> "Expr": @@ -1430,25 +1448,6 @@ def _to_pb(self) -> Value: return encode_value(self.value) -class _ListOfExprs(Expr): - """Represents a list of expressions, typically used as an argument to functions like 'in' or array functions.""" - - def __init__(self, exprs: Sequence[Expr]): - self.exprs: list[Expr] = list(exprs) - - def __eq__(self, other): - if not isinstance(other, _ListOfExprs): - return False - else: - return other.exprs == self.exprs - - def __repr__(self): - return repr(self.exprs) - - def _to_pb(self): - return Value(array_value={"values": [e._to_pb() for e in self.exprs]}) - - class Function(Expr): """A base class for expressions that represent function calls.""" @@ -1689,6 +1688,25 @@ def _from_query_filter_pb(filter_pb, client): else: raise TypeError(f"Unexpected filter type: {type(filter_pb)}") +class Array(Function): + """ + Creates an expression that creates a Firestore array value from an input list. + + Example: + >>> Expr.array(["bar", Field.of("baz")]) + + Args: + elements: THe input list to evaluate in the expression + """ + + def __init__(self, elements: list[Expr | CONSTANT_TYPE]): + if not isinstance(elements, list): + raise TypeError("Array must be constructed with a list") + converted_elements = [self._cast_to_expr_or_convert_to_constant(el) for el in elements] + super().__init__("array", converted_elements) + + def __repr__(self): + return f"Array({self.params})" class And(BooleanExpr): """ diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 83429fb70..2efd95c50 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -701,10 +701,11 @@ tests: - functionValue: args: - fieldReferenceValue: tags - - arrayValue: - values: + - functionValue: + args: - stringValue: comedy - stringValue: classic + name: array name: array_contains_any name: where - args: @@ -743,10 +744,11 @@ tests: - functionValue: args: - fieldReferenceValue: tags - - arrayValue: - values: + - functionValue: + args: - stringValue: adventure - stringValue: magic + name: array name: array_contains_all name: where - args: @@ -1791,6 +1793,218 @@ tests: - adventure - space - comedy + - description: testArrayConcat + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpr: + - Expr.array_concat: + - Field: tags + - Constant: ["new_tag", "another_tag"] + - "concatenatedTags" + assert_results: + - concatenatedTags: + - comedy + - space + - adventure + - new_tag + - another_tag + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - mapValue: + fields: + concatenatedTags: + functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: "new_tag" + - stringValue: "another_tag" + name: array + name: array_concat + name: select + - description: testArrayConcatMultiple + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "Dune" + - Select: + - AliasedExpr: + - Expr.array_concat: + - Field: tags + - Constant: ["sci-fi"] + - Constant: ["classic", "epic"] + - "concatenatedTags" + assert_results: + - concatenatedTags: + - politics + - desert + - ecology + - sci-fi + - classic + - epic + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + concatenatedTags: + functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: "sci-fi" + name: array + - functionValue: + args: + - stringValue: "classic" + - stringValue: "epic" + name: array + name: array_concat + name: select + - description: testArrayContainsAnyWithField + pipeline: + - Collection: books + - AddFields: + - AliasedExpr: + - Expr.array_concat: + - Field: tags + - Constant: ["Dystopian"] + - "new_tags" + - Where: + - Expr.array_contains_any: + - Field: new_tags + - - Constant: non_existent_tag + - Field: genre + - Select: + - title + - genre + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "1984" + genre: "Dystopian" + - title: "The Handmaid's Tale" + genre: "Dystopian" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + new_tags: + functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: "Dystopian" + name: array + name: array_concat + name: add_fields + - args: + - functionValue: + args: + - fieldReferenceValue: new_tags + - functionValue: + args: + - stringValue: "non_existent_tag" + - fieldReferenceValue: genre + name: array + name: array_contains_any + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + genre: + fieldReferenceValue: genre + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testArrayConcatLiterals + pipeline: + - Collection: books + - Limit: 1 + - Select: + - AliasedExpr: + - Expr.array_concat: + - Array: [1, 2, 3] + - Array: [4, 5] + - "concatenated" + assert_results: + - concatenated: [1, 2, 3, 4, 5] + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - integerValue: '1' + name: limit + - args: + - mapValue: + fields: + concatenated: + functionValue: + args: + - functionValue: + args: + - integerValue: '1' + - integerValue: '2' + - integerValue: '3' + name: array + - functionValue: + args: + - integerValue: '4' + - integerValue: '5' + name: array + name: array_concat + name: select - description: testExists pipeline: - Collection: books @@ -2213,11 +2427,12 @@ tests: - functionValue: args: - fieldReferenceValue: title - - arrayValue: - values: + - functionValue: + args: - stringValue: "To Kill a Mockingbird" - stringValue: "Pride and Prejudice" - stringValue: "The Lord of the Rings" + name: array name: equal_any name: where - args: diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index ab63afd37..3c41c0227 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -225,7 +225,7 @@ def _apply_yaml_args_to_callable(callable_obj, client, yaml_args): """ if isinstance(yaml_args, dict): return callable_obj(**_parse_expressions(client, yaml_args)) - elif isinstance(yaml_args, list) and not (callable_obj == expr.Constant or callable_obj == Vector): + elif isinstance(yaml_args, list) and not (callable_obj == expr.Constant or callable_obj == Vector or callable_obj == expr.Array): # yaml has an array of arguments. Treat as args return callable_obj(*_parse_expressions(client, yaml_args)) else: diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 567960808..21dec6e4c 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -24,7 +24,6 @@ from google.cloud.firestore_v1._helpers import GeoPoint import google.cloud.firestore_v1.pipeline_expressions as expr from google.cloud.firestore_v1.pipeline_expressions import BooleanExpr -from google.cloud.firestore_v1.pipeline_expressions import _ListOfExprs from google.cloud.firestore_v1.pipeline_expressions import Expr from google.cloud.firestore_v1.pipeline_expressions import Constant from google.cloud.firestore_v1.pipeline_expressions import Field @@ -173,57 +172,6 @@ def test_equality(self, first, second, expected): assert (first == second) is expected -class TestListOfExprs: - def test_to_pb(self): - instance = _ListOfExprs([Constant(1), Constant(2)]) - result = instance._to_pb() - assert len(result.array_value.values) == 2 - assert result.array_value.values[0].integer_value == 1 - assert result.array_value.values[1].integer_value == 2 - - def test_empty_to_pb(self): - instance = _ListOfExprs([]) - result = instance._to_pb() - assert len(result.array_value.values) == 0 - - def test_repr(self): - instance = _ListOfExprs([Constant(1), Constant(2)]) - repr_string = repr(instance) - assert repr_string == "[Constant.of(1), Constant.of(2)]" - empty_instance = _ListOfExprs([]) - empty_repr_string = repr(empty_instance) - assert empty_repr_string == "[]" - - @pytest.mark.parametrize( - "first,second,expected", - [ - (_ListOfExprs([]), _ListOfExprs([]), True), - (_ListOfExprs([]), _ListOfExprs([Constant(1)]), False), - (_ListOfExprs([Constant(1)]), _ListOfExprs([]), False), - ( - _ListOfExprs([Constant(1)]), - _ListOfExprs([Constant(1)]), - True, - ), - ( - _ListOfExprs([Constant(1)]), - _ListOfExprs([Constant(2)]), - False, - ), - ( - _ListOfExprs([Constant(1), Constant(2)]), - _ListOfExprs([Constant(1), Constant(2)]), - True, - ), - (_ListOfExprs([Constant(1)]), [Constant(1)], False), - (_ListOfExprs([Constant(1)]), [1], False), - (_ListOfExprs([Constant(1)]), object(), False), - ], - ) - def test_equality(self, first, second, expected): - assert (first == second) is expected - - class TestSelectable: """ contains tests for each Expr class that derives from Selectable @@ -643,6 +591,31 @@ def test__from_query_filter_pb_unknown_filter_type(self, mock_client): BooleanExpr._from_query_filter_pb(document_pb.Value(), mock_client) +class TestArray: + """Tests for the array class""" + def test_array(self): + arg1 = Field.of("field1") + instance = expr.Array([arg1]) + assert instance.name == "array" + assert instance.params == [arg1] + assert repr(instance) == "Array([Field.of('field1')])" + + def test_empty_array(self): + instance = expr.Array([]) + assert instance.name == "array" + assert instance.params == [] + assert repr(instance) == "Array([])" + + def test_array_w_primitives(self): + a = expr.Array([1, Constant.of(2), "3"]) + assert a.name == "array" + assert a.params == [Constant.of(1), Constant.of(2), Constant.of("3")] + assert repr(a) == "Array([Constant.of(1), Constant.of(2), Constant.of('3')])" + + def test_array_w_non_list(self): + with pytest.raises(TypeError): + expr.Array(1) + class TestExpressionMethods: """ contains test methods for each Expr method @@ -723,10 +696,10 @@ def test_array_contains_any(self): arg3 = self._make_arg("Element2") instance = Expr.array_contains_any(arg1, [arg2, arg3]) assert instance.name == "array_contains_any" - assert isinstance(instance.params[1], _ListOfExprs) + assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 - assert instance.params[1].exprs == [arg2, arg3] - assert repr(instance) == "ArrayField.array_contains_any([Element1, Element2])" + assert instance.params[1].params == [arg2, arg3] + assert repr(instance) == "ArrayField.array_contains_any(Array([Element1, Element2]))" infix_instance = arg1.array_contains_any([arg2, arg3]) assert infix_instance == instance @@ -805,10 +778,10 @@ def test_equal_any(self): arg3 = self._make_arg("Value2") instance = Expr.equal_any(arg1, [arg2, arg3]) assert instance.name == "equal_any" - assert isinstance(instance.params[1], _ListOfExprs) + assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 - assert instance.params[1].exprs == [arg2, arg3] - assert repr(instance) == "Field.equal_any([Value1, Value2])" + assert instance.params[1].params == [arg2, arg3] + assert repr(instance) == "Field.equal_any(Array([Value1, Value2]))" infix_instance = arg1.equal_any([arg2, arg3]) assert infix_instance == instance @@ -818,10 +791,10 @@ def test_not_equal_any(self): arg3 = self._make_arg("Value2") instance = Expr.not_equal_any(arg1, [arg2, arg3]) assert instance.name == "not_equal_any" - assert isinstance(instance.params[1], _ListOfExprs) + assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 - assert instance.params[1].exprs == [arg2, arg3] - assert repr(instance) == "Field.not_equal_any([Value1, Value2])" + assert instance.params[1].params == [arg2, arg3] + assert repr(instance) == "Field.not_equal_any(Array([Value1, Value2]))" infix_instance = arg1.not_equal_any([arg2, arg3]) assert infix_instance == instance @@ -883,10 +856,10 @@ def test_array_contains_all(self): arg3 = self._make_arg("Element2") instance = Expr.array_contains_all(arg1, [arg2, arg3]) assert instance.name == "array_contains_all" - assert isinstance(instance.params[1], _ListOfExprs) + assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 - assert instance.params[1].exprs == [arg2, arg3] - assert repr(instance) == "ArrayField.array_contains_all([Element1, Element2])" + assert instance.params[1].params == [arg2, arg3] + assert repr(instance) == "ArrayField.array_contains_all(Array([Element1, Element2]))" infix_instance = arg1.array_contains_all([arg2, arg3]) assert infix_instance == instance @@ -1223,14 +1196,14 @@ def test_dot_product(self): @pytest.mark.parametrize("method", ["euclidean_distance", "cosine_distance", "dot_product"]) @pytest.mark.parametrize( - "input", [Vector([1.0, 2.0]), [1, 2], Constant.of(Vector([1.0, 2.0])), Constant.of([1, 2]), []] + "input", [Vector([1.0, 2.0]), [1, 2], Constant.of(Vector([1.0, 2.0])), []] ) def test_vector_ctor(self, method, input): """ test constructing various vector expressions with different inputs """ - arg1 = self._make_arg("Vector") + arg1 = self._make_arg("VectorRef") instance = getattr(arg1, method)(input) assert instance.name == method got_second_param = instance.params[1] @@ -1358,6 +1331,26 @@ def test_array_reverse(self): infix_instance = arg1.array_reverse() assert infix_instance == instance + def test_array_concat(self): + arg1 = self._make_arg("ArrayRef1") + arg2 = self._make_arg("ArrayRef2") + instance = Expr.array_concat(arg1, arg2) + assert instance.name == "array_concat" + assert instance.params == [arg1, arg2] + assert repr(instance) == "ArrayRef1.array_concat(ArrayRef2)" + infix_instance = arg1.array_concat(arg2) + assert infix_instance == instance + + def test_array_concat_multiple(self): + arg1 = expr.Array([Constant.of(0)]) + arg2 = Field.of("ArrayRef2") + arg3 = Field.of("ArrayRef3") + arg4 = [self._make_arg("Constant")] + instance = arg1.array_concat(arg2, arg3, arg4) + assert instance.name == "array_concat" + assert instance.params == [arg1, arg2, arg3, expr.Array(arg4)] + assert repr(instance) == "Array([Constant.of(0)]).array_concat(Field.of('ArrayRef2'), Field.of('ArrayRef3'), Array([Constant]))" + def test_byte_length(self): arg1 = self._make_arg("Expr") instance = Expr.byte_length(arg1) From 1b69435e07ab61d20d5df0d59e33f325c03aceea Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 24 Oct 2025 16:43:44 -0700 Subject: [PATCH 07/14] added map and related expressions --- .../firestore_v1/pipeline_expressions.py | 67 ++++++- tests/system/pipeline_e2e.yaml | 164 ++++++++++++++++++ tests/unit/v1/test_pipeline_expressions.py | 51 ++++++ 3 files changed, 279 insertions(+), 3 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 25544dfb5..f9cfd8fe0 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -119,6 +119,8 @@ def _cast_to_expr_or_convert_to_constant(o: Any, include_vector=False, include_a o = o.value if isinstance(o, Expr): return o + if isinstance(o, dict): + return Map(o) if isinstance(o, list): if include_vector and all([isinstance(i, (float, int)) for i in o]): return Constant(Vector(o)) @@ -1132,13 +1134,12 @@ def join(self, delimeter: Expr | str) -> "Expr": "join", [self, self._cast_to_expr_or_convert_to_constant(delimeter)] ) - @expose_as_static def map_get(self, key: str | Constant[str]) -> "Expr": """Accesses a value from the map produced by evaluating this expression. Example: - >>> Expr.map({"city": "London"}).map_get("city") + >>> Map({"city": "London"}).map_get("city") >>> Field.of("address").map_get("city") Args: @@ -1148,7 +1149,42 @@ def map_get(self, key: str | Constant[str]) -> "Expr": A new `Expr` representing the value associated with the given key in the map. """ return Function( - "map_get", [self, Constant.of(key) if isinstance(key, str) else key] + "map_get", [self, self._cast_to_expr_or_convert_to_constant(key)] + ) + + @expose_as_static + def map_remove(self, key: str | Constant[str]) -> "Expr": + """Remove a key from a the map produced by evaluating this expression. + + Example: + >>> Map({"city": "London"}).map_remove("city") + >>> Field.of("address").map_remove("city") + + Args: + key: The key to ewmove in the map. + + Returns: + A new `Expr` representing the map_remove operation. + """ + return Function("map_remove", [self, self._cast_to_expr_or_convert_to_constant(key)]) + + @expose_as_static + def map_merge(self, *other_maps: Map | dict[str | Constant[str], Expr | CONSTANT_TYPE] | Expr)-> "Expr": + """Creates an expression that merges one or more dicts into a single map. + + Example: + >>> Map({"city": "London"}).map_merge({"country": "UK"}, {"isCapital": True}) + >>> Field.of("settings").map_merge({"enabled":True}, Function.conditional(Field.of('isAdmin'), {"admin":True}, {}}) + + Args: + *other_maps: Sequence of maps to merge into the resulting map. + + Returns: + A new `Expr` representing the value associated with the given key in the map. + """ + return Function("map_merge", [self] + [ + self._cast_to_expr_or_convert_to_constant(m) for m in other_maps + ], ) @expose_as_static @@ -1688,6 +1724,7 @@ def _from_query_filter_pb(filter_pb, client): else: raise TypeError(f"Unexpected filter type: {type(filter_pb)}") + class Array(Function): """ Creates an expression that creates a Firestore array value from an input list. @@ -1708,6 +1745,30 @@ def __init__(self, elements: list[Expr | CONSTANT_TYPE]): def __repr__(self): return f"Array({self.params})" + +class Map(Function): + """ + Creates an expression that creates a Firestore map value from an input dict. + + Example: + >>> Expr.map({"foo": "bar", "baz": Field.of("baz")}) + + Args: + elements: THe input dict to evaluate in the expression + """ + + def __init__(self, elements: dict[str | Constant[str], Expr | CONSTANT_TYPE]): + element_list = [] + for k,v in elements.items(): + element_list.append(self._cast_to_expr_or_convert_to_constant(k)) + element_list.append(self._cast_to_expr_or_convert_to_constant(v)) + super().__init__("map", element_list) + + def __repr__(self): + d = {a.value : b for a, b in zip(self.params[::2], self.params[1::2])} + return f"Map({d})" + + class And(BooleanExpr): """ Represents an expression that performs a logical 'AND' operation on multiple filter conditions. diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 2efd95c50..4540a281e 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -1583,6 +1583,153 @@ tests: - booleanValue: true name: equal name: where + - description: testMapGetWithField + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "Dune" + - AddFields: + - AliasedExpr: + - Constant: "hugo" + - "award_name" + - Select: + - AliasedExpr: + - Expr.map_get: + - Field: awards + - Field: award_name + - "hugoAward" + - Field: title + assert_results: + - hugoAward: true + title: Dune + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + award_name: + stringValue: "hugo" + name: add_fields + - args: + - mapValue: + fields: + hugoAward: + functionValue: + name: map_get + args: + - fieldReferenceValue: awards + - fieldReferenceValue: award_name + title: + fieldReferenceValue: title + name: select + - description: testMapRemove + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "Dune" + - Select: + - AliasedExpr: + - Expr.map_remove: + - Field: awards + - "nebula" + - "awards_removed" + assert_results: + - awards_removed: + hugo: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + awards_removed: + functionValue: + name: map_remove + args: + - fieldReferenceValue: awards + - stringValue: "nebula" + name: select + - description: testMapMerge + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "Dune" + - Select: + - AliasedExpr: + - Expr.map_merge: + - Field: awards + - Map: + elements: {"new_award": true, "hugo": false} + - Map: + elements: {"another_award": "yes"} + - "awards_merged" + assert_results: + - awards_merged: + hugo: false + nebula: true + new_award: true + another_award: "yes" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + awards_merged: + functionValue: + name: map_merge + args: + - fieldReferenceValue: awards + - functionValue: + name: map + args: + - stringValue: "new_award" + - booleanValue: true + - stringValue: "hugo" + - booleanValue: false + - functionValue: + name: map + args: + - stringValue: "another_award" + - stringValue: "yes" + name: select - description: testNestedFields pipeline: - Collection: books @@ -1893,6 +2040,23 @@ tests: name: array name: array_concat name: select + - description: testMapMergeLiterals + pipeline: + - Collection: books + - Limit: 1 + - Select: + - AliasedExpr: + - Expr.map_merge: + - Map: + elements: {"a": "orig", "b": "orig"} + - Map: + elements: {"b": "new", "c": "new"} + - "merged" + assert_results: + - merged: + a: "orig" + b: "new" + c: "new" - description: testArrayContainsAnyWithField pipeline: - Collection: books diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 21dec6e4c..5472fbeae 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -616,6 +616,36 @@ def test_array_w_non_list(self): with pytest.raises(TypeError): expr.Array(1) + +class TestMap: + """Tests for the map class""" + def test_map(self): + instance = expr.Map({Constant.of("a"): Constant.of("b")}) + assert instance.name == "map" + assert instance.params == [Constant.of("a"), Constant.of("b")] + assert repr(instance) == "Map({'a': Constant.of('b')})" + + def test_map_w_primitives(self): + instance = expr.Map({"a": "b", "0": 0, "bool": True}) + assert instance.params == [ + Constant.of("a"), Constant.of("b"), + Constant.of("0"), Constant.of(0), + Constant.of("bool"), Constant.of(True) + ] + assert repr(instance) == "Map({'a': Constant.of('b'), '0': Constant.of(0), 'bool': Constant.of(True)})" + + def test_empty_map(self): + instance = expr.Map({}) + assert instance.name == "map" + assert instance.params == [] + assert repr(instance) == "Map({})" + + def test_w_exprs(self): + instance = expr.Map({Constant.of("a"): expr.Array([1,2,3])}) + assert instance.params == [Constant.of("a"), expr.Array([1,2,3])] + assert repr(instance) == "Map({'a': Array([Constant.of(1), Constant.of(2), Constant.of(3)])})" + + class TestExpressionMethods: """ contains test methods for each Expr method @@ -1047,6 +1077,27 @@ def test_map_get(self): infix_instance = arg1.map_get(Constant.of(arg2)) assert infix_instance == instance + def test_map_remove(self): + arg1 = self._make_arg("Map") + arg2 = "key" + instance = Expr.map_remove(arg1, arg2) + assert instance.name == "map_remove" + assert instance.params == [arg1, Constant.of(arg2)] + assert repr(instance) == "Map.map_remove(Constant.of('key'))" + infix_instance = arg1.map_remove(Constant.of(arg2)) + assert infix_instance == instance + + def test_map_merge(self): + arg1 = expr.Map({"a": 1}) + arg2 = expr.Map({"b": 2}) + arg3 = {"c": 3} + instance = Expr.map_merge(arg1, arg2, arg3) + assert instance.name == "map_merge" + assert instance.params == [arg1, arg2, expr.Map(arg3)] + assert repr(instance) == "Map({'a': Constant.of(1)}).map_merge(Map({'b': Constant.of(2)}), Map({'c': Constant.of(3)}))" + infix_instance = arg1.map_merge(arg2, arg3) + assert infix_instance == instance + def test_mod(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") From be749e0c79fd4632fcf0fde97fc69af729560240 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 24 Oct 2025 16:53:18 -0700 Subject: [PATCH 08/14] remove dict and list from constant types --- .../cloud/firestore_v1/pipeline_expressions.py | 16 +++++++--------- tests/system/pipeline_e2e.yaml | 2 +- tests/unit/v1/test_pipeline_expressions.py | 7 ------- 3 files changed, 8 insertions(+), 17 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index f9cfd8fe0..0e43b879b 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -41,8 +41,6 @@ bytes, GeoPoint, Vector, - list, - Dict[str, Any], None, ) @@ -113,7 +111,7 @@ def _to_pb(self) -> Value: raise NotImplementedError @staticmethod - def _cast_to_expr_or_convert_to_constant(o: Any, include_vector=False, include_array=False) -> "Expr": + def _cast_to_expr_or_convert_to_constant(o: Any, include_vector=False) -> "Expr": """Convert arbitrary object to an Expr.""" if isinstance(o, Constant) and isinstance(o.value, list): o = o.value @@ -124,7 +122,7 @@ def _cast_to_expr_or_convert_to_constant(o: Any, include_vector=False, include_a if isinstance(o, list): if include_vector and all([isinstance(i, (float, int)) for i in o]): return Constant(Vector(o)) - elif include_array: + else: return Array(o) return Constant(o) @@ -583,7 +581,7 @@ def equal_any(self, array: Array | Sequence[Expr | CONSTANT_TYPE] | Expr) -> "Bo "equal_any", [ self, - self._cast_to_expr_or_convert_to_constant(array, include_array=True), + self._cast_to_expr_or_convert_to_constant(array), ], ) @@ -606,7 +604,7 @@ def not_equal_any(self, array: Array | list[Expr | CONSTANT_TYPE] | Expr) -> "Bo "not_equal_any", [ self, - self._cast_to_expr_or_convert_to_constant(array, include_array=True), + self._cast_to_expr_or_convert_to_constant(array), ], ) @@ -653,7 +651,7 @@ def array_contains_all( "array_contains_all", [ self, - self._cast_to_expr_or_convert_to_constant(elements, include_array=True), + self._cast_to_expr_or_convert_to_constant(elements), ], ) @@ -681,7 +679,7 @@ def array_contains_any( "array_contains_any", [ self, - self._cast_to_expr_or_convert_to_constant(elements, include_array=True), + self._cast_to_expr_or_convert_to_constant(elements), ], ) @@ -727,7 +725,7 @@ def array_concat(self, *other_arrays: Array | list[Expr | CONSTANT_TYPE] | Expr) """ return Function( "array_concat", [self] + [ - self._cast_to_expr_or_convert_to_constant(arr, include_array=True) for arr in other_arrays + self._cast_to_expr_or_convert_to_constant(arr) for arr in other_arrays ], ) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 4540a281e..82247e02b 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -2064,7 +2064,7 @@ tests: - AliasedExpr: - Expr.array_concat: - Field: tags - - Constant: ["Dystopian"] + - Array: ["Dystopian"] - "new_tags" - Where: - Expr.array_contains_any: diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 5472fbeae..36fdf6c8c 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -96,13 +96,6 @@ class TestConstant: Value(timestamp_value={"seconds": 1747008000}), ), (GeoPoint(1, 2), Value(geo_point_value={"latitude": 1, "longitude": 2})), - ( - [0.0, 1.0, 2.0], - Value( - array_value={"values": [Value(double_value=i) for i in range(3)]} - ), - ), - ({"a": "b"}, Value(map_value={"fields": {"a": Value(string_value="b")}})), ( Vector([1.0, 2.0]), Value( From 789f29cdf90eadb299d215151ab8aabe4612182f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 24 Oct 2025 17:10:17 -0700 Subject: [PATCH 09/14] Fixed lint --- google/cloud/firestore_v1/_pipeline_stages.py | 8 +- .../firestore_v1/pipeline_expressions.py | 85 +++++++++++++------ tests/system/test_pipeline_acceptance.py | 14 ++- tests/system/test_system.py | 9 +- tests/system/test_system_async.py | 5 +- tests/unit/v1/test_aggregation.py | 2 +- tests/unit/v1/test_async_aggregation.py | 2 +- tests/unit/v1/test_pipeline_expressions.py | 50 +++++++---- 8 files changed, 119 insertions(+), 56 deletions(-) diff --git a/google/cloud/firestore_v1/_pipeline_stages.py b/google/cloud/firestore_v1/_pipeline_stages.py index 2ada1c90f..c63b748ac 100644 --- a/google/cloud/firestore_v1/_pipeline_stages.py +++ b/google/cloud/firestore_v1/_pipeline_stages.py @@ -280,9 +280,11 @@ def __init__( super().__init__("find_nearest") self.field: Expr = Field(field) if isinstance(field, str) else field self.vector: Vector = vector if isinstance(vector, Vector) else Vector(vector) - self.distance_measure = distance_measure if isinstance( - distance_measure, DistanceMeasure - ) else DistanceMeasure[distance_measure.upper()] + self.distance_measure = ( + distance_measure + if isinstance(distance_measure, DistanceMeasure) + else DistanceMeasure[distance_measure.upper()] + ) self.options = options or FindNearestOptions() def _pb_args(self): diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 0e43b879b..3fae95494 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -17,7 +17,6 @@ Any, Generic, TypeVar, - Dict, Sequence, ) from abc import ABC @@ -334,9 +333,7 @@ def log(self, base: Expr | float) -> "Expr": Returns: A new `Expr` representing the logarithm. """ - return Function( - "log", [self, self._cast_to_expr_or_convert_to_constant(base)] - ) + return Function("log", [self, self._cast_to_expr_or_convert_to_constant(base)]) @expose_as_static def pow(self, exponent: Expr | float) -> "Expr": @@ -354,7 +351,9 @@ def pow(self, exponent: Expr | float) -> "Expr": Returns: A new `Expr` representing the power operation. """ - return Function("pow", [self, self._cast_to_expr_or_convert_to_constant(exponent)]) + return Function( + "pow", [self, self._cast_to_expr_or_convert_to_constant(exponent)] + ) @expose_as_static def round(self) -> "Expr": @@ -563,7 +562,9 @@ def less_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": ) @expose_as_static - def equal_any(self, array: Array | Sequence[Expr | CONSTANT_TYPE] | Expr) -> "BooleanExpr": + def equal_any( + self, array: Array | Sequence[Expr | CONSTANT_TYPE] | Expr + ) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to any of the provided values or expressions. @@ -586,7 +587,9 @@ def equal_any(self, array: Array | Sequence[Expr | CONSTANT_TYPE] | Expr) -> "Bo ) @expose_as_static - def not_equal_any(self, array: Array | list[Expr | CONSTANT_TYPE] | Expr) -> "BooleanExpr": + def not_equal_any( + self, array: Array | list[Expr | CONSTANT_TYPE] | Expr + ) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to any of the provided values or expressions. @@ -710,7 +713,9 @@ def array_reverse(self) -> "Expr": return Function("array_reverse", [self]) @expose_as_static - def array_concat(self, *other_arrays: Array | list[Expr | CONSTANT_TYPE] | Expr) -> "Expr": + def array_concat( + self, *other_arrays: Array | list[Expr | CONSTANT_TYPE] | Expr + ) -> "Expr": """Creates an expression that concatenates an array expression with another array. Example: @@ -724,9 +729,9 @@ def array_concat(self, *other_arrays: Array | list[Expr | CONSTANT_TYPE] | Expr) A new `Expr` representing the concatenated array. """ return Function( - "array_concat", [self] + [ - self._cast_to_expr_or_convert_to_constant(arr) for arr in other_arrays - ], + "array_concat", + [self] + + [self._cast_to_expr_or_convert_to_constant(arr) for arr in other_arrays], ) @expose_as_static @@ -1093,7 +1098,9 @@ def string_reverse(self) -> "Expr": return Function("string_reverse", [self]) @expose_as_static - def substring(self, position: Expr | int, length: Expr | int | None=None) -> "Expr": + def substring( + self, position: Expr | int, length: Expr | int | None = None + ) -> "Expr": """Creates an expression that returns a substring of the results of this expression. @@ -1164,10 +1171,14 @@ def map_remove(self, key: str | Constant[str]) -> "Expr": Returns: A new `Expr` representing the map_remove operation. """ - return Function("map_remove", [self, self._cast_to_expr_or_convert_to_constant(key)]) + return Function( + "map_remove", [self, self._cast_to_expr_or_convert_to_constant(key)] + ) @expose_as_static - def map_merge(self, *other_maps: Map | dict[str | Constant[str], Expr | CONSTANT_TYPE] | Expr)-> "Expr": + def map_merge( + self, *other_maps: Map | dict[str | Constant[str], Expr | CONSTANT_TYPE] | Expr + ) -> "Expr": """Creates an expression that merges one or more dicts into a single map. Example: @@ -1180,9 +1191,9 @@ def map_merge(self, *other_maps: Map | dict[str | Constant[str], Expr | CONSTANT Returns: A new `Expr` representing the value associated with the given key in the map. """ - return Function("map_merge", [self] + [ - self._cast_to_expr_or_convert_to_constant(m) for m in other_maps - ], + return Function( + "map_merge", + [self] + [self._cast_to_expr_or_convert_to_constant(m) for m in other_maps], ) @expose_as_static @@ -1201,7 +1212,13 @@ def cosine_distance(self, other: Expr | list[float] | Vector) -> "Expr": Returns: A new `Expr` representing the cosine distance between the two vectors. """ - return Function("cosine_distance", [self, self._cast_to_expr_or_convert_to_constant(other, include_vector=True)]) + return Function( + "cosine_distance", + [ + self, + self._cast_to_expr_or_convert_to_constant(other, include_vector=True), + ], + ) @expose_as_static def euclidean_distance(self, other: Expr | list[float] | Vector) -> "Expr": @@ -1219,7 +1236,13 @@ def euclidean_distance(self, other: Expr | list[float] | Vector) -> "Expr": Returns: A new `Expr` representing the Euclidean distance between the two vectors. """ - return Function("euclidean_distance", [self, self._cast_to_expr_or_convert_to_constant(other, include_vector=True)]) + return Function( + "euclidean_distance", + [ + self, + self._cast_to_expr_or_convert_to_constant(other, include_vector=True), + ], + ) @expose_as_static def dot_product(self, other: Expr | list[float] | Vector) -> "Expr": @@ -1237,7 +1260,13 @@ def dot_product(self, other: Expr | list[float] | Vector) -> "Expr": Returns: A new `Expr` representing the dot product between the two vectors. """ - return Function("dot_product", [self, self._cast_to_expr_or_convert_to_constant(other, include_vector=True)]) + return Function( + "dot_product", + [ + self, + self._cast_to_expr_or_convert_to_constant(other, include_vector=True), + ], + ) @expose_as_static def vector_length(self) -> "Expr": @@ -1737,7 +1766,9 @@ class Array(Function): def __init__(self, elements: list[Expr | CONSTANT_TYPE]): if not isinstance(elements, list): raise TypeError("Array must be constructed with a list") - converted_elements = [self._cast_to_expr_or_convert_to_constant(el) for el in elements] + converted_elements = [ + self._cast_to_expr_or_convert_to_constant(el) for el in elements + ] super().__init__("array", converted_elements) def __repr__(self): @@ -1757,13 +1788,16 @@ class Map(Function): def __init__(self, elements: dict[str | Constant[str], Expr | CONSTANT_TYPE]): element_list = [] - for k,v in elements.items(): + for k, v in elements.items(): element_list.append(self._cast_to_expr_or_convert_to_constant(k)) element_list.append(self._cast_to_expr_or_convert_to_constant(v)) super().__init__("map", element_list) def __repr__(self): - d = {a.value : b for a, b in zip(self.params[::2], self.params[1::2])} + formatted_params = [ + a.value if isinstance(a, Constant) else a for a in self.params + ] + d = {a: b for a, b in zip(formatted_params[::2], formatted_params[1::2])} return f"Map({d})" @@ -1854,6 +1888,7 @@ def __init__(self, condition: BooleanExpr, then_expr: Expr, else_expr: Expr): "conditional", [condition, then_expr, else_expr], use_infix_repr=False ) + class Count(AggregateFunction): """ Represents an aggregation that counts the number of stage inputs with valid evaluations of the @@ -1871,6 +1906,4 @@ class Count(AggregateFunction): def __init__(self, expression: Expr | None = None): expression_list = [expression] if expression else [] - super().__init__( - "count", expression_list, use_infix_repr=bool(expression_list) - ) + super().__init__("count", expression_list, use_infix_repr=bool(expression_list)) diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 3c41c0227..313b9d673 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -95,7 +95,9 @@ def test_pipeline_results(test_dict, client): Ensure pipeline returns expected results """ expected_results = _parse_yaml_types(test_dict.get("assert_results", None)) - expected_approximate_results = _parse_yaml_types(test_dict.get("assert_results_approximate", None)) + expected_approximate_results = _parse_yaml_types( + test_dict.get("assert_results_approximate", None) + ) expected_count = test_dict.get("assert_count", None) pipeline = parse_pipeline(client, test_dict["pipeline"]) # check if server responds as expected @@ -139,7 +141,9 @@ async def test_pipeline_results_async(test_dict, async_client): Ensure pipeline returns expected results """ expected_results = _parse_yaml_types(test_dict.get("assert_results", None)) - expected_approximate_results = _parse_yaml_types(test_dict.get("assert_results_approximate", None)) + expected_approximate_results = _parse_yaml_types( + test_dict.get("assert_results_approximate", None) + ) expected_count = test_dict.get("assert_count", None) pipeline = parse_pipeline(async_client, test_dict["pipeline"]) # check if server responds as expected @@ -225,7 +229,11 @@ def _apply_yaml_args_to_callable(callable_obj, client, yaml_args): """ if isinstance(yaml_args, dict): return callable_obj(**_parse_expressions(client, yaml_args)) - elif isinstance(yaml_args, list) and not (callable_obj == expr.Constant or callable_obj == Vector or callable_obj == expr.Array): + elif isinstance(yaml_args, list) and not ( + callable_obj == expr.Constant + or callable_obj == Vector + or callable_obj == expr.Array + ): # yaml has an array of arguments. Treat as args return callable_obj(*_parse_expressions(client, yaml_args)) else: diff --git a/tests/system/test_system.py b/tests/system/test_system.py index a8f94e2ba..c2bd93ef8 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -109,9 +109,11 @@ def _clean_results(results): if isinstance(query, BaseAggregationQuery): # aggregation queries return a list of lists of aggregation results query_results = _clean_results( - list(itertools.chain.from_iterable( - [[a._to_dict() for a in s] for s in query.get()] - )) + list( + itertools.chain.from_iterable( + [[a._to_dict() for a in s] for s in query.get()] + ) + ) ) else: # other qureies return a simple list of results @@ -1531,6 +1533,7 @@ def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): results.get_explain_metrics() verify_pipeline(query) + @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index b78a77786..d053cbd7a 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -208,7 +208,9 @@ def _clean_results(results): await pipeline.execute() else: # ensure results match query - pipeline_results = _clean_results([s.data() async for s in pipeline.stream()]) + pipeline_results = _clean_results( + [s.data() async for s in pipeline.stream()] + ) assert query_results == pipeline_results except FailedPrecondition as e: # if testing against a non-enterprise db, skip this check @@ -216,7 +218,6 @@ def _clean_results(results): raise e - @pytest.fixture(scope="module") def event_loop(): """Change event_loop fixture to module level.""" diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 5064e87ae..9a20fd386 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -1136,7 +1136,7 @@ def test_aggreation_to_pipeline_count_increment(): assert len(aggregate_stage.accumulators) == n for i in range(n): assert isinstance(aggregate_stage.accumulators[i].expr, Count) - assert aggregate_stage.accumulators[i].alias == f"field_{i+1}" + assert aggregate_stage.accumulators[i].alias == f"field_{i + 1}" def test_aggreation_to_pipeline_complex(): diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index fdd4a1450..701feab5b 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -810,7 +810,7 @@ def test_aggreation_to_pipeline_count_increment(): assert len(aggregate_stage.accumulators) == n for i in range(n): assert isinstance(aggregate_stage.accumulators[i].expr, Count) - assert aggregate_stage.accumulators[i].alias == f"field_{i+1}" + assert aggregate_stage.accumulators[i].alias == f"field_{i + 1}" def test_async_aggreation_to_pipeline_complex(): diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 36fdf6c8c..bfd8a8270 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -398,9 +398,7 @@ def test__from_query_filter_pb_composite_filter_nested(self, mock_client): field3 = Field.of("field3") expected_cond1 = expr.And(field1.exists(), field1.equal(Constant("val1"))) expected_cond2 = expr.And(field2.exists(), field2.greater_than(Constant(10))) - expected_cond3 = expr.And( - field3.exists(), field3.is_not_null() - ) + expected_cond3 = expr.And(field3.exists(), field3.is_not_null()) expected_inner_and = expr.And(expected_cond2, expected_cond3) expected_outer_or = expr.Or(expected_cond1, expected_inner_and) @@ -586,6 +584,7 @@ def test__from_query_filter_pb_unknown_filter_type(self, mock_client): class TestArray: """Tests for the array class""" + def test_array(self): arg1 = Field.of("field1") instance = expr.Array([arg1]) @@ -612,20 +611,24 @@ def test_array_w_non_list(self): class TestMap: """Tests for the map class""" + def test_map(self): instance = expr.Map({Constant.of("a"): Constant.of("b")}) assert instance.name == "map" assert instance.params == [Constant.of("a"), Constant.of("b")] - assert repr(instance) == "Map({'a': Constant.of('b')})" + assert repr(instance) == "Map({'a': 'b'})" def test_map_w_primitives(self): instance = expr.Map({"a": "b", "0": 0, "bool": True}) assert instance.params == [ - Constant.of("a"), Constant.of("b"), - Constant.of("0"), Constant.of(0), - Constant.of("bool"), Constant.of(True) + Constant.of("a"), + Constant.of("b"), + Constant.of("0"), + Constant.of(0), + Constant.of("bool"), + Constant.of(True), ] - assert repr(instance) == "Map({'a': Constant.of('b'), '0': Constant.of(0), 'bool': Constant.of(True)})" + assert repr(instance) == "Map({'a': 'b', '0': 0, 'bool': True})" def test_empty_map(self): instance = expr.Map({}) @@ -634,9 +637,12 @@ def test_empty_map(self): assert repr(instance) == "Map({})" def test_w_exprs(self): - instance = expr.Map({Constant.of("a"): expr.Array([1,2,3])}) - assert instance.params == [Constant.of("a"), expr.Array([1,2,3])] - assert repr(instance) == "Map({'a': Array([Constant.of(1), Constant.of(2), Constant.of(3)])})" + instance = expr.Map({Constant.of("a"): expr.Array([1, 2, 3])}) + assert instance.params == [Constant.of("a"), expr.Array([1, 2, 3])] + assert ( + repr(instance) + == "Map({'a': Array([Constant.of(1), Constant.of(2), Constant.of(3)])})" + ) class TestExpressionMethods: @@ -722,7 +728,10 @@ def test_array_contains_any(self): assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 assert instance.params[1].params == [arg2, arg3] - assert repr(instance) == "ArrayField.array_contains_any(Array([Element1, Element2]))" + assert ( + repr(instance) + == "ArrayField.array_contains_any(Array([Element1, Element2]))" + ) infix_instance = arg1.array_contains_any([arg2, arg3]) assert infix_instance == instance @@ -882,7 +891,10 @@ def test_array_contains_all(self): assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 assert instance.params[1].params == [arg2, arg3] - assert repr(instance) == "ArrayField.array_contains_all(Array([Element1, Element2]))" + assert ( + repr(instance) + == "ArrayField.array_contains_all(Array([Element1, Element2]))" + ) infix_instance = arg1.array_contains_all([arg2, arg3]) assert infix_instance == instance @@ -1087,7 +1099,7 @@ def test_map_merge(self): instance = Expr.map_merge(arg1, arg2, arg3) assert instance.name == "map_merge" assert instance.params == [arg1, arg2, expr.Map(arg3)] - assert repr(instance) == "Map({'a': Constant.of(1)}).map_merge(Map({'b': Constant.of(2)}), Map({'c': Constant.of(3)}))" + assert repr(instance) == "Map({'a': 1}).map_merge(Map({'b': 2}), Map({'c': 3}))" infix_instance = arg1.map_merge(arg2, arg3) assert infix_instance == instance @@ -1238,7 +1250,9 @@ def test_dot_product(self): infix_instance = arg1.dot_product(arg2) assert infix_instance == instance - @pytest.mark.parametrize("method", ["euclidean_distance", "cosine_distance", "dot_product"]) + @pytest.mark.parametrize( + "method", ["euclidean_distance", "cosine_distance", "dot_product"] + ) @pytest.mark.parametrize( "input", [Vector([1.0, 2.0]), [1, 2], Constant.of(Vector([1.0, 2.0])), []] ) @@ -1254,7 +1268,6 @@ def test_vector_ctor(self, method, input): assert isinstance(got_second_param, Constant) assert isinstance(got_second_param.value, Vector) - def test_vector_length(self): arg1 = self._make_arg("Array") instance = Expr.vector_length(arg1) @@ -1393,7 +1406,10 @@ def test_array_concat_multiple(self): instance = arg1.array_concat(arg2, arg3, arg4) assert instance.name == "array_concat" assert instance.params == [arg1, arg2, arg3, expr.Array(arg4)] - assert repr(instance) == "Array([Constant.of(0)]).array_concat(Field.of('ArrayRef2'), Field.of('ArrayRef3'), Array([Constant]))" + assert ( + repr(instance) + == "Array([Constant.of(0)]).array_concat(Field.of('ArrayRef2'), Field.of('ArrayRef3'), Array([Constant]))" + ) def test_byte_length(self): arg1 = self._make_arg("Expr") From 64be10db7fa558a87bf096fad49f8d920c43b384 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 27 Oct 2025 14:27:38 -0700 Subject: [PATCH 10/14] added count_if and count_distinct --- .../firestore_v1/pipeline_expressions.py | 29 ++++++++++ tests/system/pipeline_e2e.yaml | 58 +++++++++++++++++++ tests/unit/v1/test_pipeline_expressions.py | 18 ++++++ 3 files changed, 105 insertions(+) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 3fae95494..516349041 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -852,6 +852,35 @@ def count(self) -> "Expr": """ return AggregateFunction("count", [self]) + @expose_as_static + def count_if(self) -> "Expr": + """Creates an aggregation that counts the number of values of the provided field or expression + that evaluate to True. + + Example: + >>> # Count the number of adults + >>> Field.of("age").greater_than(18).count_if().as_("totalAdults") + + + Returns: + A new `AggregateFunction` representing the 'count_if' aggregation. + """ + return AggregateFunction("count_if", [self]) + + @expose_as_static + def count_distinct(self) -> "Expr": + """Creates an aggregation that counts the number of distinct values of the + provided field or expression. + + Example: + >>> # Count the total number of countries in the data + >>> Field.of("country").count_distinct().as_("totalCountries") + + Returns: + A new `AggregateFunction` representing the 'count_distinct' aggregation. + """ + return AggregateFunction("count_distinct", [self]) + @expose_as_static def minimum(self) -> "Expr": """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 82247e02b..93e02f3a2 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -167,6 +167,64 @@ tests: - fieldReferenceValue: rating - mapValue: {} name: aggregate + - description: "testAggregates - count_if" + pipeline: + - Collection: books + - Aggregate: + - AliasedExpr: + - Expr.count_if: + - Expr.greater_than: + - Field: rating + - Constant: 4.2 + - "count_if_rating_gt_4_2" + assert_results: + - count_if_rating_gt_4_2: 5 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + count_if_rating_gt_4_2: + functionValue: + name: count_if + args: + - functionValue: + name: greater_than + args: + - fieldReferenceValue: rating + - doubleValue: 4.2 + - mapValue: {} + name: aggregate + - description: "testAggregates - count_distinct" + pipeline: + - Collection: books + - Aggregate: + - AliasedExpr: + - Expr.count_distinct: + - Field: genre + - "distinct_genres" + assert_results: + - distinct_genres: 8 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + distinct_genres: + functionValue: + name: count_distinct + args: + - fieldReferenceValue: genre + - mapValue: {} + name: aggregate - description: "testAggregates - avg, count, max" pipeline: - Collection: books diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index bfd8a8270..45d26dca6 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -1471,6 +1471,24 @@ def test_base_count(self): assert instance.params == [] assert repr(instance) == "Count()" + def test_count_if(self): + arg1 = self._make_arg("Value") + instance = Expr.count_if(arg1) + assert instance.name == "count_if" + assert instance.params == [arg1] + assert repr(instance) == "Value.count_if()" + infix_instance = arg1.count_if() + assert infix_instance == instance + + def test_count_distinct(self): + arg1 = self._make_arg("Value") + instance = Expr.count_distinct(arg1) + assert instance.name == "count_distinct" + assert instance.params == [arg1] + assert repr(instance) == "Value.count_distinct()" + infix_instance = arg1.count_distinct() + assert infix_instance == instance + def test_minimum(self): arg1 = self._make_arg("Value") instance = Expr.minimum(arg1) From 5d4f8783473095408c735fa8a7d1b55bbf296965 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 27 Oct 2025 15:34:22 -0700 Subject: [PATCH 11/14] added misc expressions --- .../firestore_v1/pipeline_expressions.py | 61 +++++++ tests/system/pipeline_e2e.yaml | 157 +++++++++++++++++- tests/system/test_pipeline_acceptance.py | 12 +- tests/unit/v1/test_pipeline_expressions.py | 46 ++++- 4 files changed, 267 insertions(+), 9 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 516349041..7a582f2f2 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -335,6 +335,18 @@ def log(self, base: Expr | float) -> "Expr": """ return Function("log", [self, self._cast_to_expr_or_convert_to_constant(base)]) + @expose_as_static + def log10(self) -> "Expr": + """Creates an expression that calculates the base 10 logarithm of this expression. + + Example: + >>> Field.of("value").log10() + + Returns: + A new `Expr` representing the logarithm. + """ + return Function("log10", [self]) + @expose_as_static def pow(self, exponent: Expr | float) -> "Expr": """Creates an expression that calculates this expression raised to the power of the exponent. @@ -734,6 +746,32 @@ def array_concat( + [self._cast_to_expr_or_convert_to_constant(arr) for arr in other_arrays], ) + @expose_as_static + def concat(self, *others: Expr | CONSTANT_TYPE) -> "Expr": + """Creates an expression that concatenates expressions together + + Args: + *others: The expressions to concatenate. + + Returns: + A new `Expr` representing the concatenated value. + """ + return Function("concat", [self] + [self._cast_to_expr_or_convert_to_constant(o) for o in others]) + + @expose_as_static + def length(self) -> "Expr": + """ + Creates an expression that calculates the length of the expression if it is a string, array, map, or blob. + + Example: + >>> # Get the length of the 'name' field. + >>> Field.of("name").length() + + Returns: + A new `Expr` representing the length of the expression. + """ + return Function("length", [self]) + @expose_as_static def is_absent(self) -> "BooleanExpr": """Creates an expression that returns true if a value is absent. Otherwise, returns false even if @@ -1467,6 +1505,19 @@ def collection_id(self): """ return Function("collection_id", [self]) + @expose_as_static + def document_id(self): + """Creates an expression that returns the document ID from a path. + + Example: + >>> # Get the document ID from a path. + >>> Field.of("__name__").document_id() + + Returns: + A new `Expr` representing the document ID. + """ + return Function("document_id", [self]) + def ascending(self) -> Ordering: """Creates an `Ordering` that sorts documents in ascending order based on this expression. @@ -1936,3 +1987,13 @@ class Count(AggregateFunction): def __init__(self, expression: Expr | None = None): expression_list = [expression] if expression else [] super().__init__("count", expression_list, use_infix_repr=bool(expression_list)) + +class CurrentTimestamp(Function): + """Creates an expression that returns the current timestamp + + Returns: + A new `Expr` representing the current timestamp. + """ + + def __init__(self): + super().__init__("current_timestamp", [], use_infix_repr=False) \ No newline at end of file diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 93e02f3a2..215f761fd 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -993,7 +993,57 @@ tests: expression: fieldReferenceValue: title name: sort + - description: testConcat + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpr: + - Expr.concat: + - Field: author + - Constant: ": " + - Field: title + - "author_title" + - AliasedExpr: + - Expr.concat: + - Field: tags + - - Constant: "new_tag" + - "concatenatedTags" + assert_results: + - author_title: "Douglas Adams: The Hitchhiker's Guide to the Galaxy" + concatenatedTags: + - comedy + - space + - adventure + - new_tag - description: testLength + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpr: + - Expr.length: + - Field: title + - "titleLength" + - AliasedExpr: + - Expr.length: + - Field: tags + - "tagsLength" + - AliasedExpr: + - Expr.length: + - Field: awards + - "awardsLength" + assert_results: + - titleLength: 36 + tagsLength: 3 + awardsLength: 2 + - description: testCharLength pipeline: - Collection: books - Select: @@ -1998,6 +2048,95 @@ tests: - adventure - space - comedy + - description: testDocumentId + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpr: + - Expr.document_id: + - Field: __name__ + - "doc_id" + assert_results: + - doc_id: "book1" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - mapValue: + fields: + doc_id: + functionValue: + name: document_id + args: + - fieldReferenceValue: __name__ + name: select + - description: testCurrentTimestamp + pipeline: + - Collection: books + - Limit: 1 + - Select: + - AliasedExpr: + - And: + - Expr.greater_than_or_equal: + - CurrentTimestamp: [] + - Expr.unix_seconds_to_timestamp: + - Constant: 1735689600 # 2025-01-01 + - Expr.less_than: + - CurrentTimestamp: [] + - Expr.unix_seconds_to_timestamp: + - Constant: 4892438400 # 2125-01-01 + - "is_between_2025_and_2125" + assert_results: + - is_between_2025_and_2125: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - integerValue: '1' + name: limit + - args: + - mapValue: + fields: + is_between_2025_and_2125: + functionValue: + name: and + args: + - functionValue: + name: greater_than_or_equal + args: + - functionValue: + name: current_timestamp + - functionValue: + name: unix_seconds_to_timestamp + args: + - integerValue: '1735689600' + - functionValue: + name: less_than + args: + - functionValue: + name: current_timestamp + - functionValue: + name: unix_seconds_to_timestamp + args: + - integerValue: '4892438400' + name: select - description: testArrayConcat pipeline: - Collection: books @@ -2532,10 +2671,14 @@ tests: - Field: rating - "ln_rating" - AliasedExpr: - - Expr.log: + - Expr.log10: - Field: rating - - Constant: 10 - "log_rating_base10" + - AliasedExpr: + - Expr.log: + - Field: rating + - Constant: 2 + - "log_rating_base2" - AliasedExpr: - Expr.pow: - Field: rating @@ -2552,6 +2695,7 @@ tests: floor_rating: 4.0 ln_rating: 1.4350845 log_rating_base10: 0.623249 + log_rating_base2: 2.0704 pow_rating: 17.64 sqrt_rating: 2.049390 assert_proto: @@ -2599,7 +2743,12 @@ tests: functionValue: args: - fieldReferenceValue: rating - - integerValue: '10' + name: log10 + log_rating_base2: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' name: log pow_rating: functionValue: @@ -3037,4 +3186,4 @@ tests: - fieldReferenceValue: tags - stringValue: ", " name: join - name: select \ No newline at end of file + name: select diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 313b9d673..3b3e6189d 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -87,7 +87,7 @@ def test_pipeline_expected_errors(test_dict, client): @pytest.mark.parametrize( "test_dict", - [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t], + [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t or "assert_results_approximate" in t], ids=lambda x: f"{x.get('description', '')}", ) def test_pipeline_results(test_dict, client): @@ -105,7 +105,9 @@ def test_pipeline_results(test_dict, client): if expected_results: assert got_results == expected_results if expected_approximate_results: - assert got_results == pytest.approximate(expected_approximate_results) + assert len(got_results) == len(expected_approximate_results), "got unexpected result count" + for idx in range(len(got_results)): + assert got_results[idx] == pytest.approx(expected_approximate_results[idx], abs=1e-4) if expected_count is not None: assert len(got_results) == expected_count @@ -132,7 +134,7 @@ async def test_pipeline_expected_errors_async(test_dict, async_client): @pytest.mark.parametrize( "test_dict", - [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t], + [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t or "assert_results_approximate" in t], ids=lambda x: f"{x.get('description', '')}", ) @pytest.mark.asyncio @@ -151,7 +153,9 @@ async def test_pipeline_results_async(test_dict, async_client): if expected_results: assert got_results == expected_results if expected_approximate_results: - assert got_results == pytest.approximate(expected_approximate_results) + assert len(got_results) == len(expected_approximate_results), "got unexpected result count" + for idx in range(len(got_results)): + assert got_results[idx] == pytest.approx(expected_approximate_results[idx], abs=1e-4) if expected_count is not None: assert len(got_results) == expected_count diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 45d26dca6..28a5973de 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -1144,6 +1144,12 @@ def test_subtract(self): infix_instance = arg1.subtract(arg2) assert infix_instance == instance + def test_current_timestamp(self): + instance = expr.CurrentTimestamp() + assert instance.name == "current_timestamp" + assert instance.params == [] + assert repr(instance) == "CurrentTimestamp()" + def test_timestamp_add(self): arg1 = self._make_arg("Timestamp") arg2 = self._make_arg("Unit") @@ -1342,6 +1348,15 @@ def test_log(self): infix_instance = arg1.log(arg2) assert infix_instance == instance + def test_log10(self): + arg1 = self._make_arg("Value") + instance = Expr.log10(arg1) + assert instance.name == "log10" + assert instance.params == [arg1] + assert repr(instance) == "Value.log10()" + infix_instance = arg1.log10() + assert infix_instance == instance + def test_pow(self): arg1 = self._make_arg("Value") arg2 = self._make_arg("Exponent") @@ -1429,6 +1444,26 @@ def test_char_length(self): infix_instance = arg1.char_length() assert infix_instance == instance + def test_concat(self): + arg1 = self._make_arg("First") + arg2 = self._make_arg("Second") + arg3 = "Third" + instance = Expr.concat(arg1, arg2, arg3) + assert instance.name == "concat" + assert instance.params == [arg1, arg2, Constant.of(arg3)] + assert repr(instance) == "First.concat(Second, Constant.of('Third'))" + infix_instance = arg1.concat(arg2, arg3) + assert infix_instance == instance + + def test_length(self): + arg1 = self._make_arg("Expr") + instance = Expr.length(arg1) + assert instance.name == "length" + assert instance.params == [arg1] + assert repr(instance) == "Expr.length()" + infix_instance = arg1.length() + assert infix_instance == instance + def test_collection_id(self): arg1 = self._make_arg("Value") instance = Expr.collection_id(arg1) @@ -1438,6 +1473,15 @@ def test_collection_id(self): infix_instance = arg1.collection_id() assert infix_instance == instance + def test_document_id(self): + arg1 = self._make_arg("Value") + instance = Expr.document_id(arg1) + assert instance.name == "document_id" + assert instance.params == [arg1] + assert repr(instance) == "Value.document_id()" + infix_instance = arg1.document_id() + assert infix_instance == instance + def test_sum(self): arg1 = self._make_arg("Value") instance = Expr.sum(arg1) @@ -1505,4 +1549,4 @@ def test_maximum(self): assert instance.params == [arg1] assert repr(instance) == "Value.maximum()" infix_instance = arg1.maximum() - assert infix_instance == instance + assert infix_instance == instance \ No newline at end of file From 6d6c57f651cd2c5d9509ec8a25a0d874328e42e3 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 27 Oct 2025 16:05:23 -0700 Subject: [PATCH 12/14] added error functions --- .../firestore_v1/pipeline_expressions.py | 51 ++++++++ tests/system/pipeline_e2e.yaml | 114 ++++++++++++++++++ tests/unit/v1/test_pipeline_expressions.py | 30 +++++ 3 files changed, 195 insertions(+) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 7a582f2f2..5fc425642 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -786,6 +786,24 @@ def is_absent(self) -> "BooleanExpr": """ return BooleanExpr("is_absent", [self]) + @expose_as_static + def if_absent(self, default_value: Expr | CONSTANT_TYPE) -> "Expr": + """Creates an expression that returns a default value if an expression evaluates to an absent value. + + Example: + >>> # Return the value of the 'email' field, or "N/A" if it's absent. + >>> Field.of("email").if_absent("N/A") + + Args: + default_value: The expression or constant value to return if this expression is absent. + + Returns: + A new `Expr` representing the ifAbsent operation. + """ + return Function( + "if_absent", [self, self._cast_to_expr_or_convert_to_constant(default_value)] + ) + @expose_as_static def is_nan(self) -> "BooleanExpr": """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). @@ -836,6 +854,38 @@ def is_not_null(self) -> "BooleanExpr": """ return BooleanExpr("is_not_null", [self]) + @expose_as_static + def is_error(self): + """Creates an expression that checks if a given expression produces an error + + Example: + >>> # Resolves to True if an expression produces an error + >>> Field.of("value").divide("string").is_error() + + Returns: + A new `Expr` representing the isError operation. + """ + return Function("is_error", [self]) + + @expose_as_static + def if_error(self, then_value: Expr | CONSTANT_TYPE) -> "Expr": + """Creates an expression that returns ``then_value`` if this expression evaluates to an error. + Otherwise, returns the value of this expression. + + Example: + >>> # Resolves to 0 if an expression produces an error + >>> Field.of("value").divide("string").if_error(0) + + Args: + then_value: The value to return if this expression evaluates to an error. + + Returns: + A new `Expr` representing the ifError operation. + """ + return Function( + "if_error", [self, self._cast_to_expr_or_convert_to_constant(then_value)] + ) + @expose_as_static def exists(self) -> "BooleanExpr": """Creates an expression that checks if a field exists in the document. @@ -1988,6 +2038,7 @@ def __init__(self, expression: Expr | None = None): expression_list = [expression] if expression else [] super().__init__("count", expression_list, use_infix_repr=bool(expression_list)) + class CurrentTimestamp(Function): """Creates an expression that returns the current timestamp diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 215f761fd..38595224a 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -1585,6 +1585,120 @@ tests: - fieldReferenceValue: awards.pulitzer name: is_absent name: where + - description: testIfAbsent + pipeline: + - Collection: books + - Select: + - AliasedExpr: + - Expr.if_absent: + - Field: awards.pulitzer + - Constant: false + - "pulitzer_award" + - title + - Where: + - Expr.equal: + - Field: pulitzer_award + - Constant: true + assert_results: + - pulitzer_award: true + title: To Kill a Mockingbird + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + pulitzer_award: + functionValue: + name: if_absent + args: + - fieldReferenceValue: awards.pulitzer + - booleanValue: false + title: + fieldReferenceValue: title + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: pulitzer_award + - booleanValue: true + name: equal + name: where + - description: testIsError + pipeline: + - Collection: books + - Select: + - AliasedExpr: + - Expr.is_error: + - Expr.divide: + - Field: rating + - Constant: "string" + - "is_error_result" + - Limit: 1 + assert_results: + - is_error_result: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + is_error_result: + functionValue: + name: is_error + args: + - functionValue: + name: divide + args: + - fieldReferenceValue: rating + - stringValue: "string" + name: select + - args: + - integerValue: '1' + name: limit + - description: testIfError + pipeline: + - Collection: books + - Select: + - AliasedExpr: + - Expr.if_error: + - Expr.divide: + - Field: rating + - Field: genre + - Constant: "An error occurred" + - "if_error_result" + - Limit: 1 + assert_results: + - if_error_result: "An error occurred" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + if_error_result: + functionValue: + name: if_error + args: + - functionValue: + name: divide + args: + - fieldReferenceValue: rating + - fieldReferenceValue: genre + - stringValue: "An error occurred" + name: select + - args: + - integerValue: '1' + name: limit - description: testLogicalMinMax pipeline: - Collection: books diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 28a5973de..2c3b97259 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -839,6 +839,17 @@ def test_is_absent(self): infix_instance = arg1.is_absent() assert infix_instance == instance + + def test_if_absent(self): + arg1 = self._make_arg("Field") + arg2 = self._make_arg("ThenExpr") + instance = Expr.if_absent(arg1, arg2) + assert instance.name == "if_absent" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Field.if_absent(ThenExpr)" + infix_instance = arg1.if_absent(arg2) + assert infix_instance == instance + def test_is_nan(self): arg1 = self._make_arg("Value") instance = Expr.is_nan(arg1) @@ -875,6 +886,25 @@ def test_is_not_null(self): infix_instance = arg1.is_not_null() assert infix_instance == instance + def test_is_error(self): + arg1 = self._make_arg("Value") + instance = Expr.is_error(arg1) + assert instance.name == "is_error" + assert instance.params == [arg1] + assert repr(instance) == "Value.is_error()" + infix_instance = arg1.is_error() + assert infix_instance == instance + + def test_if_error(self): + arg1 = self._make_arg("Value") + arg2 = self._make_arg("ThenExpr") + instance = Expr.if_error(arg1, arg2) + assert instance.name == "if_error" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Value.if_error(ThenExpr)" + infix_instance = arg1.if_error(arg2) + assert infix_instance == instance + def test_not(self): arg1 = self._make_arg("Condition") instance = expr.Not(arg1) From f1690d84df4dd596cc873c0035e4f8a1ac78fe24 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 27 Oct 2025 16:09:21 -0700 Subject: [PATCH 13/14] fixed lint --- .../firestore_v1/pipeline_expressions.py | 12 ++++--- tests/system/test_pipeline_acceptance.py | 32 +++++++++++++++---- tests/unit/v1/test_pipeline_expressions.py | 3 +- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 5fc425642..93ceca265 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -756,7 +756,10 @@ def concat(self, *others: Expr | CONSTANT_TYPE) -> "Expr": Returns: A new `Expr` representing the concatenated value. """ - return Function("concat", [self] + [self._cast_to_expr_or_convert_to_constant(o) for o in others]) + return Function( + "concat", + [self] + [self._cast_to_expr_or_convert_to_constant(o) for o in others], + ) @expose_as_static def length(self) -> "Expr": @@ -801,7 +804,8 @@ def if_absent(self, default_value: Expr | CONSTANT_TYPE) -> "Expr": A new `Expr` representing the ifAbsent operation. """ return Function( - "if_absent", [self, self._cast_to_expr_or_convert_to_constant(default_value)] + "if_absent", + [self, self._cast_to_expr_or_convert_to_constant(default_value)], ) @expose_as_static @@ -957,7 +961,7 @@ def count_if(self) -> "Expr": @expose_as_static def count_distinct(self) -> "Expr": - """Creates an aggregation that counts the number of distinct values of the + """Creates an aggregation that counts the number of distinct values of the provided field or expression. Example: @@ -2047,4 +2051,4 @@ class CurrentTimestamp(Function): """ def __init__(self): - super().__init__("current_timestamp", [], use_infix_repr=False) \ No newline at end of file + super().__init__("current_timestamp", [], use_infix_repr=False) diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 3b3e6189d..682fe5e23 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -87,7 +87,13 @@ def test_pipeline_expected_errors(test_dict, client): @pytest.mark.parametrize( "test_dict", - [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t or "assert_results_approximate" in t], + [ + t + for t in yaml_loader() + if "assert_results" in t + or "assert_count" in t + or "assert_results_approximate" in t + ], ids=lambda x: f"{x.get('description', '')}", ) def test_pipeline_results(test_dict, client): @@ -105,9 +111,13 @@ def test_pipeline_results(test_dict, client): if expected_results: assert got_results == expected_results if expected_approximate_results: - assert len(got_results) == len(expected_approximate_results), "got unexpected result count" + assert len(got_results) == len( + expected_approximate_results + ), "got unexpected result count" for idx in range(len(got_results)): - assert got_results[idx] == pytest.approx(expected_approximate_results[idx], abs=1e-4) + assert got_results[idx] == pytest.approx( + expected_approximate_results[idx], abs=1e-4 + ) if expected_count is not None: assert len(got_results) == expected_count @@ -134,7 +144,13 @@ async def test_pipeline_expected_errors_async(test_dict, async_client): @pytest.mark.parametrize( "test_dict", - [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t or "assert_results_approximate" in t], + [ + t + for t in yaml_loader() + if "assert_results" in t + or "assert_count" in t + or "assert_results_approximate" in t + ], ids=lambda x: f"{x.get('description', '')}", ) @pytest.mark.asyncio @@ -153,9 +169,13 @@ async def test_pipeline_results_async(test_dict, async_client): if expected_results: assert got_results == expected_results if expected_approximate_results: - assert len(got_results) == len(expected_approximate_results), "got unexpected result count" + assert len(got_results) == len( + expected_approximate_results + ), "got unexpected result count" for idx in range(len(got_results)): - assert got_results[idx] == pytest.approx(expected_approximate_results[idx], abs=1e-4) + assert got_results[idx] == pytest.approx( + expected_approximate_results[idx], abs=1e-4 + ) if expected_count is not None: assert len(got_results) == expected_count diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 2c3b97259..aec721e7d 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -839,7 +839,6 @@ def test_is_absent(self): infix_instance = arg1.is_absent() assert infix_instance == instance - def test_if_absent(self): arg1 = self._make_arg("Field") arg2 = self._make_arg("ThenExpr") @@ -1579,4 +1578,4 @@ def test_maximum(self): assert instance.params == [arg1] assert repr(instance) == "Value.maximum()" infix_instance = arg1.maximum() - assert infix_instance == instance \ No newline at end of file + assert infix_instance == instance From 78cccb48477f9c7d5fc1cac9d8d9b2b4b5c4041d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 29 Oct 2025 14:05:11 -0700 Subject: [PATCH 14/14] fixed typos --- google/cloud/firestore_v1/pipeline_expressions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 93ceca265..b113e2874 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1280,14 +1280,14 @@ def map_get(self, key: str | Constant[str]) -> "Expr": @expose_as_static def map_remove(self, key: str | Constant[str]) -> "Expr": - """Remove a key from a the map produced by evaluating this expression. + """Remove a key from the map produced by evaluating this expression. Example: >>> Map({"city": "London"}).map_remove("city") >>> Field.of("address").map_remove("city") Args: - key: The key to ewmove in the map. + key: The key to remove in the map. Returns: A new `Expr` representing the map_remove operation. @@ -1894,7 +1894,7 @@ class Array(Function): >>> Expr.array(["bar", Field.of("baz")]) Args: - elements: THe input list to evaluate in the expression + elements: The input list to evaluate in the expression """ def __init__(self, elements: list[Expr | CONSTANT_TYPE]): @@ -1917,7 +1917,7 @@ class Map(Function): >>> Expr.map({"foo": "bar", "baz": Field.of("baz")}) Args: - elements: THe input dict to evaluate in the expression + elements: The input dict to evaluate in the expression """ def __init__(self, elements: dict[str | Constant[str], Expr | CONSTANT_TYPE]):