Source code for akimbo.polars

from typing import Dict

import awkward as ak
import polars as pl
import pyarrow as pa

from akimbo.mixin import EagerAccessor, LazyAccessor


[docs] @pl.api.register_series_namespace("ak") @pl.api.register_dataframe_namespace("ak") class PolarsAwkwardAccessor(EagerAccessor): """Perform awkward operations on a polars series or dataframe This is for *eager* operations. A Lazy version may eventually be made. """ series_type = pl.Series dataframe_type = pl.DataFrame @classmethod def to_arrow(cls, data): return data.to_arrow() def pack(self): # polars already implements this directly return self._obj.to_struct() def to_output(self, data=None): arr = data if data is not None else self._obj pa_arr = ak.to_arrow(arr, extensionarray=False) return pl.from_arrow(pa_arr)
@pl.api.register_lazyframe_namespace("ak") class LazyPolarsAwkwardAccessor(LazyAccessor): dataframe_type = pl.LazyFrame series_type = None # lazy is never series def to_output(self, data=None): out = self._obj.collect() if out.columns == ["_ak_series_"]: out = out["_ak_series_"] return out def __getattr__(self, item: str, **flags) -> callable: if isinstance(item, str) and item in self.subaccessors: return LazyPolarsAwkwardAccessor( self._obj, subaccessor=item, behavior=self._behavior ) def select(*inargs, subaccessor=self.subaccessor, where=None, **kwargs): if subaccessor and isinstance(item, str): func0 = getattr(self.subaccessors[subaccessor](), item) elif callable(item): func0 = item else: func0 = None def f(batch): arr = ak.from_arrow(batch.to_arrow()) if any(isinstance(_, str) and _ == "_ak_other_" for _ in inargs): # binary input other = arr[[_ for _ in arr.fields if _.startswith("_df2_")]] # 5 == len("_df2_"); rename to original fields other.layout._fields[:] = [k[5:] for k in other.fields] arr = arr[[_ for _ in arr.fields if not _.startswith("_df2_")]] if other.fields == ["_ak_series_"]: other = other["_ak_series_"] if where is not None: other = other[where] inargs0 = [other if str(_) == "_ak_other_" else _ for _ in inargs] else: inargs0 = inargs if where: arr0 = arr arr = arr[where] if arr.fields == ["_ak_series_"]: arr = arr["_ak_series_"] out = func0(arr, *inargs0, **kwargs) if where: out = ak.with_field(arr0, out, where) if not out.layout.fields: out = ak.Array({"_ak_series_": out}) arr = ak.to_arrow_table(out, extensionarray=False) return pl.DataFrame(arr, **flags) inargs = [_._obj if isinstance(_, type(self)) else _ for _ in inargs] n_others = sum(isinstance(_, self.dataframe_type) for _ in inargs) if n_others == 1: other = next(_ for _ in inargs if isinstance(_, self.dataframe_type)) inargs = [ "_ak_other_" if isinstance(_, self.dataframe_type) else _ for _ in inargs ] obj = concat(self._obj, other) elif n_others > 1: raise NotImplementedError else: obj = self._obj arrow_type = polars_to_arrow_schema(obj.collect_schema()) arr = pa.table([[]] * len(arrow_type), schema=arrow_type) out1 = f(pl.from_arrow(arr)) return obj.map_batches(f, schema=out1.schema) return select def pack(self): return self._obj.select( pl.struct(*self._obj.collect_schema().names()).alias("_ak_series_") ) def unpack(self): cols = self._obj.collect_schema().names() assert len(cols) == 1 return self._obj.select(pl.col(cols[0]).struct.unnest()) def concat(*series: pl.LazyFrame) -> pl.LazyFrame: this, *others = series # don't actually expect more than one "others" return this.with_columns( [ o.rename({c: f"_df{i + 2}_{c}" for c in o.collect_schema().names()}) for i, o in enumerate(others) ] ) def polars_to_arrow_type(polars_type: pl.DataType) -> pa.DataType: type_mapping = { pl.Int8: pa.int8(), pl.Int16: pa.int16(), pl.Int32: pa.int32(), pl.Int64: pa.int64(), pl.UInt8: pa.uint8(), pl.UInt16: pa.uint16(), pl.UInt32: pa.uint32(), pl.UInt64: pa.uint64(), pl.Float32: pa.float32(), pl.Float64: pa.float64(), pl.String: pa.string(), pl.Boolean: pa.bool_(), pl.Date: pa.date32(), } if polars_type in type_mapping: return type_mapping[polars_type] # parametrised types if isinstance(polars_type, pl.Datetime): return pa.timestamp(polars_type.unit, polars_type.time_zone) if isinstance(polars_type, pl.Decimal): return pa.decimal128(polars_type.precision, polars_type.scale) # Handle list type if isinstance(polars_type, pl.List): value_type = polars_to_arrow_type(polars_type.inner) return pa.list_(value_type) # Handle struct type if isinstance(polars_type, pl.Struct): fields = [] for name, dtype in dict(polars_type).items(): arrow_type = polars_to_arrow_type(dtype) fields.append(pa.field(name, arrow_type)) return pa.struct(fields) raise ValueError(f"Unsupported Polars type: {polars_type}") def polars_to_arrow_schema(polars_schema: Dict[str, pl.DataType]) -> pa.Schema: fields = [] for name, dtype in polars_schema.items(): arrow_type = polars_to_arrow_type(dtype) fields.append(pa.field(name, arrow_type)) return pa.schema(fields)