Skip to content

Commit

Permalink
- added name appending function to the Model exporters
Browse files Browse the repository at this point in the history
  • Loading branch information
amkrajewski committed Jan 30, 2024
1 parent e33b716 commit f503dc0
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions pysipfenn/core/modelExporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,13 @@ def toFP16All(self):
self.toFP16(model)
print('***** Done converting all models to FP16! *****')

def export(self, model: str):
def export(self, model: str, append: str = '') -> None:
"""Export a loaded model to ONNX format.
Args:
model: The name of the model to export (must be loaded in the Calculator).
append: A string to append to the exported model name after the model name, simplification marker, and
FP16 marker. Useful for adding a version number or other information to the exported model name.
Returns:
None
Expand All @@ -141,14 +143,18 @@ def export(self, model: str):
name += '_simplified'
if self.fp16Dict[model]:
name += '_fp16'
if append:
name += f'_{append}'
name += '.onnx'
onnx.save(loadedModel, name)
print(f'--> Exported as {name}', flush=True)

def exportAll(self):
"""Export all loaded models to ONNX format with the export function."""
def exportAll(self, append: str = '') -> None:
"""Export all loaded models to ONNX format with the export function. `append` can be passed to the export
function.
"""
for model in tqdm(self.calculator.loadedModels):
self.export(model)
self.export(model, append=append)
print('***** Done exporting all models! *****')


Expand All @@ -167,14 +173,16 @@ def __init__(self, calculator: Calculator):
assert len(self.calculator.loadedModels) > 0, 'No models loaded in calculator. Nothing to export.'
print(f'Initialized TorchExporter with models: {list(self.calculator.loadedModels.keys())}')

def export(self, model: str):
def export(self, model: str, append: str = '') -> None:
"""Export a loaded model to PyTorch PT format. Models are exported in eval mode (no dropout) and saved in the
current working directory.
Args:
model: The name of the model to export (must be loaded in the Calculator) and it must have a descriptor
(Ward2017 or KS2022) defined in the calculator.models dictionary created when the Calculator was
initialized.
append: A string to append to the exported model name after the model name. Useful for adding a version
number or other information to the exported model name.
Returns:
None
Expand All @@ -200,14 +208,16 @@ def export(self, model: str):

tracedModel = torch.jit.trace(loadedModel, inputs_tracer)

name = f"{model}.pt"
name = f"{model}{f'_{append}' if append else ''}.pt"
tracedModel.save(name)
print(f'--> Exported as {name}', flush=True)

def exportAll(self):
"""Export all loaded models to PyTorch PT format with the export function."""
def exportAll(self, append: str = '') -> None:
"""Exports all loaded models to PyTorch PT format with the export function. `append` can be passed to the export
function
"""
for model in tqdm(self.calculator.loadedModels):
self.export(model)
self.export(model, append=append)
print('***** Done exporting all models! *****')


Expand All @@ -227,7 +237,7 @@ def __init__(self, calculator: Calculator):
assert len(self.calculator.loadedModels)>0, 'No models loaded in calculator. Nothing to export.'
print(f'Initialized CoreMLExporter with models: {list(self.calculator.loadedModels.keys())}')

def export(self, model: str):
def export(self, model: str, append: str = '') -> None:
"""Export a loaded model to CoreML format. Models will be saved as {model}.mlpackage in the current working
directory. Models will be annotated with the feature vector name (Ward2017 or KS2022) and the output will be
named "property". The latter behavior will be adjusted in the future when model output name and unit will be
Expand All @@ -237,6 +247,8 @@ def export(self, model: str):
model: The name of the model to export (must be loaded in the Calculator) and it must have a descriptor
(Ward2017 or KS2022) defined in the calculator.models dictionary created when the Calculator was
initialized.
append: A string to append to the exported model name after the model name. Useful for adding a version
number or other information to the exported model name.
Returns:
None
Expand Down Expand Up @@ -270,12 +282,14 @@ def export(self, model: str):
inputs=inputs_converter,
outputs=[ct.TensorType(name='property')]
)
name = f"{model}.mlpackage"
name = f"{model}{f'_{append}' if append else ''}.mlpackage"
coreml_model.save(name)
print(f'--> Exported as {name}', flush=True)

def exportAll(self):
"""Export all loaded models to CoreML format with the export function."""
def exportAll(self, append: str = '') -> None:
"""Export all loaded models to CoreML format with the export function. `append` can be passed to the export
function.
"""
for model in tqdm(self.calculator.loadedModels):
self.export(model)
self.export(model, append=append)
print('***** Done exporting all models! *****')

0 comments on commit f503dc0

Please sign in to comment.