-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Adding DALLE image generator (#8448)
* First pass at adding DALLE image generator * Add missing header * Fix tests * Add tests * Fix mypy * Make mypy happy * More unit tests * Adding release notes * Add a test for run * Update haystack/components/generators/openai_dalle.py Co-authored-by: Silvano Cerza <[email protected]> * Fix pylint * Update haystack/components/generators/openai_dalle.py Co-authored-by: Amna Mubashar <[email protected]> * Update haystack/components/generators/openai_dalle.py Co-authored-by: Daria Fokina <[email protected]> * Update haystack/components/generators/openai_dalle.py Co-authored-by: Daria Fokina <[email protected]> * Update haystack/components/generators/openai_dalle.py Co-authored-by: Daria Fokina <[email protected]> * Update haystack/components/generators/openai_dalle.py Co-authored-by: Daria Fokina <[email protected]> * Update haystack/components/generators/openai_dalle.py Co-authored-by: Daria Fokina <[email protected]> --------- Co-authored-by: Silvano Cerza <[email protected]> Co-authored-by: Amna Mubashar <[email protected]> Co-authored-by: Daria Fokina <[email protected]>
- Loading branch information
1 parent
a045c0e
commit e45d332
Showing
5 changed files
with
327 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
from typing import Any, Dict, List, Literal, Optional | ||
|
||
from openai import OpenAI | ||
from openai.types.image import Image | ||
|
||
from haystack import component, default_from_dict, default_to_dict, logging | ||
from haystack.utils import Secret, deserialize_secrets_inplace | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@component | ||
class DALLEImageGenerator: | ||
""" | ||
Generates images using OpenAI's DALL-E model. | ||
For details on OpenAI API parameters, see | ||
[OpenAI documentation](https://platform.openai.com/docs/api-reference/images/create). | ||
### Usage example | ||
```python | ||
from haystack.components.generators import DALLEImageGenerator | ||
image_generator = DALLEImageGenerator() | ||
response = image_generator.run("Show me a picture of a black cat.") | ||
print(response) | ||
``` | ||
""" | ||
|
||
def __init__( # pylint: disable=too-many-positional-arguments | ||
self, | ||
model: str = "dall-e-3", | ||
quality: Literal["standard", "hd"] = "standard", | ||
size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] = "1024x1024", | ||
response_format: Literal["url", "b64_json"] = "url", | ||
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"), | ||
api_base_url: Optional[str] = None, | ||
organization: Optional[str] = None, | ||
timeout: Optional[float] = None, | ||
max_retries: Optional[int] = None, | ||
): | ||
""" | ||
Creates an instance of DALLEImageGenerator. Unless specified otherwise in `model`, uses OpenAI's dall-e-3. | ||
:param model: The model to use for image generation. Can be "dall-e-2" or "dall-e-3". | ||
:param quality: The quality of the generated image. Can be "standard" or "hd". | ||
:param size: The size of the generated images. | ||
Must be one of 256x256, 512x512, or 1024x1024 for dall-e-2. | ||
Must be one of 1024x1024, 1792x1024, or 1024x1792 for dall-e-3 models. | ||
:param response_format: The format of the response. Can be "url" or "b64_json". | ||
:param api_key: The OpenAI API key to connect to OpenAI. | ||
:param api_base_url: An optional base URL. | ||
:param organization: The Organization ID, defaults to `None`. | ||
:param timeout: | ||
Timeout for OpenAI Client calls. If not set, it is inferred from the `OPENAI_TIMEOUT` environment variable | ||
or set to 30. | ||
:param max_retries: | ||
Maximum retries to establish contact with OpenAI if it returns an internal error. If not set, it is inferred | ||
from the `OPENAI_MAX_RETRIES` environment variable or set to 5. | ||
""" | ||
self.model = model | ||
self.quality = quality | ||
self.size = size | ||
self.response_format = response_format | ||
self.api_key = api_key | ||
self.api_base_url = api_base_url | ||
self.organization = organization | ||
|
||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0)) | ||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5)) | ||
|
||
self.client: Optional[OpenAI] = None | ||
|
||
def warm_up(self) -> None: | ||
""" | ||
Warm up the OpenAI client. | ||
""" | ||
if self.client is None: | ||
self.client = OpenAI( | ||
api_key=self.api_key.resolve_value(), | ||
organization=self.organization, | ||
base_url=self.api_base_url, | ||
timeout=self.timeout, | ||
max_retries=self.max_retries, | ||
) | ||
|
||
@component.output_types(images=List[str], revised_prompt=str) | ||
def run( | ||
self, | ||
prompt: str, | ||
size: Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]] = None, | ||
quality: Optional[Literal["standard", "hd"]] = None, | ||
response_format: Optional[Optional[Literal["url", "b64_json"]]] = None, | ||
): | ||
""" | ||
Invokes the image generation inference based on the provided prompt and generation parameters. | ||
:param prompt: The prompt to generate the image. | ||
:param size: If provided, overrides the size provided during initialization. | ||
:param quality: If provided, overrides the quality provided during initialization. | ||
:param response_format: If provided, overrides the response format provided during initialization. | ||
:returns: | ||
A dictionary containing the generated list of images and the revised prompt. | ||
Depending on the `response_format` parameter, the list of images can be URLs or base64 encoded JSON strings. | ||
The revised prompt is the prompt that was used to generate the image, if there was any revision | ||
to the prompt made by OpenAI. | ||
""" | ||
if self.client is None: | ||
raise RuntimeError( | ||
"The component DALLEImageGenerator wasn't warmed up. Run 'warm_up()' before calling 'run()'." | ||
) | ||
|
||
size = size or self.size | ||
quality = quality or self.quality | ||
response_format = response_format or self.response_format | ||
response = self.client.images.generate( | ||
model=self.model, prompt=prompt, size=size, quality=quality, response_format=response_format, n=1 | ||
) | ||
image: Image = response.data[0] | ||
image_str = image.url or image.b64_json or "" | ||
return {"images": [image_str], "revised_prompt": image.revised_prompt or ""} | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serialize this component to a dictionary. | ||
:returns: | ||
The serialized component as a dictionary. | ||
""" | ||
return default_to_dict( # type: ignore | ||
self, | ||
model=self.model, | ||
quality=self.quality, | ||
size=self.size, | ||
response_format=self.response_format, | ||
api_key=self.api_key.to_dict(), | ||
api_base_url=self.api_base_url, | ||
organization=self.organization, | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "DALLEImageGenerator": | ||
""" | ||
Deserialize this component from a dictionary. | ||
:param data: | ||
The dictionary representation of this component. | ||
:returns: | ||
The deserialized component instance. | ||
""" | ||
init_params = data.get("init_parameters", {}) | ||
deserialize_secrets_inplace(init_params, keys=["api_key"]) | ||
return default_from_dict(cls, data) # type: ignore |
12 changes: 12 additions & 0 deletions
12
releasenotes/notes/add-dalle-image-generator-495aa11823e11a60.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
--- | ||
features: | ||
- | | ||
We've added a new **DALLEImageGenerator** component, bringing image generation with OpenAI's DALL-E to the Haystack | ||
- **Easy to Use**: Just a few lines of code to get started: | ||
```python | ||
from haystack.components.generators import DALLEImageGenerator | ||
image_generator = DALLEImageGenerator() | ||
response = image_generator.run("Show me a picture of a black cat.") | ||
print(response) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
from unittest.mock import Mock, patch | ||
from haystack.utils import Secret | ||
|
||
from openai.types.image import Image | ||
from openai.types import ImagesResponse | ||
from haystack.components.generators.openai_dalle import DALLEImageGenerator | ||
|
||
|
||
@pytest.fixture | ||
def mock_image_response(): | ||
with patch("openai.resources.images.Images.generate") as mock_image_generate: | ||
image_response = ImagesResponse(created=1630000000, data=[Image(url="test-url", revised_prompt="test-prompt")]) | ||
mock_image_generate.return_value = image_response | ||
yield mock_image_generate | ||
|
||
|
||
class TestDALLEImageGenerator: | ||
def test_init_default(self, monkeypatch): | ||
component = DALLEImageGenerator() | ||
assert component.model == "dall-e-3" | ||
assert component.quality == "standard" | ||
assert component.size == "1024x1024" | ||
assert component.response_format == "url" | ||
assert component.api_key == Secret.from_env_var("OPENAI_API_KEY") | ||
assert component.api_base_url is None | ||
assert component.organization is None | ||
assert pytest.approx(component.timeout) == 30.0 | ||
assert component.max_retries is 5 | ||
|
||
def test_init_with_params(self, monkeypatch): | ||
component = DALLEImageGenerator( | ||
model="dall-e-2", | ||
quality="hd", | ||
size="256x256", | ||
response_format="b64_json", | ||
api_key=Secret.from_env_var("EXAMPLE_API_KEY"), | ||
api_base_url="https://api.openai.com", | ||
organization="test-org", | ||
timeout=60, | ||
max_retries=10, | ||
) | ||
assert component.model == "dall-e-2" | ||
assert component.quality == "hd" | ||
assert component.size == "256x256" | ||
assert component.response_format == "b64_json" | ||
assert component.api_key == Secret.from_env_var("EXAMPLE_API_KEY") | ||
assert component.api_base_url == "https://api.openai.com" | ||
assert component.organization == "test-org" | ||
assert pytest.approx(component.timeout) == 60.0 | ||
assert component.max_retries == 10 | ||
|
||
def test_warm_up(self, monkeypatch): | ||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") | ||
component = DALLEImageGenerator() | ||
component.warm_up() | ||
assert component.client.api_key == "test-api-key" | ||
assert component.client.timeout == 30 | ||
assert component.client.max_retries == 5 | ||
|
||
def test_to_dict(self): | ||
generator = DALLEImageGenerator() | ||
data = generator.to_dict() | ||
assert data == { | ||
"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator", | ||
"init_parameters": { | ||
"model": "dall-e-3", | ||
"quality": "standard", | ||
"size": "1024x1024", | ||
"response_format": "url", | ||
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}, | ||
"api_base_url": None, | ||
"organization": None, | ||
}, | ||
} | ||
|
||
def test_to_dict_with_params(self): | ||
generator = DALLEImageGenerator( | ||
model="dall-e-2", | ||
quality="hd", | ||
size="256x256", | ||
response_format="b64_json", | ||
api_key=Secret.from_env_var("EXAMPLE_API_KEY"), | ||
api_base_url="https://api.openai.com", | ||
organization="test-org", | ||
timeout=60, | ||
max_retries=10, | ||
) | ||
data = generator.to_dict() | ||
assert data == { | ||
"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator", | ||
"init_parameters": { | ||
"model": "dall-e-2", | ||
"quality": "hd", | ||
"size": "256x256", | ||
"response_format": "b64_json", | ||
"api_key": {"type": "env_var", "env_vars": ["EXAMPLE_API_KEY"], "strict": True}, | ||
"api_base_url": "https://api.openai.com", | ||
"organization": "test-org", | ||
}, | ||
} | ||
|
||
def test_from_dict(self): | ||
data = { | ||
"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator", | ||
"init_parameters": { | ||
"model": "dall-e-3", | ||
"quality": "standard", | ||
"size": "1024x1024", | ||
"response_format": "url", | ||
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}, | ||
"api_base_url": None, | ||
"organization": None, | ||
}, | ||
} | ||
generator = DALLEImageGenerator.from_dict(data) | ||
assert generator.model == "dall-e-3" | ||
assert generator.quality == "standard" | ||
assert generator.size == "1024x1024" | ||
assert generator.response_format == "url" | ||
assert generator.api_key.to_dict() == {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True} | ||
|
||
def test_from_dict_default_params(self): | ||
data = {"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator", "init_parameters": {}} | ||
generator = DALLEImageGenerator.from_dict(data) | ||
assert generator.model == "dall-e-3" | ||
assert generator.quality == "standard" | ||
assert generator.size == "1024x1024" | ||
assert generator.response_format == "url" | ||
assert generator.api_key.to_dict() == {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True} | ||
assert generator.api_base_url is None | ||
assert generator.organization is None | ||
assert pytest.approx(generator.timeout) == 30.0 | ||
assert generator.max_retries == 5 | ||
|
||
def test_run(self, mock_image_response): | ||
generator = DALLEImageGenerator(api_key=Secret.from_token("test-api-key")) | ||
generator.warm_up() | ||
response = generator.run("Show me a picture of a black cat.") | ||
assert isinstance(response, dict) | ||
assert "images" in response and "revised_prompt" in response | ||
assert response["images"] == ["test-url"] | ||
assert response["revised_prompt"] == "test-prompt" |