Skip to content

Commit

Permalink
Fix: call apply/3 as intended (#598)
Browse files Browse the repository at this point in the history
* Fix: call apply/3 as intended

* Add tests for Axon.Quantizaiton.weight_only_quantized_dense
  • Loading branch information
preciz authored Oct 13, 2024
1 parent 4cc474b commit ce2e247
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
2 changes: 1 addition & 1 deletion lib/axon/quantization.ex
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ defmodule Axon.Quantization do
fun =
case opts[:kernel_initializer] do
init when is_atom(init) ->
apply(Axon.Initializers, [])
apply(Axon.Initializers, init, [])

fun when is_function(fun) ->
fun
Expand Down
14 changes: 14 additions & 0 deletions test/axon/quantization_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,18 @@ defmodule Axon.QuantizationTest do
assert_equal(predict_fn.(quantized_model_state, inp), real_fn.(quantized_model_state, inp))
end
end

describe "weight_only_quantized_dense" do
test "inits and executes properly" do
model =
Axon.input("input")
|> Axon.Quantization.weight_only_quantized_dense(10)

assert {init_fn, _} = Axon.build(model)
assert %ModelState{} = model_state = init_fn.(Nx.template({1, 1}, :f32), ModelState.empty())

assert {_, predict_fn} = Axon.build(model)
assert predict_fn.(model_state, Nx.broadcast(1.0, {1, 1}))
end
end
end

0 comments on commit ce2e247

Please sign in to comment.