Skip to content

Commit c0ac93b

Browse files
timsaucerclaude
andauthored
feat: expose arrow_field, arrow_try_cast, cast_to_type, with_metadata (#1568)
* feat: expose arrow_field, arrow_try_cast, cast_to_type, with_metadata Adds Python bindings for five scalar functions from datafusion::functions::expr_fn that were not previously surfaced: - arrow_field: returns a struct describing an expression's Arrow field (name, data_type, nullable, metadata). - arrow_try_cast: like arrow_cast but yields NULL on cast failure. - cast_to_type / try_cast_to_type: casts a value to the type of a reference expression. These are exposed as a single Python entry point cast_to_type(value, type_ref, *, try_cast=False); the kwarg switches between the strict and try variants. - with_metadata: attach Arrow field metadata; the inverse of arrow_metadata. Accepts a dict[str, str] for ergonomics. Updates skills/datafusion_python/SKILL.md to list the new functions and documents the cast_to_type kwarg behavior. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * refactor: collapse try_cast_to_type into cast_to_type kwarg The previous commit exposed cast_to_type and try_cast_to_type as two separate pyo3 bindings and unified them in the Python wrapper via a try_cast kwarg. That left try_cast_to_type in datafusion._internal without a matching public Python name, breaking test_datafusion_missing_exports. Move the dispatch into the rust binding: cast_to_type now takes a try_cast kwarg and selects between functions::expr_fn::cast_to_type and try_cast_to_type internally. Only one pyo3 binding is registered, so the wrapper-coverage check passes and the Python entrypoint is unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat: accept pyarrow DataType in arrow_try_cast Mirrors arrow_cast: arrow_try_cast now accepts `pa.DataType` in addition to `str` and `Expr`. Adds `Expr.try_cast(pa.DataType)` PyO3 binding for the pyarrow-type routing path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix: guard with_metadata against empty dict and empty keys Empty `metadata` dict now returns the input expression unchanged (previously bubbled an opaque DataFusion error about minimum arg count). Empty keys raise `ValueError` to match the docstring contract. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * docs: assert full struct shape in arrow_field doctest Previous doctest set metadata on the input field but only checked the name — the metadata setup was dead. Now the example asserts the full returned struct (name, data_type, nullable, metadata) so the demo shows what the function actually produces. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * test: add unit tests for arrow_try_cast, arrow_field, cast_to_type, with_metadata Mirrors the existing test_arrow_cast pattern. Covers: - arrow_try_cast: string-syntax, pa.DataType, and null-on-failure paths - arrow_field: full returned struct shape (name, data_type, nullable, metadata) - cast_to_type: type-from-expr happy path and try_cast=True null behavior - with_metadata: round-trip through arrow_metadata, empty-dict no-op, and empty-key ValueError Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * test: parameterize arrow cast / try_cast tests Folds the previous four cast tests (arrow_cast + arrow_try_cast × str + pyarrow target type) into a single parameterized test that runs both functions across all five target-type variants. Collapses the two cast_to_type tests (happy path + try_cast=True) into one parameterized test, and parameterizes arrow_try_cast null-on-failure over both target-type syntaxes. 7 test functions, 19 cases — net less code, same coverage. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * docs: point cast_to_type at arrow_cast for static target types Adds a one-line cross-reference so users with a known target type reach for arrow_cast / arrow_try_cast instead of building a sentinel expression to feed cast_to_type. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * refactor: split cast_to_type into cast_to_type and try_cast_to_type Replace the try_cast bool flag with separate cast_to_type and try_cast_to_type functions, matching upstream DataFusion and the arrow_cast / arrow_try_cast pair. Also drop the redundant data_type parametrization on test_arrow_try_cast_null_on_failure, since the str-vs-pyarrow distinction is already covered by test_arrow_cast_variants. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent fd15b03 commit c0ac93b

6 files changed

Lines changed: 265 additions & 19 deletions

File tree

crates/core/src/expr.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,11 @@ impl PyExpr {
358358
expr.into()
359359
}
360360

361+
pub fn try_cast(&self, to: PyArrowType<DataType>) -> PyExpr {
362+
let expr = Expr::TryCast(TryCast::new(Box::new(self.expr.clone()), to.0));
363+
expr.into()
364+
}
365+
361366
#[pyo3(signature = (low, high, negated=false))]
362367
pub fn between(&self, low: PyExpr, high: PyExpr, negated: bool) -> PyExpr {
363368
let expr = Expr::Between(Between::new(

crates/core/src/functions.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,12 @@ expr_fn_vec!(named_struct);
607607
expr_fn!(from_unixtime, unixtime);
608608
expr_fn!(arrow_typeof, arg_1);
609609
expr_fn!(arrow_cast, arg_1 datatype);
610+
expr_fn!(arrow_try_cast, arg_1 datatype);
611+
expr_fn!(arrow_field, arg_1);
612+
expr_fn!(cast_to_type, arg_1 reference);
613+
expr_fn!(try_cast_to_type, arg_1 reference);
610614
expr_fn_vec!(arrow_metadata);
615+
expr_fn_vec!(with_metadata);
611616
expr_fn!(union_tag, arg1);
612617
expr_fn!(random);
613618

@@ -966,7 +971,12 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
966971
m.add_wrapped(wrap_pyfunction!(array_agg))?;
967972
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
968973
m.add_wrapped(wrap_pyfunction!(arrow_cast))?;
974+
m.add_wrapped(wrap_pyfunction!(arrow_try_cast))?;
975+
m.add_wrapped(wrap_pyfunction!(arrow_field))?;
976+
m.add_wrapped(wrap_pyfunction!(cast_to_type))?;
977+
m.add_wrapped(wrap_pyfunction!(try_cast_to_type))?;
969978
m.add_wrapped(wrap_pyfunction!(arrow_metadata))?;
979+
m.add_wrapped(wrap_pyfunction!(with_metadata))?;
970980
m.add_wrapped(wrap_pyfunction!(ascii))?;
971981
m.add_wrapped(wrap_pyfunction!(asin))?;
972982
m.add_wrapped(wrap_pyfunction!(asinh))?;

python/datafusion/expr.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,28 @@ def cast(self, to: pa.DataType[Any] | type) -> Expr:
894894

895895
return Expr(self.expr.cast(to))
896896

897+
def try_cast(self, to: pa.DataType[Any] | type) -> Expr:
898+
"""Cast to a new data type, returning NULL on failure.
899+
900+
Like :py:meth:`cast` but produces NULL instead of erroring when the
901+
cast cannot be performed for a given row.
902+
903+
Examples:
904+
>>> ctx = dfn.SessionContext()
905+
>>> df = ctx.from_pydict({"a": ["oops"]})
906+
>>> result = df.select(col("a").try_cast(pa.float64()).alias("c"))
907+
>>> result.collect_column("c")[0].as_py() is None
908+
True
909+
"""
910+
if not isinstance(to, pa.DataType):
911+
try:
912+
to = self._to_pyarrow_types[to]
913+
except KeyError as err:
914+
error_msg = "Expected instance of pyarrow.DataType or builtins.type"
915+
raise TypeError(error_msg) from err
916+
917+
return Expr(self.expr.try_cast(to))
918+
897919
def between(self, low: Any, high: Any, negated: bool = False) -> Expr:
898920
"""Returns ``True`` if this expression is between a given range.
899921

python/datafusion/functions.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ def _warn_expr_for_literal_arg(function_name: str, arg_name: str) -> None:
133133
"arrays_overlap",
134134
"arrays_zip",
135135
"arrow_cast",
136+
"arrow_field",
136137
"arrow_metadata",
138+
"arrow_try_cast",
137139
"arrow_typeof",
138140
"ascii",
139141
"asin",
@@ -151,6 +153,7 @@ def _warn_expr_for_literal_arg(function_name: str, arg_name: str) -> None:
151153
"btrim",
152154
"cardinality",
153155
"case",
156+
"cast_to_type",
154157
"cbrt",
155158
"ceil",
156159
"char_length",
@@ -375,6 +378,7 @@ def _warn_expr_for_literal_arg(function_name: str, arg_name: str) -> None:
375378
"translate",
376379
"trim",
377380
"trunc",
381+
"try_cast_to_type",
378382
"union_extract",
379383
"union_tag",
380384
"upper",
@@ -386,6 +390,7 @@ def _warn_expr_for_literal_arg(function_name: str, arg_name: str) -> None:
386390
"var_sample",
387391
"version",
388392
"when",
393+
"with_metadata",
389394
]
390395

391396

@@ -2960,6 +2965,110 @@ def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr:
29602965
return Expr(f.arrow_cast(expr.expr, data_type.expr))
29612966

29622967

2968+
def arrow_try_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr:
2969+
"""Casts an expression to a specified data type, returning NULL on failure.
2970+
2971+
Like :py:func:`arrow_cast` but produces NULL instead of erroring when the
2972+
cast cannot be performed. The ``data_type`` may be a string in DataFusion
2973+
type syntax (for example ``"Float64"``), a ``pyarrow.DataType``, or an
2974+
``Expr`` of string type.
2975+
2976+
Examples:
2977+
>>> ctx = dfn.SessionContext()
2978+
>>> df = ctx.from_pydict({"a": ["oops"]})
2979+
>>> result = df.select(
2980+
... dfn.functions.arrow_try_cast(dfn.col("a"), "Float64").alias("c")
2981+
... )
2982+
>>> result.collect_column("c")[0].as_py() is None
2983+
True
2984+
2985+
>>> result = df.select(
2986+
... dfn.functions.arrow_try_cast(
2987+
... dfn.col("a"), data_type=pa.float64()
2988+
... ).alias("c")
2989+
... )
2990+
>>> result.collect_column("c")[0].as_py() is None
2991+
True
2992+
"""
2993+
if isinstance(data_type, pa.DataType):
2994+
return expr.try_cast(data_type)
2995+
if isinstance(data_type, str):
2996+
data_type = Expr.string_literal(data_type)
2997+
return Expr(f.arrow_try_cast(expr.expr, data_type.expr))
2998+
2999+
3000+
def arrow_field(expr: Expr) -> Expr:
3001+
"""Returns the Arrow field information of an expression as a struct.
3002+
3003+
The returned struct contains the field's name, data type, nullability,
3004+
and metadata.
3005+
3006+
Examples:
3007+
>>> field = pa.field("val", pa.int64(), metadata={"k": "v"})
3008+
>>> schema = pa.schema([field])
3009+
>>> batch = pa.RecordBatch.from_arrays([pa.array([1])], schema=schema)
3010+
>>> ctx = dfn.SessionContext()
3011+
>>> df = ctx.create_dataframe([[batch]])
3012+
>>> result = df.select(
3013+
... dfn.functions.arrow_field(dfn.col("val")).alias("f")
3014+
... )
3015+
>>> out = result.collect_column("f")[0].as_py()
3016+
>>> out["name"], out["data_type"], out["nullable"], out["metadata"]
3017+
('val', 'Int64', True, [('k', 'v')])
3018+
"""
3019+
return Expr(f.arrow_field(expr.expr))
3020+
3021+
3022+
def cast_to_type(value: Expr, type_ref: Expr) -> Expr:
3023+
"""Casts ``value`` to the data type of ``type_ref``.
3024+
3025+
Only the *type* of ``type_ref`` is used; its value is ignored. This is
3026+
useful when the target type comes from another column or expression
3027+
rather than being known up-front. Casts that fail produce an error; use
3028+
:py:func:`try_cast_to_type` for the NULL-on-failure variant.
3029+
3030+
If the target type is known statically, prefer :py:func:`arrow_cast`
3031+
(or :py:func:`arrow_try_cast` for the NULL-on-failure variant) and
3032+
pass a type string or ``pyarrow.DataType`` directly.
3033+
3034+
Examples:
3035+
>>> ctx = dfn.SessionContext()
3036+
>>> df = ctx.from_pydict({"a": [1], "b": [1.0]})
3037+
>>> result = df.select(
3038+
... dfn.functions.cast_to_type(
3039+
... dfn.col("a"), dfn.col("b")
3040+
... ).alias("c")
3041+
... )
3042+
>>> result.collect_column("c")[0].as_py()
3043+
1.0
3044+
"""
3045+
return Expr(f.cast_to_type(value.expr, type_ref.expr))
3046+
3047+
3048+
def try_cast_to_type(value: Expr, type_ref: Expr) -> Expr:
3049+
"""Casts ``value`` to the data type of ``type_ref``, NULL on failure.
3050+
3051+
Like :py:func:`cast_to_type`, but casts that fail produce NULL instead
3052+
of erroring. Only the *type* of ``type_ref`` is used; its value is
3053+
ignored.
3054+
3055+
If the target type is known statically, prefer :py:func:`arrow_try_cast`
3056+
and pass a type string or ``pyarrow.DataType`` directly.
3057+
3058+
Examples:
3059+
>>> ctx = dfn.SessionContext()
3060+
>>> df = ctx.from_pydict({"a": ["oops"], "b": [1.0]})
3061+
>>> result = df.select(
3062+
... dfn.functions.try_cast_to_type(
3063+
... dfn.col("a"), dfn.col("b")
3064+
... ).alias("c")
3065+
... )
3066+
>>> result.collect_column("c")[0].as_py() is None
3067+
True
3068+
"""
3069+
return Expr(f.try_cast_to_type(value.expr, type_ref.expr))
3070+
3071+
29633072
def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:
29643073
"""Returns the metadata of the input expression.
29653074
@@ -2993,6 +3102,41 @@ def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:
29933102
return Expr(f.arrow_metadata(expr.expr, key.expr))
29943103

29953104

3105+
def with_metadata(expr: Expr, metadata: dict[str, str]) -> Expr:
3106+
"""Attaches Arrow field metadata (key/value pairs) to the input expression.
3107+
3108+
This is the inverse of :py:func:`arrow_metadata`. Existing metadata on the
3109+
input field is preserved; new keys overwrite on collision. Keys must be
3110+
non-empty strings; empty values are allowed.
3111+
3112+
An empty ``metadata`` dict is a no-op and returns the input expression
3113+
unchanged. Empty keys raise :py:class:`ValueError`.
3114+
3115+
Examples:
3116+
>>> ctx = dfn.SessionContext()
3117+
>>> df = ctx.from_pydict({"a": [1]})
3118+
>>> result = df.select(
3119+
... dfn.functions.with_metadata(
3120+
... dfn.col("a"), {"unit": "ms"}
3121+
... ).alias("a")
3122+
... )
3123+
>>> result.select(
3124+
... dfn.functions.arrow_metadata(dfn.col("a"), "unit").alias("u")
3125+
... ).collect_column("u")[0].as_py()
3126+
'ms'
3127+
"""
3128+
if not metadata:
3129+
return expr
3130+
args = [expr.expr]
3131+
for k, v in metadata.items():
3132+
if not k:
3133+
msg = "with_metadata keys must be non-empty strings"
3134+
raise ValueError(msg)
3135+
args.append(Expr.string_literal(k).expr)
3136+
args.append(Expr.string_literal(v).expr)
3137+
return Expr(f.with_metadata(*args))
3138+
3139+
29963140
def get_field(expr: Expr, *names: Expr | str) -> Expr:
29973141
"""Extracts a (possibly nested) field from a struct or map by name.
29983142

python/tests/test_functions.py

Lines changed: 78 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,30 +1333,90 @@ def test_make_time(df):
13331333
assert result.column(0)[0].as_py() == time(12, 30)
13341334

13351335

1336-
def test_arrow_cast(df):
1337-
df = df.select(
1338-
f.arrow_cast(column("b"), "Float64").alias("b_as_float"),
1339-
f.arrow_cast(column("b"), "Int32").alias("b_as_int"),
1336+
@pytest.mark.parametrize("cast_fn", [f.arrow_cast, f.arrow_try_cast])
1337+
@pytest.mark.parametrize(
1338+
("data_type", "expected"),
1339+
[
1340+
("Float64", pa.array([4.0, 5.0, 6.0], type=pa.float64())),
1341+
("Int32", pa.array([4, 5, 6], type=pa.int32())),
1342+
(pa.float64(), pa.array([4.0, 5.0, 6.0], type=pa.float64())),
1343+
(pa.int32(), pa.array([4, 5, 6], type=pa.int32())),
1344+
(pa.string(), pa.array(["4", "5", "6"], type=pa.string())),
1345+
],
1346+
)
1347+
def test_arrow_cast_variants(df, cast_fn, data_type, expected):
1348+
"""arrow_cast / arrow_try_cast accept str and pyarrow target types."""
1349+
result = df.select(cast_fn(column("b"), data_type).alias("c")).collect()[0]
1350+
assert result.column(0) == expected
1351+
1352+
1353+
def test_arrow_try_cast_null_on_failure():
1354+
ctx = SessionContext()
1355+
batch = pa.RecordBatch.from_arrays([pa.array(["1.5", "oops", "3"])], names=["s"])
1356+
df = ctx.create_dataframe([[batch]])
1357+
1358+
result = df.select(f.arrow_try_cast(column("s"), "Float64").alias("c")).collect()[0]
1359+
1360+
assert result.column(0).to_pylist() == [1.5, None, 3.0]
1361+
1362+
1363+
def test_arrow_field():
1364+
ctx = SessionContext()
1365+
field = pa.field("val", pa.int64(), metadata={"k": "v"})
1366+
schema = pa.schema([field])
1367+
batch = pa.RecordBatch.from_arrays([pa.array([1])], schema=schema)
1368+
df = ctx.create_dataframe([[batch]])
1369+
1370+
out = (
1371+
df.select(f.arrow_field(column("val")).alias("f"))
1372+
.collect_column("f")[0]
1373+
.as_py()
13401374
)
1341-
result = df.collect()
1342-
assert len(result) == 1
1343-
result = result[0]
1375+
assert out == {
1376+
"name": "val",
1377+
"data_type": "Int64",
1378+
"nullable": True,
1379+
"metadata": [("k", "v")],
1380+
}
1381+
1382+
1383+
@pytest.mark.parametrize(
1384+
("cast_fn", "values", "expected"),
1385+
[
1386+
(f.cast_to_type, pa.array([4, 5, 6]), [4.0, 5.0, 6.0]),
1387+
(f.try_cast_to_type, pa.array(["oops", "2", "3"]), [None, 2.0, 3.0]),
1388+
],
1389+
)
1390+
def test_cast_to_type(cast_fn, values, expected):
1391+
"""cast_to_type / try_cast_to_type take target type from ``type_ref``."""
1392+
ctx = SessionContext()
1393+
batch = pa.RecordBatch.from_arrays(
1394+
[values, pa.array([1.0, 2.0, 3.0])], names=["v", "fl"]
1395+
)
1396+
df = ctx.create_dataframe([[batch]])
13441397

1345-
assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
1346-
assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())
1398+
result = df.select(cast_fn(column("v"), column("fl")).alias("c")).collect()[0]
13471399

1400+
assert result.column(0).to_pylist() == expected
1401+
assert result.column(0).type == pa.float64()
13481402

1349-
def test_arrow_cast_with_pyarrow_type(df):
1350-
df = df.select(
1351-
f.arrow_cast(column("b"), pa.float64()).alias("b_as_float"),
1352-
f.arrow_cast(column("b"), pa.int32()).alias("b_as_int"),
1353-
f.arrow_cast(column("b"), pa.string()).alias("b_as_str"),
1403+
1404+
def test_with_metadata_round_trip(df):
1405+
df = df.select(f.with_metadata(column("b"), {"unit": "ms"}).alias("b"))
1406+
result = df.select(f.arrow_metadata(column("b"), "unit").alias("u")).collect_column(
1407+
"u"
13541408
)
1355-
result = df.collect()[0]
1409+
assert result[0].as_py() == "ms"
1410+
1411+
1412+
def test_with_metadata_empty_dict_noop(df):
1413+
out = df.select(f.with_metadata(column("b"), {}).alias("b")).collect()[0]
1414+
assert out.column(0) == pa.array([4, 5, 6])
1415+
13561416

1357-
assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
1358-
assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())
1359-
assert result.column(2) == pa.array(["4", "5", "6"], type=pa.string())
1417+
def test_with_metadata_empty_key_raises():
1418+
with pytest.raises(ValueError, match="non-empty"):
1419+
f.with_metadata(column("b"), {"": "v"})
13601420

13611421

13621422
def test_case(df):

skills/datafusion_python/SKILL.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,12 @@ F.left(col("c_phone"), lit(2)) # prefix shortcut
758758

759759
**Hash**: `md5`, `sha224`, `sha256`, `sha384`, `sha512`, `digest`
760760

761-
**Type**: `arrow_typeof`, `arrow_cast`, `arrow_metadata`
761+
**Type**: `arrow_typeof`, `arrow_cast`, `arrow_try_cast`, `arrow_field`,
762+
`arrow_metadata`, `cast_to_type`, `with_metadata`
763+
764+
Note: ``cast_to_type(value, type_ref, *, try_cast=False)`` is the single
765+
Python entry point for both upstream ``cast_to_type`` and ``try_cast_to_type``;
766+
pass ``try_cast=True`` for the variant that returns NULL on failure.
762767

763768
**Other**: `in_list`, `order_by`, `alias`, `col`, `encode`, `decode`,
764769
`to_hex`, `to_char`, `uuid`, `version`, `bit_length`, `octet_length`

0 commit comments

Comments
 (0)