Skip to content

Commit

Permalink
Merge branch 'apple:main' into add_istft
Browse files Browse the repository at this point in the history
  • Loading branch information
alealv authored Jan 4, 2024
2 parents 44a98ab + ba399c8 commit f08cf10
Show file tree
Hide file tree
Showing 45 changed files with 1,659 additions and 1,123 deletions.
11 changes: 9 additions & 2 deletions coremltools/converters/mil/frontend/torch/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .._utils import get_output_names
from .internal_graph import InternalTorchIRGraph, InternalTorchIRNode
from .ops import convert_nodes
from .ops import TorchFrontend, convert_nodes
from .quantization_ops import _dequantized_weight
from .torch_op_registry import _TORCH_OPS_REGISTRY
from .torchir_passes import (
Expand Down Expand Up @@ -194,8 +194,13 @@ class TranscriptionContext:
context when stepping out.
"""

def __init__(self, name: Optional[str] = None) -> None:
def __init__(
self,
name: Optional[str] = None,
frontend: TorchFrontend = TorchFrontend.TORCHSCRIPT,
) -> None:
self.name = name if name else ""
self.frontend = frontend
self._current_graph = [{}]
self._torch_graph = None
self._quant_context = QuantizationContext(self)
Expand Down Expand Up @@ -346,6 +351,7 @@ def __init__(
self._prog = Program()

if isinstance(loaded_model, torch.jit.ScriptModule):
self.context.frontend = TorchFrontend.TORCHSCRIPT
self.graph, self.params_dict, self.buffer_dict = InternalTorchIRGraph.from_torchscript(
torchscript=loaded_model, input_values=self.inputs, cut_at_symbols=cut_at_symbols
)
Expand All @@ -363,6 +369,7 @@ def __init__(
p(self.graph)

elif _HAS_TORCH_EXPORT_API and isinstance(loaded_model, ExportedProgram):
self.context.frontend = TorchFrontend.EDGEIR
self.graph = InternalTorchIRGraph.from_edgeir(edgeir=loaded_model)
self.params_dict, self.buffer_dict = None, None
else:
Expand Down
Loading

0 comments on commit f08cf10

Please sign in to comment.