diff --git a/py_cel_env.cc b/py_cel_env.cc index 871ac64..ebe0092 100644 --- a/py_cel_env.cc +++ b/py_cel_env.cc @@ -16,6 +16,7 @@ #include // IWYU pragma: keep - Needed for PyObject +#include #include #include #include @@ -45,12 +46,18 @@ void PyCelEnv::DefinePythonBindings(pybind11::module& m) { std::optional> variables, std::optional> extensions, const std::optional& container) { - PyObject* pool_ptr = nullptr; + PyObject* pool_ptr; if (descriptor_pool.is_none()) { // Replicates python's `descriptor_pool.Default()` - pool_ptr = py::module::import("google.protobuf.descriptor_pool") - .attr("Default")() - .ptr(); + try { + pool_ptr = py::module::import("google.protobuf.descriptor_pool") + .attr("Default")() + .ptr(); + } catch (const std::exception& e) { + // google.protobuf.descriptor_pool is not available. + pool_ptr = nullptr; + PyErr_Clear(); // Clear the Python error state. + } } else { pool_ptr = descriptor_pool.ptr(); } diff --git a/py_cel_test.py b/py_cel_test.py index 86a861a..36924a0 100644 --- a/py_cel_test.py +++ b/py_cel_test.py @@ -16,6 +16,9 @@ import datetime import gc +import importlib +import importlib.abc +import sys from google.protobuf import duration_pb2 as duration_pb from google.protobuf import timestamp_pb2 as timestamp_pb @@ -730,8 +733,10 @@ def testActivationAndOtherArgs(self): self.env.Activation(data={"var_str": "World!"}), data={"var_str": "World!"}, ) - self.assertIn("Cannot provide both activation and any other arguments", - str(e.exception)) + self.assertIn( + "Cannot provide both activation and any other arguments", + str(e.exception), + ) def testCompilationErrorHandling(self): # Check parser error. @@ -799,5 +804,103 @@ def FindFileContainingSymbol(self, symbol_name: str): # pylint: disable=invalid raise LookupError("Could not find file containing symbol: %s" % symbol_name) +class PyCelWithoutProtoSupportTest(absltest.TestCase): + """Test that the environment can be created without proto support.""" + + def setUp(self): + super().setUp() + self.msg = test_all_types_pb.TestAllTypes() + self.msg.single_string = "Hey" + + # "Unimport" descriptor_pool if it is already imported. + if "google.protobuf.descriptor_pool" in sys.modules: + del sys.modules["google.protobuf.descriptor_pool"] + + # Make it impossible to import descriptor_pool. + class UnluckyFinder(importlib.abc.MetaPathFinder): + + def find_spec(self, fullname, unused_path, unused_target=None): + if fullname == "google.protobuf.descriptor_pool": + raise ImportError("Not found") + return None + + sys.meta_path.insert(0, UnluckyFinder()) + + def tearDown(self): + # Remove the unlucky finder from the meta path. + sys.meta_path.pop(0) + super().tearDown() + + def testEvalWithNonProtoTypes(self): + cel_env = cel.NewEnv( + descriptor_pool=None, + variables={ + "var_str": cel.Type.STRING, + "var_map": cel.Type.Map(cel.Type.STRING, cel.Type.STRING), + "var_list": cel.Type.List(cel.Type.STRING), + }, + ) + data = { + "var_str": "foo", + "var_map": {"key": "bar"}, + "var_list": ["foo", "bar", "baz"], + } + res = cel_env.compile("var_str").eval(data=data) + self.assertEqual(res.value(), "foo") + + res = cel_env.compile("var_map['key']").eval(data=data) + self.assertEqual(res.value(), "bar") + + res = cel_env.compile("var_list[2]").eval(data=data) + self.assertEqual(res.value(), "baz") + + def testErrorOnProtoAccess(self): + cel_env = cel.NewEnv( + descriptor_pool=None, + variables={ + "var_proto": cel.Type.DYN, + }, + ) + res = cel_env.compile("var_proto.single_string").eval( + data={"var_proto": self.msg} + ) + self.assertEqual(res.type(), cel.Type.ERROR) + self.assertIn( + "Descriptor not found for message type" + " 'cel.expr.conformance.proto2.TestAllTypes'", + str(res.value()), + ) + + with self.assertRaises(Exception) as e: + cel_env.compile( + "cel.expr.conformance.proto2.TestAllTypes{single_string: 'hello'}" + ).eval() + self.assertIn( + "undeclared reference to 'cel.expr.conformance.proto2.TestAllTypes'", + str(e.exception), + ) + + def testErrorOnProtoCreation(self): + cel_env = cel.NewEnv( + descriptor_pool=None, + variables={ + "var_proto": cel.Type.DYN, + }, + ) + # Disable type checking to allow the compilation to succeed. + expr = cel_env.compile( + "cel.expr.conformance.proto2.TestAllTypes{single_string: 'hello'}", + disable_check=True, + ) + + with self.assertRaises(Exception) as e: + expr.eval() + self.assertIn( + "Invalid struct creation: missing type info for" + " 'cel.expr.conformance.proto2.TestAllTypes'", + str(e.exception), + ) + + if __name__ == "__main__": absltest.main() diff --git a/py_descriptor_database.cc b/py_descriptor_database.cc index df33ab8..8f7df2b 100644 --- a/py_descriptor_database.cc +++ b/py_descriptor_database.cc @@ -32,12 +32,12 @@ PyDescriptorDatabase::PyDescriptorDatabase(PyObject* py_descriptor_pool) : py_descriptor_pool_(py_descriptor_pool), standard_pool_(cel::GetMinimalDescriptorPool()) { ABSL_CHECK(PyGILState_Check()); - Py_INCREF(py_descriptor_pool_); + Py_XINCREF(py_descriptor_pool_); } PyDescriptorDatabase::~PyDescriptorDatabase() { auto gil_state = PyGILState_Ensure(); - Py_DECREF(py_descriptor_pool_); + Py_XDECREF(py_descriptor_pool_); PyGILState_Release(gil_state); } @@ -52,6 +52,10 @@ bool PyDescriptorDatabase::FindFileByName(StringViewArg filename, return true; } + if (py_descriptor_pool_ == nullptr) { + return false; + } + PyObject* pyfile = PyObject_CallMethod( py_descriptor_pool_, "FindFileByName", "s#", filename.data(), static_cast(filename.size())); @@ -98,6 +102,10 @@ bool PyDescriptorDatabase::FindFileContainingSymbol( return true; } + if (py_descriptor_pool_ == nullptr) { + return false; + } + PyObject* pyfile = PyObject_CallMethod( py_descriptor_pool_, "FindFileContainingSymbol", "s#", symbol_name.data(), static_cast(symbol_name.size())); @@ -137,6 +145,10 @@ bool PyDescriptorDatabase::FindFileContainingSymbol( bool PyDescriptorDatabase::FindFileContainingExtension( StringViewArg containing_type, int field_number, google::protobuf::FileDescriptorProto* output) { + if (py_descriptor_pool_ == nullptr) { + return false; + } + ABSL_CHECK(PyGILState_Check()); PyObject* py_containing_type = PyObject_CallMethod( py_descriptor_pool_, "FindMessageTypeByName", "s#", diff --git a/py_message_factory.cc b/py_message_factory.cc index c0ee5cf..b2ff6de 100644 --- a/py_message_factory.cc +++ b/py_message_factory.cc @@ -26,7 +26,7 @@ namespace cel_python { PyMessageFactory::PyMessageFactory(PyObject* descriptor_pool) { py_descriptor_pool_ = descriptor_pool; - Py_INCREF(py_descriptor_pool_); + Py_XINCREF(py_descriptor_pool_); PyObject* pName = PyUnicode_DecodeFSDefault("google.protobuf.message_factory"); PyObject* pModule = PyImport_Import(pName); @@ -45,7 +45,7 @@ PyMessageFactory::PyMessageFactory(PyObject* descriptor_pool) { PyMessageFactory::~PyMessageFactory() { auto gil_state = PyGILState_Ensure(); - Py_DECREF(py_descriptor_pool_); + Py_XDECREF(py_descriptor_pool_); Py_XDECREF(py_func_GetMessageClass_); Py_XDECREF(py_func_MergeFromString_); for (auto const& [key, py_obj] : message_classes_) { @@ -55,6 +55,13 @@ PyMessageFactory::~PyMessageFactory() { } PyObject* PyMessageFactory::GetMessageClass(const std::string& message_type) { + if (py_descriptor_pool_ == nullptr) { + PyErr_Format(PyExc_TypeError, + "Message type not found: %s, descriptor pool is unavailable.", + message_type.c_str()); + return nullptr; + } + auto it = message_classes_.find(message_type); if (it != message_classes_.end()) { return it->second;