diff --git a/ivy/functional/frontends/jax/lax/control_flow_operators.py b/ivy/functional/frontends/jax/lax/control_flow_operators.py index 3fb33fb590749..973f1b41d1518 100644 --- a/ivy/functional/frontends/jax/lax/control_flow_operators.py +++ b/ivy/functional/frontends/jax/lax/control_flow_operators.py @@ -58,3 +58,29 @@ def while_loop(cond_fun, body_fun, init_val): while cond_fun(val): val = body_fun(val) return val + + +@to_ivy_arrays_and_back +def scan(f, init, xs, length=None, reverse=False, unroll=1): + if not (callable(f)): + raise ivy.exceptions.IvyException( + "jax.lax.scan: Argument f should be callable." + ) + if xs is None and length is None: + raise ivy.exceptions.IvyException( + "jax.lax.scan: Either xs or length must be provided." + ) + + if length is not None and (not isinstance(length, int) or length < 0): + raise ivy.exceptions.IvyException( + "jax.lax.scan: length must be a non-negative integer." + ) + if xs is None: + xs = [None] * length + + carry = init + ys = [] + for x in xs: + carry, y = f(carry, x) + ys.append(y) + return carry, ivy.stack(ys) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_control_flow_operators.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_control_flow_operators.py index d9cf964fb8957..722c5a283446b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_control_flow_operators.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_control_flow_operators.py @@ -213,3 +213,48 @@ def _test_body_fn(x): body_fun=_test_body_fn, init_val=x[0], ) + + +@handle_frontend_test( + fn_tree="jax.lax.scan", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_value=-1000, + max_value=1000, + min_num_dims=1, + min_dim_size=1, + ), + length=st.integers(min_value=-10, max_value=10), + init=st.integers(min_value=-10, max_value=10), + test_with_out=st.just(False), +) +def test_jax_scan( + *, + dtype_and_x, + length, + init, + test_flags, + on_device, + fn_tree, + frontend, + backend_fw, +): + if length == 0 or length != len(dtype_and_x[1][0]): + return + + def _test_scan_fn(carry, x): + return carry + x, x * 2 + + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + f=_test_scan_fn, + init=init, + xs=x[0], + length=length, + )