Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions py_cel_env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <Python.h> // IWYU pragma: keep - Needed for PyObject

#include <exception>
#include <memory>
#include <optional>
#include <string>
Expand Down Expand Up @@ -45,12 +46,18 @@ void PyCelEnv::DefinePythonBindings(pybind11::module& m) {
std::optional<std::unordered_map<std::string, PyCelType>> variables,
std::optional<std::vector<py::object>> extensions,
const std::optional<std::string>& container) {
PyObject* pool_ptr = nullptr;
PyObject* pool_ptr;
if (descriptor_pool.is_none()) {
// Replicates python's `descriptor_pool.Default()`
pool_ptr = py::module::import("google.protobuf.descriptor_pool")
.attr("Default")()
.ptr();
try {
pool_ptr = py::module::import("google.protobuf.descriptor_pool")
.attr("Default")()
.ptr();
} catch (const std::exception& e) {
// google.protobuf.descriptor_pool is not available.
pool_ptr = nullptr;
PyErr_Clear(); // Clear the Python error state.
}
} else {
pool_ptr = descriptor_pool.ptr();
}
Expand Down
107 changes: 105 additions & 2 deletions py_cel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

import datetime
import gc
import importlib
import importlib.abc
import sys

from google.protobuf import duration_pb2 as duration_pb
from google.protobuf import timestamp_pb2 as timestamp_pb
Expand Down Expand Up @@ -730,8 +733,10 @@ def testActivationAndOtherArgs(self):
self.env.Activation(data={"var_str": "World!"}),
data={"var_str": "World!"},
)
self.assertIn("Cannot provide both activation and any other arguments",
str(e.exception))
self.assertIn(
"Cannot provide both activation and any other arguments",
str(e.exception),
)

def testCompilationErrorHandling(self):
# Check parser error.
Expand Down Expand Up @@ -799,5 +804,103 @@ def FindFileContainingSymbol(self, symbol_name: str): # pylint: disable=invalid
raise LookupError("Could not find file containing symbol: %s" % symbol_name)


class PyCelWithoutProtoSupportTest(absltest.TestCase):
"""Test that the environment can be created without proto support."""

def setUp(self):
super().setUp()
self.msg = test_all_types_pb.TestAllTypes()
self.msg.single_string = "Hey"

# "Unimport" descriptor_pool if it is already imported.
if "google.protobuf.descriptor_pool" in sys.modules:
del sys.modules["google.protobuf.descriptor_pool"]

# Make it impossible to import descriptor_pool.
class UnluckyFinder(importlib.abc.MetaPathFinder):

def find_spec(self, fullname, unused_path, unused_target=None):
if fullname == "google.protobuf.descriptor_pool":
raise ImportError("Not found")
return None

sys.meta_path.insert(0, UnluckyFinder())

def tearDown(self):
# Remove the unlucky finder from the meta path.
sys.meta_path.pop(0)
super().tearDown()

def testEvalWithNonProtoTypes(self):
cel_env = cel.NewEnv(
descriptor_pool=None,
variables={
"var_str": cel.Type.STRING,
"var_map": cel.Type.Map(cel.Type.STRING, cel.Type.STRING),
"var_list": cel.Type.List(cel.Type.STRING),
},
)
data = {
"var_str": "foo",
"var_map": {"key": "bar"},
"var_list": ["foo", "bar", "baz"],
}
res = cel_env.compile("var_str").eval(data=data)
self.assertEqual(res.value(), "foo")

res = cel_env.compile("var_map['key']").eval(data=data)
self.assertEqual(res.value(), "bar")

res = cel_env.compile("var_list[2]").eval(data=data)
self.assertEqual(res.value(), "baz")

def testErrorOnProtoAccess(self):
cel_env = cel.NewEnv(
descriptor_pool=None,
variables={
"var_proto": cel.Type.DYN,
},
)
res = cel_env.compile("var_proto.single_string").eval(
data={"var_proto": self.msg}
)
self.assertEqual(res.type(), cel.Type.ERROR)
self.assertIn(
"Descriptor not found for message type"
" 'cel.expr.conformance.proto2.TestAllTypes'",
str(res.value()),
)

with self.assertRaises(Exception) as e:
cel_env.compile(
"cel.expr.conformance.proto2.TestAllTypes{single_string: 'hello'}"
).eval()
self.assertIn(
"undeclared reference to 'cel.expr.conformance.proto2.TestAllTypes'",
str(e.exception),
)

def testErrorOnProtoCreation(self):
cel_env = cel.NewEnv(
descriptor_pool=None,
variables={
"var_proto": cel.Type.DYN,
},
)
# Disable type checking to allow the compilation to succeed.
expr = cel_env.compile(
"cel.expr.conformance.proto2.TestAllTypes{single_string: 'hello'}",
disable_check=True,
)

with self.assertRaises(Exception) as e:
expr.eval()
self.assertIn(
"Invalid struct creation: missing type info for"
" 'cel.expr.conformance.proto2.TestAllTypes'",
str(e.exception),
)


if __name__ == "__main__":
absltest.main()
16 changes: 14 additions & 2 deletions py_descriptor_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ PyDescriptorDatabase::PyDescriptorDatabase(PyObject* py_descriptor_pool)
: py_descriptor_pool_(py_descriptor_pool),
standard_pool_(cel::GetMinimalDescriptorPool()) {
ABSL_CHECK(PyGILState_Check());
Py_INCREF(py_descriptor_pool_);
Py_XINCREF(py_descriptor_pool_);
}

PyDescriptorDatabase::~PyDescriptorDatabase() {
auto gil_state = PyGILState_Ensure();
Py_DECREF(py_descriptor_pool_);
Py_XDECREF(py_descriptor_pool_);
PyGILState_Release(gil_state);
}

Expand All @@ -52,6 +52,10 @@ bool PyDescriptorDatabase::FindFileByName(StringViewArg filename,
return true;
}

if (py_descriptor_pool_ == nullptr) {
return false;
}

PyObject* pyfile = PyObject_CallMethod(
py_descriptor_pool_, "FindFileByName", "s#", filename.data(),
static_cast<Py_ssize_t>(filename.size()));
Expand Down Expand Up @@ -98,6 +102,10 @@ bool PyDescriptorDatabase::FindFileContainingSymbol(
return true;
}

if (py_descriptor_pool_ == nullptr) {
return false;
}

PyObject* pyfile = PyObject_CallMethod(
py_descriptor_pool_, "FindFileContainingSymbol", "s#", symbol_name.data(),
static_cast<Py_ssize_t>(symbol_name.size()));
Expand Down Expand Up @@ -137,6 +145,10 @@ bool PyDescriptorDatabase::FindFileContainingSymbol(
bool PyDescriptorDatabase::FindFileContainingExtension(
StringViewArg containing_type, int field_number,
google::protobuf::FileDescriptorProto* output) {
if (py_descriptor_pool_ == nullptr) {
return false;
}

ABSL_CHECK(PyGILState_Check());
PyObject* py_containing_type = PyObject_CallMethod(
py_descriptor_pool_, "FindMessageTypeByName", "s#",
Expand Down
11 changes: 9 additions & 2 deletions py_message_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace cel_python {

PyMessageFactory::PyMessageFactory(PyObject* descriptor_pool) {
py_descriptor_pool_ = descriptor_pool;
Py_INCREF(py_descriptor_pool_);
Py_XINCREF(py_descriptor_pool_);
PyObject* pName =
PyUnicode_DecodeFSDefault("google.protobuf.message_factory");
PyObject* pModule = PyImport_Import(pName);
Expand All @@ -45,7 +45,7 @@ PyMessageFactory::PyMessageFactory(PyObject* descriptor_pool) {

PyMessageFactory::~PyMessageFactory() {
auto gil_state = PyGILState_Ensure();
Py_DECREF(py_descriptor_pool_);
Py_XDECREF(py_descriptor_pool_);
Py_XDECREF(py_func_GetMessageClass_);
Py_XDECREF(py_func_MergeFromString_);
for (auto const& [key, py_obj] : message_classes_) {
Expand All @@ -55,6 +55,13 @@ PyMessageFactory::~PyMessageFactory() {
}

PyObject* PyMessageFactory::GetMessageClass(const std::string& message_type) {
if (py_descriptor_pool_ == nullptr) {
PyErr_Format(PyExc_TypeError,
"Message type not found: %s, descriptor pool is unavailable.",
message_type.c_str());
return nullptr;
}

auto it = message_classes_.find(message_type);
if (it != message_classes_.end()) {
return it->second;
Expand Down