Skip to content

Commit f728bd6

Browse files
fix(firestore): Imropve improper pipeline aliases (#16651)
Currently, AliasedExpressions are treated like regular expressions. You can execute additional expressions off of them (`a.as_("number").add(5)`), or chain them (`a.as_("first").as_("second")`). But the backend doesn't actually support aliases being used in this way This PR raises an exception if an alias is used in a context it doesn't support Go version: googleapis/google-cloud-go#14440
1 parent cde1c0f commit f728bd6

4 files changed

Lines changed: 86 additions & 7 deletions

File tree

packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ class Expression(ABC):
7272
together method calls to create complex expressions.
7373
"""
7474

75+
# Controls whether expression methods (e.g., .add(), .multiply()) can be called on
76+
# instances of this class or its subclasses. Set to False for non-computational
77+
# expressions like AliasedExpression.
78+
_supports_expr_methods = True
79+
7580
def __repr__(self):
7681
return f"{self.__class__.__name__}()"
7782

@@ -113,6 +118,10 @@ def __init__(self, instance_func):
113118
self.instance_func = instance_func
114119

115120
def static_func(self, first_arg, *other_args, **kwargs):
121+
if getattr(first_arg, "_supports_expr_methods", True) is False:
122+
raise TypeError(
123+
f"Cannot call '{self.instance_func.__name__}' on {type(first_arg).__name__}."
124+
)
116125
if not isinstance(first_arg, (Expression, str)):
117126
raise TypeError(
118127
f"'{self.instance_func.__name__}' must be called on an Expression or a string representing a field. got {type(first_arg)}."
@@ -128,6 +137,10 @@ def __get__(self, instance, owner):
128137
if instance is None:
129138
return self.static_func
130139
else:
140+
if getattr(instance, "_supports_expr_methods", True) is False:
141+
raise TypeError(
142+
f"Cannot call '{self.instance_func.__name__}' on {type(instance).__name__}."
143+
)
131144
return self.instance_func.__get__(instance, owner)
132145

133146
@expose_as_static
@@ -2715,10 +2728,21 @@ def _to_value(field_list: Sequence[Selectable]) -> Value:
27152728
class AliasedExpression(Selectable, Generic[T]):
27162729
"""Wraps an expression with an alias."""
27172730

2731+
_supports_expr_methods = False
2732+
27182733
def __init__(self, expr: T, alias: str):
2734+
if isinstance(expr, AliasedExpression):
2735+
raise TypeError(
2736+
"Cannot wrap an AliasedExpression with another alias. An alias can only be applied once."
2737+
)
27192738
self.expr = expr
27202739
self.alias = alias
27212740

2741+
def as_(self, alias: str) -> "AliasedExpression":
2742+
raise TypeError(
2743+
"Cannot call as_() on an AliasedExpression. An alias can only be applied once."
2744+
)
2745+
27222746
def _to_map(self):
27232747
return self.alias, self.expr._to_pb()
27242748

packages/google-cloud-firestore/tests/system/pipeline_e2e/general.yaml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,3 +896,42 @@ tests:
896896
- Pipeline:
897897
- Subcollection: reviews
898898
assert_error: ".*start of a nested pipeline.*"
899+
- description: cannot_call_expression_methods_on_aliased_expression
900+
pipeline:
901+
- Collection: books
902+
- Select:
903+
- FunctionExpression.add:
904+
- AliasedExpression:
905+
- Field: pages
906+
- pages_alias
907+
- 5
908+
assert_error: "Cannot call 'add' on AliasedExpression"
909+
- description: cannot_chain_aliases
910+
pipeline:
911+
- Collection: books
912+
- Select:
913+
- AliasedExpression:
914+
- AliasedExpression:
915+
- Field: pages
916+
- pages_alias
917+
- final_alias
918+
assert_error: "Cannot wrap an AliasedExpression"
919+
- description: valid_aliased_expression_proto
920+
pipeline:
921+
- Collection: books
922+
- Select:
923+
- AliasedExpression:
924+
- Field: pages
925+
- pages_alias
926+
assert_proto:
927+
pipeline:
928+
stages:
929+
- args:
930+
- referenceValue: /books
931+
name: collection
932+
- args:
933+
- mapValue:
934+
fields:
935+
pages_alias:
936+
fieldReferenceValue: pages
937+
name: select

packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
import pytest
2626
import yaml
27-
from google.api_core.exceptions import GoogleAPIError
2827
from google.protobuf.json_format import MessageToDict
2928
from test__helpers import FIRESTORE_EMULATOR, FIRESTORE_ENTERPRISE_DB, system_test_lock
3029

@@ -124,9 +123,9 @@ def test_pipeline_expected_errors(test_dict, client):
124123
Finds assert_error statements in yaml, and ensures the pipeline raises the expected error
125124
"""
126125
error_regex = test_dict["assert_error"]
127-
pipeline = parse_pipeline(client, test_dict["pipeline"])
128-
# check if server responds as expected
129-
with pytest.raises(GoogleAPIError) as err:
126+
127+
with pytest.raises(Exception) as err:
128+
pipeline = parse_pipeline(client, test_dict["pipeline"])
130129
pipeline.execute()
131130
found_error = str(err.value)
132131
match = re.search(error_regex, found_error)
@@ -215,9 +214,8 @@ async def test_pipeline_expected_errors_async(test_dict, async_client):
215214
Finds assert_error statements in yaml, and ensures the pipeline raises the expected error
216215
"""
217216
error_regex = test_dict["assert_error"]
218-
pipeline = parse_pipeline(async_client, test_dict["pipeline"])
219-
# check if server responds as expected
220-
with pytest.raises(GoogleAPIError) as err:
217+
with pytest.raises(Exception) as err:
218+
pipeline = parse_pipeline(async_client, test_dict["pipeline"])
221219
await pipeline.execute()
222220
found_error = str(err.value)
223221
match = re.search(error_regex, found_error)

packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,24 @@ def test_to_map(self):
260260
assert result[0] == "alias1"
261261
assert result[1] == Value(field_reference_value="field1")
262262

263+
def test_chaining_aliases(self):
264+
with pytest.raises(
265+
TypeError, match="Cannot call as_\\(\\) on an AliasedExpression"
266+
):
267+
Field.of("field1").as_("alias1").as_("alias2")
268+
269+
def test_expr_method_on_aliased_raises_error(self):
270+
with pytest.raises(
271+
TypeError, match="Cannot call 'add' on AliasedExpression"
272+
):
273+
Field.of("field1").as_("alias1").add(5)
274+
275+
def test_static_expr_method_on_aliased_raises_error(self):
276+
with pytest.raises(
277+
TypeError, match="Cannot call 'add' on AliasedExpression"
278+
):
279+
expr.Expression.add(Field.of("field1").as_("alias1"), 5)
280+
263281

264282
class TestBooleanExpression:
265283
def test__from_query_filter_pb_composite_filter_or(self, mock_client):

0 commit comments

Comments
 (0)