From 775fea180b287bdc7854ee2a12a667d22c6cfd89 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Wed, 30 Oct 2024 11:30:45 -0400 Subject: [PATCH 1/8] getting cli and .env to work together for different models --- src/crewai/agent.py | 5 ++ src/crewai/cli/constants.py | 91 +++++++++++++++++++++++---- src/crewai/cli/create_crew.py | 40 +++++------- src/crewai/cli/templates/crew/crew.py | 2 +- src/crewai/cli/templates/crew/main.py | 4 ++ 5 files changed, 105 insertions(+), 37 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 937710f592..b487c3c5af 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -122,6 +122,9 @@ class Agent(BaseAgent): def post_init_setup(self): self.agent_ops_agent_name = self.role + print("IN POST INIT SETUP") + print("self.llm:", self.llm) + # Handle different cases for self.llm if isinstance(self.llm, str): # If it's a string, create an LLM instance @@ -130,6 +133,7 @@ def post_init_setup(self): # If it's already an LLM instance, keep it as is pass elif self.llm is None: + print("No LLM provided") # If it's None, use environment variables or default model_name = os.environ.get("OPENAI_MODEL_NAME", "gpt-4o-mini") llm_params = {"model": model_name} @@ -146,6 +150,7 @@ def post_init_setup(self): self.llm = LLM(**llm_params) else: + print("IN ELSE") # For any other type, attempt to extract relevant attributes llm_params = { "model": getattr(self.llm, "model_name", None) diff --git a/src/crewai/cli/constants.py b/src/crewai/cli/constants.py index 9a0b36c396..94932c0c74 100644 --- a/src/crewai/cli/constants.py +++ b/src/crewai/cli/constants.py @@ -1,19 +1,86 @@ ENV_VARS = { - 'openai': ['OPENAI_API_KEY'], - 'anthropic': ['ANTHROPIC_API_KEY'], - 'gemini': ['GEMINI_API_KEY'], - 'groq': ['GROQ_API_KEY'], - 'ollama': ['FAKE_KEY'], + "openai": [ + { + "prompt": "Enter your OPENAI API key (press Enter to skip)", + "key_name": "OPENAI_API_KEY", + } + ], + "anthropic": [ + { + "prompt": "Enter your ANTHROPIC API key (press Enter to skip)", + "key_name": "ANTHROPIC_API_KEY", + } + ], + "gemini": [ + { + "prompt": "Enter your GEMINI API key (press Enter to skip)", + "key_name": "GEMINI_API_KEY", + } + ], + "groq": [ + { + "prompt": "Enter your GROQ API key (press Enter to skip)", + "key_name": "GROQ_API_KEY", + } + ], + "watson": [ + { + "prompt": "Enter your WATSONX URL (press Enter to skip)", + "key_name": "WATSONX_URL", + }, + { + "prompt": "Enter your WATSONX API key (press Enter to skip)", + "key_name": "WATSONX_APIKEY", + }, + { + "prompt": "Enter your WATSONX token (press Enter to skip)", + "key_name": "WATSONX_TOKEN", + }, + ], } -PROVIDERS = ['openai', 'anthropic', 'gemini', 'groq', 'ollama'] + +PROVIDERS = ["openai", "anthropic", "gemini", "groq", "ollama", "watson"] MODELS = { - 'openai': ['gpt-4', 'gpt-4o', 'gpt-4o-mini', 'o1-mini', 'o1-preview'], - 'anthropic': ['claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'], - 'gemini': ['gemini-1.5-flash', 'gemini-1.5-pro', 'gemini-gemma-2-9b-it', 'gemini-gemma-2-27b-it'], - 'groq': ['llama-3.1-8b-instant', 'llama-3.1-70b-versatile', 'llama-3.1-405b-reasoning', 'gemma2-9b-it', 'gemma-7b-it'], - 'ollama': ['llama3.1', 'mixtral'], + "openai": ["gpt-4", "gpt-4o", "gpt-4o-mini", "o1-mini", "o1-preview"], + "anthropic": [ + "claude-3-5-sonnet-20240620", + "claude-3-sonnet-20240229", + "claude-3-opus-20240229", + "claude-3-haiku-20240307", + ], + "gemini": [ + "gemini-1.5-flash", + "gemini-1.5-pro", + "gemini-gemma-2-9b-it", + "gemini-gemma-2-27b-it", + ], + "groq": [ + "llama-3.1-8b-instant", + "llama-3.1-70b-versatile", + "llama-3.1-405b-reasoning", + "gemma2-9b-it", + "gemma-7b-it", + ], + "ollama": ["ollama/llama3.1", "ollama/mixtral"], + "watson": [ + "watsonx/google/flan-t5-xxl", + "watsonx/google/flan-ul2", + "watsonx/bigscience/mt0-xxl", + "watsonx/eleutherai/gpt-neox-20b", + "watsonx/ibm/mpt-7b-instruct2", + "watsonx/bigcode/starcoder", + "watsonx/meta-llama/llama-2-70b-chat", + "watsonx/meta-llama/llama-2-13b-chat", + "watsonx/ibm/granite-13b-instruct-v1", + "watsonx/ibm/granite-13b-chat-v1", + "watsonx/google/flan-t5-xl", + "watsonx/ibm/granite-13b-chat-v2", + "watsonx/ibm/granite-13b-instruct-v2", + "watsonx/elyza/elyza-japanese-llama-2-7b-instruct", + "watsonx/ibm-mistralai/mixtral-8x7b-instruct-v01-q", + ], } -JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" \ No newline at end of file +JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" diff --git a/src/crewai/cli/create_crew.py b/src/crewai/cli/create_crew.py index 5767b82a1f..bbb34c74da 100644 --- a/src/crewai/cli/create_crew.py +++ b/src/crewai/cli/create_crew.py @@ -5,7 +5,6 @@ from crewai.cli.constants import ENV_VARS from crewai.cli.provider import ( - PROVIDERS, get_provider_data, select_model, select_provider, @@ -92,7 +91,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): existing_provider = None for provider, env_keys in ENV_VARS.items(): - if any(key in env_vars for key in env_keys): + if any(details["key_name"] in env_vars for details in env_keys): existing_provider = provider break @@ -129,35 +128,28 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): "No model selected. Please try again or press 'q' to exit.", fg="red" ) - if selected_provider in PROVIDERS: - api_key_var = ENV_VARS[selected_provider][0] - else: - api_key_var = click.prompt( - f"Enter the environment variable name for your {selected_provider.capitalize()} API key", - type=str, - default="", - ) + # Check if the selected provider requires API keys + if selected_provider in ENV_VARS: + provider_env_vars = ENV_VARS[selected_provider] + for details in provider_env_vars: + prompt = details["prompt"] + key_name = details["key_name"] + api_key_value = click.prompt(prompt, default="", show_default=False) - api_key_value = "" - click.echo( - f"Enter your {selected_provider.capitalize()} API key (press Enter to skip): ", - nl=False, - ) - try: - api_key_value = input() - except (KeyboardInterrupt, EOFError): - api_key_value = "" + if api_key_value.strip(): + env_vars[key_name] = api_key_value - if api_key_value.strip(): - env_vars = {api_key_var: api_key_value} + # Save the selected model to env_vars + env_vars["MODEL"] = selected_model + + if env_vars: write_env_file(folder_path, env_vars) - click.secho("API key saved to .env file", fg="green") + click.secho("API keys and model saved to .env file", fg="green") else: click.secho( - "No API key provided. Skipping .env file creation.", fg="yellow" + "No API keys provided. Skipping .env file creation.", fg="yellow" ) - env_vars["MODEL"] = selected_model click.secho(f"Selected model: {selected_model}", fg="green") package_dir = Path(__file__).parent diff --git a/src/crewai/cli/templates/crew/crew.py b/src/crewai/cli/templates/crew/crew.py index f950d13d43..392e29edde 100644 --- a/src/crewai/cli/templates/crew/crew.py +++ b/src/crewai/cli/templates/crew/crew.py @@ -48,4 +48,4 @@ def crew(self) -> Crew: process=Process.sequential, verbose=True, # process=Process.hierarchical, # In case you wanna use that instead https://docs.crewai.com/how-to/Hierarchical/ - ) \ No newline at end of file + ) diff --git a/src/crewai/cli/templates/crew/main.py b/src/crewai/cli/templates/crew/main.py index 88edfcbffc..d441fa0fa3 100644 --- a/src/crewai/cli/templates/crew/main.py +++ b/src/crewai/cli/templates/crew/main.py @@ -1,7 +1,11 @@ #!/usr/bin/env python import sys +import warnings + from {{folder_name}}.crew import {{crew_name}}Crew +warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd") + # This main file is intended to be a way for you to run your # crew locally, so refrain from adding unnecessary logic into this file. # Replace with inputs you want to test with, it will automatically From 6b3c5d28e2443f96700765eba3746eb608c5ae7b Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Wed, 30 Oct 2024 15:12:48 -0400 Subject: [PATCH 2/8] support new models --- src/crewai/agent.py | 36 +++++++++++--- src/crewai/cli/constants.py | 91 +++++++++++++++++++++++++++++++---- src/crewai/cli/create_crew.py | 84 ++++++++++++++++++-------------- 3 files changed, 159 insertions(+), 52 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index b487c3c5af..925ad2e059 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -8,6 +8,7 @@ from crewai.agents import CacheHandler from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.crew_agent_executor import CrewAgentExecutor +from crewai.cli.constants import ENV_VARS from crewai.llm import LLM from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.tools.agent_tools import AgentTools @@ -134,8 +135,12 @@ def post_init_setup(self): pass elif self.llm is None: print("No LLM provided") - # If it's None, use environment variables or default - model_name = os.environ.get("OPENAI_MODEL_NAME", "gpt-4o-mini") + # Determine the model name from environment variables or use default + model_name = ( + os.environ.get("OPENAI_MODEL_NAME") + or os.environ.get("MODEL") + or "gpt-4o-mini" + ) llm_params = {"model": model_name} api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get( @@ -144,10 +149,29 @@ def post_init_setup(self): if api_base: llm_params["base_url"] = api_base - api_key = os.environ.get("OPENAI_API_KEY") - if api_key: - llm_params["api_key"] = api_key - + # Iterate over all environment variables to find matching API keys or use defaults + for provider, env_vars in ENV_VARS.items(): + for env_var in env_vars: + # Check if the environment variable is set + if "key_name" in env_var: + env_value = os.environ.get(env_var["key_name"]) + if env_value: + # Map key names containing "API_KEY" to "api_key" + key_name = ( + "api_key" + if "API_KEY" in env_var["key_name"] + else env_var["key_name"] + ) + llm_params[key_name] = env_value + # Check for default values if the environment variable is not set + elif env_var.get("default", False): + for key, value in env_var.items(): + if key not in ["prompt", "key_name", "default"]: + # Only add default if the key is already set in os.environ + if key in os.environ: + llm_params[key] = value + + print("LLM PARAMS:", llm_params) self.llm = LLM(**llm_params) else: print("IN ELSE") diff --git a/src/crewai/cli/constants.py b/src/crewai/cli/constants.py index 94932c0c74..39b0247927 100644 --- a/src/crewai/cli/constants.py +++ b/src/crewai/cli/constants.py @@ -37,10 +37,57 @@ "key_name": "WATSONX_TOKEN", }, ], + "ollama": [ + { + "default": True, + "API_BASE": "http://localhost:11434", + } + ], + "bedrock": [ + { + "prompt": "Enter your AWS Access Key ID (press Enter to skip)", + "key_name": "AWS_ACCESS_KEY_ID", + }, + { + "prompt": "Enter your AWS Secret Access Key (press Enter to skip)", + "key_name": "AWS_SECRET_ACCESS_KEY", + }, + { + "prompt": "Enter your AWS Region Name (press Enter to skip)", + "key_name": "AWS_REGION_NAME", + }, + ], + "azure": [ + { + "prompt": "Enter your Azure deployment name (must start with 'azure/')", + "key_name": "model", + }, + { + "prompt": "Enter your AZURE API key (press Enter to skip)", + "key_name": "AZURE_API_KEY", + }, + { + "prompt": "Enter your AZURE API base URL (press Enter to skip)", + "key_name": "AZURE_API_BASE", + }, + { + "prompt": "Enter your AZURE API version (press Enter to skip)", + "key_name": "AZURE_API_VERSION", + }, + ], } -PROVIDERS = ["openai", "anthropic", "gemini", "groq", "ollama", "watson"] +PROVIDERS = [ + "openai", + "anthropic", + "gemini", + "groq", + "ollama", + "watson", + "bedrock", + "azure", +] MODELS = { "openai": ["gpt-4", "gpt-4o", "gpt-4o-mini", "o1-mini", "o1-preview"], @@ -51,17 +98,17 @@ "claude-3-haiku-20240307", ], "gemini": [ - "gemini-1.5-flash", - "gemini-1.5-pro", - "gemini-gemma-2-9b-it", - "gemini-gemma-2-27b-it", + "gemini/gemini-1.5-flash", + "gemini/gemini-1.5-pro", + "gemini/gemini-gemma-2-9b-it", + "gemini/gemini-gemma-2-27b-it", ], "groq": [ - "llama-3.1-8b-instant", - "llama-3.1-70b-versatile", - "llama-3.1-405b-reasoning", - "gemma2-9b-it", - "gemma-7b-it", + "groq/llama-3.1-8b-instant", + "groq/llama-3.1-70b-versatile", + "groq/llama-3.1-405b-reasoning", + "groq/gemma2-9b-it", + "groq/gemma-7b-it", ], "ollama": ["ollama/llama3.1", "ollama/mixtral"], "watson": [ @@ -81,6 +128,30 @@ "watsonx/elyza/elyza-japanese-llama-2-7b-instruct", "watsonx/ibm-mistralai/mixtral-8x7b-instruct-v01-q", ], + "bedrock": [ + "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", + "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", + "bedrock/anthropic.claude-3-haiku-20240307-v1:0", + "bedrock/anthropic.claude-3-opus-20240229-v1:0", + "bedrock/anthropic.claude-v2:1", + "bedrock/anthropic.claude-v2", + "bedrock/anthropic.claude-instant-v1", + "bedrock/meta.llama3-1-405b-instruct-v1:0", + "bedrock/meta.llama3-1-70b-instruct-v1:0", + "bedrock/meta.llama3-1-8b-instruct-v1:0", + "bedrock/meta.llama3-70b-instruct-v1:0", + "bedrock/meta.llama3-8b-instruct-v1:0", + "bedrock/amazon.titan-text-lite-v1", + "bedrock/amazon.titan-text-express-v1", + "bedrock/cohere.command-text-v14", + "bedrock/ai21.j2-mid-v1", + "bedrock/ai21.j2-ultra-v1", + "bedrock/ai21.jamba-instruct-v1:0", + "bedrock/meta.llama2-13b-chat-v1", + "bedrock/meta.llama2-70b-chat-v1", + "bedrock/mistral.mistral-7b-instruct-v0:2", + "bedrock/mistral.mixtral-8x7b-instruct-v0:1", + ], } JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" diff --git a/src/crewai/cli/create_crew.py b/src/crewai/cli/create_crew.py index bbb34c74da..06440d74e9 100644 --- a/src/crewai/cli/create_crew.py +++ b/src/crewai/cli/create_crew.py @@ -1,9 +1,10 @@ +import shutil import sys from pathlib import Path import click -from crewai.cli.constants import ENV_VARS +from crewai.cli.constants import ENV_VARS, MODELS from crewai.cli.provider import ( get_provider_data, select_model, @@ -28,20 +29,20 @@ def create_folder_structure(name, parent_folder=None): click.secho("Operation cancelled.", fg="yellow") sys.exit(0) click.secho(f"Overriding folder {folder_name}...", fg="green", bold=True) - else: - click.secho( - f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...", - fg="green", - bold=True, - ) - - if not folder_path.exists(): - folder_path.mkdir(parents=True) - (folder_path / "tests").mkdir(exist_ok=True) - if not parent_folder: - (folder_path / "src" / folder_name).mkdir(parents=True) - (folder_path / "src" / folder_name / "tools").mkdir(parents=True) - (folder_path / "src" / folder_name / "config").mkdir(parents=True) + shutil.rmtree(folder_path) # Delete the existing folder and its contents + + click.secho( + f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...", + fg="green", + bold=True, + ) + + folder_path.mkdir(parents=True) + (folder_path / "tests").mkdir(exist_ok=True) + if not parent_folder: + (folder_path / "src" / folder_name).mkdir(parents=True) + (folder_path / "src" / folder_name / "tools").mkdir(parents=True) + (folder_path / "src" / folder_name / "config").mkdir(parents=True) return folder_path, folder_name, class_name @@ -91,7 +92,10 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): existing_provider = None for provider, env_keys in ENV_VARS.items(): - if any(details["key_name"] in env_vars for details in env_keys): + if any( + "key_name" in details and details["key_name"] in env_vars + for details in env_keys + ): existing_provider = provider break @@ -117,30 +121,38 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): "No provider selected. Please try again or press 'q' to exit.", fg="red" ) - while True: - selected_model = select_model(selected_provider, provider_models) - if selected_model is None: # User typed 'q' - click.secho("Exiting...", fg="yellow") - sys.exit(0) - if selected_model: # Valid selection - break - click.secho( - "No model selected. Please try again or press 'q' to exit.", fg="red" - ) + # Check if the selected provider has predefined models + if selected_provider in MODELS and MODELS[selected_provider]: + while True: + selected_model = select_model(selected_provider, provider_models) + if selected_model is None: # User typed 'q' + click.secho("Exiting...", fg="yellow") + sys.exit(0) + if selected_model: # Valid selection + break + click.secho( + "No model selected. Please try again or press 'q' to exit.", + fg="red", + ) + env_vars["MODEL"] = selected_model # Check if the selected provider requires API keys if selected_provider in ENV_VARS: provider_env_vars = ENV_VARS[selected_provider] for details in provider_env_vars: - prompt = details["prompt"] - key_name = details["key_name"] - api_key_value = click.prompt(prompt, default="", show_default=False) - - if api_key_value.strip(): - env_vars[key_name] = api_key_value - - # Save the selected model to env_vars - env_vars["MODEL"] = selected_model + if details.get("default", False): + # Automatically add default key-value pairs + for key, value in details.items(): + if key not in ["prompt", "key_name", "default"]: + env_vars[key] = value + elif "key_name" in details: + # Prompt for non-default key-value pairs + prompt = details["prompt"] + key_name = details["key_name"] + api_key_value = click.prompt(prompt, default="", show_default=False) + + if api_key_value.strip(): + env_vars[key_name] = api_key_value if env_vars: write_env_file(folder_path, env_vars) @@ -150,7 +162,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): "No API keys provided. Skipping .env file creation.", fg="yellow" ) - click.secho(f"Selected model: {selected_model}", fg="green") + click.secho(f"Selected model: {env_vars.get('MODEL', 'N/A')}", fg="green") package_dir = Path(__file__).parent templates_dir = package_dir / "templates" / "crew" From 7c7c3f88adeb87f77986ea0a9bd013a0ec41a589 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Wed, 30 Oct 2024 16:11:50 -0400 Subject: [PATCH 3/8] clean up prints --- src/crewai/agent.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 925ad2e059..abe08770df 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -123,9 +123,6 @@ class Agent(BaseAgent): def post_init_setup(self): self.agent_ops_agent_name = self.role - print("IN POST INIT SETUP") - print("self.llm:", self.llm) - # Handle different cases for self.llm if isinstance(self.llm, str): # If it's a string, create an LLM instance @@ -134,7 +131,6 @@ def post_init_setup(self): # If it's already an LLM instance, keep it as is pass elif self.llm is None: - print("No LLM provided") # Determine the model name from environment variables or use default model_name = ( os.environ.get("OPENAI_MODEL_NAME") @@ -162,6 +158,18 @@ def post_init_setup(self): if "API_KEY" in env_var["key_name"] else env_var["key_name"] ) + # Map key names containing "API_BASE" to "api_base" + key_name = ( + "api_base" + if "API_BASE" in env_var["key_name"] + else key_name + ) + # Map key names containing "API_VERSION" to "api_version" + key_name = ( + "api_version" + if "API_VERSION" in env_var["key_name"] + else key_name + ) llm_params[key_name] = env_value # Check for default values if the environment variable is not set elif env_var.get("default", False): @@ -171,10 +179,8 @@ def post_init_setup(self): if key in os.environ: llm_params[key] = value - print("LLM PARAMS:", llm_params) self.llm = LLM(**llm_params) else: - print("IN ELSE") # For any other type, attempt to extract relevant attributes llm_params = { "model": getattr(self.llm, "model_name", None) From a9eaf63c342c8f7e9b33f4d8f50e6158aeab0591 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Thu, 31 Oct 2024 10:20:59 -0400 Subject: [PATCH 4/8] Add support for cerebras --- src/crewai/cli/constants.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/crewai/cli/constants.py b/src/crewai/cli/constants.py index 39b0247927..a60922eb01 100644 --- a/src/crewai/cli/constants.py +++ b/src/crewai/cli/constants.py @@ -75,6 +75,16 @@ "key_name": "AZURE_API_VERSION", }, ], + "cerebras": [ + { + "prompt": "Enter your Cerebras model name (must start with 'cerebras/')", + "key_name": "model", + }, + { + "prompt": "Enter your Cerebras API version (press Enter to skip)", + "key_name": "CEREBRAS_API_KEY", + }, + ], } @@ -87,6 +97,7 @@ "watson", "bedrock", "azure", + "cerebras", ] MODELS = { From 40f6bdbea1445e8560381d72212a7ff822bc4d9e Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Thu, 31 Oct 2024 14:27:29 -0400 Subject: [PATCH 5/8] Fix watson keys --- src/crewai/cli/constants.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/crewai/cli/constants.py b/src/crewai/cli/constants.py index a60922eb01..4be08fa2a3 100644 --- a/src/crewai/cli/constants.py +++ b/src/crewai/cli/constants.py @@ -29,12 +29,12 @@ "key_name": "WATSONX_URL", }, { - "prompt": "Enter your WATSONX API key (press Enter to skip)", + "prompt": "Enter your WATSONX API Key (press Enter to skip)", "key_name": "WATSONX_APIKEY", }, { - "prompt": "Enter your WATSONX token (press Enter to skip)", - "key_name": "WATSONX_TOKEN", + "prompt": "Enter your WATSONX Project Id (press Enter to skip)", + "key_name": "WATSONX_PROJECT_ID", }, ], "ollama": [ From 9790fb54ee58055eae1a6e17cbbe64d163ba3e65 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Fri, 1 Nov 2024 15:45:04 -0400 Subject: [PATCH 6/8] update flows to allow inputs during kickoff --- docs/concepts/flows.mdx | 129 +++++++++++++++++++--------------------- src/crewai/flow/flow.py | 66 ++++++++++++++++++-- 2 files changed, 123 insertions(+), 72 deletions(-) diff --git a/docs/concepts/flows.mdx b/docs/concepts/flows.mdx index c867769821..306754c87f 100644 --- a/docs/concepts/flows.mdx +++ b/docs/concepts/flows.mdx @@ -18,60 +18,63 @@ Flows allow you to create structured, event-driven workflows. They provide a sea 4. **Flexible Control Flow**: Implement conditional logic, loops, and branching within your workflows. +5. **Input Flexibility**: Flows can accept inputs to initialize or update their state, with different handling for structured and unstructured state management. + ## Getting Started Let's create a simple Flow where you will use OpenAI to generate a random city in one task and then use that city to generate a fun fact in another task. -```python Code +### Passing Inputs to Flows -from crewai.flow.flow import Flow, listen, start -from dotenv import load_dotenv -from litellm import completion +Flows can accept inputs to initialize or update their state before execution. The way inputs are handled depends on whether the flow uses structured or unstructured state management. + +#### Structured State Management +In structured state management, the flow's state is defined using a Pydantic `BaseModel`. Inputs must match the model's schema, and any updates will overwrite the default values. + +```python +from crewai.flow.flow import Flow, listen, start +from pydantic import BaseModel -class ExampleFlow(Flow): - model = "gpt-4o-mini" +class ExampleState(BaseModel): + counter: int = 0 + message: str = "" +class StructuredExampleFlow(Flow[ExampleState]): @start() - def generate_city(self): - print("Starting flow") - - response = completion( - model=self.model, - messages=[ - { - "role": "user", - "content": "Return the name of a random city in the world.", - }, - ], - ) + def first_method(self): + # Implementation - random_city = response["choices"][0]["message"]["content"] - print(f"Random City: {random_city}") - - return random_city - - @listen(generate_city) - def generate_fun_fact(self, random_city): - response = completion( - model=self.model, - messages=[ - { - "role": "user", - "content": f"Tell me a fun fact about {random_city}", - }, - ], - ) +flow = StructuredExampleFlow() +flow.kickoff(inputs={"counter": 10}) +``` + +In this example, the `counter` is initialized to `10`, while `message` retains its default value. + +#### Unstructured State Management + +In unstructured state management, the flow's state is a dictionary. You can pass any dictionary to update the state. + +```python +from crewai.flow.flow import Flow, listen, start + +class UnstructuredExampleFlow(Flow): + @start() + def first_method(self): + # Implementation - fun_fact = response["choices"][0]["message"]["content"] - return fun_fact +flow = UnstructuredExampleFlow() +flow.kickoff(inputs={"counter": 5, "message": "Initial message"}) +``` +Here, both `counter` and `message` are updated based on the provided inputs. +**Note:** Ensure that inputs for structured state management adhere to the defined schema to avoid validation errors. -flow = ExampleFlow() -result = flow.kickoff() +### Example Flow -print(f"Generated fun fact: {result}") +```python +# Existing example code ``` In the above example, we have created a simple Flow that generates a random city using OpenAI and then generates a fun fact about that city. The Flow consists of two tasks: `generate_city` and `generate_fun_fact`. The `generate_city` task is the starting point of the Flow, and the `generate_fun_fact` task listens for the output of the `generate_city` task. @@ -94,14 +97,14 @@ The `@listen()` decorator can be used in several ways: 1. **Listening to a Method by Name**: You can pass the name of the method you want to listen to as a string. When that method completes, the listener method will be triggered. - ```python Code + ```python @listen("generate_city") def generate_fun_fact(self, random_city): # Implementation ``` 2. **Listening to a Method Directly**: You can pass the method itself. When that method completes, the listener method will be triggered. - ```python Code + ```python @listen(generate_city) def generate_fun_fact(self, random_city): # Implementation @@ -118,7 +121,7 @@ When you run a Flow, the final output is determined by the last method that comp Here's how you can access the final output: -```python Code +```python from crewai.flow.flow import Flow, listen, start class OutputExampleFlow(Flow): @@ -130,18 +133,17 @@ class OutputExampleFlow(Flow): def second_method(self, first_output): return f"Second method received: {first_output}" - flow = OutputExampleFlow() final_output = flow.kickoff() print("---- Final Output ----") print(final_output) -```` +``` -``` text Output +```text ---- Final Output ---- Second method received: Output from first_method -```` +``` @@ -156,7 +158,7 @@ Here's an example of how to update and access the state: -```python Code +```python from crewai.flow.flow import Flow, listen, start from pydantic import BaseModel @@ -184,7 +186,7 @@ print("Final State:") print(flow.state) ``` -```text Output +```text Final Output: Hello from first_method - updated by second_method Final State: counter=2 message='Hello from first_method - updated by second_method' @@ -208,10 +210,10 @@ allowing developers to choose the approach that best fits their application's ne In unstructured state management, all state is stored in the `state` attribute of the `Flow` class. This approach offers flexibility, enabling developers to add or modify state attributes on the fly without defining a strict schema. -```python Code +```python from crewai.flow.flow import Flow, listen, start -class UntructuredExampleFlow(Flow): +class UnstructuredExampleFlow(Flow): @start() def first_method(self): @@ -230,8 +232,7 @@ class UntructuredExampleFlow(Flow): print(f"State after third_method: {self.state}") - -flow = UntructuredExampleFlow() +flow = UnstructuredExampleFlow() flow.kickoff() ``` @@ -245,16 +246,14 @@ flow.kickoff() Structured state management leverages predefined schemas to ensure consistency and type safety across the workflow. By using models like Pydantic's `BaseModel`, developers can define the exact shape of the state, enabling better validation and auto-completion in development environments. -```python Code +```python from crewai.flow.flow import Flow, listen, start from pydantic import BaseModel - class ExampleState(BaseModel): counter: int = 0 message: str = "" - class StructuredExampleFlow(Flow[ExampleState]): @start() @@ -273,7 +272,6 @@ class StructuredExampleFlow(Flow[ExampleState]): print(f"State after third_method: {self.state}") - flow = StructuredExampleFlow() flow.kickoff() ``` @@ -307,7 +305,7 @@ The `or_` function in Flows allows you to listen to multiple methods and trigger -```python Code +```python from crewai.flow.flow import Flow, listen, or_, start class OrExampleFlow(Flow): @@ -324,13 +322,11 @@ class OrExampleFlow(Flow): def logger(self, result): print(f"Logger: {result}") - - flow = OrExampleFlow() flow.kickoff() ``` -```text Output +```text Logger: Hello from the start method Logger: Hello from the second method ``` @@ -346,7 +342,7 @@ The `and_` function in Flows allows you to listen to multiple methods and trigge -```python Code +```python from crewai.flow.flow import Flow, and_, listen, start class AndExampleFlow(Flow): @@ -368,7 +364,7 @@ flow = AndExampleFlow() flow.kickoff() ``` -```text Output +```text ---- Logger ---- {'greeting': 'Hello from the start method', 'joke': 'What do computers eat? Microchips.'} ``` @@ -385,7 +381,7 @@ You can specify different routes based on the output of the method, allowing you -```python Code +```python import random from crewai.flow.flow import Flow, listen, router, start from pydantic import BaseModel @@ -416,12 +412,11 @@ class RouterFlow(Flow[ExampleState]): def fourth_method(self): print("Fourth method running") - flow = RouterFlow() flow.kickoff() ``` -```text Output +```text Starting the structured flow Third method running Fourth method running @@ -484,7 +479,7 @@ The `main.py` file is where you create your flow and connect the crews together. Here's an example of how you can connect the `poem_crew` in the `main.py` file: -```python Code +```python #!/usr/bin/env python from random import randint @@ -610,7 +605,7 @@ CrewAI provides two convenient methods to generate plots of your flows: If you are working directly with a flow instance, you can generate a plot by calling the `plot()` method on your flow object. This method will create an HTML file containing the interactive plot of your flow. -```python Code +```python # Assuming you have a flow instance flow.plot("my_flow_plot") ``` diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index e7231e13f5..2f6e82750b 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -1,8 +1,19 @@ import asyncio import inspect -from typing import Any, Callable, Dict, Generic, List, Set, Type, TypeVar, Union - -from pydantic import BaseModel +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Optional, + Set, + Type, + TypeVar, + Union, +) + +from pydantic import BaseModel, ValidationError from crewai.flow.flow_visualizer import plot_flow from crewai.flow.utils import get_possible_return_constants @@ -191,10 +202,55 @@ def method_outputs(self) -> List[Any]: """Returns the list of all outputs from executed methods.""" return self._method_outputs - def kickoff(self) -> Any: + def _initialize_state(self, inputs: Dict[str, Any]) -> None: + """ + Initializes or updates the state with the provided inputs. + + Args: + inputs: Dictionary of inputs to initialize or update the state. + + Raises: + ValueError: If inputs do not match the structured state model. + TypeError: If state is neither a BaseModel instance nor a dictionary. + """ + if isinstance(self._state, BaseModel): + # Structured state management + try: + self._state = self._state.model_copy(update=inputs) + except ValidationError as e: + raise ValueError(f"Invalid inputs for structured state: {e}") from e + elif isinstance(self._state, dict): + # Unstructured state management + self._state.update(inputs) + else: + raise TypeError("State must be a BaseModel instance or a dictionary.") + + def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any: + """ + Starts the execution of the flow synchronously. + + Args: + inputs: Optional dictionary of inputs to initialize or update the state. + + Returns: + The final output from the flow execution. + """ + if inputs is not None: + self._initialize_state(inputs) return asyncio.run(self.kickoff_async()) - async def kickoff_async(self) -> Any: + async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any: + """ + Starts the execution of the flow asynchronously. + + Args: + inputs: Optional dictionary of inputs to initialize or update the state. + + Returns: + The final output from the flow execution. + """ + if inputs is not None: + self._initialize_state(inputs) if not self._start_methods: raise ValueError("No start method defined") From 40248aadec10e9a5a318f83c8b956f5a0c41be8b Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Fri, 1 Nov 2024 16:14:36 -0400 Subject: [PATCH 7/8] Make sure inputs adhere to state type. --- src/crewai/flow/flow.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 2f6e82750b..85e80c1bdb 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -216,7 +216,10 @@ def _initialize_state(self, inputs: Dict[str, Any]) -> None: if isinstance(self._state, BaseModel): # Structured state management try: - self._state = self._state.model_copy(update=inputs) + # Create a new instance with updated values to ensure validation + self._state = self._state.__class__( + **{**self._state.model_dump(), **inputs} + ) except ValidationError as e: raise ValueError(f"Invalid inputs for structured state: {e}") from e elif isinstance(self._state, dict): From 84bb5be5f5a0f2423b87d4c75b3060a841d79fcb Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Fri, 1 Nov 2024 16:29:14 -0400 Subject: [PATCH 8/8] complete validation --- src/crewai/flow/flow.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 85e80c1bdb..16f1cf9f0b 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -11,6 +11,7 @@ Type, TypeVar, Union, + cast, ) from pydantic import BaseModel, ValidationError @@ -216,10 +217,18 @@ def _initialize_state(self, inputs: Dict[str, Any]) -> None: if isinstance(self._state, BaseModel): # Structured state management try: - # Create a new instance with updated values to ensure validation - self._state = self._state.__class__( - **{**self._state.model_dump(), **inputs} + M = self._state.__class__ + + # Dynamically create a new model class with 'extra' set to 'forbid' + class ModelWithExtraForbid(M): + model_config = M.model_config.copy() + model_config["extra"] = "forbid" + + # Create a new instance using the combined state and inputs + self._state = cast( + T, ModelWithExtraForbid(**{**self._state.model_dump(), **inputs}) ) + except ValidationError as e: raise ValueError(f"Invalid inputs for structured state: {e}") from e elif isinstance(self._state, dict):