Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise on ambiguous inputs #599

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -486,15 +486,16 @@ defmodule Axon.Compiler do
name: name_fn,
opts: [shape: _input_shape, optional: optional?]
},
_nodes,
nodes,
{cache, op_counts, block_cache, model_state_meta},
%{mode: mode, print_values: print_values}
) do
name = name_fn.(:input, op_counts)
op_counts = Map.update(op_counts, :input, 1, fn x -> x + 1 end)
all_inputs = get_all_inputs(nodes)

predict_fun = fn _params, inputs, state, _cache, result_cache, _fn_stacktrace ->
value = get_input(inputs, name, optional?)
value = get_input(all_inputs, inputs, name, optional?)

# TODO: Add this back in
# validate_input_shape!(value, shape)
Expand All @@ -509,7 +510,7 @@ defmodule Axon.Compiler do
end

init_fun = fn template, _cache, result_cache, _fn_stacktrace, _keys ->
input = get_input(template, name, optional?)
input = get_input(all_inputs, template, name, optional?)
{Nx.to_template(input), {%{}, result_cache}}
end

Expand Down Expand Up @@ -889,16 +890,32 @@ defmodule Axon.Compiler do
{id, model_funs, cache, op_counts, block_cache, model_state_meta}
end

defp get_input(inputs, name, optional?) do
defp get_all_inputs(nodes) do
nodes
|> Enum.filter(fn {_, %{op: op}} -> op == :input end)
|> Enum.map(fn {_, %{name: name_fn}} ->
# inputs require a name, so we can just ignore op counts
name_fn.(:input, %{})
end)
|> Enum.uniq()
end

defp get_input(all_input_names, inputs, name, optional?) do
res =
case inputs do
%Nx.Tensor{} = inputs ->
case {all_input_names, inputs} do
{[^name], %Nx.Tensor{} = inputs} ->
inputs

%{} = inputs ->
{_, %Nx.Tensor{}} ->
raise ArgumentError,
"ambiguous input given to the model," <>
" expected inputs with names #{inspect(all_input_names)}" <>
" but received a single tensor as input"

{_, %{} = inputs} ->
inputs[name]

inputs when is_tuple(inputs) ->
{[^name], inputs} when is_tuple(inputs) ->
inputs

_ ->
Expand Down
14 changes: 14 additions & 0 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,20 @@
assert message =~ "exception found when compiling layer Axon.Layers.add/2 named add_0"
assert message =~ "cannot broadcast tensor of dimensions {1, 32} to {1, 64}"
end

test "raises if inputs are ambiguous" do
x = Axon.input("x")
y = Axon.input("y")
model = Axon.add(x, y)

{_, predict_fn} = Axon.build(model)

exception = assert_raise ArgumentError, fn ->
predict_fn.(ModelState.empty(), Nx.tensor([1]))
end

assert Exception.message(exception) =~ "ambiguous"
end
end

describe "optional" do
Expand Down Expand Up @@ -246,7 +260,7 @@
assert_equal(predict_fn.(ModelState.empty(), {}), Nx.tensor(3.0))
end

test "computes forward pass with output policy" do

Check failure on line 263 in test/axon/compiler_test.exs

View workflow job for this annotation

GitHub Actions / main (25.3.2.6, 1.14.5, USE_EXLA=true)

test constant computes forward pass with output policy (CompilerTest)
model = Axon.constant(Nx.tensor(1.0))
policy = AMP.create_policy(output: {:bf, 16})
mp_model = AMP.apply_policy(model, policy)
Expand Down
6 changes: 5 additions & 1 deletion test/axon/loop_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ defmodule Axon.LoopTest do
Loop.trainer(model, [mean_squared_error: 0.5, mean_absolute_error: 0.5], :adam)

assert %{model_state: %{}} =
pstate = init_fn.({Nx.tensor([[2]]), Nx.tensor([[2]])}, Axon.ModelState.empty())
pstate =
init_fn.(
{%{"input_0" => Nx.tensor([[2]]), "input_1" => Nx.tensor([[2]])}, Nx.tensor(0)},
Axon.ModelState.empty()
)

state = %State{step_state: pstate}

Expand Down
Loading