from __future__ import annotations
import functools
import awkward as ak
import pyarrow as pa
import pyarrow.compute as pc
from akimbo.apply_tree import dec
from akimbo.mixin import EagerAccessor, LazyAccessor
from akimbo.utils import match_string
def _encode(layout):
layout._parameters["__array__"] = "bytestring"
layout.content._parameters["__array__"] = "byte"
return layout
def match_bytestring(*layout):
return layout[0].is_list and layout[0].parameter("__array__") == "bytestring"
def _decode(layout):
layout._parameters["__array__"] = "string"
layout.content._parameters["__array__"] = "char"
return layout
_decode_f = dec(_decode, match=match_bytestring, inmode="ak")
_encode_f = dec(_encode, match=match_string, inmode="ak")
_SA_METHODMAPPING = {
"endswith": "ends_with",
"isalnum": "is_alnum",
"isalpha": "is_alpha",
"isascii": "is_ascii",
"isdecimal": "is_decimal",
"isdigit": "is_digit",
"islower": "is_lower",
"isnumeric": "is_numeric",
"isprintable": "is_printable",
"isspace": "is_space",
"istitle": "is_title",
"isupper": "is_upper",
"startswith": "starts_with",
}
methods = [
aname
for aname in (dir(ak.str))
if not aname.startswith(("_", "akstr_")) and not aname[0].isupper()
]
@functools.wraps(pc.strptime)
def strptime(*args, format="%FT%T", unit="us", error_is_null=True, **kw):
"""strptime with typical defaults set to reverse strftime"""
out = pc.strptime(
*args, format=format, unit=unit, error_is_null=error_is_null, **kw
)
return out
def repeat(arr, count):
return pc.binary_repeat(arr, count)
def concat(arr, arr2, sep=""):
return pc.binary_join_element_wise(
arr.cast(pa.string()), arr2.cast(pa.string()), sep
)
[docs]
class StringAccessor:
"""String operations on nested/var-length data"""
# TODO: implement dunder add (concat strings) and mul (repeat strings)
# - s.ak.str + "suffix" (and arguments swapped)
# - s.ak.str + s2.ak.str (with matching schemas)
# - s.ak.str * N (and arguments swapped)
# - s.ak.str * s (where each string maps to integers for variable repeats)
[docs]
def encode(self, arr, encoding: str = "utf-8"):
"""Encode Series of strings to Series of bytes. Leaves non-strings alone."""
if encoding.lower() not in ["utf-8", "utf8"]:
raise NotImplementedError
return _encode_f(arr)
[docs]
def decode(self, arr, encoding: str = "utf-8"):
"""Decode Series of bytes to Series of strings. Leaves non-bytestrings alone.
Validity of UTF8 is *not* checked.
"""
if encoding.lower() not in ["utf-8", "utf8"]:
raise NotImplementedError
return _decode_f(arr)
@staticmethod
def method_name(attr: str) -> str:
return _SA_METHODMAPPING.get(attr, attr)
def __getattr__(self, attr: str) -> callable:
attr = self.method_name(attr)
return getattr(ak.str, attr)
strptime = staticmethod(dec(strptime, match=match_string, inmode="arrow"))
repeat = staticmethod(dec(repeat, match=match_string, inmode="arrow"))
join_el = staticmethod(dec(concat, match=match_string, inmode="arrow"))
def __add__(self, *_):
return dec(concat, match=match_string, inmode="arrow")
def __mul__(self, *_):
return dec(repeat, match=match_string, inmode="arrow")
def __dir__(self) -> list[str]:
return sorted(methods + ["strptime", "encode", "decode"])
EagerAccessor.register_accessor("str", StringAccessor)
LazyAccessor.register_accessor("str", StringAccessor)