diff --git a/ivy/functional/frontends/paddle/tensor/manipulation.py b/ivy/functional/frontends/paddle/tensor/manipulation.py index ae2259b53edac..f4f355275dce6 100644 --- a/ivy/functional/frontends/paddle/tensor/manipulation.py +++ b/ivy/functional/frontends/paddle/tensor/manipulation.py @@ -77,6 +77,15 @@ def gather(params, indices, axis=-1, batch_dims=0, name=None): return ivy.gather(params, indices, axis=axis, batch_dims=batch_dims) +@with_supported_dtypes( + {"2.5.1 and below": ("int32", "int64", "float32", "float64")}, + "paddle", +) +@to_ivy_arrays_and_back +def repeat_interleave(x, repeats, axis=None, name=None): + return ivy.repeat(x, repeats, axis=axis) + + @to_ivy_arrays_and_back def reshape(x, shape): return ivy.reshape(x, shape) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py index be36b2a116f1c..95faa889aef59 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py @@ -8,6 +8,9 @@ from ivy_tests.test_ivy.test_functional.test_experimental.test_core.test_manipulation import ( # noqa _get_dtype_values_k_axes_for_rot90, ) +from ivy_tests.test_ivy.test_frontends.test_torch.test_miscellaneous_ops import ( + _get_repeat_interleaves_args, +) # --- Helpers --- # @@ -445,6 +448,40 @@ def test_paddle_gather( ) +# repeat_interleave +@handle_frontend_test( + fn_tree="paddle.repeat_interleave", + dtype_values_repeats_axis_output_size=_get_repeat_interleaves_args( + available_dtypes=helpers.get_dtypes("numeric"), + valid_axis=True, + max_num_dims=4, + max_dim_size=4, + ), +) +def test_paddle_repeat_interleave( + *, + dtype_values_repeats_axis_output_size, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, values, repeats, axis, _ = dtype_values_repeats_axis_output_size + + helpers.test_frontend_function( + input_dtypes=[dtype[0][0], dtype[1][0]], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=values[0], + repeats=repeats[0], + axis=axis, + ) + + # Tests # # ----- #