Skip to content

Commit

Permalink
fix missing Optional input in the signature
Browse files Browse the repository at this point in the history
  • Loading branch information
maziyarpanahi committed Sep 2, 2024
1 parent 66d94a4 commit 9285df8
Showing 1 changed file with 10 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,10 @@ package com.johnsnowlabs.nlp.annotators.classifier.dl

import com.johnsnowlabs.ml.ai.AlbertClassification
import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel}
import com.johnsnowlabs.ml.tensorflow.sentencepiece.{
ReadSentencePieceModel,
SentencePieceWrapper,
WriteSentencePieceModel
}
import com.johnsnowlabs.ml.tensorflow.{
ReadTensorflowModel,
TensorflowWrapper,
WriteTensorflowModel
}
import com.johnsnowlabs.ml.util.LoadExternalModel.{
loadSentencePieceAsset,
loadTextAsset,
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.openvino.OpenvinoWrapper
import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ReadSentencePieceModel, SentencePieceWrapper, WriteSentencePieceModel}
import com.johnsnowlabs.ml.tensorflow.{ReadTensorflowModel, TensorflowWrapper, WriteTensorflowModel}
import com.johnsnowlabs.ml.util.LoadExternalModel.{loadSentencePieceAsset, loadTextAsset, modelSanityCheck, notSupportedEngineError}
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common.{SentenceSplit, TokenizedWithSentence}
Expand Down Expand Up @@ -170,13 +158,15 @@ class AlbertForZeroShotClassification(override val uid: String)
spark: SparkSession,
tensorflowWrapper: Option[TensorflowWrapper],
onnxWrapper: Option[OnnxWrapper],
openvinoWrapper: Option[OpenvinoWrapper],
spp: SentencePieceWrapper): AlbertForZeroShotClassification = {
if (_model.isEmpty) {
_model = Some(
spark.sparkContext.broadcast(
new AlbertClassification(
tensorflowWrapper,
onnxWrapper,
openvinoWrapper,
spp,
configProtoBytes = getConfigProtoBytes,
tags = $$(labels),
Expand Down Expand Up @@ -314,7 +304,7 @@ trait ReadAlbertForZeroShotDLModel
instance.getEngine match {
case TensorFlow.name =>
val tfWrapper = readTensorflowModel(path, spark, "_albert_classification_tf")
instance.setModelIfNotSet(spark, Some(tfWrapper), None, spp)
instance.setModelIfNotSet(spark, Some(tfWrapper), None, None, spp)
case ONNX.name =>
val onnxWrapper =
readOnnxModel(
Expand All @@ -324,7 +314,7 @@ trait ReadAlbertForZeroShotDLModel
zipped = true,
useBundle = false,
None)
instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp)
instance.setModelIfNotSet(spark, None, Some(onnxWrapper), None, spp)
case _ =>
throw new Exception(notSupportedEngineError)
}
Expand Down Expand Up @@ -381,11 +371,11 @@ trait ReadAlbertForZeroShotDLModel
*/
annotatorModel
.setSignatures(_signatures)
.setModelIfNotSet(spark, Some(wrapper), None, spModel)
.setModelIfNotSet(spark, Some(wrapper), None, None, spModel)
case ONNX.name =>
val onnxWrapper =
OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true)
annotatorModel.setModelIfNotSet(spark, None, Some(onnxWrapper), spModel)
annotatorModel.setModelIfNotSet(spark, None, Some(onnxWrapper), None, spModel)
case _ =>
throw new Exception(notSupportedEngineError)
}
Expand Down

0 comments on commit 9285df8

Please sign in to comment.