diff --git a/haystack/nodes/__init__.py b/haystack/nodes/__init__.py index b88f1f53a5..ec3e6126d7 100644 --- a/haystack/nodes/__init__.py +++ b/haystack/nodes/__init__.py @@ -23,6 +23,7 @@ from haystack.nodes.label_generator import PseudoLabelGenerator from haystack.nodes.other import Docs2Answers, JoinDocuments, RouteDocuments, JoinAnswers, DocumentMerger from haystack.nodes.preprocessor import BasePreProcessor, PreProcessor +from haystack.nodes.prompt import PromptNode, PromptTemplate, PromptModel from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier from haystack.nodes.question_generator import QuestionGenerator from haystack.nodes.ranker import BaseRanker, SentenceTransformersRanker diff --git a/haystack/nodes/prompt/__init__.py b/haystack/nodes/prompt/__init__.py new file mode 100644 index 0000000000..0fef7cda02 --- /dev/null +++ b/haystack/nodes/prompt/__init__.py @@ -0,0 +1 @@ +from haystack.nodes.prompt.prompt_node import PromptNode, PromptTemplate, PromptModel diff --git a/haystack/nodes/prompt/prompt_node.py b/haystack/nodes/prompt/prompt_node.py new file mode 100644 index 0000000000..d666775406 --- /dev/null +++ b/haystack/nodes/prompt/prompt_node.py @@ -0,0 +1,822 @@ +import json +import logging +import re +from abc import ABC, abstractmethod +from string import Template +from typing import Dict, List, Optional, Tuple, Union, Any, Type + +import requests +import torch +from transformers import pipeline, AutoModelForSeq2SeqLM + +from haystack import MultiLabel +from haystack.errors import OpenAIError, OpenAIRateLimitError +from haystack.modeling.utils import initialize_device_settings +from haystack.nodes.base import BaseComponent +from haystack.schema import Document + +logger = logging.getLogger(__name__) + + +class BasePromptTemplate(BaseComponent): + + outgoing_edges = 1 + + def run( + self, + query: Optional[str] = None, + file_paths: Optional[List[str]] = None, + labels: Optional[MultiLabel] = None, + documents: Optional[List[Document]] = None, + meta: Optional[dict] = None, + ) -> Tuple[Dict, str]: + raise NotImplementedError("This method should never be implemented in the derived class") + + def run_batch( + self, + queries: Optional[Union[str, List[str]]] = None, + file_paths: Optional[List[str]] = None, + labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None, + documents: Optional[Union[List[Document], List[List[Document]]]] = None, + meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + params: Optional[dict] = None, + debug: Optional[bool] = None, + ): + raise NotImplementedError("This method should never be implemented in the derived class") + + +class PromptTemplate(BasePromptTemplate, ABC): + """ + PromptTemplate represents a template for a prompt. For example, a prompt template for the sentiment + analysis task might look like this: + + ```python + PromptTemplate(name="sentiment-analysis", + prompt_text="Please give a sentiment for this context. Answer with positive, negative + or neutral. Context: $documents; Answer:", + prompt_params=["documents"]) + ``` + + PromptTemplate declares prompt_params, which are the input parameters that need to be filled in the prompt_text. + For example, in the above example, the prompt_params are ["documents"] and the prompt_text is + "Please give a sentiment..." + + The prompt_text contains a placeholder $documents. This variable will be filled in runtime with the non-keyword + or keyword argument `documents` passed to this PromptTemplate's fill() method. + """ + + def __init__(self, name: str, prompt_text: str, prompt_params: Optional[List[str]] = None): + super().__init__() + if not prompt_params: + # Define the regex pattern to match the strings after the $ character + pattern = r"\$([a-zA-Z0-9_]+)" + prompt_params = re.findall(pattern, prompt_text) + + if prompt_text.count("$") != len(prompt_params): + raise ValueError( + f"Number of parameters in prompt text {prompt_text} for prompt template {name} " + f"does not match number of specified parameters {prompt_params}" + ) + + # use case when PromptTemplate is loaded from a YAML file, we need to start and end the prompt text with quotes + prompt_text = prompt_text.strip("'").strip('"') + + t = Template(prompt_text) + try: + t.substitute(**{param: "" for param in prompt_params}) + except KeyError as e: + raise ValueError( + f"Invalid parameter {e} in prompt text " + f"{prompt_text} for prompt template {name}, specified parameters are {prompt_params}" + ) + + self.name = name + self.prompt_text = prompt_text + self.prompt_params = prompt_params + + def fill(self, *args, **kwargs) -> Dict[str, Any]: + """ + Prepares the prompt text parameters from non-keyword and keyword arguments. + + In the case of non-keyword arguments, the order of the arguments should match the left-to-right + order of appearance of the parameters in the prompt text. For example, if the prompt text is: + `Please come up with a question for the given context and the answer. Context: $documents; + Answer: $answers; Question:` then the first non-keyword argument will fill the $documents placeholder + and the second non-keyword argument will fill the $answers placeholder. + + In the case of keyword arguments, the order of the arguments does not matter. Placeholders in the + prompt text are filled with the corresponding keyword argument. + + :param args: non-keyword arguments to use for filling the prompt text + :param kwargs: keyword arguments to use for filling the prompt text + :return: a dictionary with the prompt text and the prompt parameters + """ + template_dict = {} + # attempt to resolve args first + if args: + if len(args) != len(self.prompt_params): + logger.warning( + f"For {self.name}, expected {self.prompt_params} arguments, instead " + f"got {len(args)} arguments {args}" + ) + for prompt_param, arg in zip(self.prompt_params, args): + template_dict[prompt_param] = [arg] if isinstance(arg, str) else arg + # then attempt to resolve kwargs + if kwargs: + for param in self.prompt_params: + if param in kwargs: + template_dict[param] = kwargs[param] + + if set(template_dict.keys()) != set(self.prompt_params): + available_params = set(list(template_dict.keys()) + list(set(kwargs.keys()))) + raise ValueError(f"Expected prompt params {self.prompt_params} but got {list(available_params)}") + + template_dict["prompt_template"] = self.prompt_text + return template_dict + + +PREDEFINED_PROMPT_TEMPLATES = [ + PromptTemplate( + name="question-answering", + prompt_text="Given the context please answer the question. Context: $documents; Question: $questions; Answer:", + prompt_params=["documents", "questions"], + ), + PromptTemplate( + name="question-generation", + prompt_text="Given the context please generate a question. Context: $documents; Question:", + prompt_params=["documents"], + ), + PromptTemplate( + name="conditioned-question-generation", + prompt_text="Please come up with a question for the given context and the answer. " + "Context: $documents; Answer: $answers; Question:", + prompt_params=["documents", "answers"], + ), + PromptTemplate( + name="summarization", prompt_text="Summarize this document: $documents Summary:", prompt_params=["documents"] + ), + PromptTemplate( + name="question-answering-check", + prompt_text="Does the following context contain the answer to the question. " + "Context: $documents; Question: $questions; Please answer yes or no! Answer:", + prompt_params=["documents", "questions"], + ), + PromptTemplate( + name="sentiment-analysis", + prompt_text="Please give a sentiment for this context. Answer with positive, " + "negative or neutral. Context: $documents; Answer:", + prompt_params=["documents"], + ), + PromptTemplate( + name="multiple-choice-question-answering", + prompt_text="Question:$questions ; Choose the most suitable option to answer the above question. " + "Options: $options; Answer:", + prompt_params=["questions", "options"], + ), + PromptTemplate( + name="topic-classification", + prompt_text="Categories: $options; What category best describes: $documents; Answer:", + prompt_params=["documents", "options"], + ), + PromptTemplate( + name="language-detection", + prompt_text="Detect the language in the following context and answer with the " + "name of the language. Context: $documents; Answer:", + ), + PromptTemplate( + name="translation", + prompt_text="Translate the following context to $target_language. Context: $documents; Translation:", + ), +] + + +class PromptModelInvocationLayer: + """ + PromptModelInvocationLayer implementations execute a prompt on an underlying model. + + The implementation can be a simple invocation on the underlying model running in a local runtime, or + could be even remote, for example, a call to a remote API endpoint. + """ + + def __init__(self, model_name_or_path: str, max_length: Optional[int] = 100, **kwargs): + if model_name_or_path is None or len(model_name_or_path) == 0: + raise ValueError("model_name_or_path cannot be None or empty string") + + self.model_name_or_path = model_name_or_path + self.max_length: Optional[int] = max_length + + @abstractmethod + def invoke(self, *args, **kwargs): + pass + + @classmethod + def supports(cls, model_name_or_path: str) -> bool: + return False + + +class HFLocalInvocationLayer(PromptModelInvocationLayer): + """ + A subclass of the PromptModelInvocationLayer class. It loads a pre-trained model from Hugging Face and + passes a prepared prompt into that model. + + Note: kwargs other than init parameter names are ignored to enable reflective construction of the class + as many variants of PromptModelInvocationLayer are possible and they may have different parameters + """ + + def __init__( + self, + model_name_or_path: str = "google/flan-t5-base", + max_length: Optional[int] = 100, + use_auth_token: Optional[Union[str, bool]] = None, + use_gpu: Optional[bool] = True, + devices: Optional[List[Union[str, torch.device]]] = None, + **kwargs, + ): + super().__init__(model_name_or_path, max_length) + self.use_auth_token = use_auth_token + + self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=False) + if len(self.devices) > 1: + logger.warning( + f"Multiple devices are not supported in {self.__class__.__name__} inference, " + f"using the first device {self.devices[0]}." + ) + + # Due to reflective construction of all invocation layers we might receive some + # unknown kwargs, so we need to take only the relevant. + # For more details refer to Hugging Face pipeline documentation + # Do not use `device_map` AND `device` at the same time as they will conflict + model_input_kwargs = { + key: kwargs[key] + for key in [ + "model_kwargs", + "trust_remote_code", + "revision", + "feature_extractor", + "tokenizer", + "config", + "use_fast", + "torch_dtype", + "device_map", + ] + if key in kwargs + } + # flatten model_kwargs one level + if "model_kwargs" in model_input_kwargs: + mkwargs = model_input_kwargs.pop("model_kwargs") + model_input_kwargs.update(mkwargs) + + torch_dtype = model_input_kwargs.get("torch_dtype") + if torch_dtype is not None: + if isinstance(torch_dtype, str): + if "torch." not in torch_dtype: + raise ValueError( + f"torch_dtype should be a torch.dtype or a string with 'torch.' prefix, got {torch_dtype}" + ) + torch_dtype_resolved = getattr(torch, torch_dtype.strip("torch.")) + elif isinstance(torch_dtype, torch.dtype): + torch_dtype_resolved = torch_dtype + else: + raise ValueError(f"Invalid torch_dtype value {torch_dtype}") + model_input_kwargs["torch_dtype"] = torch_dtype_resolved + + if len(model_input_kwargs) > 0: + logger.info("Using model input kwargs %s in %s", model_input_kwargs, self.__class__.__name__) + + self.pipe = pipeline( + "text2text-generation", + model=model_name_or_path, + device=self.devices[0] if "device_map" not in model_input_kwargs else None, + use_auth_token=self.use_auth_token, + model_kwargs=model_input_kwargs, + ) + + def invoke(self, *args, **kwargs): + """ + It takes a prompt and returns a list of generated text using the local Hugging Face transformers model + :return: A list of generated text. + """ + output = [] + if kwargs and "prompt" in kwargs: + prompt = kwargs.pop("prompt") + + # We might have some uncleaned kwargs, so we need to take only the relevant. + # For more details refer to Hugging Face Text2TextGenerationPipeline documentation + model_input_kwargs = { + key: kwargs[key] + for key in ["return_tensors", "return_text", "clean_up_tokenization_spaces", "truncation"] + if key in kwargs + } + output = self.pipe(prompt, max_length=self.max_length, **model_input_kwargs) + return [o["generated_text"] for o in output] + + @classmethod + def supports(cls, model_name_or_path: str) -> bool: + if not all(m in model_name_or_path for m in ["google", "flan", "t5"]): + return False + + try: + # if it is google flan t5, load it, we'll use it anyway and also check if model loads correctly + AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) + except EnvironmentError: + return False + return True + + +class OpenAIInvocationLayer(PromptModelInvocationLayer): + """ + PromptModelInvocationLayer implementation for OpenAI's GPT-3 InstructGPT models. Invocations are made via REST API. + See [OpenAI GPT-3](https://beta.openai.com/docs/models/gpt-3) for more details. + + Note: kwargs other than init parameter names are ignored to enable reflective construction of the class + as many variants of PromptModelInvocationLayer are possible and they may have different parameters + """ + + def __init__( + self, api_key: str, model_name_or_path: str = "text-davinci-003", max_length: Optional[int] = 100, **kwargs + ): + super().__init__(model_name_or_path, max_length) + if not isinstance(api_key, str) or len(api_key) == 0: + raise OpenAIError( + f"api_key {api_key} has to be a valid OpenAI key. Please visit https://beta.openai.com/ to get one." + ) + self.api_key = api_key + self.url = "https://api.openai.com/v1/completions" + + # Due to reflective construction of all invocation layers we might receive some + # unknown kwargs, so we need to take only the relevant. + # For more details refer to OpenAI documentation + self.model_input_kwargs = { + key: kwargs[key] + for key in [ + "suffix", + "max_tokens", + "temperature", + "top_p", + "n", + "logprobs", + "echo", + "stop", + "presence_penalty", + "frequency_penalty", + "best_of", + "logit_bias", + ] + if key in kwargs + } + + def invoke(self, *args, **kwargs): + """ + Invokes a prompt on the model. It takes in a prompt, and returns a list of responses using a REST invocation. + + :return: The responses are being returned. + """ + prompt = kwargs.get("prompt") + if not prompt: + raise ValueError( + f"No prompt provided. Model {self.model_name_or_path} requires prompt" + f"Make sure to provide prompt in kwargs" + ) + + kwargs_with_defaults = self.model_input_kwargs + if kwargs: + kwargs_with_defaults.update(kwargs) + payload = { + "model": self.model_name_or_path, + "prompt": prompt, + "suffix": kwargs_with_defaults.get("suffix", None), + "max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length), + "temperature": kwargs_with_defaults.get("temperature", 0.7), + "top_p": kwargs_with_defaults.get("top_p", 1), + "n": kwargs_with_defaults.get("n", 1), + "stream": False, # no support for streaming + "logprobs": kwargs_with_defaults.get("logprobs", None), + "echo": kwargs_with_defaults.get("echo", False), + "stop": kwargs_with_defaults.get("stop", None), + "presence_penalty": kwargs_with_defaults.get("presence_penalty", 0), + "frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0), + "best_of": kwargs.get("best_of", 1), + "logit_bias": kwargs.get("logit_bias", {}), + } + headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + response = requests.request("POST", self.url, headers=headers, data=json.dumps(payload), timeout=30) + res = json.loads(response.text) + + if response.status_code != 200: + openai_error: OpenAIError + if response.status_code == 429: + openai_error = OpenAIRateLimitError(f"API rate limit exceeded: {response.text}") + else: + openai_error = OpenAIError( + f"OpenAI returned an error.\n" + f"Status code: {response.status_code}\n" + f"Response body: {response.text}", + status_code=response.status_code, + ) + raise openai_error + + responses = [ans["text"].strip() for ans in res["choices"]] + return responses + + @classmethod + def supports(cls, model_name_or_path: str) -> bool: + return any(m for m in ["ada", "babbage", "davinci", "curie"] if m in model_name_or_path) + + +class PromptModel(BaseComponent): + """ + The PromptModel class is a component that uses a pre-trained model to generate text based on a prompt. Out of + the box, it supports two model invocation layers: Hugging Face transformers and OpenAI, with the ability to + register additional custom invocation layers. + + Although it is possible to use PromptModel to make prompt invocations on the underlying model, please use + PromptNode for interactions with the model. PromptModel instances are the practical approach for multiple + PromptNode instances to use a single PromptNode and thus save computational resources. + """ + + outgoing_edges = 1 + + def __init__( + self, + model_name_or_path: str = "google/flan-t5-base", + max_length: Optional[int] = 100, + api_key: Optional[str] = None, + use_auth_token: Optional[Union[str, bool]] = None, + use_gpu: Optional[bool] = None, + devices: Optional[List[Union[str, torch.device]]] = None, + model_kwargs: Optional[Dict] = None, + ): + super().__init__() + self.model_name_or_path = model_name_or_path + self.max_length = max_length + self.api_key = api_key + self.use_auth_token = use_auth_token + self.use_gpu = use_gpu + self.devices = devices + + self.model_kwargs = model_kwargs if model_kwargs else {} + + self.invocation_layers: List[Type[PromptModelInvocationLayer]] = [] + + self.register(HFLocalInvocationLayer) # pylint: disable=W0108 + self.register(OpenAIInvocationLayer) # pylint: disable=W0108 + + self.model_invocation_layer = self.create_invocation_layer() + + def create_invocation_layer(self) -> PromptModelInvocationLayer: + kwargs = { + "api_key": self.api_key, + "use_auth_token": self.use_auth_token, + "use_gpu": self.use_gpu, + "devices": self.devices, + } + all_kwargs = {**self.model_kwargs, **kwargs} + + for invocation_layer in self.invocation_layers: + if invocation_layer.supports(self.model_name_or_path): + return invocation_layer( + model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs + ) + raise ValueError( + f"Model {self.model_name_or_path} is not supported - no invocation layer found." + f"Currently supported models are: {self.invocation_layers}" + f"Register new invocation layer for {self.model_name_or_path} using the register method." + ) + + def register(self, invocation_layer: Type[PromptModelInvocationLayer]): + """ + Registers additional prompt model invocation layer. It takes a function that returns a boolean as a + matching condition on `model_name_or_path` and a class that implements `PromptModelInvocationLayer` interface. + """ + self.invocation_layers.append(invocation_layer) + + def invoke(self, prompt: Union[str, List[str]], **kwargs) -> List[str]: + """ + It takes in a prompt, and returns a list of responses using the underlying invocation layer. + + :param prompt: The prompt to use for the invocation, it could be a single prompt or a list of prompts + :param kwargs: Additional keyword arguments to pass to the invocation layer + :return: A list of model generated responses for the prompt or prompts + """ + output = self.model_invocation_layer.invoke(prompt=prompt, **kwargs) + return output + + def run( + self, + query: Optional[str] = None, + file_paths: Optional[List[str]] = None, + labels: Optional[MultiLabel] = None, + documents: Optional[List[Document]] = None, + meta: Optional[dict] = None, + ) -> Tuple[Dict, str]: + raise NotImplementedError("This method should never be implemented in the derived class") + + def run_batch( + self, + queries: Optional[Union[str, List[str]]] = None, + file_paths: Optional[List[str]] = None, + labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None, + documents: Optional[Union[List[Document], List[List[Document]]]] = None, + meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + params: Optional[dict] = None, + debug: Optional[bool] = None, + ): + raise NotImplementedError("This method should never be implemented in the derived class") + + +class PromptNode(BaseComponent): + """ + The PromptNode class is the central abstraction in Haystack's large language model (LLM) support. PromptNode + supports multiple NLP tasks out of the box. PromptNode allows users to perform multiple tasks, such as + summarization, question answering, question generation etc., using a single, unified model within the Haystack + framework. + + One of the benefits of PromptNode is that it allows users to define and add additional prompt templates + that the model supports. Defining additional prompt templates enables users to extend the model's capabilities + and use it for a broader range of NLP tasks within the Haystack ecosystem. Prompt engineers define templates + for each NLP task and register them with PromptNode. The burden of defining templates for each task rests on + the prompt engineers, not the users. + + Using an instance of PromptModel class, we can create multiple PromptNodes that share the same model, saving + the memory and time required to load the model multiple times. + + PromptNode also supports multiple model invocation layers: Hugging Face transformers and OpenAI with an + ability to register additional custom invocation layers. + + """ + + outgoing_edges: int = 1 + prompt_templates: Dict[str, PromptTemplate] = { + prompt_template.name: prompt_template for prompt_template in PREDEFINED_PROMPT_TEMPLATES # type: ignore + } + + def __init__( + self, + model_name_or_path: Union[str, PromptModel] = "google/flan-t5-base", + default_prompt_template: Optional[Union[str, PromptTemplate]] = None, + output_variable: Optional[str] = None, + max_length: Optional[int] = 100, + api_key: Optional[str] = None, + use_auth_token: Optional[Union[str, bool]] = None, + use_gpu: Optional[bool] = None, + devices: Optional[List[Union[str, torch.device]]] = None, + ): + super().__init__() + self.default_prompt_template: Union[str, PromptTemplate, None] = default_prompt_template + self.output_variable: Optional[str] = output_variable + self.model_name_or_path: Union[str, PromptModel] = model_name_or_path + self.prompt_model: PromptModel + if isinstance(self.default_prompt_template, str) and not self.is_supported_template( + self.default_prompt_template + ): + raise ValueError( + f"Prompt template {self.default_prompt_template} is not supported. " + f"Select one of: {self.get_prompt_template_names()} " + f"or first register a new prompt template using the add_prompt_template method." + ) + + if isinstance(model_name_or_path, str): + self.prompt_model = PromptModel( + model_name_or_path=model_name_or_path, + max_length=max_length, + api_key=api_key, + use_auth_token=use_auth_token, + use_gpu=use_gpu, + devices=devices, + ) + elif isinstance(model_name_or_path, PromptModel): + self.prompt_model = model_name_or_path + else: + raise ValueError(f"model_name_or_path must be either a string or a PromptModel object") + + def __call__(self, *args, **kwargs) -> List[str]: + """ + This method is invoked when the component is called directly, for example: + ```python + PromptNode pn = ... + sa = pn.set_default_prompt_template("sentiment-analysis") + sa(documents=[Document("I am in love and I feel great!")]) + ``` + """ + if "prompt_template_name" in kwargs: + prompt_template_name = kwargs["prompt_template_name"] + kwargs.pop("prompt_template_name") + return self.prompt(prompt_template_name, *args, **kwargs) + else: + return self.prompt(self.default_prompt_template, *args, **kwargs) + + def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, **kwargs) -> List[str]: + """ + Prompts the model and represents the central API for the PromptNode. It takes a prompt template, + a list of non-keyword and keyword arguments, and returns a list of strings - the responses from + the underlying model. + + The optional prompt_template parameter, if specified, takes precedence over the default prompt + template for this PromptNode. + + :param prompt_template: The name of the optional prompt template to use + :return: A list of strings as model responses + """ + results = [] + prompt_prepared: Dict[str, Any] = {} + if isinstance(prompt_template, str) and not self.is_supported_template(prompt_template): + raise ValueError( + f"{prompt_template} not supported, please select one of: {self.get_prompt_template_names()} " + f"or pass a PromptTemplate instance for prompting." + ) + + invoke_template = self.default_prompt_template if prompt_template is None else prompt_template + if args and invoke_template is None: + # create straightforward prompt on the input, no templates used + prompt_prepared["prompt"] = list(args) + else: + template_to_fill: PromptTemplate + if isinstance(prompt_template, PromptTemplate): + template_to_fill = prompt_template + elif isinstance(prompt_template, str): + template_to_fill = self.get_prompt_template(prompt_template) + else: + raise ValueError(f"{prompt_template} with args {args} , and kwargs {kwargs} not supported") + # we have potentially args and kwargs; task selected, so templating is needed + prompt_prepared = template_to_fill.fill(*args, **kwargs) + + # straightforward prompt, no templates used + if "prompt" in prompt_prepared: + for prompt in prompt_prepared["prompt"]: + output = self.prompt_model.invoke(prompt) + for item in output: + results.append(item) + # templated prompt + # we have a prompt dictionary with prompt_template text and key/value pairs for template variables + # where key is the variable name and value is a list of variable values + # we invoke the model iterating through a list of prompt variable values replacing the variables + # in the prompt template + elif "prompt_template" in prompt_prepared: + template = Template(prompt_prepared["prompt_template"]) + prompt_context_copy = prompt_prepared.copy() + prompt_context_copy.pop("prompt_template") + for prompt_context_values in zip(*prompt_context_copy.values()): + template_input = {key: prompt_context_values[idx] for idx, key in enumerate(prompt_context_copy.keys())} + template_prepared: str = template.substitute(template_input) + # remove template keys from kwargs so we don't pass them to the model + removed_keys = [kwargs.pop(key) for key in template_input.keys() if key in kwargs] + output = self.prompt_model.invoke(template_prepared, **kwargs) + for item in output: + results.append(item) + return results + + @classmethod + def add_prompt_template(cls, prompt_template: PromptTemplate) -> None: + """ + Adds a prompt template to the list of supported prompt templates. + :param prompt_template: PromptTemplate object to be added. + :return: None + """ + if prompt_template.name in cls.prompt_templates: + raise ValueError( + f"Prompt template {prompt_template.name} already exists " + f"Please select a different name to add this prompt template." + ) + + cls.prompt_templates[prompt_template.name] = prompt_template # type: ignore + + @classmethod + def remove_prompt_template(cls, prompt_template: str) -> PromptTemplate: + """ + Removes a prompt template from the list of supported prompt templates. + :param prompt_template: Name of the prompt template to be removed. + :return: PromptTemplate object that was removed. + """ + if prompt_template in [template.name for template in PREDEFINED_PROMPT_TEMPLATES]: + raise ValueError(f"Cannot remove predefined prompt template {prompt_template}") + if prompt_template not in cls.prompt_templates: + raise ValueError(f"Prompt template {prompt_template} does not exist") + + return cls.prompt_templates.pop(prompt_template) + + def set_default_prompt_template(self, prompt_template: Union[str, PromptTemplate]) -> "PromptNode": + """ + Sets the default prompt template for the node. + :param prompt_template: the prompt template to be set as default. + :return: the current PromptNode object + """ + if not self.is_supported_template(prompt_template): + raise ValueError( + f"{prompt_template} not supported, please select one of: {self.get_prompt_template_names()}" + ) + + self.default_prompt_template = prompt_template + return self + + @classmethod + def get_prompt_templates(cls) -> List[PromptTemplate]: + """ + Returns the list of supported prompt templates. + :return: List of supported prompt templates. + """ + return list(cls.prompt_templates.values()) + + @classmethod + def get_prompt_template_names(cls) -> List[str]: + """ + Returns the list of supported prompt template names. + :return: List of supported prompt template names. + """ + return list(cls.prompt_templates.keys()) + + @classmethod + def is_supported_template(cls, prompt_template: Union[str, PromptTemplate]) -> bool: + """ + Checks if a prompt template is supported. + :param prompt_template: the prompt template to be checked. + :return: True if the prompt template is supported, False otherwise. + """ + template_name = prompt_template if isinstance(prompt_template, str) else prompt_template.name + return template_name in cls.prompt_templates + + @classmethod + def get_prompt_template(cls, prompt_template_name: str) -> PromptTemplate: + """ + Returns a prompt template by name. + :param prompt_template_name: the name of the prompt template to be returned. + :return: the prompt template object. + """ + if prompt_template_name not in cls.prompt_templates: + raise ValueError(f"Prompt template {prompt_template_name} not supported") + return cls.prompt_templates[prompt_template_name] + + @classmethod + def prompt_template_params(cls, prompt_template: str) -> List[str]: + """ + Returns the list of parameters for a prompt template. + :param prompt_template: the name of the prompt template. + :return: the list of parameters for the prompt template. + """ + if not cls.is_supported_template(prompt_template): + raise ValueError( + f"{prompt_template} not supported, please select one of: {cls.get_prompt_template_names()}" + ) + + return list(cls.prompt_templates[prompt_template].prompt_params) + + def __eq__(self, other): + if isinstance(other, PromptNode): + if self.default_prompt_template != other.default_prompt_template: + return False + return self.model_name_or_path == other.model_name_or_path + return False + + def __hash__(self): + return hash((self.default_prompt_template, self.model_name_or_path)) + + def run( + self, + query: Optional[str] = None, + file_paths: Optional[List[str]] = None, + labels: Optional[MultiLabel] = None, + documents: Optional[List[Document]] = None, + meta: Optional[dict] = None, + ) -> Tuple[Dict, str]: + """ + Runs the prompt node on these inputs parameters. Returns the output of the prompt model + Parameters file_paths, labels, and meta are usually ignored. + + :param query: the query is usually ignored by the prompt node unless it is used as a parameter in the + prompt template. + :param file_paths: the file paths are usually ignored by the prompt node unless it is used as a parameter + in the prompt template. + :param labels: the labels are usually ignored by the prompt node unless it is used as a parameter in the + prompt template. + :param documents: the documents to be used for the prompt. + :param meta: the meta to be used for the prompt. Usually not used. + """ + + if not meta: + meta = {} + # invocation_context is a dictionary that is passed from a pipeline node to a pipeline node and can be used + # to pass results from a pipeline node to any other downstream pipeline node. + if "invocation_context" not in meta: + meta["invocation_context"] = {} + + results = self( + query=query, + labels=labels, + documents=[doc.content for doc in documents if isinstance(doc.content, str)] if documents else [], + **meta["invocation_context"], + ) + + if self.output_variable: + meta["invocation_context"][self.output_variable] = results + return {"results": results, "meta": {**meta}}, "output_1" + + def run_batch( + self, + queries: Optional[Union[str, List[str]]] = None, + file_paths: Optional[List[str]] = None, + labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None, + documents: Optional[Union[List[Document], List[List[Document]]]] = None, + meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + params: Optional[dict] = None, + debug: Optional[bool] = None, + ): + pass diff --git a/haystack/pipelines/config.py b/haystack/pipelines/config.py index 364219f237..6f6d74985a 100644 --- a/haystack/pipelines/config.py +++ b/haystack/pipelines/config.py @@ -98,7 +98,8 @@ def read_pipeline_config_from_yaml(path: Path) -> Dict[str, Any]: return yaml.safe_load(stream) -JSON_FIELDS = ["custom_query"] # ElasticsearchDocumentStore.custom_query +JSON_FIELDS = ["custom_query"] +SKIP_VALIDATION_KEYS = ["prompt_text"] # PromptTemplate, PromptNode def validate_config_strings(pipeline_config: Any, is_value: bool = False): @@ -123,6 +124,8 @@ def validate_config_strings(pipeline_config: Any, is_value: bool = False): json.loads(value) except json.decoder.JSONDecodeError as e: raise PipelineConfigError(f"'{pipeline_config}' does not contain valid JSON.") + elif key in SKIP_VALIDATION_KEYS: + continue else: validate_config_strings(key) validate_config_strings(value, is_value=True) diff --git a/test/conftest.py b/test/conftest.py index fb452d61ef..253869483a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -64,6 +64,7 @@ QuestionGenerator, ) from haystack.modeling.infer import Inferencer, QAInferencer +from haystack.nodes.prompt import PromptNode, PromptModel from haystack.schema import Document from haystack.utils.import_utils import _optional_component_not_installed @@ -1048,3 +1049,19 @@ def bert_base_squad2(request): use_fast=True, # TODO parametrize this to test slow as well ) return model + + +@pytest.fixture +def prompt_node(): + return PromptNode("google/flan-t5-small", devices=["cpu"]) + + +@pytest.fixture +def prompt_model(request): + if request.param == "openai": + api_key = os.environ.get("OPENAI_API_KEY", "KEY_NOT_FOUND") + if api_key is None or api_key == "": + api_key = "KEY_NOT_FOUND" + return PromptModel("text-davinci-003", api_key=api_key) + else: + return PromptModel("google/flan-t5-base", devices=["cpu"]) diff --git a/test/nodes/test_prompt_node.py b/test/nodes/test_prompt_node.py new file mode 100644 index 0000000000..02d95e6b0f --- /dev/null +++ b/test/nodes/test_prompt_node.py @@ -0,0 +1,477 @@ +import os + +import pytest +import torch + +from haystack import Document, Pipeline +from haystack.errors import OpenAIError +from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel + + +def is_openai_api_key_set(api_key: str): + return len(api_key) > 0 and api_key != "KEY_NOT_FOUND" + + +def test_prompt_templates(): + p = PromptTemplate("t1", "Here is some fake template with variable $foo", ["foo"]) + + with pytest.raises(ValueError, match="Number of parameters in"): + PromptTemplate("t2", "Here is some fake template with variable $foo and $bar", ["foo"]) + + with pytest.raises(ValueError, match="Invalid parameter"): + PromptTemplate("t2", "Here is some fake template with variable $footur", ["foo"]) + + with pytest.raises(ValueError, match="Number of parameters in"): + PromptTemplate("t2", "Here is some fake template with variable $foo and $bar", ["foo", "bar", "baz"]) + + p = PromptTemplate("t3", "Here is some fake template with variable $for and $bar", ["for", "bar"]) + + # last parameter: "prompt_params" can be omitted + p = PromptTemplate("t4", "Here is some fake template with variable $foo and $bar") + assert p.prompt_params == ["foo", "bar"] + + p = PromptTemplate("t4", "Here is some fake template with variable $foo1 and $bar2") + assert p.prompt_params == ["foo1", "bar2"] + + p = PromptTemplate("t4", "Here is some fake template with variable $foo_1 and $bar_2") + assert p.prompt_params == ["foo_1", "bar_2"] + + p = PromptTemplate("t4", "Here is some fake template with variable $Foo_1 and $Bar_2") + assert p.prompt_params == ["Foo_1", "Bar_2"] + + p = PromptTemplate("t4", "'Here is some fake template with variable $baz'") + assert p.prompt_params == ["baz"] + # strip single quotes, happens in YAML as we need to use single quotes for the template string + assert p.prompt_text == "Here is some fake template with variable $baz" + + p = PromptTemplate("t4", '"Here is some fake template with variable $baz"') + assert p.prompt_params == ["baz"] + # strip double quotes, happens in YAML as we need to use single quotes for the template string + assert p.prompt_text == "Here is some fake template with variable $baz" + + +def test_create_prompt_model(): + model = PromptModel("google/flan-t5-small") + assert model.model_name_or_path == "google/flan-t5-small" + + model = PromptModel() + assert model.model_name_or_path == "google/flan-t5-base" + + with pytest.raises(OpenAIError): + # davinci selected but no API key provided + model = PromptModel("text-davinci-003") + + model = PromptModel("text-davinci-003", api_key="no need to provide a real key") + assert model.model_name_or_path == "text-davinci-003" + + with pytest.raises(ValueError, match="Model some-random-model is not supported"): + PromptModel("some-random-model") + + # we can also pass model kwargs to the PromptModel + model = PromptModel("google/flan-t5-small", model_kwargs={"model_kwargs": {"torch_dtype": torch.bfloat16}}) + assert model.model_name_or_path == "google/flan-t5-small" + + # we can also pass kwargs directly, see HF Pipeline constructor + model = PromptModel("google/flan-t5-small", model_kwargs={"torch_dtype": torch.bfloat16}) + assert model.model_name_or_path == "google/flan-t5-small" + + # we can't use device_map auto without accelerate library installed + with pytest.raises(ImportError, match="requires Accelerate: `pip install accelerate`"): + model = PromptModel("google/flan-t5-small", model_kwargs={"device_map": "auto"}) + assert model.model_name_or_path == "google/flan-t5-small" + + +def test_create_prompt_node(): + prompt_node = PromptNode() + assert prompt_node is not None + assert prompt_node.prompt_model is not None + + prompt_node = PromptNode("google/flan-t5-small") + assert prompt_node is not None + assert prompt_node.model_name_or_path == "google/flan-t5-small" + assert prompt_node.prompt_model is not None + + with pytest.raises(OpenAIError): + # davinci selected but no API key provided + prompt_node = PromptNode("text-davinci-003") + + prompt_node = PromptNode("text-davinci-003", api_key="no need to provide a real key") + assert prompt_node is not None + assert prompt_node.model_name_or_path == "text-davinci-003" + assert prompt_node.prompt_model is not None + + with pytest.raises(ValueError, match="Model vblagoje/bart_lfqa is not supported"): + # yes vblagoje/bart_lfqa is AutoModelForSeq2SeqLM, can be downloaded, however it is useless for prompting + # currently support only T5-Flan models + prompt_node = PromptNode("vblagoje/bart_lfqa") + + with pytest.raises(ValueError, match="Model valhalla/t5-base-e2e-qg is not supported"): + # yes valhalla/t5-base-e2e-qg is AutoModelForSeq2SeqLM, can be downloaded, however it is useless for prompting + # currently support only T5-Flan models + prompt_node = PromptNode("valhalla/t5-base-e2e-qg") + + with pytest.raises(ValueError, match="Model some-random-model is not supported"): + PromptNode("some-random-model") + + +def test_add_and_remove_template(prompt_node): + num_default_tasks = len(prompt_node.get_prompt_template_names()) + custom_task = PromptTemplate( + name="custom-task", prompt_text="Custom task: $param1, $param2", prompt_params=["param1", "param2"] + ) + prompt_node.add_prompt_template(custom_task) + assert len(prompt_node.get_prompt_template_names()) == num_default_tasks + 1 + assert "custom-task" in prompt_node.get_prompt_template_names() + + assert prompt_node.remove_prompt_template("custom-task") is not None + assert "custom-task" not in prompt_node.get_prompt_template_names() + + +def test_invalid_template(prompt_node): + with pytest.raises(ValueError, match="Invalid parameter"): + PromptTemplate( + name="custom-task", prompt_text="Custom task: $pram1 $param2", prompt_params=["param1", "param2"] + ) + + with pytest.raises(ValueError, match="Number of parameters"): + PromptTemplate(name="custom-task", prompt_text="Custom task: $param1", prompt_params=["param1", "param2"]) + + +def test_add_template_and_invoke(prompt_node): + tt = PromptTemplate( + name="sentiment-analysis-new", + prompt_text="Please give a sentiment for this context. Answer with positive, " + "negative or neutral. Context: $documents; Answer:", + prompt_params=["documents"], + ) + prompt_node.add_prompt_template(tt) + + r = prompt_node.prompt("sentiment-analysis-new", documents=["Berlin is an amazing city."]) + assert r[0].casefold() == "positive" + + +def test_on_the_fly_prompt(prompt_node): + tt = PromptTemplate( + name="sentiment-analysis-temp", + prompt_text="Please give a sentiment for this context. Answer with positive, " + "negative or neutral. Context: $documents; Answer:", + prompt_params=["documents"], + ) + r = prompt_node.prompt(tt, documents=["Berlin is an amazing city."]) + assert r[0].casefold() == "positive" + + +def test_direct_prompting(prompt_node): + r = prompt_node("What is the capital of Germany?") + assert r[0].casefold() == "berlin" + + r = prompt_node("What is the capital of Germany?", "What is the secret of universe?") + assert r[0].casefold() == "berlin" + assert len(r[1]) > 0 + + r = prompt_node("Capital of Germany is Berlin", task="question-generation") + assert len(r[0]) > 10 and "Germany" in r[0] + + r = prompt_node(["Capital of Germany is Berlin", "Capital of France is Paris"], task="question-generation") + assert len(r) == 2 + + +def test_question_generation(prompt_node): + r = prompt_node.prompt("question-generation", documents=["Berlin is the capital of Germany."]) + assert len(r) == 1 and len(r[0]) > 0 + + +def test_template_selection(prompt_node): + qa = prompt_node.set_default_prompt_template("question-answering") + r = qa( + ["Berlin is the capital of Germany.", "Paris is the capital of France."], + ["What is the capital of Germany?", "What is the capital of France"], + ) + assert r[0].casefold() == "berlin" and r[1].casefold() == "paris" + + +def test_has_supported_template_names(prompt_node): + assert len(prompt_node.get_prompt_template_names()) > 0 + + +def test_invalid_template_params(prompt_node): + with pytest.raises(ValueError, match="Expected prompt params"): + prompt_node.prompt("question-answering", {"some_crazy_key": "Berlin is the capital of Germany."}) + + +def test_wrong_template_params(prompt_node): + with pytest.raises(ValueError, match="Expected prompt params"): + # with don't have options param, multiple choice QA has + prompt_node.prompt("question-answering", options=["Berlin is the capital of Germany."]) + + +def test_run_invalid_template(prompt_node): + with pytest.raises(ValueError, match="invalid-task not supported"): + prompt_node.prompt("invalid-task", {}) + + +def test_invalid_prompting(prompt_node): + with pytest.raises(ValueError, match="Hey there, what is the best city in the worl"): + prompt_node.prompt( + "Hey there, what is the best city in the world?" "Hey there, what is the best city in the world?" + ) + + with pytest.raises(ValueError, match="Hey there, what is the best city in the"): + prompt_node.prompt(["Hey there, what is the best city in the world?", "Hey, answer me!"]) + + +def test_invalid_state_ops(prompt_node): + with pytest.raises(ValueError, match="Prompt template no_such_task_exists"): + prompt_node.remove_prompt_template("no_such_task_exists") + # remove default task + prompt_node.remove_prompt_template("question-answering") + + +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", +) +def test_open_ai_prompt_with_params(): + pm = PromptModel("text-davinci-003", api_key=os.environ["OPENAI_API_KEY"]) + pn = PromptNode(pm) + optional_davinci_params = {"temperature": 0.5, "max_tokens": 10, "top_p": 1, "frequency_penalty": 0.5} + r = pn.prompt("question-generation", documents=["Berlin is the capital of Germany."], **optional_davinci_params) + assert len(r) == 1 and len(r[0]) > 0 + + +@pytest.mark.parametrize("prompt_model", ["hf", "openai"], indirect=True) +def test_simple_pipeline(prompt_model): + if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key): + pytest.skip("No API key found for OpenAI, skipping test") + + node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis") + + pipe = Pipeline() + pipe.add_node(component=node, name="prompt_node", inputs=["Query"]) + result = pipe.run(query="not relevant", documents=[Document("Berlin is an amazing city.")]) + assert result["results"][0].casefold() == "positive" + + +@pytest.mark.parametrize("prompt_model", ["hf", "openai"], indirect=True) +def test_complex_pipeline(prompt_model): + if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key): + pytest.skip("No API key found for OpenAI, skipping test") + + node = PromptNode(prompt_model, default_prompt_template="question-generation", output_variable="questions") + node2 = PromptNode(prompt_model, default_prompt_template="question-answering") + + pipe = Pipeline() + pipe.add_node(component=node, name="prompt_node", inputs=["Query"]) + pipe.add_node(component=node2, name="prompt_node_2", inputs=["prompt_node"]) + result = pipe.run(query="not relevant", documents=[Document("Berlin is the capital of Germany")]) + + assert "berlin" in result["results"][0].casefold() + + +def test_complex_pipeline_with_shared_model(): + model = PromptModel() + node = PromptNode( + model_name_or_path=model, default_prompt_template="question-generation", output_variable="questions" + ) + node2 = PromptNode(model_name_or_path=model, default_prompt_template="question-answering") + + pipe = Pipeline() + pipe.add_node(component=node, name="prompt_node", inputs=["Query"]) + pipe.add_node(component=node2, name="prompt_node_2", inputs=["prompt_node"]) + result = pipe.run(query="not relevant", documents=[Document("Berlin is the capital of Germany")]) + + assert result["results"][0] == "Berlin" + + +def test_simple_pipeline_yaml(tmp_path): + with open(tmp_path / "tmp_config.yml", "w") as tmp_file: + tmp_file.write( + f""" + version: ignore + components: + - name: p1 + params: + default_prompt_template: sentiment-analysis + type: PromptNode + pipelines: + - name: query + nodes: + - name: p1 + inputs: + - Query + """ + ) + pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") + result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")]) + assert result["results"][0] == "positive" + + +def test_complex_pipeline_yaml(tmp_path): + with open(tmp_path / "tmp_config.yml", "w") as tmp_file: + tmp_file.write( + f""" + version: ignore + components: + - name: p1 + params: + default_prompt_template: question-generation + output_variable: questions + type: PromptNode + - name: p2 + params: + default_prompt_template: question-answering + type: PromptNode + pipelines: + - name: query + nodes: + - name: p1 + inputs: + - Query + - name: p2 + inputs: + - p1 + """ + ) + pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") + result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")]) + assert result["results"][0] == "Berlin" + assert len(result["meta"]["invocation_context"]) > 0 + + +def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path): + with open(tmp_path / "tmp_config.yml", "w") as tmp_file: + tmp_file.write( + f""" + version: ignore + components: + - name: pmodel + type: PromptModel + - name: p1 + params: + model_name_or_path: pmodel + default_prompt_template: question-generation + output_variable: questions + type: PromptNode + - name: p2 + params: + model_name_or_path: pmodel + default_prompt_template: question-answering + type: PromptNode + pipelines: + - name: query + nodes: + - name: p1 + inputs: + - Query + - name: p2 + inputs: + - p1 + """ + ) + pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") + result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")]) + assert "Berlin" in result["results"][0] + assert len(result["meta"]["invocation_context"]) > 0 + + +def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_path): + with open(tmp_path / "tmp_config_with_prompt_template.yml", "w") as tmp_file: + tmp_file.write( + f""" + version: ignore + components: + - name: pmodel + type: PromptModel + params: + model_name_or_path: google/flan-t5-small + model_kwargs: + torch_dtype: torch.bfloat16 + - name: question_generation_template + type: PromptTemplate + params: + name: question-generation-new + prompt_text: "Given the context please generate a question. Context: $documents; Question:" + - name: p1 + params: + model_name_or_path: pmodel + default_prompt_template: question_generation_template + output_variable: questions + type: PromptNode + - name: p2 + params: + model_name_or_path: pmodel + default_prompt_template: question-answering + type: PromptNode + pipelines: + - name: query + nodes: + - name: p1 + inputs: + - Query + - name: p2 + inputs: + - p1 + """ + ) + pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config_with_prompt_template.yml") + result = pipeline.run(query="not relevant", documents=[Document("Berlin is an amazing city.")]) + assert "Berlin" in result["results"][0] + assert len(result["meta"]["invocation_context"]) > 0 + + +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", +) +def test_complex_pipeline_with_all_features(tmp_path): + api_key = os.environ.get("OPENAI_API_KEY", None) + with open(tmp_path / "tmp_config_with_prompt_template.yml", "w") as tmp_file: + tmp_file.write( + f""" + version: ignore + components: + - name: pmodel + type: PromptModel + params: + model_name_or_path: google/flan-t5-small + model_kwargs: + torch_dtype: torch.bfloat16 + - name: pmodel_openai + type: PromptModel + params: + model_name_or_path: text-davinci-003 + model_kwargs: + temperature: 0.9 + max_tokens: 64 + api_key: {api_key} + - name: question_generation_template + type: PromptTemplate + params: + name: question-generation-new + prompt_text: "Given the context please generate a question. Context: $documents; Question:" + - name: p1 + params: + model_name_or_path: pmodel_openai + default_prompt_template: question_generation_template + output_variable: questions + type: PromptNode + - name: p2 + params: + model_name_or_path: pmodel + default_prompt_template: question-answering + type: PromptNode + pipelines: + - name: query + nodes: + - name: p1 + inputs: + - Query + - name: p2 + inputs: + - p1 + """ + ) + pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config_with_prompt_template.yml") + result = pipeline.run(query="not relevant", documents=[Document("Berlin is a city in Germany.")]) + assert "Berlin" in result["results"][0] or "Germany" in result["results"][0] + assert len(result["meta"]["invocation_context"]) > 0