diff --git a/BUILD b/BUILD index d9e3c8b..9a87487 100644 --- a/BUILD +++ b/BUILD @@ -12,12 +12,12 @@ exports_files(["LICENSE"]) pybind_extension( name = "py_cel", srcs = [ - "py_cel.cc", - "py_cel.h", "py_cel_activation.cc", "py_cel_activation.h", "py_cel_arena.cc", "py_cel_arena.h", + "py_cel_env.cc", + "py_cel_env.h", "py_cel_env_internal.cc", "py_cel_env_internal.h", "py_cel_expression.cc", diff --git a/README.md b/README.md index 772e82b..2b9091b 100644 --- a/README.md +++ b/README.md @@ -16,12 +16,12 @@ To create a CEL environment, you need to define variable types that can be used in expressions. ```python -cel = py_cel.Cel(variables={"x": py_cel.Type.INT, "y": py_cel.Type.INT}) +cel_env = py_cel.NewEnv(variables={"x": py_cel.Type.INT, "y": py_cel.Type.INT}) ``` #### Optional configuration parameters -The `py_cel.Cel` constructor also accepts the following optional parameters: +The `py_cel.NewEnv` constructor also accepts the following optional parameters: * `pool` (`descriptor_pool.DescriptorPool`): The descriptor pool used for resolving protobuf message types within CEL expressions. If not provided, @@ -39,7 +39,7 @@ Use the `compile()` method to compile a CEL expression string into a reusable expression object. ```python -expr = cel.compile("x + y > 10") +expr = cel_env.compile("x + y > 10") ``` The `expr` object can be serialized into a binary format for persistence and @@ -48,7 +48,7 @@ later deserialized. ```python serialized_expr = expr.serialize() # ... can be stored or sent over network ... -deserialized_expr = cel.deserialize(serialized_expr) +deserialized_expr = cel_env.deserialize(serialized_expr) ``` The `compile` method can take an optional `disable_check=True` argument, which @@ -62,7 +62,7 @@ provides bindings for variables, and then call `eval()`. ```python # Provide variable values in a dictionary. -activation = cel.Activation({"x": 7, "y": 4}) +activation = cel_env.Activation({"x": 7, "y": 4}) # Evaluate the expression. result = expr.eval(activation) @@ -94,9 +94,9 @@ garbage-collected in Python. ```python arena = py_cel.Arena() -activation1 = cel.Activation({"x": 7, "y": 4}, arena) +activation1 = cel_env.Activation({"x": 7, "y": 4}, arena) # evaluate some expressions -activation2 = cel.Activation({"x": 8, "y": 9}, arena) +activation2 = cel_env.Activation({"x": 8, "y": 9}, arena) # evaluate some more expressions # Process all results. Note: Don't put CelValues in long-lived data structures @@ -112,7 +112,7 @@ You can pass protobuf messages as variables to an activation; CEL expressions can return protobuf messages. First, ensure your proto messages are available in the descriptor pool used by -`py_cel.Cel`, by importing your proto library in Python: +`py_cel.NewEnv`, by importing your proto library in Python: from cel.expr.conformance.proto2 import test_all_types_pb2 as test_pb @@ -121,7 +121,7 @@ qualified name. ```python # Declare 'msg_var' as a message type. -cel = py_cel.Cel( +cel = py_cel.NewEnv( pool, variables={ "msg_var": py_cel.Type("cel.expr.conformance.proto2.TestAllTypes"), @@ -141,7 +141,7 @@ an instance of the Python proto message class. ```python my_msg = test_pb.TestAllTypes(single_int32=42) -activation = cel.Activation({"msg_var": my_msg}) +activation = cel_env.Activation({"msg_var": my_msg}) result = expr.eval(activation) print(f"Result: {result.value()}") ``` @@ -149,7 +149,7 @@ print(f"Result: {result.value()}") An expression can also return a proto message: ```python -msg_expr = cel.compile( +msg_expr = cel_env.compile( "cel.expr.conformance.proto2.TestAllTypes{single_int32: 123}" ) msg_result = msg_expr.eval(activation) @@ -174,8 +174,8 @@ Standard extensions are available under `py_cel.ext`. ```python from py_cel.ext import ext_math -cel = py_cel.Cel(pool, extensions=[ext_math.ExtMath()]) -expr = cel.compile("math.sqrt(4)") +cel = py_cel.NewEnv(pool, extensions=[ext_math.ExtMath()]) +expr = cel_env.compile("math.sqrt(4)") ``` #### Defining a custom extension in Python @@ -203,8 +203,8 @@ my_ext = py_cel.CelExtension( ], ) -cel = py_cel.Cel(pool, extensions=[my_ext]) -expr = cel.compile("my_func(1)") +cel_env = py_cel.NewEnv(pool, extensions=[my_ext]) +expr = cel_env.compile("my_func(1)") ``` #### Defining a custom extension in C++ @@ -304,10 +304,10 @@ Now you can use the extension in PyCel: ```python import translation_cel_ext -cel = py_cel.Cel(variables={}, +cel_env = py_cel.NewEnv(variables={}, extensions=[translation_cel_ext.TranslationCelExtension()]) -expr = cel.compile("'Hello, world!'.translate('en', 'es')") +expr = cel_env.compile("'Hello, world!'.translate('en', 'es')") ``` #### Late-bound extension functions @@ -357,11 +357,11 @@ If the extension is written in C++, use the `RegisterLazyFunction` function: Now you can bind the function at runtime: ```python -cel = py_cel.Cel(variables={}, extensions=[my_ext]) -expr = cel.compile("my_func(42)") +cel_env = py_cel.NewEnv(variables={}, extensions=[my_ext]) +expr = cel_env.compile("my_func(42)") multiplier = 2 -act = cel.Activation({}, functions={"my_func": lambda x: x * multiplier}) +act = cel_env.Activation({}, functions={"my_func": lambda x: x * multiplier}) res = expr.eval(act) # res.value() == 84 ``` diff --git a/conformance/conformance_test.py b/conformance/conformance_test.py index dfff8cc..084391f 100644 --- a/conformance/conformance_test.py +++ b/conformance/conformance_test.py @@ -158,14 +158,14 @@ def _run_conformance_test(self, simple_test: simple_pb.SimpleTest): break self.descriptor_pool = descriptor_pool.Default() - self.cel = cel.Cel( + self.env = cel.NewEnv( self.descriptor_pool, variables=decls, extensions=extensions, container=simple_test.container, ) try: - compiled_expr = self.cel.compile( + compiled_expr = self.env.compile( simple_test.expr, disable_check=simple_test.disable_check ) except Exception as e: # pylint: disable=broad-except @@ -188,7 +188,7 @@ def _run_conformance_test(self, simple_test: simple_pb.SimpleTest): for key, value in simple_test.bindings.items(): values[key] = self._convert_value(value.value) - act = self.cel.Activation(values) + act = self.env.Activation(values) try: res = compiled_expr.eval(act) except Exception as e: # pylint: disable=broad-except diff --git a/custom_ext/custom_ext_test.py b/custom_ext/custom_ext_test.py index 32f2002..2cd340a 100644 --- a/custom_ext/custom_ext_test.py +++ b/custom_ext/custom_ext_test.py @@ -40,16 +40,16 @@ def _compile_expr( ) -> cel.Expression: """Creates a CEL expression for the given extension and compiles the expression.""" self.descriptor_pool = descriptor_pool.Default() - self.cel = cel.Cel( + self.env = cel.NewEnv( self.descriptor_pool, variables={}, extensions=[ext()], ) - return self.cel.compile(expression) + return self.env.compile(expression) def _create_activation(self, impl) -> cel.Activation: """Creates a CEL Activation with a late-bound translate function.""" - return self.cel.Activation( + return self.env.Activation( {}, functions=[ cel.Function( @@ -66,7 +66,7 @@ def test_basic_function(self, ext): compiled_expr = self._compile_expr( ext, "'Hello, world!'.translate('en', 'es')" ) - act = self.cel.Activation({}) + act = self.env.Activation({}) res = compiled_expr.eval(act) self.assertEqual(res.value(), "¡Hola Mundo!") @@ -85,7 +85,7 @@ def test_late_bound_function(self, ext): @parameterized.named_parameters(EXT_IMPLEMENTATIONS) def test_error_no_matching_overload(self, ext): compiled_expr = self._compile_expr(ext, "translate_late('Hello, world!')") - act = self.cel.Activation( + act = self.env.Activation( {}, functions=[ cel.Function( diff --git a/py_cel.cc b/py_cel_env.cc similarity index 55% rename from py_cel.cc rename to py_cel_env.cc index 83f8807..cd99ad6 100644 --- a/py_cel.cc +++ b/py_cel_env.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "py_cel.h" +#include "py_cel_env.h" #include // IWYU pragma: keep - Needed for PyObject @@ -38,48 +38,47 @@ namespace cel_python { namespace py = ::pybind11; -void PyCel::DefinePythonBindings(pybind11::module& m) { - py::class_> cel_class(m, "Cel"); - cel_class - .def(py::init([](py::object descriptor_pool, - std::optional> - variables, - std::optional> extensions, - const std::optional& container) { - PyObject* pool_ptr = nullptr; - if (descriptor_pool.is_none()) { - // Replicates python's `descriptor_pool.Default()` - pool_ptr = py::module::import("google.protobuf.descriptor_pool") - .attr("Default")() - .ptr(); - } else { - pool_ptr = descriptor_pool.ptr(); - } +void PyCelEnv::DefinePythonBindings(pybind11::module& m) { + py::class_> cel_class(m, "Env"); + m.def( + "NewEnv", + [](py::object descriptor_pool, + std::optional> variables, + std::optional> extensions, + const std::optional& container) { + PyObject* pool_ptr = nullptr; + if (descriptor_pool.is_none()) { + // Replicates python's `descriptor_pool.Default()` + pool_ptr = py::module::import("google.protobuf.descriptor_pool") + .attr("Default")() + .ptr(); + } else { + pool_ptr = descriptor_pool.ptr(); + } - std::vector ext_ptrs; - if (extensions) { - ext_ptrs.reserve(extensions->size()); - for (const auto& ext : *extensions) { - ext_ptrs.push_back(ext.ptr()); - } - } + std::vector ext_ptrs; + if (extensions) { + ext_ptrs.reserve(extensions->size()); + for (const auto& ext : *extensions) { + ext_ptrs.push_back(ext.ptr()); + } + } - return std::make_shared( - pool_ptr, - std::move(variables).value_or( - std::unordered_map{}), - ext_ptrs, container.value_or("")); - }), - py::arg("descriptor_pool") = py::none(), - py::arg("variables") = py::none(), - py::arg("extensions") = py::none(), - py::arg("container") = py::none()) - .def("compile", &PyCel::Compile, py::arg("expression"), + return PyCelEnv(pool_ptr, + std::move(variables).value_or( + std::unordered_map{}), + ext_ptrs, container.value_or("")); + }, + py::arg("descriptor_pool") = py::none(), + py::arg("variables") = py::none(), py::arg("extensions") = py::none(), + py::arg("container") = py::none()); + cel_class + .def("compile", &PyCelEnv::Compile, py::arg("expression"), py::arg("disable_check") = false) - .def("deserialize", &PyCel::Deserialize, py::arg("serialized")) + .def("deserialize", &PyCelEnv::Deserialize, py::arg("serialized")) .def( "Activation", - [](PyCel& self, + [](PyCelEnv& self, std::optional> data, const std::optional>>& functions, @@ -103,30 +102,31 @@ void PyCel::DefinePythonBindings(pybind11::module& m) { py::arg("arena") = nullptr); } -PyCel::PyCel(PyObject* descriptor_pool, - std::unordered_map variable_types, - const std::vector& extensions, std::string container) +PyCelEnv::PyCelEnv(PyObject* descriptor_pool, + std::unordered_map variable_types, + const std::vector& extensions, + std::string container) : env_(std::make_unique( descriptor_pool, std::move(variable_types), extensions, std::move(container))) { ABSL_CHECK(PyGILState_Check()); } -PyCel::~PyCel() = default; +PyCelEnv::~PyCelEnv() = default; -std::shared_ptr PyCel::NewActivation( +std::shared_ptr PyCelEnv::NewActivation( const std::unordered_map& data, const std::vector>& functions, const std::shared_ptr& arena) { return std::make_shared(env_, data, functions, arena); } -absl::StatusOr PyCel::Compile(const std::string& cel_expr, - bool disable_check) { +absl::StatusOr PyCelEnv::Compile(const std::string& cel_expr, + bool disable_check) { return PyCelExpression::Compile(env_, cel_expr, disable_check); } -absl::StatusOr PyCel::Deserialize( +absl::StatusOr PyCelEnv::Deserialize( const std::string& serialized_expr) { return PyCelExpression::Deserialize(env_, serialized_expr); } diff --git a/py_cel.h b/py_cel_env.h similarity index 80% rename from py_cel.h rename to py_cel_env.h index 74d2a35..7ba0444 100644 --- a/py_cel.h +++ b/py_cel_env.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef THIRD_PARTY_CEL_PYTHON_PY_CEL_H_ -#define THIRD_PARTY_CEL_PYTHON_PY_CEL_H_ +#ifndef THIRD_PARTY_CEL_PYTHON_PY_CEL_ENV_H_ +#define THIRD_PARTY_CEL_PYTHON_PY_CEL_ENV_H_ #include // IWYU pragma: keep - Needed for PyObject @@ -41,21 +41,16 @@ class RuntimeOptions; // All classes and functions in this namespace are pybind11-wrapped. namespace cel_python { -class PyCel; -class PyCelEnv; +class PyCelEnvInternal; class PyCelFunction; class PyMessageFactory; // CEL environment. Provides access to the CEL compiler. -class PyCel { +class PyCelEnv { public: static void DefinePythonBindings(pybind11::module& m); - explicit PyCel(PyObject* descriptor_pool = nullptr, - std::unordered_map variable_types = {}, - const std::vector& extensions = {}, - std::string container = ""); - ~PyCel(); + ~PyCelEnv(); absl::StatusOr Compile(const std::string& cel_expr, bool disable_check = false); @@ -69,9 +64,13 @@ class PyCel { std::shared_ptr GetEnv() { return env_; } private: + // Private constructor. Use `py_cel.NewEnv()` in python to obtain an instance. + PyCelEnv(PyObject* descriptor_pool, + std::unordered_map variable_types, + const std::vector& extensions, std::string container); std::shared_ptr env_; }; } // namespace cel_python -#endif // THIRD_PARTY_CEL_PYTHON_PY_CEL_H_ +#endif // THIRD_PARTY_CEL_PYTHON_PY_CEL_ENV_H_ diff --git a/py_cel_module.cc b/py_cel_module.cc index 070c881..776ca98 100644 --- a/py_cel_module.cc +++ b/py_cel_module.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "py_cel.h" #include "py_cel_activation.h" #include "py_cel_arena.h" +#include "py_cel_env.h" #include "py_cel_expression.h" #include "py_cel_function.h" #include "py_cel_function_decl.h" @@ -40,7 +40,7 @@ PYBIND11_MODULE(py_cel, m) { PyCelFunctionDecl::DefinePythonBindings(m); PyCelPythonExtension::DefinePythonBindings(m); PyCelFunction::DefinePythonBindings(m); - PyCel::DefinePythonBindings(m); + PyCelEnv::DefinePythonBindings(m); } } // namespace cel_python diff --git a/py_cel_test.py b/py_cel_test.py index da8b8f5..b164abf 100644 --- a/py_cel_test.py +++ b/py_cel_test.py @@ -29,7 +29,7 @@ class PyCelTest(absltest.TestCase): def setUp(self): super().setUp() - self.cel = cel.Cel( + self.env = cel.NewEnv( variables={ "var_bool": cel.Type.BOOL, "var_int": cel.Type.INT, @@ -86,12 +86,12 @@ def _eval( ): if data is None: data = {} - expr = self.cel.compile(expression) + expr = self.env.compile(expression) if expected_return_type is not None: self.assertTrue( expected_return_type.is_assignable_from(expr.return_type()) ) - act = self.cel.Activation(data) + act = self.env.Activation(data) return expr.eval(act) def testUnsetVar(self): @@ -626,7 +626,7 @@ def testTypeType(self): # If the expected type is parameterized, is_assignable_from takes # into account the type's parameterization. - expr = self.cel.compile("type(1)") + expr = self.env.compile("type(1)") self.assertFalse( cel.Type.Type(cel.Type.STRING).is_assignable_from(expr.return_type()) ) @@ -652,38 +652,38 @@ def testTypeType(self): ) # This behavior is counterintuitive but works as implemented. def testCelExpressionPersistence_checkedExpr(self): - expr = self.cel.compile("var_msg.single_string") + expr = self.env.compile("var_msg.single_string") as_bytes = expr.serialize() - expr = self.cel.deserialize(as_bytes) + expr = self.env.deserialize(as_bytes) msg = test_all_types_pb.TestAllTypes() msg.single_string = "Hey!" - res = expr.eval(self.cel.Activation({"var_msg": msg})) + res = expr.eval(self.env.Activation({"var_msg": msg})) self.assertEqual(res.value(), "Hey!") def testCelExpressionPersistence_uncheckedExpr(self): - expr = self.cel.compile("runtimely + 2", disable_check=True) + expr = self.env.compile("runtimely + 2", disable_check=True) as_bytes = expr.serialize() - expr = self.cel.deserialize(as_bytes) + expr = self.env.deserialize(as_bytes) - res = expr.eval(self.cel.Activation({"runtimely": 40})) + res = expr.eval(self.env.Activation({"runtimely": 40})) self.assertEqual(res.value(), 42) def testCelExpressionPersistence_badSerializedFormat(self): with self.assertRaises(Exception) as e: - self.cel.deserialize("b'foo'") + self.env.deserialize("b'foo'") self.assertIn("Cannot parse serialized CEL expression", str(e.exception)) def testCheckedCelExpression_raises(self): with self.assertRaises(Exception) as e: - self.cel.compile("runtimely + 2", disable_check=False) + self.env.compile("runtimely + 2", disable_check=False) self.assertIn("undeclared reference to 'runtimely'", str(e.exception)) def testUncheckedCelExpression(self): - expr = self.cel.compile("runtimely + 2", disable_check=True) - res = expr.eval(self.cel.Activation({"runtimely": 40})) + expr = self.env.compile("runtimely + 2", disable_check=True) + res = expr.eval(self.env.Activation({"runtimely": 40})) self.assertEqual(res.value(), 42) def testActivationWithArena(self): @@ -694,12 +694,12 @@ def testActivationWithArena(self): msg = test_all_types_pb.TestAllTypes() msg.single_string = "Hey" - expr = self.cel.compile( + expr = self.env.compile( "cel.expr.conformance.proto2.TestAllTypes{" "single_string: var_msg.single_string}" ) - res = expr.eval(self.cel.Activation({"var_msg": msg}, arena=arena)) + res = expr.eval(self.env.Activation({"var_msg": msg}, arena=arena)) # Clear out reference to `arena` to test garbage collection. arena = None # pylint: disable=unused-variable @@ -721,7 +721,7 @@ def testActivationWithArena(self): def testCompilationErrorHandling(self): # Check parser error. with self.assertRaises(Exception) as e: - self.cel.compile("'Hello,' # 'World!'", disable_check=True) + self.env.compile("'Hello,' # 'World!'", disable_check=True) self.assertIn( "1:10: Syntax error: token recognition error at: '#'\n " "| 'Hello,' # 'World!'\n " @@ -737,7 +737,7 @@ def testCompilationErrorHandling(self): # Check type-checker error. with self.assertRaises(Exception) as e: - self.cel.compile("'Hello,' - 'World!'") + self.env.compile("'Hello,' - 'World!'") self.assertIn( ":1:10: found no matching overload for '_-_' applied to" " '(string, string)'\n " @@ -747,9 +747,9 @@ def testCompilationErrorHandling(self): ) def testErrorHandling(self): - bad_cel = cel.Cel(_BadDescriptorPool(), variables={}) + bad_env = cel.NewEnv(_BadDescriptorPool(), variables={}) with self.assertRaises(Exception) as e: - bad_cel.compile("cel.expr.conformance.proto2.TestSomeTypes{}") + bad_env.compile("cel.expr.conformance.proto2.TestSomeTypes{}") self.assertRegex( str(e.exception), r"Could not find file containing symbol:.* \[NOT_FOUND\]", diff --git a/py_cel_value.cc b/py_cel_value.cc index 2861f6f..9d3f378 100644 --- a/py_cel_value.cc +++ b/py_cel_value.cc @@ -36,7 +36,7 @@ #include "common/type.h" #include "common/value.h" #include "common/value_kind.h" -#include "py_cel.h" +#include "py_cel_arena.h" #include "py_cel_env_internal.h" #include "py_cel_type.h" #include "py_cel_value_provider.h"