diff --git a/docs/pydoc/config/generators_api.yml b/docs/pydoc/config/generators_api.yml index 4fcfa74879..fa1f638e99 100644 --- a/docs/pydoc/config/generators_api.yml +++ b/docs/pydoc/config/generators_api.yml @@ -7,6 +7,7 @@ loaders: "hugging_face_local", "hugging_face_api", "openai", + "openai_dalle", "chat/azure", "chat/hugging_face_local", "chat/hugging_face_api", diff --git a/haystack/components/generators/__init__.py b/haystack/components/generators/__init__.py index b93270d82f..952c2dadd2 100644 --- a/haystack/components/generators/__init__.py +++ b/haystack/components/generators/__init__.py @@ -8,5 +8,12 @@ from haystack.components.generators.azure import AzureOpenAIGenerator from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator from haystack.components.generators.hugging_face_api import HuggingFaceAPIGenerator +from haystack.components.generators.openai_dalle import DALLEImageGenerator -__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceAPIGenerator", "OpenAIGenerator", "AzureOpenAIGenerator"] +__all__ = [ + "HuggingFaceLocalGenerator", + "HuggingFaceAPIGenerator", + "OpenAIGenerator", + "AzureOpenAIGenerator", + "DALLEImageGenerator", +] diff --git a/haystack/components/generators/openai_dalle.py b/haystack/components/generators/openai_dalle.py new file mode 100644 index 0000000000..8b60c13c8b --- /dev/null +++ b/haystack/components/generators/openai_dalle.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import Any, Dict, List, Literal, Optional + +from openai import OpenAI +from openai.types.image import Image + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.utils import Secret, deserialize_secrets_inplace + +logger = logging.getLogger(__name__) + + +@component +class DALLEImageGenerator: + """ + Generates images using OpenAI's DALL-E model. + + For details on OpenAI API parameters, see + [OpenAI documentation](https://platform.openai.com/docs/api-reference/images/create). + + ### Usage example + + ```python + from haystack.components.generators import DALLEImageGenerator + image_generator = DALLEImageGenerator() + response = image_generator.run("Show me a picture of a black cat.") + print(response) + ``` + """ + + def __init__( # pylint: disable=too-many-positional-arguments + self, + model: str = "dall-e-3", + quality: Literal["standard", "hd"] = "standard", + size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] = "1024x1024", + response_format: Literal["url", "b64_json"] = "url", + api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"), + api_base_url: Optional[str] = None, + organization: Optional[str] = None, + timeout: Optional[float] = None, + max_retries: Optional[int] = None, + ): + """ + Creates an instance of DALLEImageGenerator. Unless specified otherwise in `model`, uses OpenAI's dall-e-3. + + :param model: The model to use for image generation. Can be "dall-e-2" or "dall-e-3". + :param quality: The quality of the generated image. Can be "standard" or "hd". + :param size: The size of the generated images. + Must be one of 256x256, 512x512, or 1024x1024 for dall-e-2. + Must be one of 1024x1024, 1792x1024, or 1024x1792 for dall-e-3 models. + :param response_format: The format of the response. Can be "url" or "b64_json". + :param api_key: The OpenAI API key to connect to OpenAI. + :param api_base_url: An optional base URL. + :param organization: The Organization ID, defaults to `None`. + :param timeout: + Timeout for OpenAI Client calls. If not set, it is inferred from the `OPENAI_TIMEOUT` environment variable + or set to 30. + :param max_retries: + Maximum retries to establish contact with OpenAI if it returns an internal error. If not set, it is inferred + from the `OPENAI_MAX_RETRIES` environment variable or set to 5. + """ + self.model = model + self.quality = quality + self.size = size + self.response_format = response_format + self.api_key = api_key + self.api_base_url = api_base_url + self.organization = organization + + self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0)) + self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5)) + + self.client: Optional[OpenAI] = None + + def warm_up(self) -> None: + """ + Warm up the OpenAI client. + """ + if self.client is None: + self.client = OpenAI( + api_key=self.api_key.resolve_value(), + organization=self.organization, + base_url=self.api_base_url, + timeout=self.timeout, + max_retries=self.max_retries, + ) + + @component.output_types(images=List[str], revised_prompt=str) + def run( + self, + prompt: str, + size: Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]] = None, + quality: Optional[Literal["standard", "hd"]] = None, + response_format: Optional[Optional[Literal["url", "b64_json"]]] = None, + ): + """ + Invokes the image generation inference based on the provided prompt and generation parameters. + + :param prompt: The prompt to generate the image. + :param size: If provided, overrides the size provided during initialization. + :param quality: If provided, overrides the quality provided during initialization. + :param response_format: If provided, overrides the response format provided during initialization. + + :returns: + A dictionary containing the generated list of images and the revised prompt. + Depending on the `response_format` parameter, the list of images can be URLs or base64 encoded JSON strings. + The revised prompt is the prompt that was used to generate the image, if there was any revision + to the prompt made by OpenAI. + """ + if self.client is None: + raise RuntimeError( + "The component DALLEImageGenerator wasn't warmed up. Run 'warm_up()' before calling 'run()'." + ) + + size = size or self.size + quality = quality or self.quality + response_format = response_format or self.response_format + response = self.client.images.generate( + model=self.model, prompt=prompt, size=size, quality=quality, response_format=response_format, n=1 + ) + image: Image = response.data[0] + image_str = image.url or image.b64_json or "" + return {"images": [image_str], "revised_prompt": image.revised_prompt or ""} + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + The serialized component as a dictionary. + """ + return default_to_dict( # type: ignore + self, + model=self.model, + quality=self.quality, + size=self.size, + response_format=self.response_format, + api_key=self.api_key.to_dict(), + api_base_url=self.api_base_url, + organization=self.organization, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DALLEImageGenerator": + """ + Deserialize this component from a dictionary. + + :param data: + The dictionary representation of this component. + :returns: + The deserialized component instance. + """ + init_params = data.get("init_parameters", {}) + deserialize_secrets_inplace(init_params, keys=["api_key"]) + return default_from_dict(cls, data) # type: ignore diff --git a/releasenotes/notes/add-dalle-image-generator-495aa11823e11a60.yaml b/releasenotes/notes/add-dalle-image-generator-495aa11823e11a60.yaml new file mode 100644 index 0000000000..71e93d5c30 --- /dev/null +++ b/releasenotes/notes/add-dalle-image-generator-495aa11823e11a60.yaml @@ -0,0 +1,12 @@ +--- +features: + - | + We've added a new **DALLEImageGenerator** component, bringing image generation with OpenAI's DALL-E to the Haystack + + - **Easy to Use**: Just a few lines of code to get started: + ```python + from haystack.components.generators import DALLEImageGenerator + image_generator = DALLEImageGenerator() + response = image_generator.run("Show me a picture of a black cat.") + print(response) + ``` diff --git a/test/components/generators/test_openai_dalle.py b/test/components/generators/test_openai_dalle.py new file mode 100644 index 0000000000..0c319d020a --- /dev/null +++ b/test/components/generators/test_openai_dalle.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from unittest.mock import Mock, patch +from haystack.utils import Secret + +from openai.types.image import Image +from openai.types import ImagesResponse +from haystack.components.generators.openai_dalle import DALLEImageGenerator + + +@pytest.fixture +def mock_image_response(): + with patch("openai.resources.images.Images.generate") as mock_image_generate: + image_response = ImagesResponse(created=1630000000, data=[Image(url="test-url", revised_prompt="test-prompt")]) + mock_image_generate.return_value = image_response + yield mock_image_generate + + +class TestDALLEImageGenerator: + def test_init_default(self, monkeypatch): + component = DALLEImageGenerator() + assert component.model == "dall-e-3" + assert component.quality == "standard" + assert component.size == "1024x1024" + assert component.response_format == "url" + assert component.api_key == Secret.from_env_var("OPENAI_API_KEY") + assert component.api_base_url is None + assert component.organization is None + assert pytest.approx(component.timeout) == 30.0 + assert component.max_retries is 5 + + def test_init_with_params(self, monkeypatch): + component = DALLEImageGenerator( + model="dall-e-2", + quality="hd", + size="256x256", + response_format="b64_json", + api_key=Secret.from_env_var("EXAMPLE_API_KEY"), + api_base_url="https://api.openai.com", + organization="test-org", + timeout=60, + max_retries=10, + ) + assert component.model == "dall-e-2" + assert component.quality == "hd" + assert component.size == "256x256" + assert component.response_format == "b64_json" + assert component.api_key == Secret.from_env_var("EXAMPLE_API_KEY") + assert component.api_base_url == "https://api.openai.com" + assert component.organization == "test-org" + assert pytest.approx(component.timeout) == 60.0 + assert component.max_retries == 10 + + def test_warm_up(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + component = DALLEImageGenerator() + component.warm_up() + assert component.client.api_key == "test-api-key" + assert component.client.timeout == 30 + assert component.client.max_retries == 5 + + def test_to_dict(self): + generator = DALLEImageGenerator() + data = generator.to_dict() + assert data == { + "type": "haystack.components.generators.openai_dalle.DALLEImageGenerator", + "init_parameters": { + "model": "dall-e-3", + "quality": "standard", + "size": "1024x1024", + "response_format": "url", + "api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}, + "api_base_url": None, + "organization": None, + }, + } + + def test_to_dict_with_params(self): + generator = DALLEImageGenerator( + model="dall-e-2", + quality="hd", + size="256x256", + response_format="b64_json", + api_key=Secret.from_env_var("EXAMPLE_API_KEY"), + api_base_url="https://api.openai.com", + organization="test-org", + timeout=60, + max_retries=10, + ) + data = generator.to_dict() + assert data == { + "type": "haystack.components.generators.openai_dalle.DALLEImageGenerator", + "init_parameters": { + "model": "dall-e-2", + "quality": "hd", + "size": "256x256", + "response_format": "b64_json", + "api_key": {"type": "env_var", "env_vars": ["EXAMPLE_API_KEY"], "strict": True}, + "api_base_url": "https://api.openai.com", + "organization": "test-org", + }, + } + + def test_from_dict(self): + data = { + "type": "haystack.components.generators.openai_dalle.DALLEImageGenerator", + "init_parameters": { + "model": "dall-e-3", + "quality": "standard", + "size": "1024x1024", + "response_format": "url", + "api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}, + "api_base_url": None, + "organization": None, + }, + } + generator = DALLEImageGenerator.from_dict(data) + assert generator.model == "dall-e-3" + assert generator.quality == "standard" + assert generator.size == "1024x1024" + assert generator.response_format == "url" + assert generator.api_key.to_dict() == {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True} + + def test_from_dict_default_params(self): + data = {"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator", "init_parameters": {}} + generator = DALLEImageGenerator.from_dict(data) + assert generator.model == "dall-e-3" + assert generator.quality == "standard" + assert generator.size == "1024x1024" + assert generator.response_format == "url" + assert generator.api_key.to_dict() == {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True} + assert generator.api_base_url is None + assert generator.organization is None + assert pytest.approx(generator.timeout) == 30.0 + assert generator.max_retries == 5 + + def test_run(self, mock_image_response): + generator = DALLEImageGenerator(api_key=Secret.from_token("test-api-key")) + generator.warm_up() + response = generator.run("Show me a picture of a black cat.") + assert isinstance(response, dict) + assert "images" in response and "revised_prompt" in response + assert response["images"] == ["test-url"] + assert response["revised_prompt"] == "test-prompt"