Skip to content

Spark expr replace strict #2254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

lucas-nelson-uiuc
Copy link
Contributor

What type of PR is this? (check all applicable)

  • πŸ’Ύ Refactor
  • ✨ Feature
  • πŸ› Bug Fix
  • πŸ”§ Optimization
  • πŸ“ Documentation
  • βœ… Test
  • 🐳 Other

Related issues

Checklist

  • Code follows style guide (ruff)
  • Tests added
  • Documented the changes

If you have comments or can explain your changes, please do so below

Currently supports replacing one value with another value and specifying return_dtype. See comments for issue I'm running into.

mapping_expr = self._F.create_map(old, new)
return mapping_expr[_input]

result = self._from_call(_replace_strict, "replace_strict", old=old, new=new)
Copy link
Contributor Author

@lucas-nelson-uiuc lucas-nelson-uiuc Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most likely confusing myself with this self._from_call...

The approach I'm taking is to use pyspark.sql.functions.create_map to create a mapping expression for us. Currently, it only accepts one constant for old and another constant for new (and a return_dtype).

Running into two issues when trying to expand the functionality:

  • if the user passes a sequence of values to replace, self._from_call will convert old and new into array columns; instead, I want old and new to be a sequence of columns by the time we access it in _replace_strict so that we can use it with create_map
  • this approach doesn't allow us to pass expressions to old or new

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for your pr!

i don't think we currently accept expressions here for other backends anyway, so I think we don't need to support that just yet

@dangotbanned
Copy link
Member

dangotbanned commented Mar 20, 2025

@lucas-nelson-uiuc I was a bit confused when I saw (#2253)

So right now replace_strict is declared as not_implemented for all lazy backends here:

class LazyExpr(
CompliantExpr[CompliantLazyFrameT, NativeExprT_co],
Protocol38[CompliantLazyFrameT, NativeExprT_co],
):
arg_min: not_implemented = not_implemented()
arg_max: not_implemented = not_implemented()
arg_true: not_implemented = not_implemented()
head: not_implemented = not_implemented()
tail: not_implemented = not_implemented()
mode: not_implemented = not_implemented()
sort: not_implemented = not_implemented()
rank: not_implemented = not_implemented()
sample: not_implemented = not_implemented()
map_batches: not_implemented = not_implemented()
ewm_mean: not_implemented = not_implemented()
rolling_mean: not_implemented = not_implemented()
rolling_var: not_implemented = not_implemented()
rolling_std: not_implemented = not_implemented()
gather_every: not_implemented = not_implemented()
replace_strict: not_implemented = not_implemented()
cat: not_implemented = not_implemented() # pyright: ignore[reportAssignmentType]

That is why you're seeing mypy yelling in (https://github.com/narwhals-dev/narwhals/actions/runs/13960816458/job/39081638109?pr=2254)

However, if replace_strict is possible for SparkLikeExpr - then you'll need to move this:

replace_strict: not_implemented = not_implemented()

Into each of these sections:

list = not_implemented() # pyright: ignore[reportAssignmentType]
struct = not_implemented() # pyright: ignore[reportAssignmentType]

drop_nulls = not_implemented()
diff = not_implemented()
unique = not_implemented()
shift = not_implemented()
is_unique = not_implemented()
is_first_distinct = not_implemented()
is_last_distinct = not_implemented()
cum_sum = not_implemented()
cum_count = not_implemented()
cum_min = not_implemented()
cum_max = not_implemented()
cum_prod = not_implemented()
over = not_implemented()
rolling_sum = not_implemented()

I hope that all makes sense, but let me know if you have any issues πŸ™‚

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants