diff --git a/python/sparknlp/annotator/seq2seq/__init__.py b/python/sparknlp/annotator/seq2seq/__init__.py index 9f97473853a9ed..f7591969057e18 100644 --- a/python/sparknlp/annotator/seq2seq/__init__.py +++ b/python/sparknlp/annotator/seq2seq/__init__.py @@ -25,3 +25,4 @@ from sparknlp.annotator.seq2seq.nllb_transformer import * from sparknlp.annotator.seq2seq.cpm_transformer import * from sparknlp.annotator.seq2seq.qwen_transformer import * +from sparknlp.annotator.seq2seq.starcoder_transformer import * \ No newline at end of file diff --git a/python/sparknlp/annotator/seq2seq/starcoder_transformer.py b/python/sparknlp/annotator/seq2seq/starcoder_transformer.py new file mode 100644 index 00000000000000..3c87cb653682cc --- /dev/null +++ b/python/sparknlp/annotator/seq2seq/starcoder_transformer.py @@ -0,0 +1,335 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains classes for the StarCoderTransformer.""" + +from sparknlp.common import * + + +class StarCoderTransformer(AnnotatorModel, HasBatchedAnnotate, HasEngine): + """StarCoder2: The Versatile Code Companion. + + StarCoder2 is a Transformer model designed specifically for code generation and understanding. + With 13 billion parameters, it builds upon the advancements of its predecessors and is trained + on a diverse dataset that includes multiple programming languages. This extensive training + allows StarCoder2 to support a wide array of coding tasks, from code completion to generation. + + StarCoder2 was developed to assist developers in writing and understanding code more efficiently, + making it a valuable tool for various software development and data science tasks. + + Pretrained models can be loaded with :meth:`.pretrained` of the companion + object: + + >>> starcoder2 = StarCoder2Transformer.pretrained() \\ + ... .setInputCols(["document"]) \\ + ... .setOutputCol("generation") + + The default model is ``"starcoder2-13b"``, if no name is provided. For available + pretrained models please see the `Models Hub + `__. + + ====================== ====================== + Input Annotation types Output Annotation type + ====================== ====================== + ``DOCUMENT`` ``DOCUMENT`` + ====================== ====================== + + Parameters + ---------- + configProtoBytes + ConfigProto from tensorflow, serialized into byte array. + minOutputLength + Minimum length of the sequence to be generated, by default 0 + maxOutputLength + Maximum length of output text, by default 20 + doSample + Whether or not to use sampling; use greedy decoding otherwise, by default False + temperature + The value used to modulate the next token probabilities, by default 1.0 + topK + The number of highest probability vocabulary tokens to keep for + top-k-filtering, by default 50 + topP + Top cumulative probability for vocabulary tokens, by default 1.0 + + If set to float < 1, only the most probable tokens with probabilities + that add up to ``topP`` or higher are kept for generation. + repetitionPenalty + The parameter for repetition penalty, 1.0 means no penalty. , by default + 1.0 + noRepeatNgramSize + If set to int > 0, all ngrams of that size can only occur once, by + default 0 + ignoreTokenIds + A list of token ids which are ignored in the decoder's output, by + default [] + + Notes + ----- + This is a very computationally expensive module especially on larger + sequence. The use of an accelerator such as GPU is recommended. + + References + ---------- + - `StarCoder2: The Versatile Code Companion. + `__ + - https://github.com/bigcode-project/starcoder + + **Paper Abstract:** + + *The BigCode project, an open-scientific collaboration focused on the responsible + development of Large Language Models for Code (Code LLMs), introduces StarCoder2. In + partnership with Software Heritage (SWH), we build The Stack v2 on top of the digital commons + of their source code archive. Alongside the SWH repositories spanning 619 programming + languages, we carefully select other high-quality data sources, such as GitHub pull requests, + Kaggle notebooks, and code documentation. This results in a training set that is 4× larger + than the first StarCoder dataset. We train StarCoder2 models with 3B, 7B, and 15B parameters + on 3.3 to 4.3 trillion tokens and thoroughly evaluate them on a comprehensive set of Code LLM + benchmarks.* + + *We find that our small model, StarCoder2-3B, outperforms other Code LLMs of similar size on + most benchmarks, and also outperforms StarCoderBase-15B. Our large model, StarCoder2-15B, + significantly outperforms other models of comparable size. In addition, it matches or + outperforms CodeLlama-34B, a model more than twice its size. Although DeepSeekCoder-33B is + the best-performing model at code completion for high-resource languages, we find that + StarCoder2-15B outperforms it on math and code reasoning benchmarks, as well as several + low-resource languages. We make the model weights available under an OpenRAIL license and + ensure full transparency regarding the training data by releasing the Software Heritage + persistent Identifiers (SWHIDs) of the source code data.* + + Examples + -------- + >>> import sparknlp + >>> from sparknlp.base import * + >>> from sparknlp.annotator import * + >>> from pyspark.ml import Pipeline + >>> documentAssembler = DocumentAssembler() \\ + ... .setInputCol("text") \\ + ... .setOutputCol("documents") + >>> starcoder2 = StarCoder2Transformer.pretrained("starcoder2") \\ + ... .setInputCols(["documents"]) \\ + ... .setMaxOutputLength(50) \\ + ... .setOutputCol("generation") + >>> pipeline = Pipeline().setStages([documentAssembler, starcoder2]) + >>> data = spark.createDataFrame([["def add(a, b):"]]).toDF("text") + >>> result = pipeline.fit(data).transform(data) + >>> result.select("generation.result").show(truncate=False) + +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + |result | + +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + |[def add(a, b): return a + b] | + +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + """ + + + + name = "StarCoderTransformer" + + inputAnnotatorTypes = [AnnotatorType.DOCUMENT] + + outputAnnotatorType = AnnotatorType.DOCUMENT + + configProtoBytes = Param(Params._dummy(), "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()", + TypeConverters.toListInt) + + minOutputLength = Param(Params._dummy(), "minOutputLength", "Minimum length of the sequence to be generated", + typeConverter=TypeConverters.toInt) + + maxOutputLength = Param(Params._dummy(), "maxOutputLength", "Maximum length of output text", + typeConverter=TypeConverters.toInt) + + doSample = Param(Params._dummy(), "doSample", "Whether or not to use sampling; use greedy decoding otherwise", + typeConverter=TypeConverters.toBoolean) + + temperature = Param(Params._dummy(), "temperature", "The value used to module the next token probabilities", + typeConverter=TypeConverters.toFloat) + + topK = Param(Params._dummy(), "topK", + "The number of highest probability vocabulary tokens to keep for top-k-filtering", + typeConverter=TypeConverters.toInt) + + topP = Param(Params._dummy(), "topP", + "If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or higher are kept for generation", + typeConverter=TypeConverters.toFloat) + + repetitionPenalty = Param(Params._dummy(), "repetitionPenalty", + "The parameter for repetition penalty. 1.0 means no penalty. See `this paper `__ for more details", + typeConverter=TypeConverters.toFloat) + + noRepeatNgramSize = Param(Params._dummy(), "noRepeatNgramSize", + "If set to int > 0, all ngrams of that size can only occur once", + typeConverter=TypeConverters.toInt) + + ignoreTokenIds = Param(Params._dummy(), "ignoreTokenIds", + "A list of token ids which are ignored in the decoder's output", + typeConverter=TypeConverters.toListInt) + + def setIgnoreTokenIds(self, value): + """A list of token ids which are ignored in the decoder's output. + + Parameters + ---------- + value : List[int] + The words to be filtered out + """ + return self._set(ignoreTokenIds=value) + + def setConfigProtoBytes(self, b): + """Sets configProto from tensorflow, serialized into byte array. + + Parameters + ---------- + b : List[int] + ConfigProto from tensorflow, serialized into byte array + """ + return self._set(configProtoBytes=b) + + def setMinOutputLength(self, value): + """Sets minimum length of the sequence to be generated. + + Parameters + ---------- + value : int + Minimum length of the sequence to be generated + """ + return self._set(minOutputLength=value) + + def setMaxOutputLength(self, value): + """Sets maximum length of output text. + + Parameters + ---------- + value : int + Maximum length of output text + """ + return self._set(maxOutputLength=value) + + def setDoSample(self, value): + """Sets whether or not to use sampling, use greedy decoding otherwise. + + Parameters + ---------- + value : bool + Whether or not to use sampling; use greedy decoding otherwise + """ + return self._set(doSample=value) + + def setTemperature(self, value): + """Sets the value used to module the next token probabilities. + + Parameters + ---------- + value : float + The value used to module the next token probabilities + """ + return self._set(temperature=value) + + def setTopK(self, value): + """Sets the number of highest probability vocabulary tokens to keep for + top-k-filtering. + + Parameters + ---------- + value : int + Number of highest probability vocabulary tokens to keep + """ + return self._set(topK=value) + + def setTopP(self, value): + """Sets the top cumulative probability for vocabulary tokens. + + If set to float < 1, only the most probable tokens with probabilities + that add up to ``topP`` or higher are kept for generation. + + Parameters + ---------- + value : float + Cumulative probability for vocabulary tokens + """ + return self._set(topP=value) + + def setRepetitionPenalty(self, value): + """Sets the parameter for repetition penalty. 1.0 means no penalty. + + Parameters + ---------- + value : float + The repetition penalty + + References + ---------- + See `Ctrl: A Conditional Transformer Language Model For Controllable + Generation `__ for more details. + """ + return self._set(repetitionPenalty=value) + + def setNoRepeatNgramSize(self, value): + """Sets size of n-grams that can only occur once. + + If set to int > 0, all ngrams of that size can only occur once. + + Parameters + ---------- + value : int + N-gram size can only occur once + """ + return self._set(noRepeatNgramSize=value) + + @keyword_only + def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.StarCoderTransformer", java_model=None): + super(StarCoderTransformer, self).__init__(classname=classname, java_model=java_model) + self._setDefault(minOutputLength=0, maxOutputLength=20, doSample=False, temperature=0.6, topK=50, topP=0.9, + repetitionPenalty=1.0, noRepeatNgramSize=0, ignoreTokenIds=[], batchSize=1) + + @staticmethod + def loadSavedModel(folder, spark_session, use_openvino=False): + """Loads a locally saved model. + + Parameters + ---------- + folder : str + Folder of the saved model + spark_session : pyspark.sql.SparkSession + The current SparkSession + + Returns + ------- + StarCoderTransformer + The restored model + """ + from sparknlp.internal import _StarCoderLoader + jModel = _StarCoderLoader(folder, spark_session._jsparkSession, use_openvino)._java_obj + return StarCoderTransformer(java_model=jModel) + + @staticmethod + def pretrained(name="starcoder", lang="en", remote_loc=None): + """Downloads and loads a pretrained model. + + Parameters + ---------- + name : str, optional + Name of the pretrained model, by default "starcoder" + lang : str, optional + Language of the pretrained model, by default "en" + remote_loc : str, optional + Optional remote address of the resource, by default None. Will use + Spark NLPs repositories otherwise. + + Returns + ------- + StarCoderTransformer + The restored model + """ + from sparknlp.pretrained import ResourceDownloader + return ResourceDownloader.downloadModel(StarCoderTransformer, name, lang, remote_loc) diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py index c62fde0ed965ac..0d762eabef16fc 100644 --- a/python/sparknlp/internal/__init__.py +++ b/python/sparknlp/internal/__init__.py @@ -394,6 +394,15 @@ def __init__(self, path, jspark): ) +class _StarCoderLoader(ExtendedJavaWrapper): + def __init__(self, path, jspark, use_openvino=False): + super(_StarCoderLoader, self).__init__( + "com.johnsnowlabs.nlp.annotators.seq2seq.StarCoderTransformer.loadSavedModel", + path, + jspark, + use_openvino, + ) + class _T5Loader(ExtendedJavaWrapper): def __init__(self, path, jspark): super(_T5Loader, self).__init__( diff --git a/python/test/annotator/seq2seq/starcoder_transformer_test.py b/python/test/annotator/seq2seq/starcoder_transformer_test.py new file mode 100644 index 00000000000000..ac2d44c879357f --- /dev/null +++ b/python/test/annotator/seq2seq/starcoder_transformer_test.py @@ -0,0 +1,47 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import pytest + +from sparknlp.annotator import * +from sparknlp.base import * +from test.util import SparkContextForTest + + +@pytest.mark.slow +class StarCoderTransformerTextGenerationTestSpec(unittest.TestCase): + def setUp(self): + self.spark = SparkContextForTest.spark + + def runTest(self): + data = self.spark.createDataFrame([ + [1, """def add(a, b):""".strip().replace("\n", " ")]]).toDF("id", "text") + + document_assembler = DocumentAssembler() \ + .setInputCol("text") \ + .setOutputCol("documents") + + starcoder = StarCoderTransformer \ + .pretrained() \ + .setMaxOutputLength(50) \ + .setDoSample(False) \ + .setInputCols(["documents"]) \ + .setOutputCol("generation") + + pipeline = Pipeline().setStages([document_assembler, starcoder]) + results = pipeline.fit(data).transform(data) + + results.select("generation.result").show(truncate=False) + diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/StarCoder.scala b/src/main/scala/com/johnsnowlabs/ml/ai/StarCoder.scala new file mode 100644 index 00000000000000..4a31a2adc0262e --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/ml/ai/StarCoder.scala @@ -0,0 +1,474 @@ +/* + * Copyright 2017 - 2023 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.ml.ai + +import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession} +import com.johnsnowlabs.ml.ai.util.Generation.{Generate, GenerationConfig} +import com.johnsnowlabs.ml.onnx.OnnxSession +import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers +import com.johnsnowlabs.ml.onnx.TensorResources.implicits._ +import com.johnsnowlabs.ml.openvino.OpenvinoWrapper +import com.johnsnowlabs.ml.util.{ONNX, Openvino, TensorFlow} +import com.johnsnowlabs.nlp.Annotation +import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT +import com.johnsnowlabs.nlp.annotators.common.SentenceSplit +import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.{BpeTokenizer, StarCoderTokenizer} +import org.intel.openvino.InferRequest +import org.tensorflow.{Session, Tensor} + +import scala.collection.JavaConverters._ + +private[johnsnowlabs] class StarCoder( + val onnxWrappers: Option[DecoderWrappers], + val openvinoWrapper: Option[OpenvinoWrapper], + merges: Map[(String, String), Int], + vocabulary: Map[String, Int], + generationConfig: GenerationConfig) + extends Serializable + with Generate { + + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions + val detectedEngine: String = + if (onnxWrappers.isDefined) ONNX.name + else if (openvinoWrapper.isDefined) Openvino.name + else ONNX.name + private var nextPositionId: Option[Array[Long]] = None + val bpeTokenizer: StarCoderTokenizer = BpeTokenizer + .forModel( + "starcoder", + merges = merges, + vocab = vocabulary, + padWithSequenceTokens = false, + addPrefixSpaceToSentence = true) + .asInstanceOf[StarCoderTokenizer] + + private val GenerationConfig( + bosTokenId: Int, + paddingTokenId: Int, + eosTokenId: Int, + vocabSize: Int, + beginSuppressTokens, + suppressTokenIds, + forcedDecoderIds) = + generationConfig + + /** Decode a sequence of sentences + * @param sentences + * Sequence of sentences + * @return + * Sequence of decoded sentences + */ + def decode(sentences: Array[Array[Int]]): Seq[String] = { + sentences.map(s => bpeTokenizer.decodeTokens(s.map(_.toInt))) + } + + /** Encode a sequence of sentences + * @param sentences + * Sequence of sentences + * @return + * Sequence of encoded sentences + */ + def encode(sentences: Seq[Annotation]): Seq[Array[Int]] = { + SentenceSplit + .unpack(sentences) + .map(s => { + val sentWithTask = s + bpeTokenizer + .tokenize(sentWithTask) + .map(bpeTokenizer.encode) + .flatMap(_.map(_.pieceId)) + }) + } + + def tag( + batch: Seq[Array[Int]], + minOutputLength: Int, + maxOutputLength: Int, + doSample: Boolean, + temperature: Double, + topK: Int, + topP: Double, + repetitionPenalty: Double, + noRepeatNgramSize: Int, + randomSeed: Option[Long], + ignoreTokenIds: Array[Int] = Array(), + beamSize: Int, + maxInputLength: Int, + stopTokenIds: Array[Int]): Array[Array[Int]] = { + val ignoreTokenIdsInt = ignoreTokenIds + val expandedDecoderInputsVals = batch + val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray + val maxSentenceLength = sequencesLength.max // - curLen + + val numReturn_sequences = 1 + // from config + + var effectiveBatch_size = 1 + var effectiveBatch_mult = 1 + + if (doSample) { + effectiveBatch_size = expandedDecoderInputsVals.length * numReturn_sequences + effectiveBatch_mult = numReturn_sequences + } else { + effectiveBatch_size = expandedDecoderInputsVals.length + effectiveBatch_mult = 1 + } + + // Run the prompt through the decoder and get the past +// val decoderOutputs = +// generateGreedyOnnx( +// expandedDecoderInputsVals.toArray, +// (encoderSession, env), +// maxOutputLength) + val (decoderEncoderStateTensors, encoderAttentionMaskTensors, session) = + detectedEngine match { + case ONNX.name => + // dummy tensors for decoder encode state and attention mask + val (encoderSession, env) = onnxWrappers.get.decoder.getSession(onnxSessionOptions) + ( + Right(OnnxTensor.createTensor(env, Array(0))), + Right(OnnxTensor.createTensor(env, Array(1))), + Right((env, encoderSession))) + case Openvino.name => + // not needed + (null, null, null) + } + val ovInferRequest: Option[InferRequest] = detectedEngine match { + case ONNX.name => None + case Openvino.name => Some(openvinoWrapper.get.getCompiledModel().create_infer_request()) + } + // output with beam search + val modelOutputs = generate( + batch, + decoderEncoderStateTensors, + encoderAttentionMaskTensors, + expandedDecoderInputsVals.toArray, + maxOutputLength + maxSentenceLength, + minOutputLength, + doSample, + beamSize, + 1, + temperature, + topK, + topP, + repetitionPenalty, + noRepeatNgramSize, + this.vocabSize, + this.eosTokenId, + this.paddingTokenId, + randomSeed, + ignoreTokenIdsInt, + session, + applySoftmax = false, + ovInferRequest = ovInferRequest, + stopTokenIds = stopTokenIds) + +// decoderOutputs + modelOutputs + } + + def predict( + sentences: Seq[Annotation], + batchSize: Int, + minOutputLength: Int, + maxOutputLength: Int, + doSample: Boolean, + temperature: Double, + topK: Int, + topP: Double, + repetitionPenalty: Double, + noRepeatNgramSize: Int, + randomSeed: Option[Long] = None, + ignoreTokenIds: Array[Int] = Array(), + beamSize: Int, + maxInputLength: Int, + stopTokenIds: Array[Int]): Seq[Annotation] = { + + val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch => + val batchSP = encode(batch) + val spIds = tag( + batchSP, + minOutputLength, + maxOutputLength, + doSample, + temperature, + topK, + topP, + repetitionPenalty, + noRepeatNgramSize, + randomSeed, + ignoreTokenIds, + beamSize, + maxInputLength, + stopTokenIds) + + decode(spIds) + + } + + var sentBegin, nextSentEnd = 0 + val annotations = batchDecoder.zip(sentences).map { case (content, sent) => + nextSentEnd += content.length - 1 + val annots = new Annotation( + annotatorType = DOCUMENT, + begin = sentBegin, + end = nextSentEnd, + result = content, + metadata = sent.metadata) + sentBegin += nextSentEnd + 1 + annots + } + annotations + } + + private def getDecoderOutputsWithPast( + inputIds: Array[Array[Int]], + decoderPast: Map[String, OnnxTensor], + onnxSession: (OrtSession, OrtEnvironment)) + : (Array[Array[Float]], Map[String, OnnxTensor]) = { + val (session, env) = onnxSession + + val lastTokens: Array[Array[Long]] = + inputIds.map { tokenIds => + Array(tokenIds.last.toLong) + } + + val lastTokensTensor: OnnxTensor = + OnnxTensor.createTensor(env, lastTokens) + val decoderAttentionMask: OnnxTensor = + OnnxTensor.createTensor(env, lastTokens.map(_.map(_ => 1L))) + val decoderWithPastInputs: java.util.Map[String, OnnxTensor] = (Map( + OnnxSignatures.decoderInputIDs -> lastTokensTensor, + OnnxSignatures.decoderAttentionMask -> decoderAttentionMask) ++ decoderPast).asJava + val sessionOutput = session.run(decoderWithPastInputs) + val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput) + val decoderPresent = sessionOutput.getOnnxTensors(OnnxSignatures.decoderPresent) + lastTokensTensor.close() + val batchLogits = logits.grouped(vocabSize).toArray + (batchLogits, decoderPresent) + + } + + override def getModelOutput( + encoderInputIds: Seq[Array[Int]], + decoderInputIds: Seq[Array[Int]], + decoderEncoderStateTensors: Either[Tensor, OnnxTensor], + encoderAttentionMaskTensors: Either[Tensor, OnnxTensor], + maxLength: Int, + session: Either[Session, (OrtEnvironment, OrtSession)], + ovInferRequest: Option[InferRequest]): Array[Array[Float]] = { + + detectedEngine match { + case TensorFlow.name => + // not implemented yet + Array() + case ONNX.name => + val (env, decoderSession) = session.right.get + val decoderOutputs = + getDecoderOutputs(decoderInputIds.toArray, onnxSession = (decoderSession, env)) + decoderOutputs + case Openvino.name => + val decoderOutputs = + getDecoderOutputsOv( + encoderInputIds.toArray, + decoderInputIds.toArray, + ovInferRequest.get) + decoderOutputs + } + } + + private def getDecoderOutputsOv( + encoderInputIds: Array[Array[Int]], + decoderInputIds: Array[Array[Int]], + inferRequest: InferRequest): (Array[Array[Float]]) = { + + val (inputIdsLong, inputPositionIDsLong): (Array[Long], Array[Long]) = + if (encoderInputIds.head.length == decoderInputIds.head.length) { + // First pass + val inpIdsLong = decoderInputIds.flatMap { tokenIds => tokenIds.map(_.toLong) } + val posIdsLong = decoderInputIds.flatMap { tokenIds => + tokenIds.zipWithIndex.map { case (_, i) => + i.toLong + } + } + (inpIdsLong, posIdsLong) + } else { + // Subsequent passes + val inpIdsLong = decoderInputIds.map { tokenIds => tokenIds.last.toLong } + val posIdsLong = decoderInputIds.map { tokenIds => + tokenIds.zipWithIndex.map { case (_, i) => + i.toLong + }.last + } + (inpIdsLong, posIdsLong) + } + val attentionMask: Array[Long] = + decoderInputIds.flatMap { tokenIds => tokenIds.map(_ => 1L) } + + val batchSize: Int = decoderInputIds.length + val beamIdx: Array[Int] = new Array[Int](batchSize) + val shape: Array[Int] = Array(batchSize, inputIdsLong.length / batchSize) + + val inputIdsLongTensor: org.intel.openvino.Tensor = + new org.intel.openvino.Tensor(shape, inputIdsLong) + val decoderAttentionMask: org.intel.openvino.Tensor = + new org.intel.openvino.Tensor(Array(batchSize, decoderInputIds.head.length), attentionMask) + val decoderPositionIDs: org.intel.openvino.Tensor = + new org.intel.openvino.Tensor(shape, inputPositionIDsLong) + val beamIdxTensor: org.intel.openvino.Tensor = + new org.intel.openvino.Tensor(Array(batchSize), beamIdx) + + inferRequest.set_tensor(OpenVinoSignatures.decoderInputIDs, inputIdsLongTensor) + inferRequest.set_tensor(OpenVinoSignatures.decoderAttentionMask, decoderAttentionMask) + inferRequest.set_tensor(OpenVinoSignatures.decoderPositionIDs, decoderPositionIDs) + inferRequest.set_tensor(OpenVinoSignatures.decoderBeamIdx, beamIdxTensor) + + inferRequest.infer() + + val result = inferRequest.get_tensor(OpenVinoSignatures.decoderOutput) + val logitsRaw = result.data() + + val sequenceLength = inputIdsLong.length / batchSize + val decoderOutputs = (0 until batchSize).map(i => { + logitsRaw + .slice( + i * sequenceLength * vocabSize + (sequenceLength - 1) * vocabSize, + i * sequenceLength * vocabSize + sequenceLength * vocabSize) + }) + decoderOutputs.toArray + } + + private def getDecoderOutputs( + inputIds: Array[Array[Int]], + onnxSession: (OrtSession, OrtEnvironment)): (Array[Array[Float]]) = { + val (session, env) = onnxSession + + val inputIdsLong: Array[Array[Long]] = + inputIds.map { tokenIds => tokenIds.map(_.toLong) } + + val inputPositionIDsLong: Array[Array[Long]] = + inputIds.map { tokenIds => + tokenIds.zipWithIndex.map { case (_, i) => + i.toLong + } + } + + val inputIdsLongTensor: OnnxTensor = + OnnxTensor.createTensor(env, inputIdsLong) + val decoderAttentionMask: OnnxTensor = + OnnxTensor.createTensor(env, inputIdsLong.map(_.map(_ => 1L))) + val decoderPositionIDs: OnnxTensor = + OnnxTensor.createTensor(env, inputPositionIDsLong) + + val decoderInputs: java.util.Map[String, OnnxTensor] = Map( + OnnxSignatures.decoderInputIDs -> inputIdsLongTensor, + OnnxSignatures.decoderAttentionMask -> decoderAttentionMask, + OnnxSignatures.decoderPositionIDs -> decoderPositionIDs).asJava + val sessionOutput = session.run(decoderInputs) + + val sequenceLength = inputIds.head.length + val batchSize = inputIds.length + +// val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput) +// inputIdsLongTensor.close() +// decoderPositionIDs.close() +// decoderAttentionMask.close() +// val batchLogits = logits.grouped(vocabSize).toArray +// batchLogits + + val logitsRaw = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput) + val decoderOutputs = (0 until batchSize).map(i => { + logitsRaw + .slice( + i * sequenceLength * vocabSize + (sequenceLength - 1) * vocabSize, + i * sequenceLength * vocabSize + sequenceLength * vocabSize) + }) + decoderOutputs.toArray + } + + /** Gets the index with the highest score + * + * @param scores + * Array of Scores to max + * @return + * Index of the highest score + */ + private def argmax(scores: Array[Float]): Int = + scores.zipWithIndex.maxBy { case (score, _) => + score + }._2 + private def greedyGenerationFinished( + decoderIds: Seq[Array[Int]], + eosTokenId: Int, + maxOutputLength: Int): Boolean = + decoderIds.map(_.last).forall(_ == eosTokenId) || decoderIds.head.length == maxOutputLength + + private def generateGreedyOnnx( + inputIds: Array[Array[Int]], + onnxSession: (OrtSession, OrtEnvironment), + maxOutputLength: Int): (Array[Array[Int]]) = { + + val sequencesLength = inputIds.map(x => x.length).toArray + val maxSentenceLength = sequencesLength.max // - curLen + var generatedIds: Array[Array[Int]] = inputIds + while (!greedyGenerationFinished( + generatedIds, + eosTokenId, + maxOutputLength + maxSentenceLength)) { + + val (batchLogits: Array[Array[Float]]) = + Array(getDecoderOutputs(generatedIds, onnxSession).last) + + val nextTokenIds: Array[Int] = batchLogits.map(argmax) + generatedIds = + generatedIds.zip(nextTokenIds).map { case (currentIds: Array[Int], nextId: Int) => + currentIds ++ Array(nextId) + } + } + generatedIds + } + + private object OnnxSignatures { + val decoderInputIDs: String = "input_ids" + val decoderAttentionMask: String = "attention_mask" + val decoderPositionIDs: String = "position_ids" + + // create decoder past for 32 layers of key and value eg. past_key_values.0.key and past_key_values.0.value + val decoderPast: Array[String] = (0 until 32) + .flatMap(i => Seq(s"past_key_values.$i.key", s"past_key_values.$i.value")) + .toArray + val decoderOutput: String = "logits" + val decoderPresent: Array[String] = + (0 until 32).flatMap(i => Seq(s"present.$i.key", s"present.$i.value")).toArray + } + + private object OpenVinoSignatures { + val encoderInputIDs: String = "input_ids" + val encoderAttentionMask: String = "attention_mask" + + val encoderOutput: String = "last_hidden_state" + + val decoderInputIDs: String = "input_ids" + val decoderEncoderAttentionMask: String = "encoder_attention_mask" + val decoderAttentionMask: String = "attention_mask" + val decoderPositionIDs: String = "position_ids" + val decoderBeamIdx: String = "beam_idx" + val decoderEncoderState: String = "encoder_hidden_states" + + val decoderOutput: String = "logits" + } +} diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/StarCoderTransformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/StarCoderTransformer.scala new file mode 100644 index 00000000000000..fa96a9603d9873 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/StarCoderTransformer.scala @@ -0,0 +1,470 @@ +/* + * Copyright 2017-2024 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.seq2seq + +import com.johnsnowlabs.ml.ai.util.Generation.GenerationConfig +import com.johnsnowlabs.ml.ai.StarCoder +import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} +import com.johnsnowlabs.ml.openvino.{OpenvinoWrapper, ReadOpenvinoModel, WriteOpenvinoModel} +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadJsonStringAsset, + loadSentencePieceAsset, + loadTextAsset, + modelSanityCheck, + notSupportedEngineError +} +import com.johnsnowlabs.ml.util.{ONNX, Openvino} +import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT +import com.johnsnowlabs.nlp._ +import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ + ReadSentencePieceModel, + SentencePieceWrapper, + WriteSentencePieceModel +} +import com.johnsnowlabs.nlp.serialization.MapFeature +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.SparkSession +import com.johnsnowlabs.nlp.serialization.{MapFeature, StructFeature} +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +/** StarCoder2: The Versatile Code Companion. + * + * StarCoder2 is a Transformer model designed specifically for code generation and understanding. + * With 13 billion parameters, it builds upon the advancements of its predecessors and is trained + * on a diverse dataset that includes multiple programming languages. This extensive training + * allows StarCoder2 to support a wide array of coding tasks, from code completion to generation. + * + * StarCoder2 was developed to assist developers in writing and understanding code more + * efficiently, making it a valuable tool for various software development and data science + * tasks. + * + * Pretrained models can be loaded with `pretrained` of the companion object: + * {{{ + * val starcoder2 = StarCoder2Transformer.pretrained() + * .setInputCols("document") + * .setOutputCol("generation") + * }}} + * The default model is `"StarCoder2-3B"`, if no name is provided. For available pretrained + * models please see the [[https://sparknlp.org/models?q=StarCoder2 Models Hub]]. + * + * For extended examples of usage, see + * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/StarCoder2TestSpec.scala StarCoder2TestSpec]]. + * + * '''References:''' + * - [[https://huggingface.co/blog/starcoder StarCoder2: The Versatile Code Companion]] + * - [[https://github.com/bigcode-project/starcoder]] + * + * '''Paper Abstract:''' + * + * ''The BigCode project,1 an open-scientific collaboration focused on the responsible + * development of Large Language Models for Code (Code LLMs), introduces StarCoder2. In + * partnership with Software Heritage (SWH),2 we build The Stack v2 on top of the digital commons + * of their source code archive. Alongside the SWH repositories spanning 619 programming + * languages, we carefully select other high-quality data sources, such as GitHub pull requests, + * Kaggle notebooks, and code documentation. This results in a training set that is 4× larger + * than the first StarCoder dataset. We train StarCoder2 models with 3B, 7B, and 15B parameters + * on 3.3 to 4.3 trillion tokens and thoroughly evaluate them on a comprehensive set of Code LLM + * benchmarks.'' + * + * '' We find that our small model, StarCoder2-3B, outperforms other Code LLMs of similar size on + * most benchmarks, and also outperforms StarCoderBase-15B. Our large model, StarCoder2- 15B, + * significantly outperforms other models of comparable size. In addition, it matches or + * outperforms CodeLlama-34B, a model more than twice its size. Although DeepSeekCoder- 33B is + * the best-performing model at code completion for high-resource languages, we find that + * StarCoder2-15B outperforms it on math and code reasoning benchmarks, as well as several + * low-resource languages. We make the model weights available under an OpenRAIL license and + * ensure full transparency regarding the training data by releasing the SoftWare Heritage + * persistent IDentifiers (SWHIDs) of the source code data.'' + * + * '''Note:''' + * + * This is a computationally intensive module, especially for larger code sequences. The use of + * an accelerator such as GPU is recommended. + * + * ==Example== + * {{{ + * import spark.implicits._ + * import com.johnsnowlabs.nlp.base.DocumentAssembler + * import com.johnsnowlabs.nlp.annotators.seq2seq.StarCoder2Transformer + * import org.apache.spark.ml.Pipeline + * + * val documentAssembler = new DocumentAssembler() + * .setInputCol("text") + * .setOutputCol("documents") + * + * val starcoder2 = StarCoder2Transformer.pretrained("starcoder2") + * .setInputCols(Array("documents")) + * .setMinOutputLength(10) + * .setMaxOutputLength(50) + * .setDoSample(false) + * .setTopK(50) + * .setNoRepeatNgramSize(3) + * .setOutputCol("generation") + * + * val pipeline = new Pipeline().setStages(Array(documentAssembler, starcoder2)) + * + * val data = Seq( + * "def add(a, b):" + * ).toDF("text") + * val result = pipeline.fit(data).transform(data) + * + * results.select("generation.result").show(truncate = false) + * +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + * |result | + * +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + * |[def add(a, b): return a + b] | + * +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + * }}} + * + * @param uid + * required uid for storing annotator to disk + * @groupname anno Annotator types + * @groupdesc anno + * Required input and expected output annotator types + * @groupname Ungrouped Members + * @groupname param Parameters + * @groupname setParam Parameter setters + * @groupname getParam Parameter getters + * @groupname Ungrouped Members + * @groupprio param 1 + * @groupprio anno 2 + * @groupprio Ungrouped 3 + * @groupprio setParam 4 + * @groupprio getParam 5 + * @groupdesc param + * A list of (hyper-)parameter keys this annotator can take. Users can set and get the + * parameter values through setters and getters, respectively. + */ +class StarCoderTransformer(override val uid: String) + extends AnnotatorModel[StarCoderTransformer] + with HasBatchedAnnotate[StarCoderTransformer] + with ParamsAndFeaturesWritable + with WriteOnnxModel + with WriteOpenvinoModel + with HasGeneratorProperties + with HasEngine { + + def this() = this(Identifiable.randomUID("StarCoderTRANSFORMER")) + + /** Input annotator type : DOCUMENT + * + * @group param + */ + override val inputAnnotatorTypes: Array[AnnotatorType] = Array(DOCUMENT) + + /** Output annotator type : DOCUMENT + * + * @group param + */ + override val outputAnnotatorType: String = DOCUMENT + + /** @group setParam */ + def setRandomSeed(value: Int): StarCoderTransformer.this.type = { + if (randomSeed.isEmpty) { + this.randomSeed = Some(value) + } + this + } + + /** A list of token ids which are ignored in the decoder's output (Default: `Array()`) + * + * @group param + */ + var ignoreTokenIds = new IntArrayParam( + this, + "ignoreTokenIds", + "A list of token ids which are ignored in the decoder's output") + + /** @group setParam */ + def setIgnoreTokenIds(tokenIds: Array[Int]): StarCoderTransformer.this.type = { + set(ignoreTokenIds, tokenIds) + } + + /** @group getParam */ + def getIgnoreTokenIds: Array[Int] = $(ignoreTokenIds) + + /** Vocabulary used to encode the words to ids with bpeTokenizer.encode + * + * @group param + */ + val vocabulary: MapFeature[String, Int] = new MapFeature(this, "vocabulary").setProtected() + + /** @group setParam */ + def setVocabulary(value: Map[String, Int]): this.type = set(vocabulary, value) + + /** Holding merges.txt coming from RoBERTa model + * + * @group param + */ + val merges: MapFeature[(String, String), Int] = new MapFeature(this, "merges").setProtected() + + /** @group setParam */ + def setMerges(value: Map[(String, String), Int]): this.type = set(merges, value) + + private var _model: Option[Broadcast[StarCoder]] = None + + val generationConfig: StructFeature[GenerationConfig] = + new StructFeature(this, "generationConfig").setProtected() + + def setGenerationConfig(value: GenerationConfig): this.type = + set(generationConfig, value) + + def getGenerationConfig: GenerationConfig = $$(generationConfig) + + /** @group setParam */ + def setModelIfNotSet( + spark: SparkSession, + onnxWrappers: Option[DecoderWrappers], + openvinoWrapper: Option[OpenvinoWrapper]): this.type = { + if (_model.isEmpty) { + _model = Some( + spark.sparkContext.broadcast( + new StarCoder( + onnxWrappers, + openvinoWrapper, + $$(merges), + $$(vocabulary), + generationConfig = getGenerationConfig))) + } + this + } + + /** @group getParam */ + def getModelIfNotSet: StarCoder = _model.get.value + + setDefault( + minOutputLength -> 0, + maxOutputLength -> 20, + doSample -> false, + temperature -> 0.6, + topK -> 50, + topP -> 0.9, + repetitionPenalty -> 1.0, + noRepeatNgramSize -> 3, + ignoreTokenIds -> Array(), + batchSize -> 1, + beamSize -> 1, + maxInputLength -> 4096, + stopTokenIds -> Array()) + + /** takes a document and annotations and produces new annotations of this annotator's annotation + * type + * + * @param batchedAnnotations + * Annotations that correspond to inputAnnotationCols generated by previous annotators if any + * @return + * any number of annotations processed for every input annotation. Not necessary one to one + * relationship + */ + override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = { + + val allAnnotations = batchedAnnotations + .filter(_.nonEmpty) + .zipWithIndex + .flatMap { case (annotations, i) => + annotations.filter(_.result.nonEmpty).map(x => (x, i)) + } + val processedAnnotations = if (allAnnotations.nonEmpty) { + this.getModelIfNotSet.predict( + sentences = allAnnotations.map(_._1), + batchSize = $(batchSize), + minOutputLength = $(minOutputLength), + maxOutputLength = $(maxOutputLength), + doSample = $(doSample), + temperature = $(temperature), + topK = $(topK), + topP = $(topP), + repetitionPenalty = $(repetitionPenalty), + noRepeatNgramSize = $(noRepeatNgramSize), + randomSeed = this.randomSeed, + ignoreTokenIds = $(ignoreTokenIds), + beamSize = $(beamSize), + maxInputLength = $(maxInputLength), + stopTokenIds = $(stopTokenIds)) + } else { + Seq() + } + Seq(processedAnnotations) + } + + override def onWrite(path: String, spark: SparkSession): Unit = { + super.onWrite(path, spark) + getEngine match { + case ONNX.name => + val wrappers = getModelIfNotSet.onnxWrappers + writeOnnxModels( + path, + spark, + Seq((wrappers.get.decoder, "decoder_model.onnx")), + StarCoderTransformer.suffix) + case Openvino.name => + val wrappers = getModelIfNotSet.openvinoWrapper + writeOpenvinoModel( + path, + spark, + wrappers.get, + StarCoderTransformer.suffix, + StarCoderTransformer.openvinoFile) + } + } +} + +trait ReadablePretrainedStarCoderTransformerModel + extends ParamsAndFeaturesReadable[StarCoderTransformer] + with HasPretrained[StarCoderTransformer] { + override val defaultModelName: Some[String] = Some("starcoder") + + /** Java compliant-overrides */ + override def pretrained(): StarCoderTransformer = super.pretrained() + + override def pretrained(name: String): StarCoderTransformer = super.pretrained(name) + + override def pretrained(name: String, lang: String): StarCoderTransformer = + super.pretrained(name, lang) + + override def pretrained(name: String, lang: String, remoteLoc: String): StarCoderTransformer = + super.pretrained(name, lang, remoteLoc) +} + +trait ReadStarCoderTransformerDLModel extends ReadOnnxModel with ReadOpenvinoModel { + this: ParamsAndFeaturesReadable[StarCoderTransformer] => + + override val onnxFile: String = "starcoder_onnx" + val suffix: String = "_starcoder" + override val openvinoFile: String = "starcoder_openvino" + + def readModel(instance: StarCoderTransformer, path: String, spark: SparkSession): Unit = { + instance.getEngine match { + case ONNX.name => + val wrappers = + readOnnxModels(path, spark, Seq("decoder_model.onnx"), suffix) + val onnxWrappers = + DecoderWrappers(decoder = wrappers("decoder_model.onnx")) + instance.setModelIfNotSet(spark, Some(onnxWrappers), None) + case Openvino.name => + val ovWrapper = + readOpenvinoModel(path, spark, "_starcoder_ov") + instance.setModelIfNotSet(spark, None, Some(ovWrapper)) + case _ => + throw new Exception(notSupportedEngineError) + } + } + + addReader(readModel) + + def loadSavedModel( + modelPath: String, + spark: SparkSession, + useOpenvino: Boolean = false): StarCoderTransformer = { + implicit val formats: DefaultFormats.type = DefaultFormats // for json4 + val (localModelPath, detectedEngine) = + modelSanityCheck(modelPath, isDecoder = true) + val modelConfig: JValue = + parse(loadJsonStringAsset(localModelPath, "config.json")) + + val beginSuppressTokens: Array[Int] = + (modelConfig \ "begin_suppress_tokens").extract[Array[Int]] + + val suppressTokenIds: Array[Int] = + (modelConfig \ "suppress_tokens").extract[Array[Int]] + + val forcedDecoderIds: Array[(Int, Int)] = + (modelConfig \ "forced_decoder_ids").extract[Array[Array[Int]]].map { + case idxWithTokenId: Array[Int] if idxWithTokenId.length == 2 => + (idxWithTokenId(0), idxWithTokenId(1)) + case _ => + throw new Exception( + "Could not extract forced_decoder_ids. Should be a list of tuples with 2 entries.") + } + + def arrayOrNone[T](array: Array[T]): Option[Array[T]] = + if (array.nonEmpty) Some(array) else None + + val bosTokenId = (modelConfig \ "bos_token_id").extract[Int] + val eosTokenId = (modelConfig \ "eos_token_id").extract[Int] + val padTokenId = (modelConfig \ "eos_token_id").extract[Int] + val vocabSize = (modelConfig \ "vocab_size").extract[Int] + + val vocabs = loadTextAsset(localModelPath, "vocab.txt").zipWithIndex.toMap + + val bytePairs = loadTextAsset(localModelPath, "merges.txt") + .map(_.split(" ")) + .filter(w => w.length == 2) + .map { case Array(c1, c2) => (c1, c2) } + .zipWithIndex + .toMap + + val annotatorModel = new StarCoderTransformer() + .setGenerationConfig( + GenerationConfig( + bosTokenId, + padTokenId, + eosTokenId, + vocabSize, + arrayOrNone(beginSuppressTokens), + arrayOrNone(suppressTokenIds), + arrayOrNone(forcedDecoderIds))) + .setVocabulary(vocabs) + .setMerges(bytePairs) + + val modelEngine = + if (useOpenvino) + Openvino.name + else + detectedEngine + annotatorModel.set(annotatorModel.engine, modelEngine) + + detectedEngine match { + case ONNX.name => + val onnxWrapperDecoder = + OnnxWrapper.read( + spark, + localModelPath, + zipped = false, + useBundle = true, + modelName = "decoder_model") + + val onnxWrappers = DecoderWrappers(onnxWrapperDecoder) + + annotatorModel + .setModelIfNotSet(spark, Some(onnxWrappers), None) + case Openvino.name => + val openvinoWrapper = + OpenvinoWrapper.read( + spark, + localModelPath, + zipped = false, + useBundle = true, + detectedEngine = detectedEngine) + annotatorModel.setModelIfNotSet(spark, None, Some(openvinoWrapper)) + + case _ => + throw new Exception(notSupportedEngineError) + } + + annotatorModel + } + +} + +object StarCoderTransformer + extends ReadablePretrainedStarCoderTransformerModel + with ReadStarCoderTransformerDLModel diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeSpecialTokens.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeSpecialTokens.scala index 5bab008aac220f..eb2769a4ad7458 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeSpecialTokens.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeSpecialTokens.scala @@ -162,6 +162,14 @@ private[johnsnowlabs] object SpecialTokens { maskTokenString = "<|endoftext|>", padTokenString = "<|endoftext|>") + case "starcoder" => + SpecialTokens( + vocab, + startTokenString = "<|endoftext|>", + endTokenString = "<|endoftext|>", + unkTokenString = "<|endoftext|>", + maskTokenString = "<|endoftext|>", + padTokenString = "<|endoftext|>") } } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeTokenizer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeTokenizer.scala index d1538ca4d2ac97..13bf1940d46085 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeTokenizer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeTokenizer.scala @@ -368,6 +368,13 @@ object BpeTokenizer { modelSpecialTokens(), padWithSequenceTokens, addPrefixSpaceToSentence = addPrefixSpaceToSentence) + case "starcoder" => + new StarCoderTokenizer( + merges, + vocab, + modelSpecialTokens(), + padWithSequenceTokens, + addPrefixSpaceToSentence = addPrefixSpaceToSentence) case _ => throw new IllegalArgumentException("Model type \"" + modelType + "\" not supported yet.") } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/StarCoderTokenizer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/StarCoderTokenizer.scala new file mode 100644 index 00000000000000..d7eba25661b467 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/StarCoderTokenizer.scala @@ -0,0 +1,32 @@ +/* + * Copyright 2017-2023 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.tokenizer.bpe + +class StarCoderTokenizer( + merges: Map[(String, String), Int], + vocab: Map[String, Int], + specialTokens: SpecialTokens, + padWithSequenceTokens: Boolean = false, + addPrefixSpaceToSentence: Boolean = false) + extends Gpt2Tokenizer( + merges, + vocab, + specialTokens, + padWithSequenceTokens, + prependString = "Ġ", + addPrefixSpaceToSentence, + alwaysAddPrefix = false) diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/StarCoderTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/StarCoderTestSpec.scala new file mode 100644 index 00000000000000..203af72642bc55 --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/StarCoderTestSpec.scala @@ -0,0 +1,66 @@ +/* + * Copyright 2017-2023 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.seq2seq + +import com.johnsnowlabs.nlp.base.DocumentAssembler +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import com.johnsnowlabs.tags.{FastTest, SlowTest} +import org.apache.spark.ml.Pipeline +import org.scalatest.flatspec.AnyFlatSpec + +class StarCoderTestSpec extends AnyFlatSpec { + + "starcoder" should "should handle temperature=0 correctly and not crash when predicting more than 1 element with doSample=True" taggedAs SlowTest in { + // Even tough the Paper states temperature in interval [0,1), using temperature=0 will result in division by 0 error. + // Also DoSample=True may result in infinities being generated and distFiltered.length==0 which results in exception if we don't return 0 instead internally. + val testData = ResourceHelper.spark + .createDataFrame(Seq((1, "def print_hello_world():"))) + .toDF("id", "text") + .repartition(1) + val documentAssembler = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("documents") + + val bart = StarCoderTransformer + .pretrained() + .setInputCols(Array("documents")) + .setDoSample(false) + .setMaxOutputLength(50) + .setOutputCol("generation") + .setBeamSize(1) + + val pipeline = new Pipeline() + .setStages(Array(documentAssembler, bart)) + + val pipelineModel = pipeline.fit(testData) + + pipelineModel + .transform(testData) + .show(truncate = false) + + pipelineModel + .transform(testData) + .show(truncate = false) + + pipelineModel.stages.last + .asInstanceOf[StarCoderTransformer] + .write + .overwrite() + .save("/tmp/starcoder-3b-4bit-model") + + } +}