Skip to content

Commit

Permalink
feat: Adding DALLE image generator (#8448)
Browse files Browse the repository at this point in the history
* First pass at adding DALLE image generator

* Add missing header

* Fix tests

* Add tests

* Fix mypy

* Make mypy happy

* More unit tests

* Adding release notes

* Add a test for run

* Update haystack/components/generators/openai_dalle.py

Co-authored-by: Silvano Cerza <[email protected]>

* Fix pylint

* Update haystack/components/generators/openai_dalle.py

Co-authored-by: Amna Mubashar <[email protected]>

* Update haystack/components/generators/openai_dalle.py

Co-authored-by: Daria Fokina <[email protected]>

* Update haystack/components/generators/openai_dalle.py

Co-authored-by: Daria Fokina <[email protected]>

* Update haystack/components/generators/openai_dalle.py

Co-authored-by: Daria Fokina <[email protected]>

* Update haystack/components/generators/openai_dalle.py

Co-authored-by: Daria Fokina <[email protected]>

* Update haystack/components/generators/openai_dalle.py

Co-authored-by: Daria Fokina <[email protected]>

---------

Co-authored-by: Silvano Cerza <[email protected]>
Co-authored-by: Amna Mubashar <[email protected]>
Co-authored-by: Daria Fokina <[email protected]>
  • Loading branch information
4 people authored Nov 14, 2024
1 parent a045c0e commit e45d332
Show file tree
Hide file tree
Showing 5 changed files with 327 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/pydoc/config/generators_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ loaders:
"hugging_face_local",
"hugging_face_api",
"openai",
"openai_dalle",
"chat/azure",
"chat/hugging_face_local",
"chat/hugging_face_api",
Expand Down
9 changes: 8 additions & 1 deletion haystack/components/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,12 @@
from haystack.components.generators.azure import AzureOpenAIGenerator
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator
from haystack.components.generators.hugging_face_api import HuggingFaceAPIGenerator
from haystack.components.generators.openai_dalle import DALLEImageGenerator

__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceAPIGenerator", "OpenAIGenerator", "AzureOpenAIGenerator"]
__all__ = [
"HuggingFaceLocalGenerator",
"HuggingFaceAPIGenerator",
"OpenAIGenerator",
"AzureOpenAIGenerator",
"DALLEImageGenerator",
]
159 changes: 159 additions & 0 deletions haystack/components/generators/openai_dalle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Any, Dict, List, Literal, Optional

from openai import OpenAI
from openai.types.image import Image

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.utils import Secret, deserialize_secrets_inplace

logger = logging.getLogger(__name__)


@component
class DALLEImageGenerator:
"""
Generates images using OpenAI's DALL-E model.
For details on OpenAI API parameters, see
[OpenAI documentation](https://platform.openai.com/docs/api-reference/images/create).
### Usage example
```python
from haystack.components.generators import DALLEImageGenerator
image_generator = DALLEImageGenerator()
response = image_generator.run("Show me a picture of a black cat.")
print(response)
```
"""

def __init__( # pylint: disable=too-many-positional-arguments
self,
model: str = "dall-e-3",
quality: Literal["standard", "hd"] = "standard",
size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] = "1024x1024",
response_format: Literal["url", "b64_json"] = "url",
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
):
"""
Creates an instance of DALLEImageGenerator. Unless specified otherwise in `model`, uses OpenAI's dall-e-3.
:param model: The model to use for image generation. Can be "dall-e-2" or "dall-e-3".
:param quality: The quality of the generated image. Can be "standard" or "hd".
:param size: The size of the generated images.
Must be one of 256x256, 512x512, or 1024x1024 for dall-e-2.
Must be one of 1024x1024, 1792x1024, or 1024x1792 for dall-e-3 models.
:param response_format: The format of the response. Can be "url" or "b64_json".
:param api_key: The OpenAI API key to connect to OpenAI.
:param api_base_url: An optional base URL.
:param organization: The Organization ID, defaults to `None`.
:param timeout:
Timeout for OpenAI Client calls. If not set, it is inferred from the `OPENAI_TIMEOUT` environment variable
or set to 30.
:param max_retries:
Maximum retries to establish contact with OpenAI if it returns an internal error. If not set, it is inferred
from the `OPENAI_MAX_RETRIES` environment variable or set to 5.
"""
self.model = model
self.quality = quality
self.size = size
self.response_format = response_format
self.api_key = api_key
self.api_base_url = api_base_url
self.organization = organization

self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))

self.client: Optional[OpenAI] = None

def warm_up(self) -> None:
"""
Warm up the OpenAI client.
"""
if self.client is None:
self.client = OpenAI(
api_key=self.api_key.resolve_value(),
organization=self.organization,
base_url=self.api_base_url,
timeout=self.timeout,
max_retries=self.max_retries,
)

@component.output_types(images=List[str], revised_prompt=str)
def run(
self,
prompt: str,
size: Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]] = None,
quality: Optional[Literal["standard", "hd"]] = None,
response_format: Optional[Optional[Literal["url", "b64_json"]]] = None,
):
"""
Invokes the image generation inference based on the provided prompt and generation parameters.
:param prompt: The prompt to generate the image.
:param size: If provided, overrides the size provided during initialization.
:param quality: If provided, overrides the quality provided during initialization.
:param response_format: If provided, overrides the response format provided during initialization.
:returns:
A dictionary containing the generated list of images and the revised prompt.
Depending on the `response_format` parameter, the list of images can be URLs or base64 encoded JSON strings.
The revised prompt is the prompt that was used to generate the image, if there was any revision
to the prompt made by OpenAI.
"""
if self.client is None:
raise RuntimeError(
"The component DALLEImageGenerator wasn't warmed up. Run 'warm_up()' before calling 'run()'."
)

size = size or self.size
quality = quality or self.quality
response_format = response_format or self.response_format
response = self.client.images.generate(
model=self.model, prompt=prompt, size=size, quality=quality, response_format=response_format, n=1
)
image: Image = response.data[0]
image_str = image.url or image.b64_json or ""
return {"images": [image_str], "revised_prompt": image.revised_prompt or ""}

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:returns:
The serialized component as a dictionary.
"""
return default_to_dict( # type: ignore
self,
model=self.model,
quality=self.quality,
size=self.size,
response_format=self.response_format,
api_key=self.api_key.to_dict(),
api_base_url=self.api_base_url,
organization=self.organization,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DALLEImageGenerator":
"""
Deserialize this component from a dictionary.
:param data:
The dictionary representation of this component.
:returns:
The deserialized component instance.
"""
init_params = data.get("init_parameters", {})
deserialize_secrets_inplace(init_params, keys=["api_key"])
return default_from_dict(cls, data) # type: ignore
12 changes: 12 additions & 0 deletions releasenotes/notes/add-dalle-image-generator-495aa11823e11a60.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---
features:
- |
We've added a new **DALLEImageGenerator** component, bringing image generation with OpenAI's DALL-E to the Haystack
- **Easy to Use**: Just a few lines of code to get started:
```python
from haystack.components.generators import DALLEImageGenerator
image_generator = DALLEImageGenerator()
response = image_generator.run("Show me a picture of a black cat.")
print(response)
```
147 changes: 147 additions & 0 deletions test/components/generators/test_openai_dalle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from unittest.mock import Mock, patch
from haystack.utils import Secret

from openai.types.image import Image
from openai.types import ImagesResponse
from haystack.components.generators.openai_dalle import DALLEImageGenerator


@pytest.fixture
def mock_image_response():
with patch("openai.resources.images.Images.generate") as mock_image_generate:
image_response = ImagesResponse(created=1630000000, data=[Image(url="test-url", revised_prompt="test-prompt")])
mock_image_generate.return_value = image_response
yield mock_image_generate


class TestDALLEImageGenerator:
def test_init_default(self, monkeypatch):
component = DALLEImageGenerator()
assert component.model == "dall-e-3"
assert component.quality == "standard"
assert component.size == "1024x1024"
assert component.response_format == "url"
assert component.api_key == Secret.from_env_var("OPENAI_API_KEY")
assert component.api_base_url is None
assert component.organization is None
assert pytest.approx(component.timeout) == 30.0
assert component.max_retries is 5

def test_init_with_params(self, monkeypatch):
component = DALLEImageGenerator(
model="dall-e-2",
quality="hd",
size="256x256",
response_format="b64_json",
api_key=Secret.from_env_var("EXAMPLE_API_KEY"),
api_base_url="https://api.openai.com",
organization="test-org",
timeout=60,
max_retries=10,
)
assert component.model == "dall-e-2"
assert component.quality == "hd"
assert component.size == "256x256"
assert component.response_format == "b64_json"
assert component.api_key == Secret.from_env_var("EXAMPLE_API_KEY")
assert component.api_base_url == "https://api.openai.com"
assert component.organization == "test-org"
assert pytest.approx(component.timeout) == 60.0
assert component.max_retries == 10

def test_warm_up(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = DALLEImageGenerator()
component.warm_up()
assert component.client.api_key == "test-api-key"
assert component.client.timeout == 30
assert component.client.max_retries == 5

def test_to_dict(self):
generator = DALLEImageGenerator()
data = generator.to_dict()
assert data == {
"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator",
"init_parameters": {
"model": "dall-e-3",
"quality": "standard",
"size": "1024x1024",
"response_format": "url",
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
"api_base_url": None,
"organization": None,
},
}

def test_to_dict_with_params(self):
generator = DALLEImageGenerator(
model="dall-e-2",
quality="hd",
size="256x256",
response_format="b64_json",
api_key=Secret.from_env_var("EXAMPLE_API_KEY"),
api_base_url="https://api.openai.com",
organization="test-org",
timeout=60,
max_retries=10,
)
data = generator.to_dict()
assert data == {
"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator",
"init_parameters": {
"model": "dall-e-2",
"quality": "hd",
"size": "256x256",
"response_format": "b64_json",
"api_key": {"type": "env_var", "env_vars": ["EXAMPLE_API_KEY"], "strict": True},
"api_base_url": "https://api.openai.com",
"organization": "test-org",
},
}

def test_from_dict(self):
data = {
"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator",
"init_parameters": {
"model": "dall-e-3",
"quality": "standard",
"size": "1024x1024",
"response_format": "url",
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
"api_base_url": None,
"organization": None,
},
}
generator = DALLEImageGenerator.from_dict(data)
assert generator.model == "dall-e-3"
assert generator.quality == "standard"
assert generator.size == "1024x1024"
assert generator.response_format == "url"
assert generator.api_key.to_dict() == {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}

def test_from_dict_default_params(self):
data = {"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator", "init_parameters": {}}
generator = DALLEImageGenerator.from_dict(data)
assert generator.model == "dall-e-3"
assert generator.quality == "standard"
assert generator.size == "1024x1024"
assert generator.response_format == "url"
assert generator.api_key.to_dict() == {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}
assert generator.api_base_url is None
assert generator.organization is None
assert pytest.approx(generator.timeout) == 30.0
assert generator.max_retries == 5

def test_run(self, mock_image_response):
generator = DALLEImageGenerator(api_key=Secret.from_token("test-api-key"))
generator.warm_up()
response = generator.run("Show me a picture of a black cat.")
assert isinstance(response, dict)
assert "images" in response and "revised_prompt" in response
assert response["images"] == ["test-url"]
assert response["revised_prompt"] == "test-prompt"

0 comments on commit e45d332

Please sign in to comment.