Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix graph building to exclude input, output and initializer from value_info #1320

Merged
merged 2 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,9 +822,15 @@
# nn.Modules exported by dynamo exporter have unique call sites, their function
# op_type name can serve to form the unique identifier for value info.
# Store inside top level GraphProto.
existing_value_info.update(self.generate_subgraphs_value_info_proto())
# Insert value info for nodes in top level graph.
existing_value_info.update(self.generate_maingraph_value_info_proto())
new_value_info = self.generate_maingraph_value_info_proto()
# Do not store input, output or initializer into value_info
for name in onnx_model.graph.input:
new_value_info.pop(name.name, None)
Fixed Show fixed Hide fixed
for name in onnx_model.graph.output:
new_value_info.pop(name.name, None)
Fixed Show fixed Hide fixed
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
for name in self.initializers:
new_value_info.pop(name, None)
Fixed Show fixed Hide fixed
existing_value_info.update(new_value_info)
onnx_model.graph.value_info.extend(existing_value_info.values())

return onnx_model
Expand Down
30 changes: 28 additions & 2 deletions onnxscript/function_libs/torch_lib/graph_building_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# mypy: disable-error-code="arg-type,type-arg,valid-type"
from __future__ import annotations

import os
import unittest

import torch
Expand Down Expand Up @@ -140,7 +139,6 @@


class TestModelSaving(unittest.TestCase):
@unittest.skipIf(os.getenv("CI") == "true", "CI is not ready to run dyanmo_export.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the context here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially we disabled it to avoid flakiness due to the changing pytorch

def test_save_initializer_to_files_for_large_model(self):
class MLP(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
Expand All @@ -167,6 +165,34 @@
# Assert model is larger than 2GB (~=3GB)
self.assertGreater(model_proto.ByteSize(), 2**31)

def test_input_output_and_initializer_are_not_stored_in_value_info(self):
class MLP(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.fc2 = torch.nn.Linear(hidden_size, output_size)
self.relu = torch.nn.ReLU()

def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out

batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10
model = MLP(input_size, hidden_size, output_size)
x = torch.randn(batch_size, input_size)

model_proto = torch.onnx.dynamo_export(model, x).model_proto
v_names = set(v.name for v in model_proto.graph.value_info)
Fixed Show fixed Hide fixed
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
print(v_names)
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
for i in model_proto.graph.input:
self.assertNotIn(i.name, v_names)
for o in model_proto.graph.output:
self.assertNotIn(o.name, v_names)
for i in model_proto.graph.initializer:
self.assertNotIn(i.name, v_names)


if __name__ == "__main__":
unittest.main()
Loading